Spaces:
Runtime error
Runtime error
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +19 -0
- README.md +2 -1
- app.py +234 -0
- encoder/__init__.py +0 -0
- encoder/__pycache__/__init__.cpython-38.pyc +0 -0
- encoder/__pycache__/audio.cpython-38.pyc +0 -0
- encoder/__pycache__/inference.cpython-38.pyc +0 -0
- encoder/__pycache__/model.cpython-38.pyc +0 -0
- encoder/__pycache__/params_data.cpython-38.pyc +0 -0
- encoder/__pycache__/params_model.cpython-38.pyc +0 -0
- encoder/audio.py +117 -0
- encoder/config.py +45 -0
- encoder/data_objects/__init__.py +2 -0
- encoder/data_objects/random_cycler.py +37 -0
- encoder/data_objects/speaker.py +40 -0
- encoder/data_objects/speaker_batch.py +13 -0
- encoder/data_objects/speaker_verification_dataset.py +56 -0
- encoder/data_objects/utterance.py +26 -0
- encoder/inference.py +178 -0
- encoder/model.py +135 -0
- encoder/params_data.py +29 -0
- encoder/params_model.py +11 -0
- encoder/preprocess.py +184 -0
- encoder/train.py +125 -0
- encoder/visualizations.py +179 -0
- requirements.txt +0 -0
- synthesizer/LICENSE.txt +24 -0
- synthesizer/__init__.py +1 -0
- synthesizer/__pycache__/__init__.cpython-38.pyc +0 -0
- synthesizer/__pycache__/audio.cpython-38.pyc +0 -0
- synthesizer/__pycache__/hparams.cpython-38.pyc +0 -0
- synthesizer/__pycache__/inference.cpython-38.pyc +0 -0
- synthesizer/audio.py +206 -0
- synthesizer/hparams.py +92 -0
- synthesizer/inference.py +165 -0
- synthesizer/models/__pycache__/tacotron.cpython-38.pyc +0 -0
- synthesizer/models/tacotron.py +519 -0
- synthesizer/preprocess.py +258 -0
- synthesizer/synthesize.py +92 -0
- synthesizer/synthesizer_dataset.py +92 -0
- synthesizer/train.py +258 -0
- synthesizer/utils/__init__.py +45 -0
- synthesizer/utils/__pycache__/__init__.cpython-38.pyc +0 -0
- synthesizer/utils/__pycache__/cleaners.cpython-38.pyc +0 -0
- synthesizer/utils/__pycache__/numbers.cpython-38.pyc +0 -0
- synthesizer/utils/__pycache__/symbols.cpython-38.pyc +0 -0
- synthesizer/utils/__pycache__/text.cpython-38.pyc +0 -0
- synthesizer/utils/_cmudict.py +62 -0
- synthesizer/utils/cleaners.py +88 -0
- synthesizer/utils/numbers.py +69 -0
.gitattributes
CHANGED
@@ -29,3 +29,22 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.pyc
|
33 |
+
*.aux
|
34 |
+
*.log
|
35 |
+
*.out
|
36 |
+
*.synctex.gz
|
37 |
+
*.suo
|
38 |
+
*__pycache__
|
39 |
+
*.idea
|
40 |
+
*.ipynb_checkpoints
|
41 |
+
*.pickle
|
42 |
+
*.npy
|
43 |
+
*.blg
|
44 |
+
*.bbl
|
45 |
+
*.bcf
|
46 |
+
*.toc
|
47 |
+
*.sh
|
48 |
+
encoder/saved_models/*
|
49 |
+
synthesizer/saved_models/*
|
50 |
+
vocoder/saved_models/*
|
README.md
CHANGED
@@ -3,8 +3,9 @@ title: Clone Your Voice
|
|
3 |
emoji: 📚
|
4 |
colorFrom: blue
|
5 |
colorTo: yellow
|
|
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
3 |
emoji: 📚
|
4 |
colorFrom: blue
|
5 |
colorTo: yellow
|
6 |
+
python_version: 3.8.9
|
7 |
sdk: gradio
|
8 |
+
sdk_version: 3.0.4
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
---
|
app.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
import string
|
6 |
+
import numpy as np
|
7 |
+
import IPython
|
8 |
+
from IPython.display import Audio
|
9 |
+
import torch
|
10 |
+
import argparse
|
11 |
+
import os
|
12 |
+
from pathlib import Path
|
13 |
+
import librosa
|
14 |
+
import numpy as np
|
15 |
+
import soundfile as sf
|
16 |
+
import torch
|
17 |
+
from encoder import inference as encoder
|
18 |
+
from encoder.params_model import model_embedding_size as speaker_embedding_size
|
19 |
+
from synthesizer.inference import Synthesizer
|
20 |
+
from utils.argutils import print_args
|
21 |
+
from utils.default_models import ensure_default_models
|
22 |
+
from vocoder import inference as vocoder
|
23 |
+
import sounddevice as sd
|
24 |
+
parser = argparse.ArgumentParser(
|
25 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
26 |
+
)
|
27 |
+
parser.add_argument("-e", "--enc_model_fpath", type=Path,
|
28 |
+
default="saved_models/default/encoder.pt",
|
29 |
+
help="Path to a saved encoder")
|
30 |
+
parser.add_argument("-s", "--syn_model_fpath", type=Path,
|
31 |
+
default="saved_models/default/synthesizer.pt",
|
32 |
+
help="Path to a saved synthesizer")
|
33 |
+
parser.add_argument("-v", "--voc_model_fpath", type=Path,
|
34 |
+
default="saved_models/default/vocoder.pt",
|
35 |
+
help="Path to a saved vocoder")
|
36 |
+
parser.add_argument("--cpu", action="store_true", help=\
|
37 |
+
"If True, processing is done on CPU, even when a GPU is available.")
|
38 |
+
parser.add_argument("--no_sound", action="store_true", help=\
|
39 |
+
"If True, audio won't be played.")
|
40 |
+
parser.add_argument("--seed", type=int, default=None, help=\
|
41 |
+
"Optional random number seed value to make toolbox deterministic.")
|
42 |
+
args = parser.parse_args()
|
43 |
+
arg_dict = vars(args)
|
44 |
+
print_args(args, parser)
|
45 |
+
|
46 |
+
# Hide GPUs from Pytorch to force CPU processing
|
47 |
+
if arg_dict.pop("cpu"):
|
48 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
49 |
+
|
50 |
+
print("Running a test of your configuration...\n")
|
51 |
+
|
52 |
+
if torch.cuda.is_available():
|
53 |
+
device_id = torch.cuda.current_device()
|
54 |
+
gpu_properties = torch.cuda.get_device_properties(device_id)
|
55 |
+
## Print some environment information (for debugging purposes)
|
56 |
+
print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with "
|
57 |
+
"%.1fGb total memory.\n" %
|
58 |
+
(torch.cuda.device_count(),
|
59 |
+
device_id,
|
60 |
+
gpu_properties.name,
|
61 |
+
gpu_properties.major,
|
62 |
+
gpu_properties.minor,
|
63 |
+
gpu_properties.total_memory / 1e9))
|
64 |
+
else:
|
65 |
+
print("Using CPU for inference.\n")
|
66 |
+
|
67 |
+
## Load the models one by one.
|
68 |
+
print("Preparing the encoder, the synthesizer and the vocoder...")
|
69 |
+
ensure_default_models(Path("saved_models"))
|
70 |
+
encoder.load_model(args.enc_model_fpath)
|
71 |
+
synthesizer = Synthesizer(args.syn_model_fpath)
|
72 |
+
vocoder.load_model(args.voc_model_fpath)
|
73 |
+
|
74 |
+
def compute_embedding(in_fpath):
|
75 |
+
## Computing the embedding
|
76 |
+
# First, we load the wav using the function that the speaker encoder provides. This is
|
77 |
+
# important: there is preprocessing that must be applied.
|
78 |
+
|
79 |
+
# The following two methods are equivalent:
|
80 |
+
# - Directly load from the filepath:
|
81 |
+
preprocessed_wav = encoder.preprocess_wav(in_fpath)
|
82 |
+
# - If the wav is already loaded:
|
83 |
+
original_wav, sampling_rate = librosa.load(str(in_fpath))
|
84 |
+
preprocessed_wav = encoder.preprocess_wav(original_wav, sampling_rate)
|
85 |
+
print("Loaded file succesfully")
|
86 |
+
|
87 |
+
# Then we derive the embedding. There are many functions and parameters that the
|
88 |
+
# speaker encoder interfaces. These are mostly for in-depth research. You will typically
|
89 |
+
# only use this function (with its default parameters):
|
90 |
+
embed = encoder.embed_utterance(preprocessed_wav)
|
91 |
+
|
92 |
+
return embed
|
93 |
+
def create_spectrogram(text,embed, synthesizer ):
|
94 |
+
# If seed is specified, reset torch seed and force synthesizer reload
|
95 |
+
if args.seed is not None:
|
96 |
+
torch.manual_seed(args.seed)
|
97 |
+
synthesizer = Synthesizer(args.syn_model_fpath)
|
98 |
+
# The synthesizer works in batch, so you need to put your data in a list or numpy array
|
99 |
+
texts = [text]
|
100 |
+
embeds = [embed]
|
101 |
+
# If you know what the attention layer alignments are, you can retrieve them here by
|
102 |
+
# passing return_alignments=True
|
103 |
+
specs = synthesizer.synthesize_spectrograms(texts, embeds)
|
104 |
+
spec = specs[0]
|
105 |
+
return spec
|
106 |
+
|
107 |
+
def generate_waveform(spec):
|
108 |
+
## Generating the waveform
|
109 |
+
print("Synthesizing the waveform:")
|
110 |
+
# If seed is specified, reset torch seed and reload vocoder
|
111 |
+
if args.seed is not None:
|
112 |
+
torch.manual_seed(args.seed)
|
113 |
+
vocoder.load_model(args.voc_model_fpath)
|
114 |
+
# Synthesizing the waveform is fairly straightforward. Remember that the longer the
|
115 |
+
# spectrogram, the more time-efficient the vocoder.
|
116 |
+
generated_wav = vocoder.infer_waveform(spec)
|
117 |
+
|
118 |
+
## Post-generation
|
119 |
+
# There's a bug with sounddevice that makes the audio cut one second earlier, so we
|
120 |
+
# pad it.
|
121 |
+
generated_wav = np.pad(generated_wav, (0, synthesizer.sample_rate), mode="constant")
|
122 |
+
|
123 |
+
# Trim excess silences to compensate for gaps in spectrograms (issue #53)
|
124 |
+
generated_wav = encoder.preprocess_wav(generated_wav)
|
125 |
+
return generated_wav
|
126 |
+
|
127 |
+
|
128 |
+
def save_on_disk(generated_wav,synthesizer):
|
129 |
+
# Save it on the disk
|
130 |
+
filename = "cloned_voice.wav"
|
131 |
+
print(generated_wav.dtype)
|
132 |
+
#OUT=os.environ['OUT_PATH']
|
133 |
+
# Returns `None` if key doesn't exist
|
134 |
+
#OUT=os.environ.get('OUT_PATH')
|
135 |
+
#result = os.path.join(OUT, filename)
|
136 |
+
result = filename
|
137 |
+
print(" > Saving output to {}".format(result))
|
138 |
+
sf.write(result, generated_wav.astype(np.float32), synthesizer.sample_rate)
|
139 |
+
print("\nSaved output as %s\n\n" % result)
|
140 |
+
|
141 |
+
return result
|
142 |
+
def play_audio(generated_wav,synthesizer):
|
143 |
+
# Play the audio (non-blocking)
|
144 |
+
if not args.no_sound:
|
145 |
+
|
146 |
+
try:
|
147 |
+
sd.stop()
|
148 |
+
sd.play(generated_wav, synthesizer.sample_rate)
|
149 |
+
except sd.PortAudioError as e:
|
150 |
+
print("\nCaught exception: %s" % repr(e))
|
151 |
+
print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n")
|
152 |
+
except:
|
153 |
+
raise
|
154 |
+
|
155 |
+
def clone_voice(in_fpath, text,synthesizer):
|
156 |
+
try:
|
157 |
+
# Compute embedding
|
158 |
+
embed=compute_embedding(in_fpath)
|
159 |
+
print("Created the embedding")
|
160 |
+
# Generating the spectrogram
|
161 |
+
spec = create_spectrogram(text,embed,synthesizer)
|
162 |
+
print("Created the mel spectrogram")
|
163 |
+
|
164 |
+
# Create waveform
|
165 |
+
generated_wav=generate_waveform(spec)
|
166 |
+
print("Created the the waveform ")
|
167 |
+
|
168 |
+
# Save it on the disk
|
169 |
+
save_on_disk(generated_wav,synthesizer)
|
170 |
+
|
171 |
+
#Play the audio
|
172 |
+
play_audio(generated_wav,synthesizer)
|
173 |
+
|
174 |
+
return
|
175 |
+
except Exception as e:
|
176 |
+
print("Caught exception: %s" % repr(e))
|
177 |
+
print("Restarting\n")
|
178 |
+
|
179 |
+
# Set environment variables
|
180 |
+
home_dir = os.getcwd()
|
181 |
+
OUT_PATH=os.path.join(home_dir, "out/")
|
182 |
+
os.environ['OUT_PATH'] = OUT_PATH
|
183 |
+
|
184 |
+
# create output path
|
185 |
+
os.makedirs(OUT_PATH, exist_ok=True)
|
186 |
+
|
187 |
+
USE_CUDA = torch.cuda.is_available()
|
188 |
+
|
189 |
+
os.system('pip install -q pydub ffmpeg-normalize')
|
190 |
+
CONFIG_SE_PATH = "config_se.json"
|
191 |
+
CHECKPOINT_SE_PATH = "SE_checkpoint.pth.tar"
|
192 |
+
def greet(Text,Voicetoclone):
|
193 |
+
text= "%s" % (Text)
|
194 |
+
#reference_files= "%s" % (Voicetoclone)
|
195 |
+
reference_files= Voicetoclone
|
196 |
+
print("path url")
|
197 |
+
print(Voicetoclone)
|
198 |
+
sample= str(Voicetoclone)
|
199 |
+
os.environ['sample'] = sample
|
200 |
+
size= len(reference_files)*sys.getsizeof(reference_files)
|
201 |
+
size2= size / 1000000
|
202 |
+
if (size2 > 0.012) or len(text)>2000:
|
203 |
+
message="File is greater than 30mb or Text inserted is longer than 2000 characters. Please re-try with smaller sizes."
|
204 |
+
print(message)
|
205 |
+
raise SystemExit("File is greater than 30mb. Please re-try or Text inserted is longer than 2000 characters. Please re-try with smaller sizes.")
|
206 |
+
else:
|
207 |
+
|
208 |
+
env_var = 'sample'
|
209 |
+
if env_var in os.environ:
|
210 |
+
print(f'{env_var} value is {os.environ[env_var]}')
|
211 |
+
else:
|
212 |
+
print(f'{env_var} does not exist')
|
213 |
+
#os.system(f'ffmpeg-normalize {os.environ[env_var]} -nt rms -t=-27 -o {os.environ[env_var]} -ar 16000 -f')
|
214 |
+
in_fpath = Path(sample)
|
215 |
+
#in_fpath= in_fpath.replace("\"", "").replace("\'", "")
|
216 |
+
|
217 |
+
out_path=clone_voice(in_fpath, text,synthesizer)
|
218 |
+
|
219 |
+
print(" > text: {}".format(text))
|
220 |
+
|
221 |
+
print("Generated Audio")
|
222 |
+
return "cloned_voice.wav"
|
223 |
+
|
224 |
+
demo = gr.Interface(
|
225 |
+
fn=greet,
|
226 |
+
inputs=[gr.inputs.Textbox(label='What would you like the voice to say? (max. 2000 characters per request)'),
|
227 |
+
gr.Audio(
|
228 |
+
type="filepath",
|
229 |
+
source="upload",
|
230 |
+
label='Please upload a voice to clone (max. 30mb)')
|
231 |
+
],
|
232 |
+
outputs="audio",
|
233 |
+
)
|
234 |
+
demo.launch()
|
encoder/__init__.py
ADDED
File without changes
|
encoder/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (200 Bytes). View file
|
|
encoder/__pycache__/audio.cpython-38.pyc
ADDED
Binary file (4.05 kB). View file
|
|
encoder/__pycache__/inference.cpython-38.pyc
ADDED
Binary file (7.25 kB). View file
|
|
encoder/__pycache__/model.cpython-38.pyc
ADDED
Binary file (4.84 kB). View file
|
|
encoder/__pycache__/params_data.cpython-38.pyc
ADDED
Binary file (507 Bytes). View file
|
|
encoder/__pycache__/params_model.cpython-38.pyc
ADDED
Binary file (387 Bytes). View file
|
|
encoder/audio.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.ndimage.morphology import binary_dilation
|
2 |
+
from encoder.params_data import *
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, Union
|
5 |
+
from warnings import warn
|
6 |
+
import numpy as np
|
7 |
+
import librosa
|
8 |
+
import struct
|
9 |
+
|
10 |
+
try:
|
11 |
+
import webrtcvad
|
12 |
+
except:
|
13 |
+
warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
|
14 |
+
webrtcvad=None
|
15 |
+
|
16 |
+
int16_max = (2 ** 15) - 1
|
17 |
+
|
18 |
+
|
19 |
+
def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
|
20 |
+
source_sr: Optional[int] = None,
|
21 |
+
normalize: Optional[bool] = True,
|
22 |
+
trim_silence: Optional[bool] = True):
|
23 |
+
"""
|
24 |
+
Applies the preprocessing operations used in training the Speaker Encoder to a waveform
|
25 |
+
either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
|
26 |
+
|
27 |
+
:param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
|
28 |
+
just .wav), either the waveform as a numpy array of floats.
|
29 |
+
:param source_sr: if passing an audio waveform, the sampling rate of the waveform before
|
30 |
+
preprocessing. After preprocessing, the waveform's sampling rate will match the data
|
31 |
+
hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
|
32 |
+
this argument will be ignored.
|
33 |
+
"""
|
34 |
+
# Load the wav from disk if needed
|
35 |
+
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
36 |
+
wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
|
37 |
+
else:
|
38 |
+
wav = fpath_or_wav
|
39 |
+
|
40 |
+
# Resample the wav if needed
|
41 |
+
if source_sr is not None and source_sr != sampling_rate:
|
42 |
+
wav = librosa.resample(wav, source_sr, sampling_rate)
|
43 |
+
|
44 |
+
# Apply the preprocessing: normalize volume and shorten long silences
|
45 |
+
if normalize:
|
46 |
+
wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
|
47 |
+
if webrtcvad and trim_silence:
|
48 |
+
wav = trim_long_silences(wav)
|
49 |
+
|
50 |
+
return wav
|
51 |
+
|
52 |
+
|
53 |
+
def wav_to_mel_spectrogram(wav):
|
54 |
+
"""
|
55 |
+
Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
|
56 |
+
Note: this not a log-mel spectrogram.
|
57 |
+
"""
|
58 |
+
frames = librosa.feature.melspectrogram(
|
59 |
+
wav,
|
60 |
+
sampling_rate,
|
61 |
+
n_fft=int(sampling_rate * mel_window_length / 1000),
|
62 |
+
hop_length=int(sampling_rate * mel_window_step / 1000),
|
63 |
+
n_mels=mel_n_channels
|
64 |
+
)
|
65 |
+
return frames.astype(np.float32).T
|
66 |
+
|
67 |
+
|
68 |
+
def trim_long_silences(wav):
|
69 |
+
"""
|
70 |
+
Ensures that segments without voice in the waveform remain no longer than a
|
71 |
+
threshold determined by the VAD parameters in params.py.
|
72 |
+
|
73 |
+
:param wav: the raw waveform as a numpy array of floats
|
74 |
+
:return: the same waveform with silences trimmed away (length <= original wav length)
|
75 |
+
"""
|
76 |
+
# Compute the voice detection window size
|
77 |
+
samples_per_window = (vad_window_length * sampling_rate) // 1000
|
78 |
+
|
79 |
+
# Trim the end of the audio to have a multiple of the window size
|
80 |
+
wav = wav[:len(wav) - (len(wav) % samples_per_window)]
|
81 |
+
|
82 |
+
# Convert the float waveform to 16-bit mono PCM
|
83 |
+
pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
|
84 |
+
|
85 |
+
# Perform voice activation detection
|
86 |
+
voice_flags = []
|
87 |
+
vad = webrtcvad.Vad(mode=3)
|
88 |
+
for window_start in range(0, len(wav), samples_per_window):
|
89 |
+
window_end = window_start + samples_per_window
|
90 |
+
voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
|
91 |
+
sample_rate=sampling_rate))
|
92 |
+
voice_flags = np.array(voice_flags)
|
93 |
+
|
94 |
+
# Smooth the voice detection with a moving average
|
95 |
+
def moving_average(array, width):
|
96 |
+
array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
|
97 |
+
ret = np.cumsum(array_padded, dtype=float)
|
98 |
+
ret[width:] = ret[width:] - ret[:-width]
|
99 |
+
return ret[width - 1:] / width
|
100 |
+
|
101 |
+
audio_mask = moving_average(voice_flags, vad_moving_average_width)
|
102 |
+
audio_mask = np.round(audio_mask).astype(np.bool)
|
103 |
+
|
104 |
+
# Dilate the voiced regions
|
105 |
+
audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
|
106 |
+
audio_mask = np.repeat(audio_mask, samples_per_window)
|
107 |
+
|
108 |
+
return wav[audio_mask == True]
|
109 |
+
|
110 |
+
|
111 |
+
def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
|
112 |
+
if increase_only and decrease_only:
|
113 |
+
raise ValueError("Both increase only and decrease only are set")
|
114 |
+
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
|
115 |
+
if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
|
116 |
+
return wav
|
117 |
+
return wav * (10 ** (dBFS_change / 20))
|
encoder/config.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
librispeech_datasets = {
|
2 |
+
"train": {
|
3 |
+
"clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
|
4 |
+
"other": ["LibriSpeech/train-other-500"]
|
5 |
+
},
|
6 |
+
"test": {
|
7 |
+
"clean": ["LibriSpeech/test-clean"],
|
8 |
+
"other": ["LibriSpeech/test-other"]
|
9 |
+
},
|
10 |
+
"dev": {
|
11 |
+
"clean": ["LibriSpeech/dev-clean"],
|
12 |
+
"other": ["LibriSpeech/dev-other"]
|
13 |
+
},
|
14 |
+
}
|
15 |
+
libritts_datasets = {
|
16 |
+
"train": {
|
17 |
+
"clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
|
18 |
+
"other": ["LibriTTS/train-other-500"]
|
19 |
+
},
|
20 |
+
"test": {
|
21 |
+
"clean": ["LibriTTS/test-clean"],
|
22 |
+
"other": ["LibriTTS/test-other"]
|
23 |
+
},
|
24 |
+
"dev": {
|
25 |
+
"clean": ["LibriTTS/dev-clean"],
|
26 |
+
"other": ["LibriTTS/dev-other"]
|
27 |
+
},
|
28 |
+
}
|
29 |
+
voxceleb_datasets = {
|
30 |
+
"voxceleb1" : {
|
31 |
+
"train": ["VoxCeleb1/wav"],
|
32 |
+
"test": ["VoxCeleb1/test_wav"]
|
33 |
+
},
|
34 |
+
"voxceleb2" : {
|
35 |
+
"train": ["VoxCeleb2/dev/aac"],
|
36 |
+
"test": ["VoxCeleb2/test_wav"]
|
37 |
+
}
|
38 |
+
}
|
39 |
+
|
40 |
+
other_datasets = [
|
41 |
+
"LJSpeech-1.1",
|
42 |
+
"VCTK-Corpus/wav48",
|
43 |
+
]
|
44 |
+
|
45 |
+
anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
|
encoder/data_objects/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
2 |
+
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
|
encoder/data_objects/random_cycler.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
class RandomCycler:
|
4 |
+
"""
|
5 |
+
Creates an internal copy of a sequence and allows access to its items in a constrained random
|
6 |
+
order. For a source sequence of n items and one or several consecutive queries of a total
|
7 |
+
of m items, the following guarantees hold (one implies the other):
|
8 |
+
- Each item will be returned between m // n and ((m - 1) // n) + 1 times.
|
9 |
+
- Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, source):
|
13 |
+
if len(source) == 0:
|
14 |
+
raise Exception("Can't create RandomCycler from an empty collection")
|
15 |
+
self.all_items = list(source)
|
16 |
+
self.next_items = []
|
17 |
+
|
18 |
+
def sample(self, count: int):
|
19 |
+
shuffle = lambda l: random.sample(l, len(l))
|
20 |
+
|
21 |
+
out = []
|
22 |
+
while count > 0:
|
23 |
+
if count >= len(self.all_items):
|
24 |
+
out.extend(shuffle(list(self.all_items)))
|
25 |
+
count -= len(self.all_items)
|
26 |
+
continue
|
27 |
+
n = min(count, len(self.next_items))
|
28 |
+
out.extend(self.next_items[:n])
|
29 |
+
count -= n
|
30 |
+
self.next_items = self.next_items[n:]
|
31 |
+
if len(self.next_items) == 0:
|
32 |
+
self.next_items = shuffle(list(self.all_items))
|
33 |
+
return out
|
34 |
+
|
35 |
+
def __next__(self):
|
36 |
+
return self.sample(1)[0]
|
37 |
+
|
encoder/data_objects/speaker.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.data_objects.random_cycler import RandomCycler
|
2 |
+
from encoder.data_objects.utterance import Utterance
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
# Contains the set of utterances of a single speaker
|
6 |
+
class Speaker:
|
7 |
+
def __init__(self, root: Path):
|
8 |
+
self.root = root
|
9 |
+
self.name = root.name
|
10 |
+
self.utterances = None
|
11 |
+
self.utterance_cycler = None
|
12 |
+
|
13 |
+
def _load_utterances(self):
|
14 |
+
with self.root.joinpath("_sources.txt").open("r") as sources_file:
|
15 |
+
sources = [l.split(",") for l in sources_file]
|
16 |
+
sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
|
17 |
+
self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
|
18 |
+
self.utterance_cycler = RandomCycler(self.utterances)
|
19 |
+
|
20 |
+
def random_partial(self, count, n_frames):
|
21 |
+
"""
|
22 |
+
Samples a batch of <count> unique partial utterances from the disk in a way that all
|
23 |
+
utterances come up at least once every two cycles and in a random order every time.
|
24 |
+
|
25 |
+
:param count: The number of partial utterances to sample from the set of utterances from
|
26 |
+
that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
|
27 |
+
the number of utterances available.
|
28 |
+
:param n_frames: The number of frames in the partial utterance.
|
29 |
+
:return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
|
30 |
+
frames are the frames of the partial utterances and range is the range of the partial
|
31 |
+
utterance with regard to the complete utterance.
|
32 |
+
"""
|
33 |
+
if self.utterances is None:
|
34 |
+
self._load_utterances()
|
35 |
+
|
36 |
+
utterances = self.utterance_cycler.sample(count)
|
37 |
+
|
38 |
+
a = [(u,) + u.random_partial(n_frames) for u in utterances]
|
39 |
+
|
40 |
+
return a
|
encoder/data_objects/speaker_batch.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import List
|
3 |
+
from encoder.data_objects.speaker import Speaker
|
4 |
+
|
5 |
+
|
6 |
+
class SpeakerBatch:
|
7 |
+
def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
|
8 |
+
self.speakers = speakers
|
9 |
+
self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
|
10 |
+
|
11 |
+
# Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
|
12 |
+
# 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
|
13 |
+
self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
|
encoder/data_objects/speaker_verification_dataset.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.data_objects.random_cycler import RandomCycler
|
2 |
+
from encoder.data_objects.speaker_batch import SpeakerBatch
|
3 |
+
from encoder.data_objects.speaker import Speaker
|
4 |
+
from encoder.params_data import partials_n_frames
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# TODO: improve with a pool of speakers for data efficiency
|
9 |
+
|
10 |
+
class SpeakerVerificationDataset(Dataset):
|
11 |
+
def __init__(self, datasets_root: Path):
|
12 |
+
self.root = datasets_root
|
13 |
+
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
|
14 |
+
if len(speaker_dirs) == 0:
|
15 |
+
raise Exception("No speakers found. Make sure you are pointing to the directory "
|
16 |
+
"containing all preprocessed speaker directories.")
|
17 |
+
self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
|
18 |
+
self.speaker_cycler = RandomCycler(self.speakers)
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return int(1e10)
|
22 |
+
|
23 |
+
def __getitem__(self, index):
|
24 |
+
return next(self.speaker_cycler)
|
25 |
+
|
26 |
+
def get_logs(self):
|
27 |
+
log_string = ""
|
28 |
+
for log_fpath in self.root.glob("*.txt"):
|
29 |
+
with log_fpath.open("r") as log_file:
|
30 |
+
log_string += "".join(log_file.readlines())
|
31 |
+
return log_string
|
32 |
+
|
33 |
+
|
34 |
+
class SpeakerVerificationDataLoader(DataLoader):
|
35 |
+
def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
|
36 |
+
batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
|
37 |
+
worker_init_fn=None):
|
38 |
+
self.utterances_per_speaker = utterances_per_speaker
|
39 |
+
|
40 |
+
super().__init__(
|
41 |
+
dataset=dataset,
|
42 |
+
batch_size=speakers_per_batch,
|
43 |
+
shuffle=False,
|
44 |
+
sampler=sampler,
|
45 |
+
batch_sampler=batch_sampler,
|
46 |
+
num_workers=num_workers,
|
47 |
+
collate_fn=self.collate,
|
48 |
+
pin_memory=pin_memory,
|
49 |
+
drop_last=False,
|
50 |
+
timeout=timeout,
|
51 |
+
worker_init_fn=worker_init_fn
|
52 |
+
)
|
53 |
+
|
54 |
+
def collate(self, speakers):
|
55 |
+
return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
|
56 |
+
|
encoder/data_objects/utterance.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class Utterance:
|
5 |
+
def __init__(self, frames_fpath, wave_fpath):
|
6 |
+
self.frames_fpath = frames_fpath
|
7 |
+
self.wave_fpath = wave_fpath
|
8 |
+
|
9 |
+
def get_frames(self):
|
10 |
+
return np.load(self.frames_fpath)
|
11 |
+
|
12 |
+
def random_partial(self, n_frames):
|
13 |
+
"""
|
14 |
+
Crops the frames into a partial utterance of n_frames
|
15 |
+
|
16 |
+
:param n_frames: The number of frames of the partial utterance
|
17 |
+
:return: the partial utterance frames and a tuple indicating the start and end of the
|
18 |
+
partial utterance in the complete utterance.
|
19 |
+
"""
|
20 |
+
frames = self.get_frames()
|
21 |
+
if frames.shape[0] == n_frames:
|
22 |
+
start = 0
|
23 |
+
else:
|
24 |
+
start = np.random.randint(0, frames.shape[0] - n_frames)
|
25 |
+
end = start + n_frames
|
26 |
+
return frames[start:end], (start, end)
|
encoder/inference.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.params_data import *
|
2 |
+
from encoder.model import SpeakerEncoder
|
3 |
+
from encoder.audio import preprocess_wav # We want to expose this function from here
|
4 |
+
from matplotlib import cm
|
5 |
+
from encoder import audio
|
6 |
+
from pathlib import Path
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
_model = None # type: SpeakerEncoder
|
11 |
+
_device = None # type: torch.device
|
12 |
+
|
13 |
+
|
14 |
+
def load_model(weights_fpath: Path, device=None):
|
15 |
+
"""
|
16 |
+
Loads the model in memory. If this function is not explicitely called, it will be run on the
|
17 |
+
first call to embed_frames() with the default weights file.
|
18 |
+
|
19 |
+
:param weights_fpath: the path to saved model weights.
|
20 |
+
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
|
21 |
+
model will be loaded and will run on this device. Outputs will however always be on the cpu.
|
22 |
+
If None, will default to your GPU if it"s available, otherwise your CPU.
|
23 |
+
"""
|
24 |
+
# TODO: I think the slow loading of the encoder might have something to do with the device it
|
25 |
+
# was saved on. Worth investigating.
|
26 |
+
global _model, _device
|
27 |
+
if device is None:
|
28 |
+
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
elif isinstance(device, str):
|
30 |
+
_device = torch.device(device)
|
31 |
+
_model = SpeakerEncoder(_device, torch.device("cpu"))
|
32 |
+
checkpoint = torch.load(weights_fpath, _device)
|
33 |
+
_model.load_state_dict(checkpoint["model_state"])
|
34 |
+
_model.eval()
|
35 |
+
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
36 |
+
|
37 |
+
|
38 |
+
def is_loaded():
|
39 |
+
return _model is not None
|
40 |
+
|
41 |
+
|
42 |
+
def embed_frames_batch(frames_batch):
|
43 |
+
"""
|
44 |
+
Computes embeddings for a batch of mel spectrogram.
|
45 |
+
|
46 |
+
:param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
|
47 |
+
(batch_size, n_frames, n_channels)
|
48 |
+
:return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
|
49 |
+
"""
|
50 |
+
if _model is None:
|
51 |
+
raise Exception("Model was not loaded. Call load_model() before inference.")
|
52 |
+
|
53 |
+
frames = torch.from_numpy(frames_batch).to(_device)
|
54 |
+
embed = _model.forward(frames).detach().cpu().numpy()
|
55 |
+
return embed
|
56 |
+
|
57 |
+
|
58 |
+
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
59 |
+
min_pad_coverage=0.75, overlap=0.5):
|
60 |
+
"""
|
61 |
+
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
62 |
+
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
63 |
+
spectrogram slices are returned, so as to make each partial utterance waveform correspond to
|
64 |
+
its spectrogram. This function assumes that the mel spectrogram parameters used are those
|
65 |
+
defined in params_data.py.
|
66 |
+
|
67 |
+
The returned ranges may be indexing further than the length of the waveform. It is
|
68 |
+
recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
|
69 |
+
|
70 |
+
:param n_samples: the number of samples in the waveform
|
71 |
+
:param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
|
72 |
+
utterance
|
73 |
+
:param min_pad_coverage: when reaching the last partial utterance, it may or may not have
|
74 |
+
enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
|
75 |
+
then the last partial utterance will be considered, as if we padded the audio. Otherwise,
|
76 |
+
it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
|
77 |
+
utterance, this parameter is ignored so that the function always returns at least 1 slice.
|
78 |
+
:param overlap: by how much the partial utterance should overlap. If set to 0, the partial
|
79 |
+
utterances are entirely disjoint.
|
80 |
+
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
81 |
+
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
82 |
+
utterances.
|
83 |
+
"""
|
84 |
+
assert 0 <= overlap < 1
|
85 |
+
assert 0 < min_pad_coverage <= 1
|
86 |
+
|
87 |
+
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
88 |
+
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
89 |
+
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
90 |
+
|
91 |
+
# Compute the slices
|
92 |
+
wav_slices, mel_slices = [], []
|
93 |
+
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
|
94 |
+
for i in range(0, steps, frame_step):
|
95 |
+
mel_range = np.array([i, i + partial_utterance_n_frames])
|
96 |
+
wav_range = mel_range * samples_per_frame
|
97 |
+
mel_slices.append(slice(*mel_range))
|
98 |
+
wav_slices.append(slice(*wav_range))
|
99 |
+
|
100 |
+
# Evaluate whether extra padding is warranted or not
|
101 |
+
last_wav_range = wav_slices[-1]
|
102 |
+
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
103 |
+
if coverage < min_pad_coverage and len(mel_slices) > 1:
|
104 |
+
mel_slices = mel_slices[:-1]
|
105 |
+
wav_slices = wav_slices[:-1]
|
106 |
+
|
107 |
+
return wav_slices, mel_slices
|
108 |
+
|
109 |
+
|
110 |
+
def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
|
111 |
+
"""
|
112 |
+
Computes an embedding for a single utterance.
|
113 |
+
|
114 |
+
# TODO: handle multiple wavs to benefit from batching on GPU
|
115 |
+
:param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
|
116 |
+
:param using_partials: if True, then the utterance is split in partial utterances of
|
117 |
+
<partial_utterance_n_frames> frames and the utterance embedding is computed from their
|
118 |
+
normalized average. If False, the utterance is instead computed from feeding the entire
|
119 |
+
spectogram to the network.
|
120 |
+
:param return_partials: if True, the partial embeddings will also be returned along with the
|
121 |
+
wav slices that correspond to the partial embeddings.
|
122 |
+
:param kwargs: additional arguments to compute_partial_splits()
|
123 |
+
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
124 |
+
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
125 |
+
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
126 |
+
returned. If <using_partials> is simultaneously set to False, both these values will be None
|
127 |
+
instead.
|
128 |
+
"""
|
129 |
+
# Process the entire utterance if not using partials
|
130 |
+
if not using_partials:
|
131 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
132 |
+
embed = embed_frames_batch(frames[None, ...])[0]
|
133 |
+
if return_partials:
|
134 |
+
return embed, None, None
|
135 |
+
return embed
|
136 |
+
|
137 |
+
# Compute where to split the utterance into partials and pad if necessary
|
138 |
+
wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
|
139 |
+
max_wave_length = wave_slices[-1].stop
|
140 |
+
if max_wave_length >= len(wav):
|
141 |
+
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
142 |
+
|
143 |
+
# Split the utterance into partials
|
144 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
145 |
+
frames_batch = np.array([frames[s] for s in mel_slices])
|
146 |
+
partial_embeds = embed_frames_batch(frames_batch)
|
147 |
+
|
148 |
+
# Compute the utterance embedding from the partial embeddings
|
149 |
+
raw_embed = np.mean(partial_embeds, axis=0)
|
150 |
+
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
151 |
+
|
152 |
+
if return_partials:
|
153 |
+
return embed, partial_embeds, wave_slices
|
154 |
+
return embed
|
155 |
+
|
156 |
+
|
157 |
+
def embed_speaker(wavs, **kwargs):
|
158 |
+
raise NotImplemented()
|
159 |
+
|
160 |
+
|
161 |
+
def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
|
162 |
+
import matplotlib.pyplot as plt
|
163 |
+
if ax is None:
|
164 |
+
ax = plt.gca()
|
165 |
+
|
166 |
+
if shape is None:
|
167 |
+
height = int(np.sqrt(len(embed)))
|
168 |
+
shape = (height, -1)
|
169 |
+
embed = embed.reshape(shape)
|
170 |
+
|
171 |
+
cmap = cm.get_cmap()
|
172 |
+
mappable = ax.imshow(embed, cmap=cmap)
|
173 |
+
cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
|
174 |
+
sm = cm.ScalarMappable(cmap=cmap)
|
175 |
+
sm.set_clim(*color_range)
|
176 |
+
|
177 |
+
ax.set_xticks([]), ax.set_yticks([])
|
178 |
+
ax.set_title(title)
|
encoder/model.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.params_model import *
|
2 |
+
from encoder.params_data import *
|
3 |
+
from scipy.interpolate import interp1d
|
4 |
+
from sklearn.metrics import roc_curve
|
5 |
+
from torch.nn.utils import clip_grad_norm_
|
6 |
+
from scipy.optimize import brentq
|
7 |
+
from torch import nn
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class SpeakerEncoder(nn.Module):
|
13 |
+
def __init__(self, device, loss_device):
|
14 |
+
super().__init__()
|
15 |
+
self.loss_device = loss_device
|
16 |
+
|
17 |
+
# Network defition
|
18 |
+
self.lstm = nn.LSTM(input_size=mel_n_channels,
|
19 |
+
hidden_size=model_hidden_size,
|
20 |
+
num_layers=model_num_layers,
|
21 |
+
batch_first=True).to(device)
|
22 |
+
self.linear = nn.Linear(in_features=model_hidden_size,
|
23 |
+
out_features=model_embedding_size).to(device)
|
24 |
+
self.relu = torch.nn.ReLU().to(device)
|
25 |
+
|
26 |
+
# Cosine similarity scaling (with fixed initial parameter values)
|
27 |
+
self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
|
28 |
+
self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
|
29 |
+
|
30 |
+
# Loss
|
31 |
+
self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
|
32 |
+
|
33 |
+
def do_gradient_ops(self):
|
34 |
+
# Gradient scale
|
35 |
+
self.similarity_weight.grad *= 0.01
|
36 |
+
self.similarity_bias.grad *= 0.01
|
37 |
+
|
38 |
+
# Gradient clipping
|
39 |
+
clip_grad_norm_(self.parameters(), 3, norm_type=2)
|
40 |
+
|
41 |
+
def forward(self, utterances, hidden_init=None):
|
42 |
+
"""
|
43 |
+
Computes the embeddings of a batch of utterance spectrograms.
|
44 |
+
|
45 |
+
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
|
46 |
+
(batch_size, n_frames, n_channels)
|
47 |
+
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
|
48 |
+
batch_size, hidden_size). Will default to a tensor of zeros if None.
|
49 |
+
:return: the embeddings as a tensor of shape (batch_size, embedding_size)
|
50 |
+
"""
|
51 |
+
# Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
|
52 |
+
# and the final cell state.
|
53 |
+
out, (hidden, cell) = self.lstm(utterances, hidden_init)
|
54 |
+
|
55 |
+
# We take only the hidden state of the last layer
|
56 |
+
embeds_raw = self.relu(self.linear(hidden[-1]))
|
57 |
+
|
58 |
+
# L2-normalize it
|
59 |
+
embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
|
60 |
+
|
61 |
+
return embeds
|
62 |
+
|
63 |
+
def similarity_matrix(self, embeds):
|
64 |
+
"""
|
65 |
+
Computes the similarity matrix according the section 2.1 of GE2E.
|
66 |
+
|
67 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
68 |
+
utterances_per_speaker, embedding_size)
|
69 |
+
:return: the similarity matrix as a tensor of shape (speakers_per_batch,
|
70 |
+
utterances_per_speaker, speakers_per_batch)
|
71 |
+
"""
|
72 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
73 |
+
|
74 |
+
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
|
75 |
+
centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
|
76 |
+
centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
|
77 |
+
|
78 |
+
# Exclusive centroids (1 per utterance)
|
79 |
+
centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
|
80 |
+
centroids_excl /= (utterances_per_speaker - 1)
|
81 |
+
centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
|
82 |
+
|
83 |
+
# Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
|
84 |
+
# product of these vectors (which is just an element-wise multiplication reduced by a sum).
|
85 |
+
# We vectorize the computation for efficiency.
|
86 |
+
sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
|
87 |
+
speakers_per_batch).to(self.loss_device)
|
88 |
+
mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
|
89 |
+
for j in range(speakers_per_batch):
|
90 |
+
mask = np.where(mask_matrix[j])[0]
|
91 |
+
sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
|
92 |
+
sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
|
93 |
+
|
94 |
+
## Even more vectorized version (slower maybe because of transpose)
|
95 |
+
# sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
|
96 |
+
# ).to(self.loss_device)
|
97 |
+
# eye = np.eye(speakers_per_batch, dtype=np.int)
|
98 |
+
# mask = np.where(1 - eye)
|
99 |
+
# sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
|
100 |
+
# mask = np.where(eye)
|
101 |
+
# sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
|
102 |
+
# sim_matrix2 = sim_matrix2.transpose(1, 2)
|
103 |
+
|
104 |
+
sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
|
105 |
+
return sim_matrix
|
106 |
+
|
107 |
+
def loss(self, embeds):
|
108 |
+
"""
|
109 |
+
Computes the softmax loss according the section 2.1 of GE2E.
|
110 |
+
|
111 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
112 |
+
utterances_per_speaker, embedding_size)
|
113 |
+
:return: the loss and the EER for this batch of embeddings.
|
114 |
+
"""
|
115 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
116 |
+
|
117 |
+
# Loss
|
118 |
+
sim_matrix = self.similarity_matrix(embeds)
|
119 |
+
sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
|
120 |
+
speakers_per_batch))
|
121 |
+
ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
|
122 |
+
target = torch.from_numpy(ground_truth).long().to(self.loss_device)
|
123 |
+
loss = self.loss_fn(sim_matrix, target)
|
124 |
+
|
125 |
+
# EER (not backpropagated)
|
126 |
+
with torch.no_grad():
|
127 |
+
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
|
128 |
+
labels = np.array([inv_argmax(i) for i in ground_truth])
|
129 |
+
preds = sim_matrix.detach().cpu().numpy()
|
130 |
+
|
131 |
+
# Snippet from https://yangcha.github.io/EER-ROC/
|
132 |
+
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
|
133 |
+
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
134 |
+
|
135 |
+
return loss, eer
|
encoder/params_data.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## Mel-filterbank
|
3 |
+
mel_window_length = 25 # In milliseconds
|
4 |
+
mel_window_step = 10 # In milliseconds
|
5 |
+
mel_n_channels = 40
|
6 |
+
|
7 |
+
|
8 |
+
## Audio
|
9 |
+
sampling_rate = 16000
|
10 |
+
# Number of spectrogram frames in a partial utterance
|
11 |
+
partials_n_frames = 160 # 1600 ms
|
12 |
+
# Number of spectrogram frames at inference
|
13 |
+
inference_n_frames = 80 # 800 ms
|
14 |
+
|
15 |
+
|
16 |
+
## Voice Activation Detection
|
17 |
+
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
18 |
+
# This sets the granularity of the VAD. Should not need to be changed.
|
19 |
+
vad_window_length = 30 # In milliseconds
|
20 |
+
# Number of frames to average together when performing the moving average smoothing.
|
21 |
+
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
22 |
+
vad_moving_average_width = 8
|
23 |
+
# Maximum number of consecutive silent frames a segment can have.
|
24 |
+
vad_max_silence_length = 6
|
25 |
+
|
26 |
+
|
27 |
+
## Audio volume normalization
|
28 |
+
audio_norm_target_dBFS = -30
|
29 |
+
|
encoder/params_model.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## Model parameters
|
3 |
+
model_hidden_size = 256
|
4 |
+
model_embedding_size = 256
|
5 |
+
model_num_layers = 3
|
6 |
+
|
7 |
+
|
8 |
+
## Training parameters
|
9 |
+
learning_rate_init = 1e-4
|
10 |
+
speakers_per_batch = 64
|
11 |
+
utterances_per_speaker = 10
|
encoder/preprocess.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from functools import partial
|
3 |
+
from multiprocessing import Pool
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from encoder import audio
|
10 |
+
from encoder.config import librispeech_datasets, anglophone_nationalites
|
11 |
+
from encoder.params_data import *
|
12 |
+
|
13 |
+
|
14 |
+
_AUDIO_EXTENSIONS = ("wav", "flac", "m4a", "mp3")
|
15 |
+
|
16 |
+
class DatasetLog:
|
17 |
+
"""
|
18 |
+
Registers metadata about the dataset in a text file.
|
19 |
+
"""
|
20 |
+
def __init__(self, root, name):
|
21 |
+
self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
|
22 |
+
self.sample_data = dict()
|
23 |
+
|
24 |
+
start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
25 |
+
self.write_line("Creating dataset %s on %s" % (name, start_time))
|
26 |
+
self.write_line("-----")
|
27 |
+
self._log_params()
|
28 |
+
|
29 |
+
def _log_params(self):
|
30 |
+
from encoder import params_data
|
31 |
+
self.write_line("Parameter values:")
|
32 |
+
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
33 |
+
value = getattr(params_data, param_name)
|
34 |
+
self.write_line("\t%s: %s" % (param_name, value))
|
35 |
+
self.write_line("-----")
|
36 |
+
|
37 |
+
def write_line(self, line):
|
38 |
+
self.text_file.write("%s\n" % line)
|
39 |
+
|
40 |
+
def add_sample(self, **kwargs):
|
41 |
+
for param_name, value in kwargs.items():
|
42 |
+
if not param_name in self.sample_data:
|
43 |
+
self.sample_data[param_name] = []
|
44 |
+
self.sample_data[param_name].append(value)
|
45 |
+
|
46 |
+
def finalize(self):
|
47 |
+
self.write_line("Statistics:")
|
48 |
+
for param_name, values in self.sample_data.items():
|
49 |
+
self.write_line("\t%s:" % param_name)
|
50 |
+
self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
|
51 |
+
self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
|
52 |
+
self.write_line("-----")
|
53 |
+
end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
54 |
+
self.write_line("Finished on %s" % end_time)
|
55 |
+
self.text_file.close()
|
56 |
+
|
57 |
+
|
58 |
+
def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
|
59 |
+
dataset_root = datasets_root.joinpath(dataset_name)
|
60 |
+
if not dataset_root.exists():
|
61 |
+
print("Couldn\'t find %s, skipping this dataset." % dataset_root)
|
62 |
+
return None, None
|
63 |
+
return dataset_root, DatasetLog(out_dir, dataset_name)
|
64 |
+
|
65 |
+
|
66 |
+
def _preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, skip_existing: bool):
|
67 |
+
# Give a name to the speaker that includes its dataset
|
68 |
+
speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
|
69 |
+
|
70 |
+
# Create an output directory with that name, as well as a txt file containing a
|
71 |
+
# reference to each source file.
|
72 |
+
speaker_out_dir = out_dir.joinpath(speaker_name)
|
73 |
+
speaker_out_dir.mkdir(exist_ok=True)
|
74 |
+
sources_fpath = speaker_out_dir.joinpath("_sources.txt")
|
75 |
+
|
76 |
+
# There's a possibility that the preprocessing was interrupted earlier, check if
|
77 |
+
# there already is a sources file.
|
78 |
+
if sources_fpath.exists():
|
79 |
+
try:
|
80 |
+
with sources_fpath.open("r") as sources_file:
|
81 |
+
existing_fnames = {line.split(",")[0] for line in sources_file}
|
82 |
+
except:
|
83 |
+
existing_fnames = {}
|
84 |
+
else:
|
85 |
+
existing_fnames = {}
|
86 |
+
|
87 |
+
# Gather all audio files for that speaker recursively
|
88 |
+
sources_file = sources_fpath.open("a" if skip_existing else "w")
|
89 |
+
audio_durs = []
|
90 |
+
for extension in _AUDIO_EXTENSIONS:
|
91 |
+
for in_fpath in speaker_dir.glob("**/*.%s" % extension):
|
92 |
+
# Check if the target output file already exists
|
93 |
+
out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
|
94 |
+
out_fname = out_fname.replace(".%s" % extension, ".npy")
|
95 |
+
if skip_existing and out_fname in existing_fnames:
|
96 |
+
continue
|
97 |
+
|
98 |
+
# Load and preprocess the waveform
|
99 |
+
wav = audio.preprocess_wav(in_fpath)
|
100 |
+
if len(wav) == 0:
|
101 |
+
continue
|
102 |
+
|
103 |
+
# Create the mel spectrogram, discard those that are too short
|
104 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
105 |
+
if len(frames) < partials_n_frames:
|
106 |
+
continue
|
107 |
+
|
108 |
+
out_fpath = speaker_out_dir.joinpath(out_fname)
|
109 |
+
np.save(out_fpath, frames)
|
110 |
+
sources_file.write("%s,%s\n" % (out_fname, in_fpath))
|
111 |
+
audio_durs.append(len(wav) / sampling_rate)
|
112 |
+
|
113 |
+
sources_file.close()
|
114 |
+
|
115 |
+
return audio_durs
|
116 |
+
|
117 |
+
|
118 |
+
def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger):
|
119 |
+
print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
|
120 |
+
|
121 |
+
# Process the utterances for each speaker
|
122 |
+
work_fn = partial(_preprocess_speaker, datasets_root=datasets_root, out_dir=out_dir, skip_existing=skip_existing)
|
123 |
+
with Pool(4) as pool:
|
124 |
+
tasks = pool.imap(work_fn, speaker_dirs)
|
125 |
+
for sample_durs in tqdm(tasks, dataset_name, len(speaker_dirs), unit="speakers"):
|
126 |
+
for sample_dur in sample_durs:
|
127 |
+
logger.add_sample(duration=sample_dur)
|
128 |
+
|
129 |
+
logger.finalize()
|
130 |
+
print("Done preprocessing %s.\n" % dataset_name)
|
131 |
+
|
132 |
+
|
133 |
+
def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
|
134 |
+
for dataset_name in librispeech_datasets["train"]["other"]:
|
135 |
+
# Initialize the preprocessing
|
136 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
137 |
+
if not dataset_root:
|
138 |
+
return
|
139 |
+
|
140 |
+
# Preprocess all speakers
|
141 |
+
speaker_dirs = list(dataset_root.glob("*"))
|
142 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
|
143 |
+
|
144 |
+
|
145 |
+
def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
|
146 |
+
# Initialize the preprocessing
|
147 |
+
dataset_name = "VoxCeleb1"
|
148 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
149 |
+
if not dataset_root:
|
150 |
+
return
|
151 |
+
|
152 |
+
# Get the contents of the meta file
|
153 |
+
with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
|
154 |
+
metadata = [line.split("\t") for line in metafile][1:]
|
155 |
+
|
156 |
+
# Select the ID and the nationality, filter out non-anglophone speakers
|
157 |
+
nationalities = {line[0]: line[3] for line in metadata}
|
158 |
+
keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
|
159 |
+
nationality.lower() in anglophone_nationalites]
|
160 |
+
print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
|
161 |
+
(len(keep_speaker_ids), len(nationalities)))
|
162 |
+
|
163 |
+
# Get the speaker directories for anglophone speakers only
|
164 |
+
speaker_dirs = dataset_root.joinpath("wav").glob("*")
|
165 |
+
speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
|
166 |
+
speaker_dir.name in keep_speaker_ids]
|
167 |
+
print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
|
168 |
+
(len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
|
169 |
+
|
170 |
+
# Preprocess all speakers
|
171 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
|
172 |
+
|
173 |
+
|
174 |
+
def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
|
175 |
+
# Initialize the preprocessing
|
176 |
+
dataset_name = "VoxCeleb2"
|
177 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
178 |
+
if not dataset_root:
|
179 |
+
return
|
180 |
+
|
181 |
+
# Get the speaker directories
|
182 |
+
# Preprocess all speakers
|
183 |
+
speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
|
184 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
|
encoder/train.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
|
6 |
+
from encoder.model import SpeakerEncoder
|
7 |
+
from encoder.params_model import *
|
8 |
+
from encoder.visualizations import Visualizations
|
9 |
+
from utils.profiler import Profiler
|
10 |
+
|
11 |
+
|
12 |
+
def sync(device: torch.device):
|
13 |
+
# For correct profiling (cuda operations are async)
|
14 |
+
if device.type == "cuda":
|
15 |
+
torch.cuda.synchronize(device)
|
16 |
+
|
17 |
+
|
18 |
+
def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
|
19 |
+
backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
|
20 |
+
no_visdom: bool):
|
21 |
+
# Create a dataset and a dataloader
|
22 |
+
dataset = SpeakerVerificationDataset(clean_data_root)
|
23 |
+
loader = SpeakerVerificationDataLoader(
|
24 |
+
dataset,
|
25 |
+
speakers_per_batch,
|
26 |
+
utterances_per_speaker,
|
27 |
+
num_workers=4,
|
28 |
+
)
|
29 |
+
|
30 |
+
# Setup the device on which to run the forward pass and the loss. These can be different,
|
31 |
+
# because the forward pass is faster on the GPU whereas the loss is often (depending on your
|
32 |
+
# hyperparameters) faster on the CPU.
|
33 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
+
# FIXME: currently, the gradient is None if loss_device is cuda
|
35 |
+
loss_device = torch.device("cpu")
|
36 |
+
|
37 |
+
# Create the model and the optimizer
|
38 |
+
model = SpeakerEncoder(device, loss_device)
|
39 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
|
40 |
+
init_step = 1
|
41 |
+
|
42 |
+
# Configure file path for the model
|
43 |
+
model_dir = models_dir / run_id
|
44 |
+
model_dir.mkdir(exist_ok=True, parents=True)
|
45 |
+
state_fpath = model_dir / "encoder.pt"
|
46 |
+
|
47 |
+
# Load any existing model
|
48 |
+
if not force_restart:
|
49 |
+
if state_fpath.exists():
|
50 |
+
print("Found existing model \"%s\", loading it and resuming training." % run_id)
|
51 |
+
checkpoint = torch.load(state_fpath)
|
52 |
+
init_step = checkpoint["step"]
|
53 |
+
model.load_state_dict(checkpoint["model_state"])
|
54 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
55 |
+
optimizer.param_groups[0]["lr"] = learning_rate_init
|
56 |
+
else:
|
57 |
+
print("No model \"%s\" found, starting training from scratch." % run_id)
|
58 |
+
else:
|
59 |
+
print("Starting the training from scratch.")
|
60 |
+
model.train()
|
61 |
+
|
62 |
+
# Initialize the visualization environment
|
63 |
+
vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
|
64 |
+
vis.log_dataset(dataset)
|
65 |
+
vis.log_params()
|
66 |
+
device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
|
67 |
+
vis.log_implementation({"Device": device_name})
|
68 |
+
|
69 |
+
# Training loop
|
70 |
+
profiler = Profiler(summarize_every=10, disabled=False)
|
71 |
+
for step, speaker_batch in enumerate(loader, init_step):
|
72 |
+
profiler.tick("Blocking, waiting for batch (threaded)")
|
73 |
+
|
74 |
+
# Forward pass
|
75 |
+
inputs = torch.from_numpy(speaker_batch.data).to(device)
|
76 |
+
sync(device)
|
77 |
+
profiler.tick("Data to %s" % device)
|
78 |
+
embeds = model(inputs)
|
79 |
+
sync(device)
|
80 |
+
profiler.tick("Forward pass")
|
81 |
+
embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
|
82 |
+
loss, eer = model.loss(embeds_loss)
|
83 |
+
sync(loss_device)
|
84 |
+
profiler.tick("Loss")
|
85 |
+
|
86 |
+
# Backward pass
|
87 |
+
model.zero_grad()
|
88 |
+
loss.backward()
|
89 |
+
profiler.tick("Backward pass")
|
90 |
+
model.do_gradient_ops()
|
91 |
+
optimizer.step()
|
92 |
+
profiler.tick("Parameter update")
|
93 |
+
|
94 |
+
# Update visualizations
|
95 |
+
# learning_rate = optimizer.param_groups[0]["lr"]
|
96 |
+
vis.update(loss.item(), eer, step)
|
97 |
+
|
98 |
+
# Draw projections and save them to the backup folder
|
99 |
+
if umap_every != 0 and step % umap_every == 0:
|
100 |
+
print("Drawing and saving projections (step %d)" % step)
|
101 |
+
projection_fpath = model_dir / f"umap_{step:06d}.png"
|
102 |
+
embeds = embeds.detach().cpu().numpy()
|
103 |
+
vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
|
104 |
+
vis.save()
|
105 |
+
|
106 |
+
# Overwrite the latest version of the model
|
107 |
+
if save_every != 0 and step % save_every == 0:
|
108 |
+
print("Saving the model (step %d)" % step)
|
109 |
+
torch.save({
|
110 |
+
"step": step + 1,
|
111 |
+
"model_state": model.state_dict(),
|
112 |
+
"optimizer_state": optimizer.state_dict(),
|
113 |
+
}, state_fpath)
|
114 |
+
|
115 |
+
# Make a backup
|
116 |
+
if backup_every != 0 and step % backup_every == 0:
|
117 |
+
print("Making a backup (step %d)" % step)
|
118 |
+
backup_fpath = model_dir / f"encoder_{step:06d}.bak"
|
119 |
+
torch.save({
|
120 |
+
"step": step + 1,
|
121 |
+
"model_state": model.state_dict(),
|
122 |
+
"optimizer_state": optimizer.state_dict(),
|
123 |
+
}, backup_fpath)
|
124 |
+
|
125 |
+
profiler.tick("Extras (visualizations, saving)")
|
encoder/visualizations.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from time import perf_counter as timer
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import umap
|
6 |
+
import visdom
|
7 |
+
|
8 |
+
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
9 |
+
|
10 |
+
|
11 |
+
colormap = np.array([
|
12 |
+
[76, 255, 0],
|
13 |
+
[0, 127, 70],
|
14 |
+
[255, 0, 0],
|
15 |
+
[255, 217, 38],
|
16 |
+
[0, 135, 255],
|
17 |
+
[165, 0, 165],
|
18 |
+
[255, 167, 255],
|
19 |
+
[0, 255, 255],
|
20 |
+
[255, 96, 38],
|
21 |
+
[142, 76, 0],
|
22 |
+
[33, 0, 127],
|
23 |
+
[0, 0, 0],
|
24 |
+
[183, 183, 183],
|
25 |
+
], dtype=np.float) / 255
|
26 |
+
|
27 |
+
|
28 |
+
class Visualizations:
|
29 |
+
def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
|
30 |
+
# Tracking data
|
31 |
+
self.last_update_timestamp = timer()
|
32 |
+
self.update_every = update_every
|
33 |
+
self.step_times = []
|
34 |
+
self.losses = []
|
35 |
+
self.eers = []
|
36 |
+
print("Updating the visualizations every %d steps." % update_every)
|
37 |
+
|
38 |
+
# If visdom is disabled TODO: use a better paradigm for that
|
39 |
+
self.disabled = disabled
|
40 |
+
if self.disabled:
|
41 |
+
return
|
42 |
+
|
43 |
+
# Set the environment name
|
44 |
+
now = str(datetime.now().strftime("%d-%m %Hh%M"))
|
45 |
+
if env_name is None:
|
46 |
+
self.env_name = now
|
47 |
+
else:
|
48 |
+
self.env_name = "%s (%s)" % (env_name, now)
|
49 |
+
|
50 |
+
# Connect to visdom and open the corresponding window in the browser
|
51 |
+
try:
|
52 |
+
self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
|
53 |
+
except ConnectionError:
|
54 |
+
raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
|
55 |
+
"start it.")
|
56 |
+
# webbrowser.open("http://localhost:8097/env/" + self.env_name)
|
57 |
+
|
58 |
+
# Create the windows
|
59 |
+
self.loss_win = None
|
60 |
+
self.eer_win = None
|
61 |
+
# self.lr_win = None
|
62 |
+
self.implementation_win = None
|
63 |
+
self.projection_win = None
|
64 |
+
self.implementation_string = ""
|
65 |
+
|
66 |
+
def log_params(self):
|
67 |
+
if self.disabled:
|
68 |
+
return
|
69 |
+
from encoder import params_data
|
70 |
+
from encoder import params_model
|
71 |
+
param_string = "<b>Model parameters</b>:<br>"
|
72 |
+
for param_name in (p for p in dir(params_model) if not p.startswith("__")):
|
73 |
+
value = getattr(params_model, param_name)
|
74 |
+
param_string += "\t%s: %s<br>" % (param_name, value)
|
75 |
+
param_string += "<b>Data parameters</b>:<br>"
|
76 |
+
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
77 |
+
value = getattr(params_data, param_name)
|
78 |
+
param_string += "\t%s: %s<br>" % (param_name, value)
|
79 |
+
self.vis.text(param_string, opts={"title": "Parameters"})
|
80 |
+
|
81 |
+
def log_dataset(self, dataset: SpeakerVerificationDataset):
|
82 |
+
if self.disabled:
|
83 |
+
return
|
84 |
+
dataset_string = ""
|
85 |
+
dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
|
86 |
+
dataset_string += "\n" + dataset.get_logs()
|
87 |
+
dataset_string = dataset_string.replace("\n", "<br>")
|
88 |
+
self.vis.text(dataset_string, opts={"title": "Dataset"})
|
89 |
+
|
90 |
+
def log_implementation(self, params):
|
91 |
+
if self.disabled:
|
92 |
+
return
|
93 |
+
implementation_string = ""
|
94 |
+
for param, value in params.items():
|
95 |
+
implementation_string += "<b>%s</b>: %s\n" % (param, value)
|
96 |
+
implementation_string = implementation_string.replace("\n", "<br>")
|
97 |
+
self.implementation_string = implementation_string
|
98 |
+
self.implementation_win = self.vis.text(
|
99 |
+
implementation_string,
|
100 |
+
opts={"title": "Training implementation"}
|
101 |
+
)
|
102 |
+
|
103 |
+
def update(self, loss, eer, step):
|
104 |
+
# Update the tracking data
|
105 |
+
now = timer()
|
106 |
+
self.step_times.append(1000 * (now - self.last_update_timestamp))
|
107 |
+
self.last_update_timestamp = now
|
108 |
+
self.losses.append(loss)
|
109 |
+
self.eers.append(eer)
|
110 |
+
print(".", end="")
|
111 |
+
|
112 |
+
# Update the plots every <update_every> steps
|
113 |
+
if step % self.update_every != 0:
|
114 |
+
return
|
115 |
+
time_string = "Step time: mean: %5dms std: %5dms" % \
|
116 |
+
(int(np.mean(self.step_times)), int(np.std(self.step_times)))
|
117 |
+
print("\nStep %6d Loss: %.4f EER: %.4f %s" %
|
118 |
+
(step, np.mean(self.losses), np.mean(self.eers), time_string))
|
119 |
+
if not self.disabled:
|
120 |
+
self.loss_win = self.vis.line(
|
121 |
+
[np.mean(self.losses)],
|
122 |
+
[step],
|
123 |
+
win=self.loss_win,
|
124 |
+
update="append" if self.loss_win else None,
|
125 |
+
opts=dict(
|
126 |
+
legend=["Avg. loss"],
|
127 |
+
xlabel="Step",
|
128 |
+
ylabel="Loss",
|
129 |
+
title="Loss",
|
130 |
+
)
|
131 |
+
)
|
132 |
+
self.eer_win = self.vis.line(
|
133 |
+
[np.mean(self.eers)],
|
134 |
+
[step],
|
135 |
+
win=self.eer_win,
|
136 |
+
update="append" if self.eer_win else None,
|
137 |
+
opts=dict(
|
138 |
+
legend=["Avg. EER"],
|
139 |
+
xlabel="Step",
|
140 |
+
ylabel="EER",
|
141 |
+
title="Equal error rate"
|
142 |
+
)
|
143 |
+
)
|
144 |
+
if self.implementation_win is not None:
|
145 |
+
self.vis.text(
|
146 |
+
self.implementation_string + ("<b>%s</b>" % time_string),
|
147 |
+
win=self.implementation_win,
|
148 |
+
opts={"title": "Training implementation"},
|
149 |
+
)
|
150 |
+
|
151 |
+
# Reset the tracking
|
152 |
+
self.losses.clear()
|
153 |
+
self.eers.clear()
|
154 |
+
self.step_times.clear()
|
155 |
+
|
156 |
+
def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, max_speakers=10):
|
157 |
+
import matplotlib.pyplot as plt
|
158 |
+
|
159 |
+
max_speakers = min(max_speakers, len(colormap))
|
160 |
+
embeds = embeds[:max_speakers * utterances_per_speaker]
|
161 |
+
|
162 |
+
n_speakers = len(embeds) // utterances_per_speaker
|
163 |
+
ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
|
164 |
+
colors = [colormap[i] for i in ground_truth]
|
165 |
+
|
166 |
+
reducer = umap.UMAP()
|
167 |
+
projected = reducer.fit_transform(embeds)
|
168 |
+
plt.scatter(projected[:, 0], projected[:, 1], c=colors)
|
169 |
+
plt.gca().set_aspect("equal", "datalim")
|
170 |
+
plt.title("UMAP projection (step %d)" % step)
|
171 |
+
if not self.disabled:
|
172 |
+
self.projection_win = self.vis.matplot(plt, win=self.projection_win)
|
173 |
+
if out_fpath is not None:
|
174 |
+
plt.savefig(out_fpath)
|
175 |
+
plt.clf()
|
176 |
+
|
177 |
+
def save(self):
|
178 |
+
if not self.disabled:
|
179 |
+
self.vis.save([self.env_name])
|
requirements.txt
ADDED
Binary file (450 Bytes). View file
|
|
synthesizer/LICENSE.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
|
4 |
+
Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
|
5 |
+
Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
|
6 |
+
Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish)
|
7 |
+
|
8 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
of this software and associated documentation files (the "Software"), to deal
|
10 |
+
in the Software without restriction, including without limitation the rights
|
11 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
copies of the Software, and to permit persons to whom the Software is
|
13 |
+
furnished to do so, subject to the following conditions:
|
14 |
+
|
15 |
+
The above copyright notice and this permission notice shall be included in all
|
16 |
+
copies or substantial portions of the Software.
|
17 |
+
|
18 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
19 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
20 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
21 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
22 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
23 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
24 |
+
SOFTWARE.
|
synthesizer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
#
|
synthesizer/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (204 Bytes). View file
|
|
synthesizer/__pycache__/audio.cpython-38.pyc
ADDED
Binary file (6.85 kB). View file
|
|
synthesizer/__pycache__/hparams.cpython-38.pyc
ADDED
Binary file (2.85 kB). View file
|
|
synthesizer/__pycache__/inference.cpython-38.pyc
ADDED
Binary file (6.39 kB). View file
|
|
synthesizer/audio.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import librosa.filters
|
3 |
+
import numpy as np
|
4 |
+
from scipy import signal
|
5 |
+
from scipy.io import wavfile
|
6 |
+
import soundfile as sf
|
7 |
+
|
8 |
+
|
9 |
+
def load_wav(path, sr):
|
10 |
+
return librosa.core.load(path, sr=sr)[0]
|
11 |
+
|
12 |
+
def save_wav(wav, path, sr):
|
13 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
14 |
+
#proposed by @dsmiller
|
15 |
+
wavfile.write(path, sr, wav.astype(np.int16))
|
16 |
+
|
17 |
+
def save_wavenet_wav(wav, path, sr):
|
18 |
+
sf.write(path, wav.astype(np.float32), sr)
|
19 |
+
|
20 |
+
def preemphasis(wav, k, preemphasize=True):
|
21 |
+
if preemphasize:
|
22 |
+
return signal.lfilter([1, -k], [1], wav)
|
23 |
+
return wav
|
24 |
+
|
25 |
+
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
26 |
+
if inv_preemphasize:
|
27 |
+
return signal.lfilter([1], [1, -k], wav)
|
28 |
+
return wav
|
29 |
+
|
30 |
+
#From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
|
31 |
+
def start_and_end_indices(quantized, silence_threshold=2):
|
32 |
+
for start in range(quantized.size):
|
33 |
+
if abs(quantized[start] - 127) > silence_threshold:
|
34 |
+
break
|
35 |
+
for end in range(quantized.size - 1, 1, -1):
|
36 |
+
if abs(quantized[end] - 127) > silence_threshold:
|
37 |
+
break
|
38 |
+
|
39 |
+
assert abs(quantized[start] - 127) > silence_threshold
|
40 |
+
assert abs(quantized[end] - 127) > silence_threshold
|
41 |
+
|
42 |
+
return start, end
|
43 |
+
|
44 |
+
def get_hop_size(hparams):
|
45 |
+
hop_size = hparams.hop_size
|
46 |
+
if hop_size is None:
|
47 |
+
assert hparams.frame_shift_ms is not None
|
48 |
+
hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
|
49 |
+
return hop_size
|
50 |
+
|
51 |
+
def linearspectrogram(wav, hparams):
|
52 |
+
D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
|
53 |
+
S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
|
54 |
+
|
55 |
+
if hparams.signal_normalization:
|
56 |
+
return _normalize(S, hparams)
|
57 |
+
return S
|
58 |
+
|
59 |
+
def melspectrogram(wav, hparams):
|
60 |
+
D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
|
61 |
+
S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
|
62 |
+
|
63 |
+
if hparams.signal_normalization:
|
64 |
+
return _normalize(S, hparams)
|
65 |
+
return S
|
66 |
+
|
67 |
+
def inv_linear_spectrogram(linear_spectrogram, hparams):
|
68 |
+
"""Converts linear spectrogram to waveform using librosa"""
|
69 |
+
if hparams.signal_normalization:
|
70 |
+
D = _denormalize(linear_spectrogram, hparams)
|
71 |
+
else:
|
72 |
+
D = linear_spectrogram
|
73 |
+
|
74 |
+
S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
|
75 |
+
|
76 |
+
if hparams.use_lws:
|
77 |
+
processor = _lws_processor(hparams)
|
78 |
+
D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
|
79 |
+
y = processor.istft(D).astype(np.float32)
|
80 |
+
return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
|
81 |
+
else:
|
82 |
+
return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
|
83 |
+
|
84 |
+
def inv_mel_spectrogram(mel_spectrogram, hparams):
|
85 |
+
"""Converts mel spectrogram to waveform using librosa"""
|
86 |
+
if hparams.signal_normalization:
|
87 |
+
D = _denormalize(mel_spectrogram, hparams)
|
88 |
+
else:
|
89 |
+
D = mel_spectrogram
|
90 |
+
|
91 |
+
S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
|
92 |
+
|
93 |
+
if hparams.use_lws:
|
94 |
+
processor = _lws_processor(hparams)
|
95 |
+
D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
|
96 |
+
y = processor.istft(D).astype(np.float32)
|
97 |
+
return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
|
98 |
+
else:
|
99 |
+
return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
|
100 |
+
|
101 |
+
def _lws_processor(hparams):
|
102 |
+
import lws
|
103 |
+
return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
|
104 |
+
|
105 |
+
def _griffin_lim(S, hparams):
|
106 |
+
"""librosa implementation of Griffin-Lim
|
107 |
+
Based on https://github.com/librosa/librosa/issues/434
|
108 |
+
"""
|
109 |
+
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
110 |
+
S_complex = np.abs(S).astype(np.complex)
|
111 |
+
y = _istft(S_complex * angles, hparams)
|
112 |
+
for i in range(hparams.griffin_lim_iters):
|
113 |
+
angles = np.exp(1j * np.angle(_stft(y, hparams)))
|
114 |
+
y = _istft(S_complex * angles, hparams)
|
115 |
+
return y
|
116 |
+
|
117 |
+
def _stft(y, hparams):
|
118 |
+
if hparams.use_lws:
|
119 |
+
return _lws_processor(hparams).stft(y).T
|
120 |
+
else:
|
121 |
+
return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
|
122 |
+
|
123 |
+
def _istft(y, hparams):
|
124 |
+
return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
|
125 |
+
|
126 |
+
##########################################################
|
127 |
+
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
128 |
+
def num_frames(length, fsize, fshift):
|
129 |
+
"""Compute number of time frames of spectrogram
|
130 |
+
"""
|
131 |
+
pad = (fsize - fshift)
|
132 |
+
if length % fshift == 0:
|
133 |
+
M = (length + pad * 2 - fsize) // fshift + 1
|
134 |
+
else:
|
135 |
+
M = (length + pad * 2 - fsize) // fshift + 2
|
136 |
+
return M
|
137 |
+
|
138 |
+
|
139 |
+
def pad_lr(x, fsize, fshift):
|
140 |
+
"""Compute left and right padding
|
141 |
+
"""
|
142 |
+
M = num_frames(len(x), fsize, fshift)
|
143 |
+
pad = (fsize - fshift)
|
144 |
+
T = len(x) + 2 * pad
|
145 |
+
r = (M - 1) * fshift + fsize - T
|
146 |
+
return pad, pad + r
|
147 |
+
##########################################################
|
148 |
+
#Librosa correct padding
|
149 |
+
def librosa_pad_lr(x, fsize, fshift):
|
150 |
+
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
151 |
+
|
152 |
+
# Conversions
|
153 |
+
_mel_basis = None
|
154 |
+
_inv_mel_basis = None
|
155 |
+
|
156 |
+
def _linear_to_mel(spectogram, hparams):
|
157 |
+
global _mel_basis
|
158 |
+
if _mel_basis is None:
|
159 |
+
_mel_basis = _build_mel_basis(hparams)
|
160 |
+
return np.dot(_mel_basis, spectogram)
|
161 |
+
|
162 |
+
def _mel_to_linear(mel_spectrogram, hparams):
|
163 |
+
global _inv_mel_basis
|
164 |
+
if _inv_mel_basis is None:
|
165 |
+
_inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
|
166 |
+
return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
|
167 |
+
|
168 |
+
def _build_mel_basis(hparams):
|
169 |
+
assert hparams.fmax <= hparams.sample_rate // 2
|
170 |
+
return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
|
171 |
+
fmin=hparams.fmin, fmax=hparams.fmax)
|
172 |
+
|
173 |
+
def _amp_to_db(x, hparams):
|
174 |
+
min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
|
175 |
+
return 20 * np.log10(np.maximum(min_level, x))
|
176 |
+
|
177 |
+
def _db_to_amp(x):
|
178 |
+
return np.power(10.0, (x) * 0.05)
|
179 |
+
|
180 |
+
def _normalize(S, hparams):
|
181 |
+
if hparams.allow_clipping_in_normalization:
|
182 |
+
if hparams.symmetric_mels:
|
183 |
+
return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
|
184 |
+
-hparams.max_abs_value, hparams.max_abs_value)
|
185 |
+
else:
|
186 |
+
return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
|
187 |
+
|
188 |
+
assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
|
189 |
+
if hparams.symmetric_mels:
|
190 |
+
return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
|
191 |
+
else:
|
192 |
+
return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
|
193 |
+
|
194 |
+
def _denormalize(D, hparams):
|
195 |
+
if hparams.allow_clipping_in_normalization:
|
196 |
+
if hparams.symmetric_mels:
|
197 |
+
return (((np.clip(D, -hparams.max_abs_value,
|
198 |
+
hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
|
199 |
+
+ hparams.min_level_db)
|
200 |
+
else:
|
201 |
+
return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
|
202 |
+
|
203 |
+
if hparams.symmetric_mels:
|
204 |
+
return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
|
205 |
+
else:
|
206 |
+
return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
|
synthesizer/hparams.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import pprint
|
3 |
+
|
4 |
+
class HParams(object):
|
5 |
+
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
6 |
+
def __setitem__(self, key, value): setattr(self, key, value)
|
7 |
+
def __getitem__(self, key): return getattr(self, key)
|
8 |
+
def __repr__(self): return pprint.pformat(self.__dict__)
|
9 |
+
|
10 |
+
def parse(self, string):
|
11 |
+
# Overrides hparams from a comma-separated string of name=value pairs
|
12 |
+
if len(string) > 0:
|
13 |
+
overrides = [s.split("=") for s in string.split(",")]
|
14 |
+
keys, values = zip(*overrides)
|
15 |
+
keys = list(map(str.strip, keys))
|
16 |
+
values = list(map(str.strip, values))
|
17 |
+
for k in keys:
|
18 |
+
self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
|
19 |
+
return self
|
20 |
+
|
21 |
+
hparams = HParams(
|
22 |
+
### Signal Processing (used in both synthesizer and vocoder)
|
23 |
+
sample_rate = 16000,
|
24 |
+
n_fft = 800,
|
25 |
+
num_mels = 80,
|
26 |
+
hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
|
27 |
+
win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
|
28 |
+
fmin = 55,
|
29 |
+
min_level_db = -100,
|
30 |
+
ref_level_db = 20,
|
31 |
+
max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
|
32 |
+
preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
|
33 |
+
preemphasize = True,
|
34 |
+
|
35 |
+
### Tacotron Text-to-Speech (TTS)
|
36 |
+
tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
|
37 |
+
tts_encoder_dims = 256,
|
38 |
+
tts_decoder_dims = 128,
|
39 |
+
tts_postnet_dims = 512,
|
40 |
+
tts_encoder_K = 5,
|
41 |
+
tts_lstm_dims = 1024,
|
42 |
+
tts_postnet_K = 5,
|
43 |
+
tts_num_highways = 4,
|
44 |
+
tts_dropout = 0.5,
|
45 |
+
tts_cleaner_names = ["english_cleaners"],
|
46 |
+
tts_stop_threshold = -3.4, # Value below which audio generation ends.
|
47 |
+
# For example, for a range of [-4, 4], this
|
48 |
+
# will terminate the sequence at the first
|
49 |
+
# frame that has all values < -3.4
|
50 |
+
|
51 |
+
### Tacotron Training
|
52 |
+
tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
|
53 |
+
(2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
|
54 |
+
(2, 2e-4, 80_000, 12), #
|
55 |
+
(2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
|
56 |
+
(2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
|
57 |
+
(2, 1e-5, 640_000, 12)], # lr = learning rate
|
58 |
+
|
59 |
+
tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
|
60 |
+
tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
|
61 |
+
# Set to -1 to generate after completing epoch, or 0 to disable
|
62 |
+
|
63 |
+
tts_eval_num_samples = 1, # Makes this number of samples
|
64 |
+
|
65 |
+
### Data Preprocessing
|
66 |
+
max_mel_frames = 900,
|
67 |
+
rescale = True,
|
68 |
+
rescaling_max = 0.9,
|
69 |
+
synthesis_batch_size = 16, # For vocoder preprocessing and inference.
|
70 |
+
|
71 |
+
### Mel Visualization and Griffin-Lim
|
72 |
+
signal_normalization = True,
|
73 |
+
power = 1.5,
|
74 |
+
griffin_lim_iters = 60,
|
75 |
+
|
76 |
+
### Audio processing options
|
77 |
+
fmax = 7600, # Should not exceed (sample_rate // 2)
|
78 |
+
allow_clipping_in_normalization = True, # Used when signal_normalization = True
|
79 |
+
clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
|
80 |
+
use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
|
81 |
+
symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
|
82 |
+
# and [0, max_abs_value] if False
|
83 |
+
trim_silence = True, # Use with sample_rate of 16000 for best results
|
84 |
+
|
85 |
+
### SV2TTS
|
86 |
+
speaker_embedding_size = 256, # Dimension for the speaker embedding
|
87 |
+
silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
|
88 |
+
utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
|
89 |
+
)
|
90 |
+
|
91 |
+
def hparams_debug_string():
|
92 |
+
return str(hparams)
|
synthesizer/inference.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from synthesizer import audio
|
3 |
+
from synthesizer.hparams import hparams
|
4 |
+
from synthesizer.models.tacotron import Tacotron
|
5 |
+
from synthesizer.utils.symbols import symbols
|
6 |
+
from synthesizer.utils.text import text_to_sequence
|
7 |
+
from vocoder.display import simple_table
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Union, List
|
10 |
+
import numpy as np
|
11 |
+
import librosa
|
12 |
+
|
13 |
+
|
14 |
+
class Synthesizer:
|
15 |
+
sample_rate = hparams.sample_rate
|
16 |
+
hparams = hparams
|
17 |
+
|
18 |
+
def __init__(self, model_fpath: Path, verbose=True):
|
19 |
+
"""
|
20 |
+
The model isn't instantiated and loaded in memory until needed or until load() is called.
|
21 |
+
|
22 |
+
:param model_fpath: path to the trained model file
|
23 |
+
:param verbose: if False, prints less information when using the model
|
24 |
+
"""
|
25 |
+
self.model_fpath = model_fpath
|
26 |
+
self.verbose = verbose
|
27 |
+
|
28 |
+
# Check for GPU
|
29 |
+
if torch.cuda.is_available():
|
30 |
+
self.device = torch.device("cuda")
|
31 |
+
else:
|
32 |
+
self.device = torch.device("cpu")
|
33 |
+
if self.verbose:
|
34 |
+
print("Synthesizer using device:", self.device)
|
35 |
+
|
36 |
+
# Tacotron model will be instantiated later on first use.
|
37 |
+
self._model = None
|
38 |
+
|
39 |
+
def is_loaded(self):
|
40 |
+
"""
|
41 |
+
Whether the model is loaded in memory.
|
42 |
+
"""
|
43 |
+
return self._model is not None
|
44 |
+
|
45 |
+
def load(self):
|
46 |
+
"""
|
47 |
+
Instantiates and loads the model given the weights file that was passed in the constructor.
|
48 |
+
"""
|
49 |
+
self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
50 |
+
num_chars=len(symbols),
|
51 |
+
encoder_dims=hparams.tts_encoder_dims,
|
52 |
+
decoder_dims=hparams.tts_decoder_dims,
|
53 |
+
n_mels=hparams.num_mels,
|
54 |
+
fft_bins=hparams.num_mels,
|
55 |
+
postnet_dims=hparams.tts_postnet_dims,
|
56 |
+
encoder_K=hparams.tts_encoder_K,
|
57 |
+
lstm_dims=hparams.tts_lstm_dims,
|
58 |
+
postnet_K=hparams.tts_postnet_K,
|
59 |
+
num_highways=hparams.tts_num_highways,
|
60 |
+
dropout=hparams.tts_dropout,
|
61 |
+
stop_threshold=hparams.tts_stop_threshold,
|
62 |
+
speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
|
63 |
+
|
64 |
+
self._model.load(self.model_fpath)
|
65 |
+
self._model.eval()
|
66 |
+
|
67 |
+
if self.verbose:
|
68 |
+
print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
|
69 |
+
|
70 |
+
def synthesize_spectrograms(self, texts: List[str],
|
71 |
+
embeddings: Union[np.ndarray, List[np.ndarray]],
|
72 |
+
return_alignments=False):
|
73 |
+
"""
|
74 |
+
Synthesizes mel spectrograms from texts and speaker embeddings.
|
75 |
+
|
76 |
+
:param texts: a list of N text prompts to be synthesized
|
77 |
+
:param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
|
78 |
+
:param return_alignments: if True, a matrix representing the alignments between the
|
79 |
+
characters
|
80 |
+
and each decoder output step will be returned for each spectrogram
|
81 |
+
:return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
|
82 |
+
sequence length of spectrogram i, and possibly the alignments.
|
83 |
+
"""
|
84 |
+
# Load the model on the first request.
|
85 |
+
if not self.is_loaded():
|
86 |
+
self.load()
|
87 |
+
|
88 |
+
# Preprocess text inputs
|
89 |
+
inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts]
|
90 |
+
if not isinstance(embeddings, list):
|
91 |
+
embeddings = [embeddings]
|
92 |
+
|
93 |
+
# Batch inputs
|
94 |
+
batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
|
95 |
+
for i in range(0, len(inputs), hparams.synthesis_batch_size)]
|
96 |
+
batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
|
97 |
+
for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
|
98 |
+
|
99 |
+
specs = []
|
100 |
+
for i, batch in enumerate(batched_inputs, 1):
|
101 |
+
if self.verbose:
|
102 |
+
print(f"\n| Generating {i}/{len(batched_inputs)}")
|
103 |
+
|
104 |
+
# Pad texts so they are all the same length
|
105 |
+
text_lens = [len(text) for text in batch]
|
106 |
+
max_text_len = max(text_lens)
|
107 |
+
chars = [pad1d(text, max_text_len) for text in batch]
|
108 |
+
chars = np.stack(chars)
|
109 |
+
|
110 |
+
# Stack speaker embeddings into 2D array for batch processing
|
111 |
+
speaker_embeds = np.stack(batched_embeds[i-1])
|
112 |
+
|
113 |
+
# Convert to tensor
|
114 |
+
chars = torch.tensor(chars).long().to(self.device)
|
115 |
+
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
|
116 |
+
|
117 |
+
# Inference
|
118 |
+
_, mels, alignments = self._model.generate(chars, speaker_embeddings)
|
119 |
+
mels = mels.detach().cpu().numpy()
|
120 |
+
for m in mels:
|
121 |
+
# Trim silence from end of each spectrogram
|
122 |
+
while np.max(m[:, -1]) < hparams.tts_stop_threshold:
|
123 |
+
m = m[:, :-1]
|
124 |
+
specs.append(m)
|
125 |
+
|
126 |
+
if self.verbose:
|
127 |
+
print("\n\nDone.\n")
|
128 |
+
return (specs, alignments) if return_alignments else specs
|
129 |
+
|
130 |
+
@staticmethod
|
131 |
+
def load_preprocess_wav(fpath):
|
132 |
+
"""
|
133 |
+
Loads and preprocesses an audio file under the same conditions the audio files were used to
|
134 |
+
train the synthesizer.
|
135 |
+
"""
|
136 |
+
wav = librosa.load(str(fpath), hparams.sample_rate)[0]
|
137 |
+
if hparams.rescale:
|
138 |
+
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
139 |
+
return wav
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
|
143 |
+
"""
|
144 |
+
Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
|
145 |
+
were fed to the synthesizer when training.
|
146 |
+
"""
|
147 |
+
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
148 |
+
wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
|
149 |
+
else:
|
150 |
+
wav = fpath_or_wav
|
151 |
+
|
152 |
+
mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
|
153 |
+
return mel_spectrogram
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def griffin_lim(mel):
|
157 |
+
"""
|
158 |
+
Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
|
159 |
+
with the same parameters present in hparams.py.
|
160 |
+
"""
|
161 |
+
return audio.inv_mel_spectrogram(mel, hparams)
|
162 |
+
|
163 |
+
|
164 |
+
def pad1d(x, max_len, pad_value=0):
|
165 |
+
return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
|
synthesizer/models/__pycache__/tacotron.cpython-38.pyc
ADDED
Binary file (14.3 kB). View file
|
|
synthesizer/models/tacotron.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
|
10 |
+
class HighwayNetwork(nn.Module):
|
11 |
+
def __init__(self, size):
|
12 |
+
super().__init__()
|
13 |
+
self.W1 = nn.Linear(size, size)
|
14 |
+
self.W2 = nn.Linear(size, size)
|
15 |
+
self.W1.bias.data.fill_(0.)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x1 = self.W1(x)
|
19 |
+
x2 = self.W2(x)
|
20 |
+
g = torch.sigmoid(x2)
|
21 |
+
y = g * F.relu(x1) + (1. - g) * x
|
22 |
+
return y
|
23 |
+
|
24 |
+
|
25 |
+
class Encoder(nn.Module):
|
26 |
+
def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
|
27 |
+
super().__init__()
|
28 |
+
prenet_dims = (encoder_dims, encoder_dims)
|
29 |
+
cbhg_channels = encoder_dims
|
30 |
+
self.embedding = nn.Embedding(num_chars, embed_dims)
|
31 |
+
self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
32 |
+
dropout=dropout)
|
33 |
+
self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
|
34 |
+
proj_channels=[cbhg_channels, cbhg_channels],
|
35 |
+
num_highways=num_highways)
|
36 |
+
|
37 |
+
def forward(self, x, speaker_embedding=None):
|
38 |
+
x = self.embedding(x)
|
39 |
+
x = self.pre_net(x)
|
40 |
+
x.transpose_(1, 2)
|
41 |
+
x = self.cbhg(x)
|
42 |
+
if speaker_embedding is not None:
|
43 |
+
x = self.add_speaker_embedding(x, speaker_embedding)
|
44 |
+
return x
|
45 |
+
|
46 |
+
def add_speaker_embedding(self, x, speaker_embedding):
|
47 |
+
# SV2TTS
|
48 |
+
# The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
|
49 |
+
# When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
|
50 |
+
# (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
|
51 |
+
# This concats the speaker embedding for each char in the encoder output
|
52 |
+
|
53 |
+
# Save the dimensions as human-readable names
|
54 |
+
batch_size = x.size()[0]
|
55 |
+
num_chars = x.size()[1]
|
56 |
+
|
57 |
+
if speaker_embedding.dim() == 1:
|
58 |
+
idx = 0
|
59 |
+
else:
|
60 |
+
idx = 1
|
61 |
+
|
62 |
+
# Start by making a copy of each speaker embedding to match the input text length
|
63 |
+
# The output of this has size (batch_size, num_chars * tts_embed_dims)
|
64 |
+
speaker_embedding_size = speaker_embedding.size()[idx]
|
65 |
+
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
66 |
+
|
67 |
+
# Reshape it and transpose
|
68 |
+
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
|
69 |
+
e = e.transpose(1, 2)
|
70 |
+
|
71 |
+
# Concatenate the tiled speaker embedding with the encoder output
|
72 |
+
x = torch.cat((x, e), 2)
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
class BatchNormConv(nn.Module):
|
77 |
+
def __init__(self, in_channels, out_channels, kernel, relu=True):
|
78 |
+
super().__init__()
|
79 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
|
80 |
+
self.bnorm = nn.BatchNorm1d(out_channels)
|
81 |
+
self.relu = relu
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
x = self.conv(x)
|
85 |
+
x = F.relu(x) if self.relu is True else x
|
86 |
+
return self.bnorm(x)
|
87 |
+
|
88 |
+
|
89 |
+
class CBHG(nn.Module):
|
90 |
+
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
# List of all rnns to call `flatten_parameters()` on
|
94 |
+
self._to_flatten = []
|
95 |
+
|
96 |
+
self.bank_kernels = [i for i in range(1, K + 1)]
|
97 |
+
self.conv1d_bank = nn.ModuleList()
|
98 |
+
for k in self.bank_kernels:
|
99 |
+
conv = BatchNormConv(in_channels, channels, k)
|
100 |
+
self.conv1d_bank.append(conv)
|
101 |
+
|
102 |
+
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
103 |
+
|
104 |
+
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
|
105 |
+
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
|
106 |
+
|
107 |
+
# Fix the highway input if necessary
|
108 |
+
if proj_channels[-1] != channels:
|
109 |
+
self.highway_mismatch = True
|
110 |
+
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
111 |
+
else:
|
112 |
+
self.highway_mismatch = False
|
113 |
+
|
114 |
+
self.highways = nn.ModuleList()
|
115 |
+
for i in range(num_highways):
|
116 |
+
hn = HighwayNetwork(channels)
|
117 |
+
self.highways.append(hn)
|
118 |
+
|
119 |
+
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
|
120 |
+
self._to_flatten.append(self.rnn)
|
121 |
+
|
122 |
+
# Avoid fragmentation of RNN parameters and associated warning
|
123 |
+
self._flatten_parameters()
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
# Although we `_flatten_parameters()` on init, when using DataParallel
|
127 |
+
# the model gets replicated, making it no longer guaranteed that the
|
128 |
+
# weights are contiguous in GPU memory. Hence, we must call it again
|
129 |
+
self._flatten_parameters()
|
130 |
+
|
131 |
+
# Save these for later
|
132 |
+
residual = x
|
133 |
+
seq_len = x.size(-1)
|
134 |
+
conv_bank = []
|
135 |
+
|
136 |
+
# Convolution Bank
|
137 |
+
for conv in self.conv1d_bank:
|
138 |
+
c = conv(x) # Convolution
|
139 |
+
conv_bank.append(c[:, :, :seq_len])
|
140 |
+
|
141 |
+
# Stack along the channel axis
|
142 |
+
conv_bank = torch.cat(conv_bank, dim=1)
|
143 |
+
|
144 |
+
# dump the last padding to fit residual
|
145 |
+
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
146 |
+
|
147 |
+
# Conv1d projections
|
148 |
+
x = self.conv_project1(x)
|
149 |
+
x = self.conv_project2(x)
|
150 |
+
|
151 |
+
# Residual Connect
|
152 |
+
x = x + residual
|
153 |
+
|
154 |
+
# Through the highways
|
155 |
+
x = x.transpose(1, 2)
|
156 |
+
if self.highway_mismatch is True:
|
157 |
+
x = self.pre_highway(x)
|
158 |
+
for h in self.highways: x = h(x)
|
159 |
+
|
160 |
+
# And then the RNN
|
161 |
+
x, _ = self.rnn(x)
|
162 |
+
return x
|
163 |
+
|
164 |
+
def _flatten_parameters(self):
|
165 |
+
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
166 |
+
to improve efficiency and avoid PyTorch yelling at us."""
|
167 |
+
[m.flatten_parameters() for m in self._to_flatten]
|
168 |
+
|
169 |
+
class PreNet(nn.Module):
|
170 |
+
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
|
171 |
+
super().__init__()
|
172 |
+
self.fc1 = nn.Linear(in_dims, fc1_dims)
|
173 |
+
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
|
174 |
+
self.p = dropout
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
x = self.fc1(x)
|
178 |
+
x = F.relu(x)
|
179 |
+
x = F.dropout(x, self.p, training=True)
|
180 |
+
x = self.fc2(x)
|
181 |
+
x = F.relu(x)
|
182 |
+
x = F.dropout(x, self.p, training=True)
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
class Attention(nn.Module):
|
187 |
+
def __init__(self, attn_dims):
|
188 |
+
super().__init__()
|
189 |
+
self.W = nn.Linear(attn_dims, attn_dims, bias=False)
|
190 |
+
self.v = nn.Linear(attn_dims, 1, bias=False)
|
191 |
+
|
192 |
+
def forward(self, encoder_seq_proj, query, t):
|
193 |
+
|
194 |
+
# print(encoder_seq_proj.shape)
|
195 |
+
# Transform the query vector
|
196 |
+
query_proj = self.W(query).unsqueeze(1)
|
197 |
+
|
198 |
+
# Compute the scores
|
199 |
+
u = self.v(torch.tanh(encoder_seq_proj + query_proj))
|
200 |
+
scores = F.softmax(u, dim=1)
|
201 |
+
|
202 |
+
return scores.transpose(1, 2)
|
203 |
+
|
204 |
+
|
205 |
+
class LSA(nn.Module):
|
206 |
+
def __init__(self, attn_dim, kernel_size=31, filters=32):
|
207 |
+
super().__init__()
|
208 |
+
self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
|
209 |
+
self.L = nn.Linear(filters, attn_dim, bias=False)
|
210 |
+
self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
|
211 |
+
self.v = nn.Linear(attn_dim, 1, bias=False)
|
212 |
+
self.cumulative = None
|
213 |
+
self.attention = None
|
214 |
+
|
215 |
+
def init_attention(self, encoder_seq_proj):
|
216 |
+
device = next(self.parameters()).device # use same device as parameters
|
217 |
+
b, t, c = encoder_seq_proj.size()
|
218 |
+
self.cumulative = torch.zeros(b, t, device=device)
|
219 |
+
self.attention = torch.zeros(b, t, device=device)
|
220 |
+
|
221 |
+
def forward(self, encoder_seq_proj, query, t, chars):
|
222 |
+
|
223 |
+
if t == 0: self.init_attention(encoder_seq_proj)
|
224 |
+
|
225 |
+
processed_query = self.W(query).unsqueeze(1)
|
226 |
+
|
227 |
+
location = self.cumulative.unsqueeze(1)
|
228 |
+
processed_loc = self.L(self.conv(location).transpose(1, 2))
|
229 |
+
|
230 |
+
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
|
231 |
+
u = u.squeeze(-1)
|
232 |
+
|
233 |
+
# Mask zero padding chars
|
234 |
+
u = u * (chars != 0).float()
|
235 |
+
|
236 |
+
# Smooth Attention
|
237 |
+
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
|
238 |
+
scores = F.softmax(u, dim=1)
|
239 |
+
self.attention = scores
|
240 |
+
self.cumulative = self.cumulative + self.attention
|
241 |
+
|
242 |
+
return scores.unsqueeze(-1).transpose(1, 2)
|
243 |
+
|
244 |
+
|
245 |
+
class Decoder(nn.Module):
|
246 |
+
# Class variable because its value doesn't change between classes
|
247 |
+
# yet ought to be scoped by class because its a property of a Decoder
|
248 |
+
max_r = 20
|
249 |
+
def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
|
250 |
+
dropout, speaker_embedding_size):
|
251 |
+
super().__init__()
|
252 |
+
self.register_buffer("r", torch.tensor(1, dtype=torch.int))
|
253 |
+
self.n_mels = n_mels
|
254 |
+
prenet_dims = (decoder_dims * 2, decoder_dims * 2)
|
255 |
+
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
256 |
+
dropout=dropout)
|
257 |
+
self.attn_net = LSA(decoder_dims)
|
258 |
+
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
|
259 |
+
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
|
260 |
+
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
261 |
+
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
262 |
+
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
263 |
+
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
264 |
+
|
265 |
+
def zoneout(self, prev, current, p=0.1):
|
266 |
+
device = next(self.parameters()).device # Use same device as parameters
|
267 |
+
mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
|
268 |
+
return prev * mask + current * (1 - mask)
|
269 |
+
|
270 |
+
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
|
271 |
+
hidden_states, cell_states, context_vec, t, chars):
|
272 |
+
|
273 |
+
# Need this for reshaping mels
|
274 |
+
batch_size = encoder_seq.size(0)
|
275 |
+
|
276 |
+
# Unpack the hidden and cell states
|
277 |
+
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
|
278 |
+
rnn1_cell, rnn2_cell = cell_states
|
279 |
+
|
280 |
+
# PreNet for the Attention RNN
|
281 |
+
prenet_out = self.prenet(prenet_in)
|
282 |
+
|
283 |
+
# Compute the Attention RNN hidden state
|
284 |
+
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
|
285 |
+
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
|
286 |
+
|
287 |
+
# Compute the attention scores
|
288 |
+
scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
|
289 |
+
|
290 |
+
# Dot product to create the context vector
|
291 |
+
context_vec = scores @ encoder_seq
|
292 |
+
context_vec = context_vec.squeeze(1)
|
293 |
+
|
294 |
+
# Concat Attention RNN output w. Context Vector & project
|
295 |
+
x = torch.cat([context_vec, attn_hidden], dim=1)
|
296 |
+
x = self.rnn_input(x)
|
297 |
+
|
298 |
+
# Compute first Residual RNN
|
299 |
+
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
300 |
+
if self.training:
|
301 |
+
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
|
302 |
+
else:
|
303 |
+
rnn1_hidden = rnn1_hidden_next
|
304 |
+
x = x + rnn1_hidden
|
305 |
+
|
306 |
+
# Compute second Residual RNN
|
307 |
+
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
|
308 |
+
if self.training:
|
309 |
+
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
|
310 |
+
else:
|
311 |
+
rnn2_hidden = rnn2_hidden_next
|
312 |
+
x = x + rnn2_hidden
|
313 |
+
|
314 |
+
# Project Mels
|
315 |
+
mels = self.mel_proj(x)
|
316 |
+
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
|
317 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
318 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
319 |
+
|
320 |
+
# Stop token prediction
|
321 |
+
s = torch.cat((x, context_vec), dim=1)
|
322 |
+
s = self.stop_proj(s)
|
323 |
+
stop_tokens = torch.sigmoid(s)
|
324 |
+
|
325 |
+
return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
|
326 |
+
|
327 |
+
|
328 |
+
class Tacotron(nn.Module):
|
329 |
+
def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
|
330 |
+
fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
|
331 |
+
dropout, stop_threshold, speaker_embedding_size):
|
332 |
+
super().__init__()
|
333 |
+
self.n_mels = n_mels
|
334 |
+
self.lstm_dims = lstm_dims
|
335 |
+
self.encoder_dims = encoder_dims
|
336 |
+
self.decoder_dims = decoder_dims
|
337 |
+
self.speaker_embedding_size = speaker_embedding_size
|
338 |
+
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
339 |
+
encoder_K, num_highways, dropout)
|
340 |
+
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
|
341 |
+
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
342 |
+
dropout, speaker_embedding_size)
|
343 |
+
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
344 |
+
[postnet_dims, fft_bins], num_highways)
|
345 |
+
self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
|
346 |
+
|
347 |
+
self.init_model()
|
348 |
+
self.num_params()
|
349 |
+
|
350 |
+
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
|
351 |
+
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
|
352 |
+
|
353 |
+
@property
|
354 |
+
def r(self):
|
355 |
+
return self.decoder.r.item()
|
356 |
+
|
357 |
+
@r.setter
|
358 |
+
def r(self, value):
|
359 |
+
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
360 |
+
|
361 |
+
def forward(self, x, m, speaker_embedding):
|
362 |
+
device = next(self.parameters()).device # use same device as parameters
|
363 |
+
|
364 |
+
self.step += 1
|
365 |
+
batch_size, _, steps = m.size()
|
366 |
+
|
367 |
+
# Initialise all hidden states and pack into tuple
|
368 |
+
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
369 |
+
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
370 |
+
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
371 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
372 |
+
|
373 |
+
# Initialise all lstm cell states and pack into tuple
|
374 |
+
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
375 |
+
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
376 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
377 |
+
|
378 |
+
# <GO> Frame for start of decoder loop
|
379 |
+
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
380 |
+
|
381 |
+
# Need an initial context vector
|
382 |
+
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
383 |
+
|
384 |
+
# SV2TTS: Run the encoder with the speaker embedding
|
385 |
+
# The projection avoids unnecessary matmuls in the decoder loop
|
386 |
+
encoder_seq = self.encoder(x, speaker_embedding)
|
387 |
+
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
388 |
+
|
389 |
+
# Need a couple of lists for outputs
|
390 |
+
mel_outputs, attn_scores, stop_outputs = [], [], []
|
391 |
+
|
392 |
+
# Run the decoder loop
|
393 |
+
for t in range(0, steps, self.r):
|
394 |
+
prenet_in = m[:, :, t - 1] if t > 0 else go_frame
|
395 |
+
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
396 |
+
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
397 |
+
hidden_states, cell_states, context_vec, t, x)
|
398 |
+
mel_outputs.append(mel_frames)
|
399 |
+
attn_scores.append(scores)
|
400 |
+
stop_outputs.extend([stop_tokens] * self.r)
|
401 |
+
|
402 |
+
# Concat the mel outputs into sequence
|
403 |
+
mel_outputs = torch.cat(mel_outputs, dim=2)
|
404 |
+
|
405 |
+
# Post-Process for Linear Spectrograms
|
406 |
+
postnet_out = self.postnet(mel_outputs)
|
407 |
+
linear = self.post_proj(postnet_out)
|
408 |
+
linear = linear.transpose(1, 2)
|
409 |
+
|
410 |
+
# For easy visualisation
|
411 |
+
attn_scores = torch.cat(attn_scores, 1)
|
412 |
+
# attn_scores = attn_scores.cpu().data.numpy()
|
413 |
+
stop_outputs = torch.cat(stop_outputs, 1)
|
414 |
+
|
415 |
+
return mel_outputs, linear, attn_scores, stop_outputs
|
416 |
+
|
417 |
+
def generate(self, x, speaker_embedding=None, steps=2000):
|
418 |
+
self.eval()
|
419 |
+
device = next(self.parameters()).device # use same device as parameters
|
420 |
+
|
421 |
+
batch_size, _ = x.size()
|
422 |
+
|
423 |
+
# Need to initialise all hidden states and pack into tuple for tidyness
|
424 |
+
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
425 |
+
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
426 |
+
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
427 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
428 |
+
|
429 |
+
# Need to initialise all lstm cell states and pack into tuple for tidyness
|
430 |
+
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
431 |
+
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
432 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
433 |
+
|
434 |
+
# Need a <GO> Frame for start of decoder loop
|
435 |
+
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
436 |
+
|
437 |
+
# Need an initial context vector
|
438 |
+
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
439 |
+
|
440 |
+
# SV2TTS: Run the encoder with the speaker embedding
|
441 |
+
# The projection avoids unnecessary matmuls in the decoder loop
|
442 |
+
encoder_seq = self.encoder(x, speaker_embedding)
|
443 |
+
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
444 |
+
|
445 |
+
# Need a couple of lists for outputs
|
446 |
+
mel_outputs, attn_scores, stop_outputs = [], [], []
|
447 |
+
|
448 |
+
# Run the decoder loop
|
449 |
+
for t in range(0, steps, self.r):
|
450 |
+
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
451 |
+
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
452 |
+
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
453 |
+
hidden_states, cell_states, context_vec, t, x)
|
454 |
+
mel_outputs.append(mel_frames)
|
455 |
+
attn_scores.append(scores)
|
456 |
+
stop_outputs.extend([stop_tokens] * self.r)
|
457 |
+
# Stop the loop when all stop tokens in batch exceed threshold
|
458 |
+
if (stop_tokens > 0.5).all() and t > 10: break
|
459 |
+
|
460 |
+
# Concat the mel outputs into sequence
|
461 |
+
mel_outputs = torch.cat(mel_outputs, dim=2)
|
462 |
+
|
463 |
+
# Post-Process for Linear Spectrograms
|
464 |
+
postnet_out = self.postnet(mel_outputs)
|
465 |
+
linear = self.post_proj(postnet_out)
|
466 |
+
|
467 |
+
|
468 |
+
linear = linear.transpose(1, 2)
|
469 |
+
|
470 |
+
# For easy visualisation
|
471 |
+
attn_scores = torch.cat(attn_scores, 1)
|
472 |
+
stop_outputs = torch.cat(stop_outputs, 1)
|
473 |
+
|
474 |
+
self.train()
|
475 |
+
|
476 |
+
return mel_outputs, linear, attn_scores
|
477 |
+
|
478 |
+
def init_model(self):
|
479 |
+
for p in self.parameters():
|
480 |
+
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
481 |
+
|
482 |
+
def get_step(self):
|
483 |
+
return self.step.data.item()
|
484 |
+
|
485 |
+
def reset_step(self):
|
486 |
+
# assignment to parameters or buffers is overloaded, updates internal dict entry
|
487 |
+
self.step = self.step.data.new_tensor(1)
|
488 |
+
|
489 |
+
def log(self, path, msg):
|
490 |
+
with open(path, "a") as f:
|
491 |
+
print(msg, file=f)
|
492 |
+
|
493 |
+
def load(self, path, optimizer=None):
|
494 |
+
# Use device of model params as location for loaded state
|
495 |
+
device = next(self.parameters()).device
|
496 |
+
checkpoint = torch.load(str(path), map_location=device)
|
497 |
+
self.load_state_dict(checkpoint["model_state"])
|
498 |
+
|
499 |
+
if "optimizer_state" in checkpoint and optimizer is not None:
|
500 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
501 |
+
|
502 |
+
def save(self, path, optimizer=None):
|
503 |
+
if optimizer is not None:
|
504 |
+
torch.save({
|
505 |
+
"model_state": self.state_dict(),
|
506 |
+
"optimizer_state": optimizer.state_dict(),
|
507 |
+
}, str(path))
|
508 |
+
else:
|
509 |
+
torch.save({
|
510 |
+
"model_state": self.state_dict(),
|
511 |
+
}, str(path))
|
512 |
+
|
513 |
+
|
514 |
+
def num_params(self, print_out=True):
|
515 |
+
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
516 |
+
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
517 |
+
if print_out:
|
518 |
+
print("Trainable Parameters: %.3fM" % parameters)
|
519 |
+
return parameters
|
synthesizer/preprocess.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocessing.pool import Pool
|
2 |
+
from synthesizer import audio
|
3 |
+
from functools import partial
|
4 |
+
from itertools import chain
|
5 |
+
from encoder import inference as encoder
|
6 |
+
from pathlib import Path
|
7 |
+
from utils import logmmse
|
8 |
+
from tqdm import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import librosa
|
11 |
+
|
12 |
+
|
13 |
+
def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int, skip_existing: bool, hparams,
|
14 |
+
no_alignments: bool, datasets_name: str, subfolders: str):
|
15 |
+
# Gather the input directories
|
16 |
+
dataset_root = datasets_root.joinpath(datasets_name)
|
17 |
+
input_dirs = [dataset_root.joinpath(subfolder.strip()) for subfolder in subfolders.split(",")]
|
18 |
+
print("\n ".join(map(str, ["Using data from:"] + input_dirs)))
|
19 |
+
assert all(input_dir.exists() for input_dir in input_dirs)
|
20 |
+
|
21 |
+
# Create the output directories for each output file type
|
22 |
+
out_dir.joinpath("mels").mkdir(exist_ok=True)
|
23 |
+
out_dir.joinpath("audio").mkdir(exist_ok=True)
|
24 |
+
|
25 |
+
# Create a metadata file
|
26 |
+
metadata_fpath = out_dir.joinpath("train.txt")
|
27 |
+
metadata_file = metadata_fpath.open("a" if skip_existing else "w", encoding="utf-8")
|
28 |
+
|
29 |
+
# Preprocess the dataset
|
30 |
+
speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs))
|
31 |
+
func = partial(preprocess_speaker, out_dir=out_dir, skip_existing=skip_existing,
|
32 |
+
hparams=hparams, no_alignments=no_alignments)
|
33 |
+
job = Pool(n_processes).imap(func, speaker_dirs)
|
34 |
+
for speaker_metadata in tqdm(job, datasets_name, len(speaker_dirs), unit="speakers"):
|
35 |
+
for metadatum in speaker_metadata:
|
36 |
+
metadata_file.write("|".join(str(x) for x in metadatum) + "\n")
|
37 |
+
metadata_file.close()
|
38 |
+
|
39 |
+
# Verify the contents of the metadata file
|
40 |
+
with metadata_fpath.open("r", encoding="utf-8") as metadata_file:
|
41 |
+
metadata = [line.split("|") for line in metadata_file]
|
42 |
+
mel_frames = sum([int(m[4]) for m in metadata])
|
43 |
+
timesteps = sum([int(m[3]) for m in metadata])
|
44 |
+
sample_rate = hparams.sample_rate
|
45 |
+
hours = (timesteps / sample_rate) / 3600
|
46 |
+
print("The dataset consists of %d utterances, %d mel frames, %d audio timesteps (%.2f hours)." %
|
47 |
+
(len(metadata), mel_frames, timesteps, hours))
|
48 |
+
print("Max input length (text chars): %d" % max(len(m[5]) for m in metadata))
|
49 |
+
print("Max mel frames length: %d" % max(int(m[4]) for m in metadata))
|
50 |
+
print("Max audio timesteps length: %d" % max(int(m[3]) for m in metadata))
|
51 |
+
|
52 |
+
|
53 |
+
def preprocess_speaker(speaker_dir, out_dir: Path, skip_existing: bool, hparams, no_alignments: bool):
|
54 |
+
metadata = []
|
55 |
+
for book_dir in speaker_dir.glob("*"):
|
56 |
+
if no_alignments:
|
57 |
+
# Gather the utterance audios and texts
|
58 |
+
# LibriTTS uses .wav but we will include extensions for compatibility with other datasets
|
59 |
+
extensions = ["*.wav", "*.flac", "*.mp3"]
|
60 |
+
for extension in extensions:
|
61 |
+
wav_fpaths = book_dir.glob(extension)
|
62 |
+
|
63 |
+
for wav_fpath in wav_fpaths:
|
64 |
+
# Load the audio waveform
|
65 |
+
wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
|
66 |
+
if hparams.rescale:
|
67 |
+
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
68 |
+
|
69 |
+
# Get the corresponding text
|
70 |
+
# Check for .txt (for compatibility with other datasets)
|
71 |
+
text_fpath = wav_fpath.with_suffix(".txt")
|
72 |
+
if not text_fpath.exists():
|
73 |
+
# Check for .normalized.txt (LibriTTS)
|
74 |
+
text_fpath = wav_fpath.with_suffix(".normalized.txt")
|
75 |
+
assert text_fpath.exists()
|
76 |
+
with text_fpath.open("r") as text_file:
|
77 |
+
text = "".join([line for line in text_file])
|
78 |
+
text = text.replace("\"", "")
|
79 |
+
text = text.strip()
|
80 |
+
|
81 |
+
# Process the utterance
|
82 |
+
metadata.append(process_utterance(wav, text, out_dir, str(wav_fpath.with_suffix("").name),
|
83 |
+
skip_existing, hparams))
|
84 |
+
else:
|
85 |
+
# Process alignment file (LibriSpeech support)
|
86 |
+
# Gather the utterance audios and texts
|
87 |
+
try:
|
88 |
+
alignments_fpath = next(book_dir.glob("*.alignment.txt"))
|
89 |
+
with alignments_fpath.open("r") as alignments_file:
|
90 |
+
alignments = [line.rstrip().split(" ") for line in alignments_file]
|
91 |
+
except StopIteration:
|
92 |
+
# A few alignment files will be missing
|
93 |
+
continue
|
94 |
+
|
95 |
+
# Iterate over each entry in the alignments file
|
96 |
+
for wav_fname, words, end_times in alignments:
|
97 |
+
wav_fpath = book_dir.joinpath(wav_fname + ".flac")
|
98 |
+
assert wav_fpath.exists()
|
99 |
+
words = words.replace("\"", "").split(",")
|
100 |
+
end_times = list(map(float, end_times.replace("\"", "").split(",")))
|
101 |
+
|
102 |
+
# Process each sub-utterance
|
103 |
+
wavs, texts = split_on_silences(wav_fpath, words, end_times, hparams)
|
104 |
+
for i, (wav, text) in enumerate(zip(wavs, texts)):
|
105 |
+
sub_basename = "%s_%02d" % (wav_fname, i)
|
106 |
+
metadata.append(process_utterance(wav, text, out_dir, sub_basename,
|
107 |
+
skip_existing, hparams))
|
108 |
+
|
109 |
+
return [m for m in metadata if m is not None]
|
110 |
+
|
111 |
+
|
112 |
+
def split_on_silences(wav_fpath, words, end_times, hparams):
|
113 |
+
# Load the audio waveform
|
114 |
+
wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
|
115 |
+
if hparams.rescale:
|
116 |
+
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
117 |
+
|
118 |
+
words = np.array(words)
|
119 |
+
start_times = np.array([0.0] + end_times[:-1])
|
120 |
+
end_times = np.array(end_times)
|
121 |
+
assert len(words) == len(end_times) == len(start_times)
|
122 |
+
assert words[0] == "" and words[-1] == ""
|
123 |
+
|
124 |
+
# Find pauses that are too long
|
125 |
+
mask = (words == "") & (end_times - start_times >= hparams.silence_min_duration_split)
|
126 |
+
mask[0] = mask[-1] = True
|
127 |
+
breaks = np.where(mask)[0]
|
128 |
+
|
129 |
+
# Profile the noise from the silences and perform noise reduction on the waveform
|
130 |
+
silence_times = [[start_times[i], end_times[i]] for i in breaks]
|
131 |
+
silence_times = (np.array(silence_times) * hparams.sample_rate).astype(np.int)
|
132 |
+
noisy_wav = np.concatenate([wav[stime[0]:stime[1]] for stime in silence_times])
|
133 |
+
if len(noisy_wav) > hparams.sample_rate * 0.02:
|
134 |
+
profile = logmmse.profile_noise(noisy_wav, hparams.sample_rate)
|
135 |
+
wav = logmmse.denoise(wav, profile, eta=0)
|
136 |
+
|
137 |
+
# Re-attach segments that are too short
|
138 |
+
segments = list(zip(breaks[:-1], breaks[1:]))
|
139 |
+
segment_durations = [start_times[end] - end_times[start] for start, end in segments]
|
140 |
+
i = 0
|
141 |
+
while i < len(segments) and len(segments) > 1:
|
142 |
+
if segment_durations[i] < hparams.utterance_min_duration:
|
143 |
+
# See if the segment can be re-attached with the right or the left segment
|
144 |
+
left_duration = float("inf") if i == 0 else segment_durations[i - 1]
|
145 |
+
right_duration = float("inf") if i == len(segments) - 1 else segment_durations[i + 1]
|
146 |
+
joined_duration = segment_durations[i] + min(left_duration, right_duration)
|
147 |
+
|
148 |
+
# Do not re-attach if it causes the joined utterance to be too long
|
149 |
+
if joined_duration > hparams.hop_size * hparams.max_mel_frames / hparams.sample_rate:
|
150 |
+
i += 1
|
151 |
+
continue
|
152 |
+
|
153 |
+
# Re-attach the segment with the neighbour of shortest duration
|
154 |
+
j = i - 1 if left_duration <= right_duration else i
|
155 |
+
segments[j] = (segments[j][0], segments[j + 1][1])
|
156 |
+
segment_durations[j] = joined_duration
|
157 |
+
del segments[j + 1], segment_durations[j + 1]
|
158 |
+
else:
|
159 |
+
i += 1
|
160 |
+
|
161 |
+
# Split the utterance
|
162 |
+
segment_times = [[end_times[start], start_times[end]] for start, end in segments]
|
163 |
+
segment_times = (np.array(segment_times) * hparams.sample_rate).astype(np.int)
|
164 |
+
wavs = [wav[segment_time[0]:segment_time[1]] for segment_time in segment_times]
|
165 |
+
texts = [" ".join(words[start + 1:end]).replace(" ", " ") for start, end in segments]
|
166 |
+
|
167 |
+
# # DEBUG: play the audio segments (run with -n=1)
|
168 |
+
# import sounddevice as sd
|
169 |
+
# if len(wavs) > 1:
|
170 |
+
# print("This sentence was split in %d segments:" % len(wavs))
|
171 |
+
# else:
|
172 |
+
# print("There are no silences long enough for this sentence to be split:")
|
173 |
+
# for wav, text in zip(wavs, texts):
|
174 |
+
# # Pad the waveform with 1 second of silence because sounddevice tends to cut them early
|
175 |
+
# # when playing them. You shouldn't need to do that in your parsers.
|
176 |
+
# wav = np.concatenate((wav, [0] * 16000))
|
177 |
+
# print("\t%s" % text)
|
178 |
+
# sd.play(wav, 16000, blocking=True)
|
179 |
+
# print("")
|
180 |
+
|
181 |
+
return wavs, texts
|
182 |
+
|
183 |
+
|
184 |
+
def process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
|
185 |
+
skip_existing: bool, hparams):
|
186 |
+
## FOR REFERENCE:
|
187 |
+
# For you not to lose your head if you ever wish to change things here or implement your own
|
188 |
+
# synthesizer.
|
189 |
+
# - Both the audios and the mel spectrograms are saved as numpy arrays
|
190 |
+
# - There is no processing done to the audios that will be saved to disk beyond volume
|
191 |
+
# normalization (in split_on_silences)
|
192 |
+
# - However, pre-emphasis is applied to the audios before computing the mel spectrogram. This
|
193 |
+
# is why we re-apply it on the audio on the side of the vocoder.
|
194 |
+
# - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved
|
195 |
+
# without extra padding. This means that you won't have an exact relation between the length
|
196 |
+
# of the wav and of the mel spectrogram. See the vocoder data loader.
|
197 |
+
|
198 |
+
|
199 |
+
# Skip existing utterances if needed
|
200 |
+
mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
|
201 |
+
wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
|
202 |
+
if skip_existing and mel_fpath.exists() and wav_fpath.exists():
|
203 |
+
return None
|
204 |
+
|
205 |
+
# Trim silence
|
206 |
+
if hparams.trim_silence:
|
207 |
+
wav = encoder.preprocess_wav(wav, normalize=False, trim_silence=True)
|
208 |
+
|
209 |
+
# Skip utterances that are too short
|
210 |
+
if len(wav) < hparams.utterance_min_duration * hparams.sample_rate:
|
211 |
+
return None
|
212 |
+
|
213 |
+
# Compute the mel spectrogram
|
214 |
+
mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
|
215 |
+
mel_frames = mel_spectrogram.shape[1]
|
216 |
+
|
217 |
+
# Skip utterances that are too long
|
218 |
+
if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length:
|
219 |
+
return None
|
220 |
+
|
221 |
+
# Write the spectrogram, embed and audio to disk
|
222 |
+
np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
|
223 |
+
np.save(wav_fpath, wav, allow_pickle=False)
|
224 |
+
|
225 |
+
# Return a tuple describing this training example
|
226 |
+
return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
|
227 |
+
|
228 |
+
|
229 |
+
def embed_utterance(fpaths, encoder_model_fpath):
|
230 |
+
if not encoder.is_loaded():
|
231 |
+
encoder.load_model(encoder_model_fpath)
|
232 |
+
|
233 |
+
# Compute the speaker embedding of the utterance
|
234 |
+
wav_fpath, embed_fpath = fpaths
|
235 |
+
wav = np.load(wav_fpath)
|
236 |
+
wav = encoder.preprocess_wav(wav)
|
237 |
+
embed = encoder.embed_utterance(wav)
|
238 |
+
np.save(embed_fpath, embed, allow_pickle=False)
|
239 |
+
|
240 |
+
|
241 |
+
def create_embeddings(synthesizer_root: Path, encoder_model_fpath: Path, n_processes: int):
|
242 |
+
wav_dir = synthesizer_root.joinpath("audio")
|
243 |
+
metadata_fpath = synthesizer_root.joinpath("train.txt")
|
244 |
+
assert wav_dir.exists() and metadata_fpath.exists()
|
245 |
+
embed_dir = synthesizer_root.joinpath("embeds")
|
246 |
+
embed_dir.mkdir(exist_ok=True)
|
247 |
+
|
248 |
+
# Gather the input wave filepath and the target output embed filepath
|
249 |
+
with metadata_fpath.open("r") as metadata_file:
|
250 |
+
metadata = [line.split("|") for line in metadata_file]
|
251 |
+
fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
|
252 |
+
|
253 |
+
# TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
|
254 |
+
# Embed the utterances in separate threads
|
255 |
+
func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
|
256 |
+
job = Pool(n_processes).imap(func, fpaths)
|
257 |
+
list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
|
258 |
+
|
synthesizer/synthesize.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
from functools import partial
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from synthesizer.hparams import hparams_debug_string
|
11 |
+
from synthesizer.models.tacotron import Tacotron
|
12 |
+
from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
|
13 |
+
from synthesizer.utils import data_parallel_workaround
|
14 |
+
from synthesizer.utils.symbols import symbols
|
15 |
+
|
16 |
+
|
17 |
+
def run_synthesis(in_dir: Path, out_dir: Path, syn_model_fpath: Path, hparams):
|
18 |
+
# This generates ground truth-aligned mels for vocoder training
|
19 |
+
synth_dir = out_dir / "mels_gta"
|
20 |
+
synth_dir.mkdir(exist_ok=True, parents=True)
|
21 |
+
print(hparams_debug_string())
|
22 |
+
|
23 |
+
# Check for GPU
|
24 |
+
if torch.cuda.is_available():
|
25 |
+
device = torch.device("cuda")
|
26 |
+
if hparams.synthesis_batch_size % torch.cuda.device_count() != 0:
|
27 |
+
raise ValueError("`hparams.synthesis_batch_size` must be evenly divisible by n_gpus!")
|
28 |
+
else:
|
29 |
+
device = torch.device("cpu")
|
30 |
+
print("Synthesizer using device:", device)
|
31 |
+
|
32 |
+
# Instantiate Tacotron model
|
33 |
+
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
34 |
+
num_chars=len(symbols),
|
35 |
+
encoder_dims=hparams.tts_encoder_dims,
|
36 |
+
decoder_dims=hparams.tts_decoder_dims,
|
37 |
+
n_mels=hparams.num_mels,
|
38 |
+
fft_bins=hparams.num_mels,
|
39 |
+
postnet_dims=hparams.tts_postnet_dims,
|
40 |
+
encoder_K=hparams.tts_encoder_K,
|
41 |
+
lstm_dims=hparams.tts_lstm_dims,
|
42 |
+
postnet_K=hparams.tts_postnet_K,
|
43 |
+
num_highways=hparams.tts_num_highways,
|
44 |
+
dropout=0., # Use zero dropout for gta mels
|
45 |
+
stop_threshold=hparams.tts_stop_threshold,
|
46 |
+
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
47 |
+
|
48 |
+
# Load the weights
|
49 |
+
print("\nLoading weights at %s" % syn_model_fpath)
|
50 |
+
model.load(syn_model_fpath)
|
51 |
+
print("Tacotron weights loaded from step %d" % model.step)
|
52 |
+
|
53 |
+
# Synthesize using same reduction factor as the model is currently trained
|
54 |
+
r = np.int32(model.r)
|
55 |
+
|
56 |
+
# Set model to eval mode (disable gradient and zoneout)
|
57 |
+
model.eval()
|
58 |
+
|
59 |
+
# Initialize the dataset
|
60 |
+
metadata_fpath = in_dir.joinpath("train.txt")
|
61 |
+
mel_dir = in_dir.joinpath("mels")
|
62 |
+
embed_dir = in_dir.joinpath("embeds")
|
63 |
+
|
64 |
+
dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
|
65 |
+
collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
|
66 |
+
data_loader = DataLoader(dataset, hparams.synthesis_batch_size, collate_fn=collate_fn, num_workers=2)
|
67 |
+
|
68 |
+
# Generate GTA mels
|
69 |
+
meta_out_fpath = out_dir / "synthesized.txt"
|
70 |
+
with meta_out_fpath.open("w") as file:
|
71 |
+
for i, (texts, mels, embeds, idx) in tqdm(enumerate(data_loader), total=len(data_loader)):
|
72 |
+
texts, mels, embeds = texts.to(device), mels.to(device), embeds.to(device)
|
73 |
+
|
74 |
+
# Parallelize model onto GPUS using workaround due to python bug
|
75 |
+
if device.type == "cuda" and torch.cuda.device_count() > 1:
|
76 |
+
_, mels_out, _ = data_parallel_workaround(model, texts, mels, embeds)
|
77 |
+
else:
|
78 |
+
_, mels_out, _, _ = model(texts, mels, embeds)
|
79 |
+
|
80 |
+
for j, k in enumerate(idx):
|
81 |
+
# Note: outputs mel-spectrogram files and target ones have same names, just different folders
|
82 |
+
mel_filename = Path(synth_dir).joinpath(dataset.metadata[k][1])
|
83 |
+
mel_out = mels_out[j].detach().cpu().numpy().T
|
84 |
+
|
85 |
+
# Use the length of the ground truth mel to remove padding from the generated mels
|
86 |
+
mel_out = mel_out[:int(dataset.metadata[k][4])]
|
87 |
+
|
88 |
+
# Write the spectrogram to disk
|
89 |
+
np.save(mel_filename, mel_out, allow_pickle=False)
|
90 |
+
|
91 |
+
# Write metadata into the synthesized file
|
92 |
+
file.write("|".join(dataset.metadata[k]))
|
synthesizer/synthesizer_dataset.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import numpy as np
|
4 |
+
from pathlib import Path
|
5 |
+
from synthesizer.utils.text import text_to_sequence
|
6 |
+
|
7 |
+
|
8 |
+
class SynthesizerDataset(Dataset):
|
9 |
+
def __init__(self, metadata_fpath: Path, mel_dir: Path, embed_dir: Path, hparams):
|
10 |
+
print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, embed_dir))
|
11 |
+
|
12 |
+
with metadata_fpath.open("r") as metadata_file:
|
13 |
+
metadata = [line.split("|") for line in metadata_file]
|
14 |
+
|
15 |
+
mel_fnames = [x[1] for x in metadata if int(x[4])]
|
16 |
+
mel_fpaths = [mel_dir.joinpath(fname) for fname in mel_fnames]
|
17 |
+
embed_fnames = [x[2] for x in metadata if int(x[4])]
|
18 |
+
embed_fpaths = [embed_dir.joinpath(fname) for fname in embed_fnames]
|
19 |
+
self.samples_fpaths = list(zip(mel_fpaths, embed_fpaths))
|
20 |
+
self.samples_texts = [x[5].strip() for x in metadata if int(x[4])]
|
21 |
+
self.metadata = metadata
|
22 |
+
self.hparams = hparams
|
23 |
+
|
24 |
+
print("Found %d samples" % len(self.samples_fpaths))
|
25 |
+
|
26 |
+
def __getitem__(self, index):
|
27 |
+
# Sometimes index may be a list of 2 (not sure why this happens)
|
28 |
+
# If that is the case, return a single item corresponding to first element in index
|
29 |
+
if index is list:
|
30 |
+
index = index[0]
|
31 |
+
|
32 |
+
mel_path, embed_path = self.samples_fpaths[index]
|
33 |
+
mel = np.load(mel_path).T.astype(np.float32)
|
34 |
+
|
35 |
+
# Load the embed
|
36 |
+
embed = np.load(embed_path)
|
37 |
+
|
38 |
+
# Get the text and clean it
|
39 |
+
text = text_to_sequence(self.samples_texts[index], self.hparams.tts_cleaner_names)
|
40 |
+
|
41 |
+
# Convert the list returned by text_to_sequence to a numpy array
|
42 |
+
text = np.asarray(text).astype(np.int32)
|
43 |
+
|
44 |
+
return text, mel.astype(np.float32), embed.astype(np.float32), index
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return len(self.samples_fpaths)
|
48 |
+
|
49 |
+
|
50 |
+
def collate_synthesizer(batch, r, hparams):
|
51 |
+
# Text
|
52 |
+
x_lens = [len(x[0]) for x in batch]
|
53 |
+
max_x_len = max(x_lens)
|
54 |
+
|
55 |
+
chars = [pad1d(x[0], max_x_len) for x in batch]
|
56 |
+
chars = np.stack(chars)
|
57 |
+
|
58 |
+
# Mel spectrogram
|
59 |
+
spec_lens = [x[1].shape[-1] for x in batch]
|
60 |
+
max_spec_len = max(spec_lens) + 1
|
61 |
+
if max_spec_len % r != 0:
|
62 |
+
max_spec_len += r - max_spec_len % r
|
63 |
+
|
64 |
+
# WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence
|
65 |
+
# By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence.
|
66 |
+
if hparams.symmetric_mels:
|
67 |
+
mel_pad_value = -1 * hparams.max_abs_value
|
68 |
+
else:
|
69 |
+
mel_pad_value = 0
|
70 |
+
|
71 |
+
mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch]
|
72 |
+
mel = np.stack(mel)
|
73 |
+
|
74 |
+
# Speaker embedding (SV2TTS)
|
75 |
+
embeds = np.array([x[2] for x in batch])
|
76 |
+
|
77 |
+
# Index (for vocoder preprocessing)
|
78 |
+
indices = [x[3] for x in batch]
|
79 |
+
|
80 |
+
|
81 |
+
# Convert all to tensor
|
82 |
+
chars = torch.tensor(chars).long()
|
83 |
+
mel = torch.tensor(mel)
|
84 |
+
embeds = torch.tensor(embeds)
|
85 |
+
|
86 |
+
return chars, mel, embeds, indices
|
87 |
+
|
88 |
+
def pad1d(x, max_len, pad_value=0):
|
89 |
+
return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
|
90 |
+
|
91 |
+
def pad2d(x, max_len, pad_value=0):
|
92 |
+
return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value)
|
synthesizer/train.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from functools import partial
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import optim
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
from synthesizer import audio
|
11 |
+
from synthesizer.models.tacotron import Tacotron
|
12 |
+
from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
|
13 |
+
from synthesizer.utils import ValueWindow, data_parallel_workaround
|
14 |
+
from synthesizer.utils.plot import plot_spectrogram
|
15 |
+
from synthesizer.utils.symbols import symbols
|
16 |
+
from synthesizer.utils.text import sequence_to_text
|
17 |
+
from vocoder.display import *
|
18 |
+
|
19 |
+
|
20 |
+
def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
|
21 |
+
|
22 |
+
|
23 |
+
def time_string():
|
24 |
+
return datetime.now().strftime("%Y-%m-%d %H:%M")
|
25 |
+
|
26 |
+
|
27 |
+
def train(run_id: str, syn_dir: Path, models_dir: Path, save_every: int, backup_every: int, force_restart: bool,
|
28 |
+
hparams):
|
29 |
+
models_dir.mkdir(exist_ok=True)
|
30 |
+
|
31 |
+
model_dir = models_dir.joinpath(run_id)
|
32 |
+
plot_dir = model_dir.joinpath("plots")
|
33 |
+
wav_dir = model_dir.joinpath("wavs")
|
34 |
+
mel_output_dir = model_dir.joinpath("mel-spectrograms")
|
35 |
+
meta_folder = model_dir.joinpath("metas")
|
36 |
+
model_dir.mkdir(exist_ok=True)
|
37 |
+
plot_dir.mkdir(exist_ok=True)
|
38 |
+
wav_dir.mkdir(exist_ok=True)
|
39 |
+
mel_output_dir.mkdir(exist_ok=True)
|
40 |
+
meta_folder.mkdir(exist_ok=True)
|
41 |
+
|
42 |
+
weights_fpath = model_dir / f"synthesizer.pt"
|
43 |
+
metadata_fpath = syn_dir.joinpath("train.txt")
|
44 |
+
|
45 |
+
print("Checkpoint path: {}".format(weights_fpath))
|
46 |
+
print("Loading training data from: {}".format(metadata_fpath))
|
47 |
+
print("Using model: Tacotron")
|
48 |
+
|
49 |
+
# Bookkeeping
|
50 |
+
time_window = ValueWindow(100)
|
51 |
+
loss_window = ValueWindow(100)
|
52 |
+
|
53 |
+
# From WaveRNN/train_tacotron.py
|
54 |
+
if torch.cuda.is_available():
|
55 |
+
device = torch.device("cuda")
|
56 |
+
|
57 |
+
for session in hparams.tts_schedule:
|
58 |
+
_, _, _, batch_size = session
|
59 |
+
if batch_size % torch.cuda.device_count() != 0:
|
60 |
+
raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
|
61 |
+
else:
|
62 |
+
device = torch.device("cpu")
|
63 |
+
print("Using device:", device)
|
64 |
+
|
65 |
+
# Instantiate Tacotron Model
|
66 |
+
print("\nInitialising Tacotron Model...\n")
|
67 |
+
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
68 |
+
num_chars=len(symbols),
|
69 |
+
encoder_dims=hparams.tts_encoder_dims,
|
70 |
+
decoder_dims=hparams.tts_decoder_dims,
|
71 |
+
n_mels=hparams.num_mels,
|
72 |
+
fft_bins=hparams.num_mels,
|
73 |
+
postnet_dims=hparams.tts_postnet_dims,
|
74 |
+
encoder_K=hparams.tts_encoder_K,
|
75 |
+
lstm_dims=hparams.tts_lstm_dims,
|
76 |
+
postnet_K=hparams.tts_postnet_K,
|
77 |
+
num_highways=hparams.tts_num_highways,
|
78 |
+
dropout=hparams.tts_dropout,
|
79 |
+
stop_threshold=hparams.tts_stop_threshold,
|
80 |
+
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
81 |
+
|
82 |
+
# Initialize the optimizer
|
83 |
+
optimizer = optim.Adam(model.parameters())
|
84 |
+
|
85 |
+
# Load the weights
|
86 |
+
if force_restart or not weights_fpath.exists():
|
87 |
+
print("\nStarting the training of Tacotron from scratch\n")
|
88 |
+
model.save(weights_fpath)
|
89 |
+
|
90 |
+
# Embeddings metadata
|
91 |
+
char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
|
92 |
+
with open(char_embedding_fpath, "w", encoding="utf-8") as f:
|
93 |
+
for symbol in symbols:
|
94 |
+
if symbol == " ":
|
95 |
+
symbol = "\\s" # For visual purposes, swap space with \s
|
96 |
+
|
97 |
+
f.write("{}\n".format(symbol))
|
98 |
+
|
99 |
+
else:
|
100 |
+
print("\nLoading weights at %s" % weights_fpath)
|
101 |
+
model.load(weights_fpath, optimizer)
|
102 |
+
print("Tacotron weights loaded from step %d" % model.step)
|
103 |
+
|
104 |
+
# Initialize the dataset
|
105 |
+
metadata_fpath = syn_dir.joinpath("train.txt")
|
106 |
+
mel_dir = syn_dir.joinpath("mels")
|
107 |
+
embed_dir = syn_dir.joinpath("embeds")
|
108 |
+
dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
|
109 |
+
|
110 |
+
for i, session in enumerate(hparams.tts_schedule):
|
111 |
+
current_step = model.get_step()
|
112 |
+
|
113 |
+
r, lr, max_step, batch_size = session
|
114 |
+
|
115 |
+
training_steps = max_step - current_step
|
116 |
+
|
117 |
+
# Do we need to change to the next session?
|
118 |
+
if current_step >= max_step:
|
119 |
+
# Are there no further sessions than the current one?
|
120 |
+
if i == len(hparams.tts_schedule) - 1:
|
121 |
+
# We have completed training. Save the model and exit
|
122 |
+
model.save(weights_fpath, optimizer)
|
123 |
+
break
|
124 |
+
else:
|
125 |
+
# There is a following session, go to it
|
126 |
+
continue
|
127 |
+
|
128 |
+
model.r = r
|
129 |
+
|
130 |
+
# Begin the training
|
131 |
+
simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
|
132 |
+
("Batch Size", batch_size),
|
133 |
+
("Learning Rate", lr),
|
134 |
+
("Outputs/Step (r)", model.r)])
|
135 |
+
|
136 |
+
for p in optimizer.param_groups:
|
137 |
+
p["lr"] = lr
|
138 |
+
|
139 |
+
collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
|
140 |
+
data_loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn)
|
141 |
+
|
142 |
+
total_iters = len(dataset)
|
143 |
+
steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
|
144 |
+
epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
|
145 |
+
|
146 |
+
for epoch in range(1, epochs+1):
|
147 |
+
for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
|
148 |
+
start_time = time.time()
|
149 |
+
|
150 |
+
# Generate stop tokens for training
|
151 |
+
stop = torch.ones(mels.shape[0], mels.shape[2])
|
152 |
+
for j, k in enumerate(idx):
|
153 |
+
stop[j, :int(dataset.metadata[k][4])-1] = 0
|
154 |
+
|
155 |
+
texts = texts.to(device)
|
156 |
+
mels = mels.to(device)
|
157 |
+
embeds = embeds.to(device)
|
158 |
+
stop = stop.to(device)
|
159 |
+
|
160 |
+
# Forward pass
|
161 |
+
# Parallelize model onto GPUS using workaround due to python bug
|
162 |
+
if device.type == "cuda" and torch.cuda.device_count() > 1:
|
163 |
+
m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts, mels, embeds)
|
164 |
+
else:
|
165 |
+
m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
|
166 |
+
|
167 |
+
# Backward pass
|
168 |
+
m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
|
169 |
+
m2_loss = F.mse_loss(m2_hat, mels)
|
170 |
+
stop_loss = F.binary_cross_entropy(stop_pred, stop)
|
171 |
+
|
172 |
+
loss = m1_loss + m2_loss + stop_loss
|
173 |
+
|
174 |
+
optimizer.zero_grad()
|
175 |
+
loss.backward()
|
176 |
+
|
177 |
+
if hparams.tts_clip_grad_norm is not None:
|
178 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
|
179 |
+
if np.isnan(grad_norm.cpu()):
|
180 |
+
print("grad_norm was NaN!")
|
181 |
+
|
182 |
+
optimizer.step()
|
183 |
+
|
184 |
+
time_window.append(time.time() - start_time)
|
185 |
+
loss_window.append(loss.item())
|
186 |
+
|
187 |
+
step = model.get_step()
|
188 |
+
k = step // 1000
|
189 |
+
|
190 |
+
msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | " \
|
191 |
+
f"{1./time_window.average:#.2} steps/s | Step: {k}k | "
|
192 |
+
stream(msg)
|
193 |
+
|
194 |
+
# Backup or save model as appropriate
|
195 |
+
if backup_every != 0 and step % backup_every == 0 :
|
196 |
+
backup_fpath = weights_fpath.parent / f"synthesizer_{k:06d}.pt"
|
197 |
+
model.save(backup_fpath, optimizer)
|
198 |
+
|
199 |
+
if save_every != 0 and step % save_every == 0 :
|
200 |
+
# Must save latest optimizer state to ensure that resuming training
|
201 |
+
# doesn't produce artifacts
|
202 |
+
model.save(weights_fpath, optimizer)
|
203 |
+
|
204 |
+
# Evaluate model to generate samples
|
205 |
+
epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
|
206 |
+
step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps
|
207 |
+
if epoch_eval or step_eval:
|
208 |
+
for sample_idx in range(hparams.tts_eval_num_samples):
|
209 |
+
# At most, generate samples equal to number in the batch
|
210 |
+
if sample_idx + 1 <= len(texts):
|
211 |
+
# Remove padding from mels using frame length in metadata
|
212 |
+
mel_length = int(dataset.metadata[idx[sample_idx]][4])
|
213 |
+
mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
|
214 |
+
target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
|
215 |
+
attention_len = mel_length // model.r
|
216 |
+
|
217 |
+
eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
|
218 |
+
mel_prediction=mel_prediction,
|
219 |
+
target_spectrogram=target_spectrogram,
|
220 |
+
input_seq=np_now(texts[sample_idx]),
|
221 |
+
step=step,
|
222 |
+
plot_dir=plot_dir,
|
223 |
+
mel_output_dir=mel_output_dir,
|
224 |
+
wav_dir=wav_dir,
|
225 |
+
sample_num=sample_idx + 1,
|
226 |
+
loss=loss,
|
227 |
+
hparams=hparams)
|
228 |
+
|
229 |
+
# Break out of loop to update training schedule
|
230 |
+
if step >= max_step:
|
231 |
+
break
|
232 |
+
|
233 |
+
# Add line break after every epoch
|
234 |
+
print("")
|
235 |
+
|
236 |
+
|
237 |
+
def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
|
238 |
+
plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
|
239 |
+
# Save some results for evaluation
|
240 |
+
attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
|
241 |
+
save_attention(attention, attention_path)
|
242 |
+
|
243 |
+
# save predicted mel spectrogram to disk (debug)
|
244 |
+
mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
|
245 |
+
np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
|
246 |
+
|
247 |
+
# save griffin lim inverted wav for debug (mel -> wav)
|
248 |
+
wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
|
249 |
+
wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num))
|
250 |
+
audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
|
251 |
+
|
252 |
+
# save real and predicted mel-spectrogram plot to disk (control purposes)
|
253 |
+
spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
|
254 |
+
title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
|
255 |
+
plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
|
256 |
+
target_spectrogram=target_spectrogram,
|
257 |
+
max_len=target_spectrogram.size // hparams.num_mels)
|
258 |
+
print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
|
synthesizer/utils/__init__.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
_output_ref = None
|
5 |
+
_replicas_ref = None
|
6 |
+
|
7 |
+
def data_parallel_workaround(model, *input):
|
8 |
+
global _output_ref
|
9 |
+
global _replicas_ref
|
10 |
+
device_ids = list(range(torch.cuda.device_count()))
|
11 |
+
output_device = device_ids[0]
|
12 |
+
replicas = torch.nn.parallel.replicate(model, device_ids)
|
13 |
+
# input.shape = (num_args, batch, ...)
|
14 |
+
inputs = torch.nn.parallel.scatter(input, device_ids)
|
15 |
+
# inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
|
16 |
+
replicas = replicas[:len(inputs)]
|
17 |
+
outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
|
18 |
+
y_hat = torch.nn.parallel.gather(outputs, output_device)
|
19 |
+
_output_ref = outputs
|
20 |
+
_replicas_ref = replicas
|
21 |
+
return y_hat
|
22 |
+
|
23 |
+
|
24 |
+
class ValueWindow():
|
25 |
+
def __init__(self, window_size=100):
|
26 |
+
self._window_size = window_size
|
27 |
+
self._values = []
|
28 |
+
|
29 |
+
def append(self, x):
|
30 |
+
self._values = self._values[-(self._window_size - 1):] + [x]
|
31 |
+
|
32 |
+
@property
|
33 |
+
def sum(self):
|
34 |
+
return sum(self._values)
|
35 |
+
|
36 |
+
@property
|
37 |
+
def count(self):
|
38 |
+
return len(self._values)
|
39 |
+
|
40 |
+
@property
|
41 |
+
def average(self):
|
42 |
+
return self.sum / max(1, self.count)
|
43 |
+
|
44 |
+
def reset(self):
|
45 |
+
self._values = []
|
synthesizer/utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (1.76 kB). View file
|
|
synthesizer/utils/__pycache__/cleaners.cpython-38.pyc
ADDED
Binary file (2.87 kB). View file
|
|
synthesizer/utils/__pycache__/numbers.cpython-38.pyc
ADDED
Binary file (2.23 kB). View file
|
|
synthesizer/utils/__pycache__/symbols.cpython-38.pyc
ADDED
Binary file (623 Bytes). View file
|
|
synthesizer/utils/__pycache__/text.cpython-38.pyc
ADDED
Binary file (2.77 kB). View file
|
|
synthesizer/utils/_cmudict.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
valid_symbols = [
|
4 |
+
"AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2",
|
5 |
+
"AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2",
|
6 |
+
"B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY",
|
7 |
+
"EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1",
|
8 |
+
"IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0",
|
9 |
+
"OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW",
|
10 |
+
"UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH"
|
11 |
+
]
|
12 |
+
|
13 |
+
_valid_symbol_set = set(valid_symbols)
|
14 |
+
|
15 |
+
|
16 |
+
class CMUDict:
|
17 |
+
"""Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
|
18 |
+
def __init__(self, file_or_path, keep_ambiguous=True):
|
19 |
+
if isinstance(file_or_path, str):
|
20 |
+
with open(file_or_path, encoding="latin-1") as f:
|
21 |
+
entries = _parse_cmudict(f)
|
22 |
+
else:
|
23 |
+
entries = _parse_cmudict(file_or_path)
|
24 |
+
if not keep_ambiguous:
|
25 |
+
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
26 |
+
self._entries = entries
|
27 |
+
|
28 |
+
|
29 |
+
def __len__(self):
|
30 |
+
return len(self._entries)
|
31 |
+
|
32 |
+
|
33 |
+
def lookup(self, word):
|
34 |
+
"""Returns list of ARPAbet pronunciations of the given word."""
|
35 |
+
return self._entries.get(word.upper())
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
_alt_re = re.compile(r"\([0-9]+\)")
|
40 |
+
|
41 |
+
|
42 |
+
def _parse_cmudict(file):
|
43 |
+
cmudict = {}
|
44 |
+
for line in file:
|
45 |
+
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
46 |
+
parts = line.split(" ")
|
47 |
+
word = re.sub(_alt_re, "", parts[0])
|
48 |
+
pronunciation = _get_pronunciation(parts[1])
|
49 |
+
if pronunciation:
|
50 |
+
if word in cmudict:
|
51 |
+
cmudict[word].append(pronunciation)
|
52 |
+
else:
|
53 |
+
cmudict[word] = [pronunciation]
|
54 |
+
return cmudict
|
55 |
+
|
56 |
+
|
57 |
+
def _get_pronunciation(s):
|
58 |
+
parts = s.strip().split(" ")
|
59 |
+
for part in parts:
|
60 |
+
if part not in _valid_symbol_set:
|
61 |
+
return None
|
62 |
+
return " ".join(parts)
|
synthesizer/utils/cleaners.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
3 |
+
|
4 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
5 |
+
hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
|
6 |
+
1. "english_cleaners" for English text
|
7 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
8 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
9 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
10 |
+
the symbols in symbols.py to match your data).
|
11 |
+
"""
|
12 |
+
import re
|
13 |
+
from unidecode import unidecode
|
14 |
+
from synthesizer.utils.numbers import normalize_numbers
|
15 |
+
|
16 |
+
|
17 |
+
# Regular expression matching whitespace:
|
18 |
+
_whitespace_re = re.compile(r"\s+")
|
19 |
+
|
20 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
21 |
+
_abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
|
22 |
+
("mrs", "misess"),
|
23 |
+
("mr", "mister"),
|
24 |
+
("dr", "doctor"),
|
25 |
+
("st", "saint"),
|
26 |
+
("co", "company"),
|
27 |
+
("jr", "junior"),
|
28 |
+
("maj", "major"),
|
29 |
+
("gen", "general"),
|
30 |
+
("drs", "doctors"),
|
31 |
+
("rev", "reverend"),
|
32 |
+
("lt", "lieutenant"),
|
33 |
+
("hon", "honorable"),
|
34 |
+
("sgt", "sergeant"),
|
35 |
+
("capt", "captain"),
|
36 |
+
("esq", "esquire"),
|
37 |
+
("ltd", "limited"),
|
38 |
+
("col", "colonel"),
|
39 |
+
("ft", "fort"),
|
40 |
+
]]
|
41 |
+
|
42 |
+
|
43 |
+
def expand_abbreviations(text):
|
44 |
+
for regex, replacement in _abbreviations:
|
45 |
+
text = re.sub(regex, replacement, text)
|
46 |
+
return text
|
47 |
+
|
48 |
+
|
49 |
+
def expand_numbers(text):
|
50 |
+
return normalize_numbers(text)
|
51 |
+
|
52 |
+
|
53 |
+
def lowercase(text):
|
54 |
+
"""lowercase input tokens."""
|
55 |
+
return text.lower()
|
56 |
+
|
57 |
+
|
58 |
+
def collapse_whitespace(text):
|
59 |
+
return re.sub(_whitespace_re, " ", text)
|
60 |
+
|
61 |
+
|
62 |
+
def convert_to_ascii(text):
|
63 |
+
return unidecode(text)
|
64 |
+
|
65 |
+
|
66 |
+
def basic_cleaners(text):
|
67 |
+
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
68 |
+
text = lowercase(text)
|
69 |
+
text = collapse_whitespace(text)
|
70 |
+
return text
|
71 |
+
|
72 |
+
|
73 |
+
def transliteration_cleaners(text):
|
74 |
+
"""Pipeline for non-English text that transliterates to ASCII."""
|
75 |
+
text = convert_to_ascii(text)
|
76 |
+
text = lowercase(text)
|
77 |
+
text = collapse_whitespace(text)
|
78 |
+
return text
|
79 |
+
|
80 |
+
|
81 |
+
def english_cleaners(text):
|
82 |
+
"""Pipeline for English text, including number and abbreviation expansion."""
|
83 |
+
text = convert_to_ascii(text)
|
84 |
+
text = lowercase(text)
|
85 |
+
text = expand_numbers(text)
|
86 |
+
text = expand_abbreviations(text)
|
87 |
+
text = collapse_whitespace(text)
|
88 |
+
return text
|
synthesizer/utils/numbers.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import inflect
|
3 |
+
|
4 |
+
|
5 |
+
_inflect = inflect.engine()
|
6 |
+
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
7 |
+
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
8 |
+
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
9 |
+
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
10 |
+
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
11 |
+
_number_re = re.compile(r"[0-9]+")
|
12 |
+
|
13 |
+
|
14 |
+
def _remove_commas(m):
|
15 |
+
return m.group(1).replace(",", "")
|
16 |
+
|
17 |
+
|
18 |
+
def _expand_decimal_point(m):
|
19 |
+
return m.group(1).replace(".", " point ")
|
20 |
+
|
21 |
+
|
22 |
+
def _expand_dollars(m):
|
23 |
+
match = m.group(1)
|
24 |
+
parts = match.split(".")
|
25 |
+
if len(parts) > 2:
|
26 |
+
return match + " dollars" # Unexpected format
|
27 |
+
dollars = int(parts[0]) if parts[0] else 0
|
28 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
29 |
+
if dollars and cents:
|
30 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
31 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
32 |
+
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
33 |
+
elif dollars:
|
34 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
35 |
+
return "%s %s" % (dollars, dollar_unit)
|
36 |
+
elif cents:
|
37 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
38 |
+
return "%s %s" % (cents, cent_unit)
|
39 |
+
else:
|
40 |
+
return "zero dollars"
|
41 |
+
|
42 |
+
|
43 |
+
def _expand_ordinal(m):
|
44 |
+
return _inflect.number_to_words(m.group(0))
|
45 |
+
|
46 |
+
|
47 |
+
def _expand_number(m):
|
48 |
+
num = int(m.group(0))
|
49 |
+
if num > 1000 and num < 3000:
|
50 |
+
if num == 2000:
|
51 |
+
return "two thousand"
|
52 |
+
elif num > 2000 and num < 2010:
|
53 |
+
return "two thousand " + _inflect.number_to_words(num % 100)
|
54 |
+
elif num % 100 == 0:
|
55 |
+
return _inflect.number_to_words(num // 100) + " hundred"
|
56 |
+
else:
|
57 |
+
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
58 |
+
else:
|
59 |
+
return _inflect.number_to_words(num, andword="")
|
60 |
+
|
61 |
+
|
62 |
+
def normalize_numbers(text):
|
63 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
64 |
+
text = re.sub(_pounds_re, r"\1 pounds", text)
|
65 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
66 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
67 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
68 |
+
text = re.sub(_number_re, _expand_number, text)
|
69 |
+
return text
|