diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..10af2da8ffbdd403540bc7da8d9810658616cc0d
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,50 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.pyc
+*.aux
+*.log
+*.out
+*.synctex.gz
+*.suo
+*__pycache__
+*.idea
+*.ipynb_checkpoints
+*.pickle
+*.npy
+*.blg
+*.bbl
+*.bcf
+*.toc
+*.sh
+encoder/saved_models/*
+synthesizer/saved_models/*
+vocoder/saved_models/*
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..4118a9a52b3acf5231981fe0d3ef5887cfc61810
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,5 @@
+
+*.pyc
+*.pt
+*.ipynb
+*.wav
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..96318e9af0f67c1567eaf0b889c328f4548a2228
--- /dev/null
+++ b/README.md
@@ -0,0 +1,14 @@
+---
+title: Clone Your Voice
+emoji: 📚
+colorFrom: blue
+colorTo: yellow
+python_version: 3.8.4
+sdk: gradio
+sdk_version: 3.0.4
+app_file: app.py
+pinned: false
+duplicated_from: ruslanmv/Clone-Your-Voice
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cada7b6a760be1f8c426ae0bcd61ee99feb072c
--- /dev/null
+++ b/app.py
@@ -0,0 +1,377 @@
+import gradio as gr
+import os
+from utils.default_models import ensure_default_models
+import sys
+import traceback
+from pathlib import Path
+from time import perf_counter as timer
+import numpy as np
+import torch
+from encoder import inference as encoder
+from synthesizer.inference import Synthesizer
+#from toolbox.utterance import Utterance
+from vocoder import inference as vocoder
+import time
+import librosa
+import numpy as np
+#import sounddevice as sd
+import soundfile as sf
+import argparse
+from utils.argutils import print_args
+
+parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+)
+parser.add_argument("-e", "--enc_model_fpath", type=Path,
+ default="saved_models/default/encoder.pt",
+ help="Path to a saved encoder")
+parser.add_argument("-s", "--syn_model_fpath", type=Path,
+ default="saved_models/default/synthesizer.pt",
+ help="Path to a saved synthesizer")
+parser.add_argument("-v", "--voc_model_fpath", type=Path,
+ default="saved_models/default/vocoder.pt",
+ help="Path to a saved vocoder")
+parser.add_argument("--cpu", action="store_true", help=\
+ "If True, processing is done on CPU, even when a GPU is available.")
+parser.add_argument("--no_sound", action="store_true", help=\
+ "If True, audio won't be played.")
+parser.add_argument("--seed", type=int, default=None, help=\
+ "Optional random number seed value to make toolbox deterministic.")
+args = parser.parse_args()
+arg_dict = vars(args)
+print_args(args, parser)
+
+# Maximum of generated wavs to keep on memory
+MAX_WAVS = 15
+utterances = set()
+current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
+synthesizer = None # type: Synthesizer
+current_wav = None
+waves_list = []
+waves_count = 0
+waves_namelist = []
+
+# Hide GPUs from Pytorch to force CPU processing
+if arg_dict.pop("cpu"):
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
+
+print("Running a test of your configuration...\n")
+
+if torch.cuda.is_available():
+ device_id = torch.cuda.current_device()
+ gpu_properties = torch.cuda.get_device_properties(device_id)
+ ## Print some environment information (for debugging purposes)
+ print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with "
+ "%.1fGb total memory.\n" %
+ (torch.cuda.device_count(),
+ device_id,
+ gpu_properties.name,
+ gpu_properties.major,
+ gpu_properties.minor,
+ gpu_properties.total_memory / 1e9))
+else:
+ print("Using CPU for inference.\n")
+
+## Load the models one by one.
+print("Preparing the encoder, the synthesizer and the vocoder...")
+ensure_default_models(Path("saved_models"))
+#encoder.load_model(args.enc_model_fpath)
+#synthesizer = Synthesizer(args.syn_model_fpath)
+#vocoder.load_model(args.voc_model_fpath)
+
+def compute_embedding(in_fpath):
+
+ if not encoder.is_loaded():
+ model_fpath = args.enc_model_fpath
+ print("Loading the encoder %s... " % model_fpath)
+ start = time.time()
+ encoder.load_model(model_fpath)
+ print("Done (%dms)." % int(1000 * (time.time() - start)), "append")
+
+
+ ## Computing the embedding
+ # First, we load the wav using the function that the speaker encoder provides. This is
+
+ # Get the wav from the disk. We take the wav with the vocoder/synthesizer format for
+ # playback, so as to have a fair comparison with the generated audio
+ wav = Synthesizer.load_preprocess_wav(in_fpath)
+
+ # important: there is preprocessing that must be applied.
+
+ # The following two methods are equivalent:
+ # - Directly load from the filepath:
+ preprocessed_wav = encoder.preprocess_wav(wav)
+
+ # - If the wav is already loaded:
+ #original_wav, sampling_rate = librosa.load(str(in_fpath))
+ #preprocessed_wav = encoder.preprocess_wav(original_wav, sampling_rate)
+
+ # Compute the embedding
+ embed, partial_embeds, _ = encoder.embed_utterance(preprocessed_wav, return_partials=True)
+
+
+ print("Loaded file succesfully")
+
+ # Then we derive the embedding. There are many functions and parameters that the
+ # speaker encoder interfaces. These are mostly for in-depth research. You will typically
+ # only use this function (with its default parameters):
+ #embed = encoder.embed_utterance(preprocessed_wav)
+
+ return embed
+def create_spectrogram(text,embed):
+ # If seed is specified, reset torch seed and force synthesizer reload
+ if args.seed is not None:
+ torch.manual_seed(args.seed)
+ synthesizer = Synthesizer(args.syn_model_fpath)
+
+
+ # Synthesize the spectrogram
+ model_fpath = args.syn_model_fpath
+ print("Loading the synthesizer %s... " % model_fpath)
+ start = time.time()
+ synthesizer = Synthesizer(model_fpath)
+ print("Done (%dms)." % int(1000 * (time.time()- start)), "append")
+
+
+ # The synthesizer works in batch, so you need to put your data in a list or numpy array
+ texts = [text]
+ embeds = [embed]
+ # If you know what the attention layer alignments are, you can retrieve them here by
+ # passing return_alignments=True
+ specs = synthesizer.synthesize_spectrograms(texts, embeds)
+ breaks = [spec.shape[1] for spec in specs]
+ spec = np.concatenate(specs, axis=1)
+ sample_rate=synthesizer.sample_rate
+ return spec, breaks , sample_rate
+
+
+def generate_waveform(current_generated):
+
+ speaker_name, spec, breaks = current_generated
+ assert spec is not None
+
+ ## Generating the waveform
+ print("Synthesizing the waveform:")
+ # If seed is specified, reset torch seed and reload vocoder
+ if args.seed is not None:
+ torch.manual_seed(args.seed)
+ vocoder.load_model(args.voc_model_fpath)
+
+ model_fpath = args.voc_model_fpath
+ # Synthesize the waveform
+ if not vocoder.is_loaded():
+ print("Loading the vocoder %s... " % model_fpath)
+ start = time.time()
+ vocoder.load_model(model_fpath)
+ print("Done (%dms)." % int(1000 * (time.time()- start)), "append")
+
+ current_vocoder_fpath= model_fpath
+ def vocoder_progress(i, seq_len, b_size, gen_rate):
+ real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000
+ line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \
+ % (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor)
+ print(line, "overwrite")
+
+
+ # Synthesizing the waveform is fairly straightforward. Remember that the longer the
+ # spectrogram, the more time-efficient the vocoder.
+ if current_vocoder_fpath is not None:
+ print("")
+ generated_wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
+ else:
+ print("Waveform generation with Griffin-Lim... ")
+ generated_wav = Synthesizer.griffin_lim(spec)
+
+ print(" Done!", "append")
+
+
+ ## Post-generation
+ # There's a bug with sounddevice that makes the audio cut one second earlier, so we
+ # pad it.
+ generated_wav = np.pad(generated_wav, (0, Synthesizer.sample_rate), mode="constant")
+
+ # Add breaks
+ b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
+ b_starts = np.concatenate(([0], b_ends[:-1]))
+ wavs = [generated_wav[start:end] for start, end, in zip(b_starts, b_ends)]
+ breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
+ generated_wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
+
+
+ # Trim excess silences to compensate for gaps in spectrograms (issue #53)
+ generated_wav = encoder.preprocess_wav(generated_wav)
+
+
+ return generated_wav
+
+
+def save_on_disk(generated_wav,sample_rate):
+ # Save it on the disk
+ filename = "cloned_voice.wav"
+ print(generated_wav.dtype)
+ #OUT=os.environ['OUT_PATH']
+ # Returns `None` if key doesn't exist
+ #OUT=os.environ.get('OUT_PATH')
+ #result = os.path.join(OUT, filename)
+ result = filename
+ print(" > Saving output to {}".format(result))
+ sf.write(result, generated_wav.astype(np.float32), sample_rate)
+ print("\nSaved output as %s\n\n" % result)
+
+ return result
+def play_audio(generated_wav,sample_rate):
+ # Play the audio (non-blocking)
+ if not args.no_sound:
+
+ try:
+ sd.stop()
+ sd.play(generated_wav, sample_rate)
+ except sd.PortAudioError as e:
+ print("\nCaught exception: %s" % repr(e))
+ print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n")
+ except:
+ raise
+
+
+def clean_memory():
+ import gc
+ #import GPUtil
+ # To see memory usage
+ print('Before clean ')
+ #GPUtil.showUtilization()
+ #cleaning memory 1
+ gc.collect()
+ torch.cuda.empty_cache()
+ time.sleep(2)
+ print('After Clean GPU')
+ #GPUtil.showUtilization()
+
+def clone_voice(in_fpath, text):
+ try:
+ speaker_name = "output"
+ # Compute embedding
+ embed=compute_embedding(in_fpath)
+ print("Created the embedding")
+ # Generating the spectrogram
+ spec, breaks, sample_rate = create_spectrogram(text,embed)
+ current_generated = (speaker_name, spec, breaks)
+ print("Created the mel spectrogram")
+
+ # Create waveform
+ generated_wav=generate_waveform(current_generated)
+ print("Created the the waveform ")
+
+ # Save it on the disk
+ save_on_disk(generated_wav,sample_rate)
+
+ #Play the audio
+ #play_audio(generated_wav,sample_rate)
+
+ return
+ except Exception as e:
+ print("Caught exception: %s" % repr(e))
+ print("Restarting\n")
+
+# Set environment variables
+home_dir = os.getcwd()
+OUT_PATH=os.path.join(home_dir, "out/")
+os.environ['OUT_PATH'] = OUT_PATH
+
+# create output path
+os.makedirs(OUT_PATH, exist_ok=True)
+
+USE_CUDA = torch.cuda.is_available()
+
+os.system('pip install -q pydub ffmpeg-normalize')
+CONFIG_SE_PATH = "config_se.json"
+CHECKPOINT_SE_PATH = "SE_checkpoint.pth.tar"
+def greet(Text,Voicetoclone ,input_mic=None):
+ text= "%s" % (Text)
+ #reference_files= "%s" % (Voicetoclone)
+
+ clean_memory()
+ print(text,len(text),type(text))
+ print(Voicetoclone,type(Voicetoclone))
+
+ if len(text) == 0 :
+ print("Please add text to the program")
+ Text="Please add text to the program, thank you."
+ is_no_text=True
+ else:
+ is_no_text=False
+
+
+ if Voicetoclone==None and input_mic==None:
+ print("There is no input audio")
+ Text="Please add audio input, to the program, thank you."
+ Voicetoclone='trump.mp3'
+ if is_no_text:
+ Text="Please add text and audio, to the program, thank you."
+
+ if input_mic != "" and input_mic != None :
+ # Get the wav file from the microphone
+ print('The value of MIC IS :',input_mic,type(input_mic))
+ Voicetoclone= input_mic
+
+ text= "%s" % (Text)
+ reference_files= Voicetoclone
+ print("path url")
+ print(Voicetoclone)
+ sample= str(Voicetoclone)
+ os.environ['sample'] = sample
+ size= len(reference_files)*sys.getsizeof(reference_files)
+ size2= size / 1000000
+ if (size2 > 0.012) or len(text)>2000:
+ message="File is greater than 30mb or Text inserted is longer than 2000 characters. Please re-try with smaller sizes."
+ print(message)
+ raise SystemExit("File is greater than 30mb. Please re-try or Text inserted is longer than 2000 characters. Please re-try with smaller sizes.")
+ else:
+
+ env_var = 'sample'
+ if env_var in os.environ:
+ print(f'{env_var} value is {os.environ[env_var]}')
+ else:
+ print(f'{env_var} does not exist')
+ #os.system(f'ffmpeg-normalize {os.environ[env_var]} -nt rms -t=-27 -o {os.environ[env_var]} -ar 16000 -f')
+ in_fpath = Path(Voicetoclone)
+ #in_fpath= in_fpath.replace("\"", "").replace("\'", "")
+
+ out_path=clone_voice(in_fpath, text)
+
+ print(" > text: {}".format(text))
+
+ print("Generated Audio")
+ return "cloned_voice.wav"
+
+demo = gr.Interface(
+ fn=greet,
+ inputs=[gr.inputs.Textbox(label='What would you like the voice to say? (max. 2000 characters per request)'),
+ gr.Audio(
+ type="filepath",
+ source="upload",
+ label='Please upload a voice to clone (max. 30mb)'),
+ gr.inputs.Audio(
+ source="microphone",
+ label='or record',
+ type="filepath",
+ optional=True)
+ ],
+ outputs="audio",
+
+ title = 'Clone Your Voice',
+ description = 'A simple application that Clone Your Voice. Wait one minute to process.',
+ article =
+ '''
+
All you need to do is record your voice, type what you want be say
+ ,then wait for compiling. After that click on Play/Pause for listen the audio. The audio is saved in an wav format.
+ For more information visit ruslanmv.com
+
+
''',
+
+ examples = [["I am the cloned version of Donald Trump. Well. I think what's happening to this country is unbelievably bad. We're no longer a respected country","trump.mp3","trump.mp3"],
+ ["I am the cloned version of Elon Musk. Persistence is very important. You should not give up unless you are forced to give up.","musk.mp3","musk.mp3"] ,
+ ["I am the cloned version of Elizabeth. It has always been easy to hate and destroy. To build and to cherish is much more difficult." ,"queen.mp3","queen.mp3"]
+ ]
+
+ )
+demo.launch()
\ No newline at end of file
diff --git a/encoder/__init__.py b/encoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/encoder/audio.py b/encoder/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..715f0a47cf845209513398191b1351c5c60a63d1
--- /dev/null
+++ b/encoder/audio.py
@@ -0,0 +1,117 @@
+from scipy.ndimage.morphology import binary_dilation
+from encoder.params_data import *
+from pathlib import Path
+from typing import Optional, Union
+from warnings import warn
+import numpy as np
+import librosa
+import struct
+
+try:
+ import webrtcvad
+except:
+ warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
+ webrtcvad=None
+
+int16_max = (2 ** 15) - 1
+
+
+def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
+ source_sr: Optional[int] = None,
+ normalize: Optional[bool] = True,
+ trim_silence: Optional[bool] = True):
+ """
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
+
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
+ just .wav), either the waveform as a numpy array of floats.
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
+ this argument will be ignored.
+ """
+ # Load the wav from disk if needed
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
+ wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
+ else:
+ wav = fpath_or_wav
+
+ # Resample the wav if needed
+ if source_sr is not None and source_sr != sampling_rate:
+ wav = librosa.resample(wav, source_sr, sampling_rate)
+
+ # Apply the preprocessing: normalize volume and shorten long silences
+ if normalize:
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
+ if webrtcvad and trim_silence:
+ wav = trim_long_silences(wav)
+
+ return wav
+
+
+def wav_to_mel_spectrogram(wav):
+ """
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
+ Note: this not a log-mel spectrogram.
+ """
+ frames = librosa.feature.melspectrogram(
+ wav,
+ sampling_rate,
+ n_fft=int(sampling_rate * mel_window_length / 1000),
+ hop_length=int(sampling_rate * mel_window_step / 1000),
+ n_mels=mel_n_channels
+ )
+ return frames.astype(np.float32).T
+
+
+def trim_long_silences(wav):
+ """
+ Ensures that segments without voice in the waveform remain no longer than a
+ threshold determined by the VAD parameters in params.py.
+
+ :param wav: the raw waveform as a numpy array of floats
+ :return: the same waveform with silences trimmed away (length <= original wav length)
+ """
+ # Compute the voice detection window size
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
+
+ # Trim the end of the audio to have a multiple of the window size
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
+
+ # Convert the float waveform to 16-bit mono PCM
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
+
+ # Perform voice activation detection
+ voice_flags = []
+ vad = webrtcvad.Vad(mode=3)
+ for window_start in range(0, len(wav), samples_per_window):
+ window_end = window_start + samples_per_window
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
+ sample_rate=sampling_rate))
+ voice_flags = np.array(voice_flags)
+
+ # Smooth the voice detection with a moving average
+ def moving_average(array, width):
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
+ ret = np.cumsum(array_padded, dtype=float)
+ ret[width:] = ret[width:] - ret[:-width]
+ return ret[width - 1:] / width
+
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
+ audio_mask = np.round(audio_mask).astype(np.bool)
+
+ # Dilate the voiced regions
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
+ audio_mask = np.repeat(audio_mask, samples_per_window)
+
+ return wav[audio_mask == True]
+
+
+def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
+ if increase_only and decrease_only:
+ raise ValueError("Both increase only and decrease only are set")
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
+ return wav
+ return wav * (10 ** (dBFS_change / 20))
diff --git a/encoder/config.py b/encoder/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d12228c81152487da24a6090e5a736f9de0755b0
--- /dev/null
+++ b/encoder/config.py
@@ -0,0 +1,45 @@
+librispeech_datasets = {
+ "train": {
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
+ "other": ["LibriSpeech/train-other-500"]
+ },
+ "test": {
+ "clean": ["LibriSpeech/test-clean"],
+ "other": ["LibriSpeech/test-other"]
+ },
+ "dev": {
+ "clean": ["LibriSpeech/dev-clean"],
+ "other": ["LibriSpeech/dev-other"]
+ },
+}
+libritts_datasets = {
+ "train": {
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
+ "other": ["LibriTTS/train-other-500"]
+ },
+ "test": {
+ "clean": ["LibriTTS/test-clean"],
+ "other": ["LibriTTS/test-other"]
+ },
+ "dev": {
+ "clean": ["LibriTTS/dev-clean"],
+ "other": ["LibriTTS/dev-other"]
+ },
+}
+voxceleb_datasets = {
+ "voxceleb1" : {
+ "train": ["VoxCeleb1/wav"],
+ "test": ["VoxCeleb1/test_wav"]
+ },
+ "voxceleb2" : {
+ "train": ["VoxCeleb2/dev/aac"],
+ "test": ["VoxCeleb2/test_wav"]
+ }
+}
+
+other_datasets = [
+ "LJSpeech-1.1",
+ "VCTK-Corpus/wav48",
+]
+
+anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
diff --git a/encoder/data_objects/__init__.py b/encoder/data_objects/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a64bcded446ff52844ab441c4713fbb6006c9d2
--- /dev/null
+++ b/encoder/data_objects/__init__.py
@@ -0,0 +1,2 @@
+from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
+from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
diff --git a/encoder/data_objects/random_cycler.py b/encoder/data_objects/random_cycler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e5cf738d3ca5214034ce3babdedf6eaea64c469
--- /dev/null
+++ b/encoder/data_objects/random_cycler.py
@@ -0,0 +1,37 @@
+import random
+
+class RandomCycler:
+ """
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
+ order. For a source sequence of n items and one or several consecutive queries of a total
+ of m items, the following guarantees hold (one implies the other):
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
+ """
+
+ def __init__(self, source):
+ if len(source) == 0:
+ raise Exception("Can't create RandomCycler from an empty collection")
+ self.all_items = list(source)
+ self.next_items = []
+
+ def sample(self, count: int):
+ shuffle = lambda l: random.sample(l, len(l))
+
+ out = []
+ while count > 0:
+ if count >= len(self.all_items):
+ out.extend(shuffle(list(self.all_items)))
+ count -= len(self.all_items)
+ continue
+ n = min(count, len(self.next_items))
+ out.extend(self.next_items[:n])
+ count -= n
+ self.next_items = self.next_items[n:]
+ if len(self.next_items) == 0:
+ self.next_items = shuffle(list(self.all_items))
+ return out
+
+ def __next__(self):
+ return self.sample(1)[0]
+
diff --git a/encoder/data_objects/speaker.py b/encoder/data_objects/speaker.py
new file mode 100644
index 0000000000000000000000000000000000000000..8527808a4763c6115136d0025368d3f0fd212d43
--- /dev/null
+++ b/encoder/data_objects/speaker.py
@@ -0,0 +1,40 @@
+from encoder.data_objects.random_cycler import RandomCycler
+from encoder.data_objects.utterance import Utterance
+from pathlib import Path
+
+# Contains the set of utterances of a single speaker
+class Speaker:
+ def __init__(self, root: Path):
+ self.root = root
+ self.name = root.name
+ self.utterances = None
+ self.utterance_cycler = None
+
+ def _load_utterances(self):
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
+ sources = [l.split(",") for l in sources_file]
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
+ self.utterance_cycler = RandomCycler(self.utterances)
+
+ def random_partial(self, count, n_frames):
+ """
+ Samples a batch of unique partial utterances from the disk in a way that all
+ utterances come up at least once every two cycles and in a random order every time.
+
+ :param count: The number of partial utterances to sample from the set of utterances from
+ that speaker. Utterances are guaranteed not to be repeated if is not larger than
+ the number of utterances available.
+ :param n_frames: The number of frames in the partial utterance.
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
+ frames are the frames of the partial utterances and range is the range of the partial
+ utterance with regard to the complete utterance.
+ """
+ if self.utterances is None:
+ self._load_utterances()
+
+ utterances = self.utterance_cycler.sample(count)
+
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
+
+ return a
diff --git a/encoder/data_objects/speaker_batch.py b/encoder/data_objects/speaker_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..d98ebd8ad6ec56f7a2c9a75d1c1837de4cf9e762
--- /dev/null
+++ b/encoder/data_objects/speaker_batch.py
@@ -0,0 +1,13 @@
+import numpy as np
+from typing import List
+from encoder.data_objects.speaker import Speaker
+
+
+class SpeakerBatch:
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
+ self.speakers = speakers
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
+
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
diff --git a/encoder/data_objects/speaker_verification_dataset.py b/encoder/data_objects/speaker_verification_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e54115504ff2be5b2cb56fa296f55bc701087102
--- /dev/null
+++ b/encoder/data_objects/speaker_verification_dataset.py
@@ -0,0 +1,56 @@
+from encoder.data_objects.random_cycler import RandomCycler
+from encoder.data_objects.speaker_batch import SpeakerBatch
+from encoder.data_objects.speaker import Speaker
+from encoder.params_data import partials_n_frames
+from torch.utils.data import Dataset, DataLoader
+from pathlib import Path
+
+# TODO: improve with a pool of speakers for data efficiency
+
+class SpeakerVerificationDataset(Dataset):
+ def __init__(self, datasets_root: Path):
+ self.root = datasets_root
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
+ if len(speaker_dirs) == 0:
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
+ "containing all preprocessed speaker directories.")
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
+ self.speaker_cycler = RandomCycler(self.speakers)
+
+ def __len__(self):
+ return int(1e10)
+
+ def __getitem__(self, index):
+ return next(self.speaker_cycler)
+
+ def get_logs(self):
+ log_string = ""
+ for log_fpath in self.root.glob("*.txt"):
+ with log_fpath.open("r") as log_file:
+ log_string += "".join(log_file.readlines())
+ return log_string
+
+
+class SpeakerVerificationDataLoader(DataLoader):
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
+ worker_init_fn=None):
+ self.utterances_per_speaker = utterances_per_speaker
+
+ super().__init__(
+ dataset=dataset,
+ batch_size=speakers_per_batch,
+ shuffle=False,
+ sampler=sampler,
+ batch_sampler=batch_sampler,
+ num_workers=num_workers,
+ collate_fn=self.collate,
+ pin_memory=pin_memory,
+ drop_last=False,
+ timeout=timeout,
+ worker_init_fn=worker_init_fn
+ )
+
+ def collate(self, speakers):
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
+
\ No newline at end of file
diff --git a/encoder/data_objects/utterance.py b/encoder/data_objects/utterance.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff3185ec781eaf5be2a58d61c22b32586d366126
--- /dev/null
+++ b/encoder/data_objects/utterance.py
@@ -0,0 +1,26 @@
+import numpy as np
+
+
+class Utterance:
+ def __init__(self, frames_fpath, wave_fpath):
+ self.frames_fpath = frames_fpath
+ self.wave_fpath = wave_fpath
+
+ def get_frames(self):
+ return np.load(self.frames_fpath)
+
+ def random_partial(self, n_frames):
+ """
+ Crops the frames into a partial utterance of n_frames
+
+ :param n_frames: The number of frames of the partial utterance
+ :return: the partial utterance frames and a tuple indicating the start and end of the
+ partial utterance in the complete utterance.
+ """
+ frames = self.get_frames()
+ if frames.shape[0] == n_frames:
+ start = 0
+ else:
+ start = np.random.randint(0, frames.shape[0] - n_frames)
+ end = start + n_frames
+ return frames[start:end], (start, end)
\ No newline at end of file
diff --git a/encoder/inference.py b/encoder/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..b312868e6ba72bddb1123dbe090ffe51f2cc4542
--- /dev/null
+++ b/encoder/inference.py
@@ -0,0 +1,178 @@
+from encoder.params_data import *
+from encoder.model import SpeakerEncoder
+from encoder.audio import preprocess_wav # We want to expose this function from here
+from matplotlib import cm
+from encoder import audio
+from pathlib import Path
+import numpy as np
+import torch
+
+_model = None # type: SpeakerEncoder
+_device = None # type: torch.device
+
+
+def load_model(weights_fpath: Path, device=None):
+ """
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
+ first call to embed_frames() with the default weights file.
+
+ :param weights_fpath: the path to saved model weights.
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
+ If None, will default to your GPU if it"s available, otherwise your CPU.
+ """
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
+ # was saved on. Worth investigating.
+ global _model, _device
+ if device is None:
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ elif isinstance(device, str):
+ _device = torch.device(device)
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
+ checkpoint = torch.load(weights_fpath, _device)
+ _model.load_state_dict(checkpoint["model_state"])
+ _model.eval()
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
+
+
+def is_loaded():
+ return _model is not None
+
+
+def embed_frames_batch(frames_batch):
+ """
+ Computes embeddings for a batch of mel spectrogram.
+
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
+ (batch_size, n_frames, n_channels)
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
+ """
+ if _model is None:
+ raise Exception("Model was not loaded. Call load_model() before inference.")
+
+ frames = torch.from_numpy(frames_batch).to(_device)
+ embed = _model.forward(frames).detach().cpu().numpy()
+ return embed
+
+
+def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
+ min_pad_coverage=0.75, overlap=0.5):
+ """
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
+ partial utterances of each. Both the waveform and the mel
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
+ defined in params_data.py.
+
+ The returned ranges may be indexing further than the length of the waveform. It is
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
+
+ :param n_samples: the number of samples in the waveform
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
+ utterance
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
+ enough frames. If at least of are present,
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
+ utterances are entirely disjoint.
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
+ utterances.
+ """
+ assert 0 <= overlap < 1
+ assert 0 < min_pad_coverage <= 1
+
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
+
+ # Compute the slices
+ wav_slices, mel_slices = [], []
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
+ for i in range(0, steps, frame_step):
+ mel_range = np.array([i, i + partial_utterance_n_frames])
+ wav_range = mel_range * samples_per_frame
+ mel_slices.append(slice(*mel_range))
+ wav_slices.append(slice(*wav_range))
+
+ # Evaluate whether extra padding is warranted or not
+ last_wav_range = wav_slices[-1]
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
+ mel_slices = mel_slices[:-1]
+ wav_slices = wav_slices[:-1]
+
+ return wav_slices, mel_slices
+
+
+def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
+ """
+ Computes an embedding for a single utterance.
+
+ # TODO: handle multiple wavs to benefit from batching on GPU
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
+ :param using_partials: if True, then the utterance is split in partial utterances of
+ frames and the utterance embedding is computed from their
+ normalized average. If False, the utterance is instead computed from feeding the entire
+ spectogram to the network.
+ :param return_partials: if True, the partial embeddings will also be returned along with the
+ wav slices that correspond to the partial embeddings.
+ :param kwargs: additional arguments to compute_partial_splits()
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
+ is True, the partial utterances as a numpy array of float32 of shape
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
+ returned. If is simultaneously set to False, both these values will be None
+ instead.
+ """
+ # Process the entire utterance if not using partials
+ if not using_partials:
+ frames = audio.wav_to_mel_spectrogram(wav)
+ embed = embed_frames_batch(frames[None, ...])[0]
+ if return_partials:
+ return embed, None, None
+ return embed
+
+ # Compute where to split the utterance into partials and pad if necessary
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
+ max_wave_length = wave_slices[-1].stop
+ if max_wave_length >= len(wav):
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
+
+ # Split the utterance into partials
+ frames = audio.wav_to_mel_spectrogram(wav)
+ frames_batch = np.array([frames[s] for s in mel_slices])
+ partial_embeds = embed_frames_batch(frames_batch)
+
+ # Compute the utterance embedding from the partial embeddings
+ raw_embed = np.mean(partial_embeds, axis=0)
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
+
+ if return_partials:
+ return embed, partial_embeds, wave_slices
+ return embed
+
+
+def embed_speaker(wavs, **kwargs):
+ raise NotImplemented()
+
+
+def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
+ import matplotlib.pyplot as plt
+ if ax is None:
+ ax = plt.gca()
+
+ if shape is None:
+ height = int(np.sqrt(len(embed)))
+ shape = (height, -1)
+ embed = embed.reshape(shape)
+
+ cmap = cm.get_cmap()
+ mappable = ax.imshow(embed, cmap=cmap)
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
+ sm = cm.ScalarMappable(cmap=cmap)
+ sm.set_clim(*color_range)
+
+ ax.set_xticks([]), ax.set_yticks([])
+ ax.set_title(title)
diff --git a/encoder/model.py b/encoder/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd0f2ea2738d1762188d0e274aef478069c4cd98
--- /dev/null
+++ b/encoder/model.py
@@ -0,0 +1,135 @@
+from encoder.params_model import *
+from encoder.params_data import *
+from scipy.interpolate import interp1d
+from sklearn.metrics import roc_curve
+from torch.nn.utils import clip_grad_norm_
+from scipy.optimize import brentq
+from torch import nn
+import numpy as np
+import torch
+
+
+class SpeakerEncoder(nn.Module):
+ def __init__(self, device, loss_device):
+ super().__init__()
+ self.loss_device = loss_device
+
+ # Network defition
+ self.lstm = nn.LSTM(input_size=mel_n_channels,
+ hidden_size=model_hidden_size,
+ num_layers=model_num_layers,
+ batch_first=True).to(device)
+ self.linear = nn.Linear(in_features=model_hidden_size,
+ out_features=model_embedding_size).to(device)
+ self.relu = torch.nn.ReLU().to(device)
+
+ # Cosine similarity scaling (with fixed initial parameter values)
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
+
+ # Loss
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
+
+ def do_gradient_ops(self):
+ # Gradient scale
+ self.similarity_weight.grad *= 0.01
+ self.similarity_bias.grad *= 0.01
+
+ # Gradient clipping
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
+
+ def forward(self, utterances, hidden_init=None):
+ """
+ Computes the embeddings of a batch of utterance spectrograms.
+
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
+ (batch_size, n_frames, n_channels)
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
+ """
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
+ # and the final cell state.
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
+
+ # We take only the hidden state of the last layer
+ embeds_raw = self.relu(self.linear(hidden[-1]))
+
+ # L2-normalize it
+ embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
+
+ return embeds
+
+ def similarity_matrix(self, embeds):
+ """
+ Computes the similarity matrix according the section 2.1 of GE2E.
+
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
+ utterances_per_speaker, embedding_size)
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
+ utterances_per_speaker, speakers_per_batch)
+ """
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
+
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
+ centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
+
+ # Exclusive centroids (1 per utterance)
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
+ centroids_excl /= (utterances_per_speaker - 1)
+ centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
+
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
+ # We vectorize the computation for efficiency.
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
+ speakers_per_batch).to(self.loss_device)
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
+ for j in range(speakers_per_batch):
+ mask = np.where(mask_matrix[j])[0]
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
+
+ ## Even more vectorized version (slower maybe because of transpose)
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
+ # ).to(self.loss_device)
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
+ # mask = np.where(1 - eye)
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
+ # mask = np.where(eye)
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
+
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
+ return sim_matrix
+
+ def loss(self, embeds):
+ """
+ Computes the softmax loss according the section 2.1 of GE2E.
+
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
+ utterances_per_speaker, embedding_size)
+ :return: the loss and the EER for this batch of embeddings.
+ """
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
+
+ # Loss
+ sim_matrix = self.similarity_matrix(embeds)
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
+ speakers_per_batch))
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
+ loss = self.loss_fn(sim_matrix, target)
+
+ # EER (not backpropagated)
+ with torch.no_grad():
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
+ labels = np.array([inv_argmax(i) for i in ground_truth])
+ preds = sim_matrix.detach().cpu().numpy()
+
+ # Snippet from https://yangcha.github.io/EER-ROC/
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
+
+ return loss, eer
diff --git a/encoder/params_data.py b/encoder/params_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..676e6dc197faf01648de7a830140172d5594b999
--- /dev/null
+++ b/encoder/params_data.py
@@ -0,0 +1,29 @@
+
+## Mel-filterbank
+mel_window_length = 25 # In milliseconds
+mel_window_step = 10 # In milliseconds
+mel_n_channels = 40
+
+
+## Audio
+sampling_rate = 16000
+# Number of spectrogram frames in a partial utterance
+partials_n_frames = 160 # 1600 ms
+# Number of spectrogram frames at inference
+inference_n_frames = 80 # 800 ms
+
+
+## Voice Activation Detection
+# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
+# This sets the granularity of the VAD. Should not need to be changed.
+vad_window_length = 30 # In milliseconds
+# Number of frames to average together when performing the moving average smoothing.
+# The larger this value, the larger the VAD variations must be to not get smoothed out.
+vad_moving_average_width = 8
+# Maximum number of consecutive silent frames a segment can have.
+vad_max_silence_length = 6
+
+
+## Audio volume normalization
+audio_norm_target_dBFS = -30
+
diff --git a/encoder/params_model.py b/encoder/params_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..32731f295b3b26e9e38bb9f9047d5c784649e127
--- /dev/null
+++ b/encoder/params_model.py
@@ -0,0 +1,11 @@
+
+## Model parameters
+model_hidden_size = 256
+model_embedding_size = 256
+model_num_layers = 3
+
+
+## Training parameters
+learning_rate_init = 1e-4
+speakers_per_batch = 64
+utterances_per_speaker = 10
diff --git a/encoder/preprocess.py b/encoder/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab8f8d60f58286ea9c7e711a7d731cd63489a552
--- /dev/null
+++ b/encoder/preprocess.py
@@ -0,0 +1,184 @@
+from datetime import datetime
+from functools import partial
+from multiprocessing import Pool
+from pathlib import Path
+
+import numpy as np
+from tqdm import tqdm
+
+from encoder import audio
+from encoder.config import librispeech_datasets, anglophone_nationalites
+from encoder.params_data import *
+
+
+_AUDIO_EXTENSIONS = ("wav", "flac", "m4a", "mp3")
+
+class DatasetLog:
+ """
+ Registers metadata about the dataset in a text file.
+ """
+ def __init__(self, root, name):
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
+ self.sample_data = dict()
+
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
+ self.write_line("-----")
+ self._log_params()
+
+ def _log_params(self):
+ from encoder import params_data
+ self.write_line("Parameter values:")
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
+ value = getattr(params_data, param_name)
+ self.write_line("\t%s: %s" % (param_name, value))
+ self.write_line("-----")
+
+ def write_line(self, line):
+ self.text_file.write("%s\n" % line)
+
+ def add_sample(self, **kwargs):
+ for param_name, value in kwargs.items():
+ if not param_name in self.sample_data:
+ self.sample_data[param_name] = []
+ self.sample_data[param_name].append(value)
+
+ def finalize(self):
+ self.write_line("Statistics:")
+ for param_name, values in self.sample_data.items():
+ self.write_line("\t%s:" % param_name)
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
+ self.write_line("-----")
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
+ self.write_line("Finished on %s" % end_time)
+ self.text_file.close()
+
+
+def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
+ dataset_root = datasets_root.joinpath(dataset_name)
+ if not dataset_root.exists():
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
+ return None, None
+ return dataset_root, DatasetLog(out_dir, dataset_name)
+
+
+def _preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, skip_existing: bool):
+ # Give a name to the speaker that includes its dataset
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
+
+ # Create an output directory with that name, as well as a txt file containing a
+ # reference to each source file.
+ speaker_out_dir = out_dir.joinpath(speaker_name)
+ speaker_out_dir.mkdir(exist_ok=True)
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
+
+ # There's a possibility that the preprocessing was interrupted earlier, check if
+ # there already is a sources file.
+ if sources_fpath.exists():
+ try:
+ with sources_fpath.open("r") as sources_file:
+ existing_fnames = {line.split(",")[0] for line in sources_file}
+ except:
+ existing_fnames = {}
+ else:
+ existing_fnames = {}
+
+ # Gather all audio files for that speaker recursively
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
+ audio_durs = []
+ for extension in _AUDIO_EXTENSIONS:
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
+ # Check if the target output file already exists
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
+ if skip_existing and out_fname in existing_fnames:
+ continue
+
+ # Load and preprocess the waveform
+ wav = audio.preprocess_wav(in_fpath)
+ if len(wav) == 0:
+ continue
+
+ # Create the mel spectrogram, discard those that are too short
+ frames = audio.wav_to_mel_spectrogram(wav)
+ if len(frames) < partials_n_frames:
+ continue
+
+ out_fpath = speaker_out_dir.joinpath(out_fname)
+ np.save(out_fpath, frames)
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
+ audio_durs.append(len(wav) / sampling_rate)
+
+ sources_file.close()
+
+ return audio_durs
+
+
+def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger):
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
+
+ # Process the utterances for each speaker
+ work_fn = partial(_preprocess_speaker, datasets_root=datasets_root, out_dir=out_dir, skip_existing=skip_existing)
+ with Pool(4) as pool:
+ tasks = pool.imap(work_fn, speaker_dirs)
+ for sample_durs in tqdm(tasks, dataset_name, len(speaker_dirs), unit="speakers"):
+ for sample_dur in sample_durs:
+ logger.add_sample(duration=sample_dur)
+
+ logger.finalize()
+ print("Done preprocessing %s.\n" % dataset_name)
+
+
+def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
+ for dataset_name in librispeech_datasets["train"]["other"]:
+ # Initialize the preprocessing
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
+ if not dataset_root:
+ return
+
+ # Preprocess all speakers
+ speaker_dirs = list(dataset_root.glob("*"))
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
+
+
+def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
+ # Initialize the preprocessing
+ dataset_name = "VoxCeleb1"
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
+ if not dataset_root:
+ return
+
+ # Get the contents of the meta file
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
+ metadata = [line.split("\t") for line in metafile][1:]
+
+ # Select the ID and the nationality, filter out non-anglophone speakers
+ nationalities = {line[0]: line[3] for line in metadata}
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
+ nationality.lower() in anglophone_nationalites]
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
+ (len(keep_speaker_ids), len(nationalities)))
+
+ # Get the speaker directories for anglophone speakers only
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
+ speaker_dir.name in keep_speaker_ids]
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
+
+ # Preprocess all speakers
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
+
+
+def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
+ # Initialize the preprocessing
+ dataset_name = "VoxCeleb2"
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
+ if not dataset_root:
+ return
+
+ # Get the speaker directories
+ # Preprocess all speakers
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
diff --git a/encoder/train.py b/encoder/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..3abbe8da2df5f286ca395e8dc07a24dabaa45022
--- /dev/null
+++ b/encoder/train.py
@@ -0,0 +1,125 @@
+from pathlib import Path
+
+import torch
+
+from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
+from encoder.model import SpeakerEncoder
+from encoder.params_model import *
+from encoder.visualizations import Visualizations
+from utils.profiler import Profiler
+
+
+def sync(device: torch.device):
+ # For correct profiling (cuda operations are async)
+ if device.type == "cuda":
+ torch.cuda.synchronize(device)
+
+
+def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
+ no_visdom: bool):
+ # Create a dataset and a dataloader
+ dataset = SpeakerVerificationDataset(clean_data_root)
+ loader = SpeakerVerificationDataLoader(
+ dataset,
+ speakers_per_batch,
+ utterances_per_speaker,
+ num_workers=4,
+ )
+
+ # Setup the device on which to run the forward pass and the loss. These can be different,
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
+ # hyperparameters) faster on the CPU.
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ # FIXME: currently, the gradient is None if loss_device is cuda
+ loss_device = torch.device("cpu")
+
+ # Create the model and the optimizer
+ model = SpeakerEncoder(device, loss_device)
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
+ init_step = 1
+
+ # Configure file path for the model
+ model_dir = models_dir / run_id
+ model_dir.mkdir(exist_ok=True, parents=True)
+ state_fpath = model_dir / "encoder.pt"
+
+ # Load any existing model
+ if not force_restart:
+ if state_fpath.exists():
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
+ checkpoint = torch.load(state_fpath)
+ init_step = checkpoint["step"]
+ model.load_state_dict(checkpoint["model_state"])
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
+ optimizer.param_groups[0]["lr"] = learning_rate_init
+ else:
+ print("No model \"%s\" found, starting training from scratch." % run_id)
+ else:
+ print("Starting the training from scratch.")
+ model.train()
+
+ # Initialize the visualization environment
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
+ vis.log_dataset(dataset)
+ vis.log_params()
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
+ vis.log_implementation({"Device": device_name})
+
+ # Training loop
+ profiler = Profiler(summarize_every=10, disabled=False)
+ for step, speaker_batch in enumerate(loader, init_step):
+ profiler.tick("Blocking, waiting for batch (threaded)")
+
+ # Forward pass
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
+ sync(device)
+ profiler.tick("Data to %s" % device)
+ embeds = model(inputs)
+ sync(device)
+ profiler.tick("Forward pass")
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
+ loss, eer = model.loss(embeds_loss)
+ sync(loss_device)
+ profiler.tick("Loss")
+
+ # Backward pass
+ model.zero_grad()
+ loss.backward()
+ profiler.tick("Backward pass")
+ model.do_gradient_ops()
+ optimizer.step()
+ profiler.tick("Parameter update")
+
+ # Update visualizations
+ # learning_rate = optimizer.param_groups[0]["lr"]
+ vis.update(loss.item(), eer, step)
+
+ # Draw projections and save them to the backup folder
+ if umap_every != 0 and step % umap_every == 0:
+ print("Drawing and saving projections (step %d)" % step)
+ projection_fpath = model_dir / f"umap_{step:06d}.png"
+ embeds = embeds.detach().cpu().numpy()
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
+ vis.save()
+
+ # Overwrite the latest version of the model
+ if save_every != 0 and step % save_every == 0:
+ print("Saving the model (step %d)" % step)
+ torch.save({
+ "step": step + 1,
+ "model_state": model.state_dict(),
+ "optimizer_state": optimizer.state_dict(),
+ }, state_fpath)
+
+ # Make a backup
+ if backup_every != 0 and step % backup_every == 0:
+ print("Making a backup (step %d)" % step)
+ backup_fpath = model_dir / f"encoder_{step:06d}.bak"
+ torch.save({
+ "step": step + 1,
+ "model_state": model.state_dict(),
+ "optimizer_state": optimizer.state_dict(),
+ }, backup_fpath)
+
+ profiler.tick("Extras (visualizations, saving)")
diff --git a/encoder/visualizations.py b/encoder/visualizations.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf685ee140e1d0af66f69cf0b3dc0349adc7c954
--- /dev/null
+++ b/encoder/visualizations.py
@@ -0,0 +1,179 @@
+from datetime import datetime
+from time import perf_counter as timer
+
+import numpy as np
+import umap
+import visdom
+
+from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
+
+
+colormap = np.array([
+ [76, 255, 0],
+ [0, 127, 70],
+ [255, 0, 0],
+ [255, 217, 38],
+ [0, 135, 255],
+ [165, 0, 165],
+ [255, 167, 255],
+ [0, 255, 255],
+ [255, 96, 38],
+ [142, 76, 0],
+ [33, 0, 127],
+ [0, 0, 0],
+ [183, 183, 183],
+], dtype=np.float) / 255
+
+
+class Visualizations:
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
+ # Tracking data
+ self.last_update_timestamp = timer()
+ self.update_every = update_every
+ self.step_times = []
+ self.losses = []
+ self.eers = []
+ print("Updating the visualizations every %d steps." % update_every)
+
+ # If visdom is disabled TODO: use a better paradigm for that
+ self.disabled = disabled
+ if self.disabled:
+ return
+
+ # Set the environment name
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
+ if env_name is None:
+ self.env_name = now
+ else:
+ self.env_name = "%s (%s)" % (env_name, now)
+
+ # Connect to visdom and open the corresponding window in the browser
+ try:
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
+ except ConnectionError:
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
+ "start it.")
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
+
+ # Create the windows
+ self.loss_win = None
+ self.eer_win = None
+ # self.lr_win = None
+ self.implementation_win = None
+ self.projection_win = None
+ self.implementation_string = ""
+
+ def log_params(self):
+ if self.disabled:
+ return
+ from encoder import params_data
+ from encoder import params_model
+ param_string = "Model parameters:
"
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
+ value = getattr(params_model, param_name)
+ param_string += "\t%s: %s
" % (param_name, value)
+ param_string += "Data parameters:
"
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
+ value = getattr(params_data, param_name)
+ param_string += "\t%s: %s
" % (param_name, value)
+ self.vis.text(param_string, opts={"title": "Parameters"})
+
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
+ if self.disabled:
+ return
+ dataset_string = ""
+ dataset_string += "Speakers: %s\n" % len(dataset.speakers)
+ dataset_string += "\n" + dataset.get_logs()
+ dataset_string = dataset_string.replace("\n", "
")
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
+
+ def log_implementation(self, params):
+ if self.disabled:
+ return
+ implementation_string = ""
+ for param, value in params.items():
+ implementation_string += "%s: %s\n" % (param, value)
+ implementation_string = implementation_string.replace("\n", "
")
+ self.implementation_string = implementation_string
+ self.implementation_win = self.vis.text(
+ implementation_string,
+ opts={"title": "Training implementation"}
+ )
+
+ def update(self, loss, eer, step):
+ # Update the tracking data
+ now = timer()
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
+ self.last_update_timestamp = now
+ self.losses.append(loss)
+ self.eers.append(eer)
+ print(".", end="")
+
+ # Update the plots every steps
+ if step % self.update_every != 0:
+ return
+ time_string = "Step time: mean: %5dms std: %5dms" % \
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
+ if not self.disabled:
+ self.loss_win = self.vis.line(
+ [np.mean(self.losses)],
+ [step],
+ win=self.loss_win,
+ update="append" if self.loss_win else None,
+ opts=dict(
+ legend=["Avg. loss"],
+ xlabel="Step",
+ ylabel="Loss",
+ title="Loss",
+ )
+ )
+ self.eer_win = self.vis.line(
+ [np.mean(self.eers)],
+ [step],
+ win=self.eer_win,
+ update="append" if self.eer_win else None,
+ opts=dict(
+ legend=["Avg. EER"],
+ xlabel="Step",
+ ylabel="EER",
+ title="Equal error rate"
+ )
+ )
+ if self.implementation_win is not None:
+ self.vis.text(
+ self.implementation_string + ("%s" % time_string),
+ win=self.implementation_win,
+ opts={"title": "Training implementation"},
+ )
+
+ # Reset the tracking
+ self.losses.clear()
+ self.eers.clear()
+ self.step_times.clear()
+
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, max_speakers=10):
+ import matplotlib.pyplot as plt
+
+ max_speakers = min(max_speakers, len(colormap))
+ embeds = embeds[:max_speakers * utterances_per_speaker]
+
+ n_speakers = len(embeds) // utterances_per_speaker
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
+ colors = [colormap[i] for i in ground_truth]
+
+ reducer = umap.UMAP()
+ projected = reducer.fit_transform(embeds)
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
+ plt.gca().set_aspect("equal", "datalim")
+ plt.title("UMAP projection (step %d)" % step)
+ if not self.disabled:
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
+ if out_fpath is not None:
+ plt.savefig(out_fpath)
+ plt.clf()
+
+ def save(self):
+ if not self.disabled:
+ self.vis.save([self.env_name])
diff --git a/musk.mp3 b/musk.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..581a6c532e9ca9ebe73f51cef4592b2a673e9c97
Binary files /dev/null and b/musk.mp3 differ
diff --git a/queen.mp3 b/queen.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..05a4df2e7fede32d3dbf76f5db3e6ad917904bea
Binary files /dev/null and b/queen.mp3 differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..971bf7cdba429f98c17d58993aae6e10afc79d19
Binary files /dev/null and b/requirements.txt differ
diff --git a/synthesizer/LICENSE.txt b/synthesizer/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fc7f41f51e3c727335e87262e8b3e0ae760ab3ad
--- /dev/null
+++ b/synthesizer/LICENSE.txt
@@ -0,0 +1,24 @@
+MIT License
+
+Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
+Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
+Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
+Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/synthesizer/__init__.py b/synthesizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4287ca8617970fa8fc025b75cb319c7032706910
--- /dev/null
+++ b/synthesizer/__init__.py
@@ -0,0 +1 @@
+#
\ No newline at end of file
diff --git a/synthesizer/audio.py b/synthesizer/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..94e626965959a121ca83673894e03f541c65fbd9
--- /dev/null
+++ b/synthesizer/audio.py
@@ -0,0 +1,206 @@
+import librosa
+import librosa.filters
+import numpy as np
+from scipy import signal
+from scipy.io import wavfile
+import soundfile as sf
+
+
+def load_wav(path, sr):
+ return librosa.core.load(path, sr=sr)[0]
+
+def save_wav(wav, path, sr):
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
+ #proposed by @dsmiller
+ wavfile.write(path, sr, wav.astype(np.int16))
+
+def save_wavenet_wav(wav, path, sr):
+ sf.write(path, wav.astype(np.float32), sr)
+
+def preemphasis(wav, k, preemphasize=True):
+ if preemphasize:
+ return signal.lfilter([1, -k], [1], wav)
+ return wav
+
+def inv_preemphasis(wav, k, inv_preemphasize=True):
+ if inv_preemphasize:
+ return signal.lfilter([1], [1, -k], wav)
+ return wav
+
+#From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
+def start_and_end_indices(quantized, silence_threshold=2):
+ for start in range(quantized.size):
+ if abs(quantized[start] - 127) > silence_threshold:
+ break
+ for end in range(quantized.size - 1, 1, -1):
+ if abs(quantized[end] - 127) > silence_threshold:
+ break
+
+ assert abs(quantized[start] - 127) > silence_threshold
+ assert abs(quantized[end] - 127) > silence_threshold
+
+ return start, end
+
+def get_hop_size(hparams):
+ hop_size = hparams.hop_size
+ if hop_size is None:
+ assert hparams.frame_shift_ms is not None
+ hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
+ return hop_size
+
+def linearspectrogram(wav, hparams):
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
+ S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
+
+ if hparams.signal_normalization:
+ return _normalize(S, hparams)
+ return S
+
+def melspectrogram(wav, hparams):
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
+ S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
+
+ if hparams.signal_normalization:
+ return _normalize(S, hparams)
+ return S
+
+def inv_linear_spectrogram(linear_spectrogram, hparams):
+ """Converts linear spectrogram to waveform using librosa"""
+ if hparams.signal_normalization:
+ D = _denormalize(linear_spectrogram, hparams)
+ else:
+ D = linear_spectrogram
+
+ S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
+
+ if hparams.use_lws:
+ processor = _lws_processor(hparams)
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
+ y = processor.istft(D).astype(np.float32)
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
+ else:
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
+
+def inv_mel_spectrogram(mel_spectrogram, hparams):
+ """Converts mel spectrogram to waveform using librosa"""
+ if hparams.signal_normalization:
+ D = _denormalize(mel_spectrogram, hparams)
+ else:
+ D = mel_spectrogram
+
+ S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
+
+ if hparams.use_lws:
+ processor = _lws_processor(hparams)
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
+ y = processor.istft(D).astype(np.float32)
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
+ else:
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
+
+def _lws_processor(hparams):
+ import lws
+ return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
+
+def _griffin_lim(S, hparams):
+ """librosa implementation of Griffin-Lim
+ Based on https://github.com/librosa/librosa/issues/434
+ """
+ angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
+ S_complex = np.abs(S).astype(np.complex)
+ y = _istft(S_complex * angles, hparams)
+ for i in range(hparams.griffin_lim_iters):
+ angles = np.exp(1j * np.angle(_stft(y, hparams)))
+ y = _istft(S_complex * angles, hparams)
+ return y
+
+def _stft(y, hparams):
+ if hparams.use_lws:
+ return _lws_processor(hparams).stft(y).T
+ else:
+ return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
+
+def _istft(y, hparams):
+ return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
+
+##########################################################
+#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
+def num_frames(length, fsize, fshift):
+ """Compute number of time frames of spectrogram
+ """
+ pad = (fsize - fshift)
+ if length % fshift == 0:
+ M = (length + pad * 2 - fsize) // fshift + 1
+ else:
+ M = (length + pad * 2 - fsize) // fshift + 2
+ return M
+
+
+def pad_lr(x, fsize, fshift):
+ """Compute left and right padding
+ """
+ M = num_frames(len(x), fsize, fshift)
+ pad = (fsize - fshift)
+ T = len(x) + 2 * pad
+ r = (M - 1) * fshift + fsize - T
+ return pad, pad + r
+##########################################################
+#Librosa correct padding
+def librosa_pad_lr(x, fsize, fshift):
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
+
+# Conversions
+_mel_basis = None
+_inv_mel_basis = None
+
+def _linear_to_mel(spectogram, hparams):
+ global _mel_basis
+ if _mel_basis is None:
+ _mel_basis = _build_mel_basis(hparams)
+ return np.dot(_mel_basis, spectogram)
+
+def _mel_to_linear(mel_spectrogram, hparams):
+ global _inv_mel_basis
+ if _inv_mel_basis is None:
+ _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
+ return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
+
+def _build_mel_basis(hparams):
+ assert hparams.fmax <= hparams.sample_rate // 2
+ return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
+ fmin=hparams.fmin, fmax=hparams.fmax)
+
+def _amp_to_db(x, hparams):
+ min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
+ return 20 * np.log10(np.maximum(min_level, x))
+
+def _db_to_amp(x):
+ return np.power(10.0, (x) * 0.05)
+
+def _normalize(S, hparams):
+ if hparams.allow_clipping_in_normalization:
+ if hparams.symmetric_mels:
+ return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
+ -hparams.max_abs_value, hparams.max_abs_value)
+ else:
+ return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
+
+ assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
+ if hparams.symmetric_mels:
+ return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
+ else:
+ return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
+
+def _denormalize(D, hparams):
+ if hparams.allow_clipping_in_normalization:
+ if hparams.symmetric_mels:
+ return (((np.clip(D, -hparams.max_abs_value,
+ hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
+ + hparams.min_level_db)
+ else:
+ return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
+
+ if hparams.symmetric_mels:
+ return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
+ else:
+ return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
diff --git a/synthesizer/hparams.py b/synthesizer/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..3add59df69d04d55e9ef27b8b4685c8f1a13c477
--- /dev/null
+++ b/synthesizer/hparams.py
@@ -0,0 +1,92 @@
+import ast
+import pprint
+
+class HParams(object):
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
+ def __setitem__(self, key, value): setattr(self, key, value)
+ def __getitem__(self, key): return getattr(self, key)
+ def __repr__(self): return pprint.pformat(self.__dict__)
+
+ def parse(self, string):
+ # Overrides hparams from a comma-separated string of name=value pairs
+ if len(string) > 0:
+ overrides = [s.split("=") for s in string.split(",")]
+ keys, values = zip(*overrides)
+ keys = list(map(str.strip, keys))
+ values = list(map(str.strip, values))
+ for k in keys:
+ self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
+ return self
+
+hparams = HParams(
+ ### Signal Processing (used in both synthesizer and vocoder)
+ sample_rate = 16000,
+ n_fft = 800,
+ num_mels = 80,
+ hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
+ win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
+ fmin = 55,
+ min_level_db = -100,
+ ref_level_db = 20,
+ max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
+ preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
+ preemphasize = True,
+
+ ### Tacotron Text-to-Speech (TTS)
+ tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
+ tts_encoder_dims = 256,
+ tts_decoder_dims = 128,
+ tts_postnet_dims = 512,
+ tts_encoder_K = 5,
+ tts_lstm_dims = 1024,
+ tts_postnet_K = 5,
+ tts_num_highways = 4,
+ tts_dropout = 0.5,
+ tts_cleaner_names = ["english_cleaners"],
+ tts_stop_threshold = -3.4, # Value below which audio generation ends.
+ # For example, for a range of [-4, 4], this
+ # will terminate the sequence at the first
+ # frame that has all values < -3.4
+
+ ### Tacotron Training
+ tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
+ (2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
+ (2, 2e-4, 80_000, 12), #
+ (2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
+ (2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
+ (2, 1e-5, 640_000, 12)], # lr = learning rate
+
+ tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
+ tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
+ # Set to -1 to generate after completing epoch, or 0 to disable
+
+ tts_eval_num_samples = 1, # Makes this number of samples
+
+ ### Data Preprocessing
+ max_mel_frames = 900,
+ rescale = True,
+ rescaling_max = 0.9,
+ synthesis_batch_size = 16, # For vocoder preprocessing and inference.
+
+ ### Mel Visualization and Griffin-Lim
+ signal_normalization = True,
+ power = 1.5,
+ griffin_lim_iters = 60,
+
+ ### Audio processing options
+ fmax = 7600, # Should not exceed (sample_rate // 2)
+ allow_clipping_in_normalization = True, # Used when signal_normalization = True
+ clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
+ use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
+ symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
+ # and [0, max_abs_value] if False
+ trim_silence = True, # Use with sample_rate of 16000 for best results
+
+ ### SV2TTS
+ speaker_embedding_size = 256, # Dimension for the speaker embedding
+ silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
+ utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
+ )
+
+def hparams_debug_string():
+ return str(hparams)
diff --git a/synthesizer/inference.py b/synthesizer/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..db5370b00841ea5757d0cd286f0cdc402a30bf59
--- /dev/null
+++ b/synthesizer/inference.py
@@ -0,0 +1,165 @@
+import torch
+from synthesizer import audio
+from synthesizer.hparams import hparams
+from synthesizer.models.tacotron import Tacotron
+from synthesizer.utils.symbols import symbols
+from synthesizer.utils.text import text_to_sequence
+from vocoder.display import simple_table
+from pathlib import Path
+from typing import Union, List
+import numpy as np
+import librosa
+
+
+class Synthesizer:
+ sample_rate = hparams.sample_rate
+ hparams = hparams
+
+ def __init__(self, model_fpath: Path, verbose=True):
+ """
+ The model isn't instantiated and loaded in memory until needed or until load() is called.
+
+ :param model_fpath: path to the trained model file
+ :param verbose: if False, prints less information when using the model
+ """
+ self.model_fpath = model_fpath
+ self.verbose = verbose
+
+ # Check for GPU
+ if torch.cuda.is_available():
+ self.device = torch.device("cuda")
+ else:
+ self.device = torch.device("cpu")
+ if self.verbose:
+ print("Synthesizer using device:", self.device)
+
+ # Tacotron model will be instantiated later on first use.
+ self._model = None
+
+ def is_loaded(self):
+ """
+ Whether the model is loaded in memory.
+ """
+ return self._model is not None
+
+ def load(self):
+ """
+ Instantiates and loads the model given the weights file that was passed in the constructor.
+ """
+ self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
+ num_chars=len(symbols),
+ encoder_dims=hparams.tts_encoder_dims,
+ decoder_dims=hparams.tts_decoder_dims,
+ n_mels=hparams.num_mels,
+ fft_bins=hparams.num_mels,
+ postnet_dims=hparams.tts_postnet_dims,
+ encoder_K=hparams.tts_encoder_K,
+ lstm_dims=hparams.tts_lstm_dims,
+ postnet_K=hparams.tts_postnet_K,
+ num_highways=hparams.tts_num_highways,
+ dropout=hparams.tts_dropout,
+ stop_threshold=hparams.tts_stop_threshold,
+ speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
+
+ self._model.load(self.model_fpath)
+ self._model.eval()
+
+ if self.verbose:
+ print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
+
+ def synthesize_spectrograms(self, texts: List[str],
+ embeddings: Union[np.ndarray, List[np.ndarray]],
+ return_alignments=False):
+ """
+ Synthesizes mel spectrograms from texts and speaker embeddings.
+
+ :param texts: a list of N text prompts to be synthesized
+ :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
+ :param return_alignments: if True, a matrix representing the alignments between the
+ characters
+ and each decoder output step will be returned for each spectrogram
+ :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
+ sequence length of spectrogram i, and possibly the alignments.
+ """
+ # Load the model on the first request.
+ if not self.is_loaded():
+ self.load()
+
+ # Preprocess text inputs
+ inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts]
+ if not isinstance(embeddings, list):
+ embeddings = [embeddings]
+
+ # Batch inputs
+ batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
+ for i in range(0, len(inputs), hparams.synthesis_batch_size)]
+ batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
+ for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
+
+ specs = []
+ for i, batch in enumerate(batched_inputs, 1):
+ if self.verbose:
+ print(f"\n| Generating {i}/{len(batched_inputs)}")
+
+ # Pad texts so they are all the same length
+ text_lens = [len(text) for text in batch]
+ max_text_len = max(text_lens)
+ chars = [pad1d(text, max_text_len) for text in batch]
+ chars = np.stack(chars)
+
+ # Stack speaker embeddings into 2D array for batch processing
+ speaker_embeds = np.stack(batched_embeds[i-1])
+
+ # Convert to tensor
+ chars = torch.tensor(chars).long().to(self.device)
+ speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
+
+ # Inference
+ _, mels, alignments = self._model.generate(chars, speaker_embeddings)
+ mels = mels.detach().cpu().numpy()
+ for m in mels:
+ # Trim silence from end of each spectrogram
+ while np.max(m[:, -1]) < hparams.tts_stop_threshold:
+ m = m[:, :-1]
+ specs.append(m)
+
+ if self.verbose:
+ print("\n\nDone.\n")
+ return (specs, alignments) if return_alignments else specs
+
+ @staticmethod
+ def load_preprocess_wav(fpath):
+ """
+ Loads and preprocesses an audio file under the same conditions the audio files were used to
+ train the synthesizer.
+ """
+ wav = librosa.load(str(fpath), hparams.sample_rate)[0]
+ if hparams.rescale:
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
+ return wav
+
+ @staticmethod
+ def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
+ """
+ Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
+ were fed to the synthesizer when training.
+ """
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
+ wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
+ else:
+ wav = fpath_or_wav
+
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
+ return mel_spectrogram
+
+ @staticmethod
+ def griffin_lim(mel):
+ """
+ Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
+ with the same parameters present in hparams.py.
+ """
+ return audio.inv_mel_spectrogram(mel, hparams)
+
+
+def pad1d(x, max_len, pad_value=0):
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py
new file mode 100644
index 0000000000000000000000000000000000000000..de9980d5bdbf5ed4eafdcd4adbe058d42af15fec
--- /dev/null
+++ b/synthesizer/models/tacotron.py
@@ -0,0 +1,519 @@
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pathlib import Path
+from typing import Union
+
+
+class HighwayNetwork(nn.Module):
+ def __init__(self, size):
+ super().__init__()
+ self.W1 = nn.Linear(size, size)
+ self.W2 = nn.Linear(size, size)
+ self.W1.bias.data.fill_(0.)
+
+ def forward(self, x):
+ x1 = self.W1(x)
+ x2 = self.W2(x)
+ g = torch.sigmoid(x2)
+ y = g * F.relu(x1) + (1. - g) * x
+ return y
+
+
+class Encoder(nn.Module):
+ def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
+ super().__init__()
+ prenet_dims = (encoder_dims, encoder_dims)
+ cbhg_channels = encoder_dims
+ self.embedding = nn.Embedding(num_chars, embed_dims)
+ self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
+ dropout=dropout)
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
+ proj_channels=[cbhg_channels, cbhg_channels],
+ num_highways=num_highways)
+
+ def forward(self, x, speaker_embedding=None):
+ x = self.embedding(x)
+ x = self.pre_net(x)
+ x.transpose_(1, 2)
+ x = self.cbhg(x)
+ if speaker_embedding is not None:
+ x = self.add_speaker_embedding(x, speaker_embedding)
+ return x
+
+ def add_speaker_embedding(self, x, speaker_embedding):
+ # SV2TTS
+ # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
+ # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
+ # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
+ # This concats the speaker embedding for each char in the encoder output
+
+ # Save the dimensions as human-readable names
+ batch_size = x.size()[0]
+ num_chars = x.size()[1]
+
+ if speaker_embedding.dim() == 1:
+ idx = 0
+ else:
+ idx = 1
+
+ # Start by making a copy of each speaker embedding to match the input text length
+ # The output of this has size (batch_size, num_chars * tts_embed_dims)
+ speaker_embedding_size = speaker_embedding.size()[idx]
+ e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
+
+ # Reshape it and transpose
+ e = e.reshape(batch_size, speaker_embedding_size, num_chars)
+ e = e.transpose(1, 2)
+
+ # Concatenate the tiled speaker embedding with the encoder output
+ x = torch.cat((x, e), 2)
+ return x
+
+
+class BatchNormConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
+ super().__init__()
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
+ self.bnorm = nn.BatchNorm1d(out_channels)
+ self.relu = relu
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = F.relu(x) if self.relu is True else x
+ return self.bnorm(x)
+
+
+class CBHG(nn.Module):
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
+ super().__init__()
+
+ # List of all rnns to call `flatten_parameters()` on
+ self._to_flatten = []
+
+ self.bank_kernels = [i for i in range(1, K + 1)]
+ self.conv1d_bank = nn.ModuleList()
+ for k in self.bank_kernels:
+ conv = BatchNormConv(in_channels, channels, k)
+ self.conv1d_bank.append(conv)
+
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
+
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
+
+ # Fix the highway input if necessary
+ if proj_channels[-1] != channels:
+ self.highway_mismatch = True
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
+ else:
+ self.highway_mismatch = False
+
+ self.highways = nn.ModuleList()
+ for i in range(num_highways):
+ hn = HighwayNetwork(channels)
+ self.highways.append(hn)
+
+ self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
+ self._to_flatten.append(self.rnn)
+
+ # Avoid fragmentation of RNN parameters and associated warning
+ self._flatten_parameters()
+
+ def forward(self, x):
+ # Although we `_flatten_parameters()` on init, when using DataParallel
+ # the model gets replicated, making it no longer guaranteed that the
+ # weights are contiguous in GPU memory. Hence, we must call it again
+ self._flatten_parameters()
+
+ # Save these for later
+ residual = x
+ seq_len = x.size(-1)
+ conv_bank = []
+
+ # Convolution Bank
+ for conv in self.conv1d_bank:
+ c = conv(x) # Convolution
+ conv_bank.append(c[:, :, :seq_len])
+
+ # Stack along the channel axis
+ conv_bank = torch.cat(conv_bank, dim=1)
+
+ # dump the last padding to fit residual
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
+
+ # Conv1d projections
+ x = self.conv_project1(x)
+ x = self.conv_project2(x)
+
+ # Residual Connect
+ x = x + residual
+
+ # Through the highways
+ x = x.transpose(1, 2)
+ if self.highway_mismatch is True:
+ x = self.pre_highway(x)
+ for h in self.highways: x = h(x)
+
+ # And then the RNN
+ x, _ = self.rnn(x)
+ return x
+
+ def _flatten_parameters(self):
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
+ to improve efficiency and avoid PyTorch yelling at us."""
+ [m.flatten_parameters() for m in self._to_flatten]
+
+class PreNet(nn.Module):
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
+ super().__init__()
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
+ self.p = dropout
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = F.relu(x)
+ x = F.dropout(x, self.p, training=True)
+ x = self.fc2(x)
+ x = F.relu(x)
+ x = F.dropout(x, self.p, training=True)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, attn_dims):
+ super().__init__()
+ self.W = nn.Linear(attn_dims, attn_dims, bias=False)
+ self.v = nn.Linear(attn_dims, 1, bias=False)
+
+ def forward(self, encoder_seq_proj, query, t):
+
+ # print(encoder_seq_proj.shape)
+ # Transform the query vector
+ query_proj = self.W(query).unsqueeze(1)
+
+ # Compute the scores
+ u = self.v(torch.tanh(encoder_seq_proj + query_proj))
+ scores = F.softmax(u, dim=1)
+
+ return scores.transpose(1, 2)
+
+
+class LSA(nn.Module):
+ def __init__(self, attn_dim, kernel_size=31, filters=32):
+ super().__init__()
+ self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
+ self.L = nn.Linear(filters, attn_dim, bias=False)
+ self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
+ self.v = nn.Linear(attn_dim, 1, bias=False)
+ self.cumulative = None
+ self.attention = None
+
+ def init_attention(self, encoder_seq_proj):
+ device = next(self.parameters()).device # use same device as parameters
+ b, t, c = encoder_seq_proj.size()
+ self.cumulative = torch.zeros(b, t, device=device)
+ self.attention = torch.zeros(b, t, device=device)
+
+ def forward(self, encoder_seq_proj, query, t, chars):
+
+ if t == 0: self.init_attention(encoder_seq_proj)
+
+ processed_query = self.W(query).unsqueeze(1)
+
+ location = self.cumulative.unsqueeze(1)
+ processed_loc = self.L(self.conv(location).transpose(1, 2))
+
+ u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
+ u = u.squeeze(-1)
+
+ # Mask zero padding chars
+ u = u * (chars != 0).float()
+
+ # Smooth Attention
+ # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
+ scores = F.softmax(u, dim=1)
+ self.attention = scores
+ self.cumulative = self.cumulative + self.attention
+
+ return scores.unsqueeze(-1).transpose(1, 2)
+
+
+class Decoder(nn.Module):
+ # Class variable because its value doesn't change between classes
+ # yet ought to be scoped by class because its a property of a Decoder
+ max_r = 20
+ def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
+ dropout, speaker_embedding_size):
+ super().__init__()
+ self.register_buffer("r", torch.tensor(1, dtype=torch.int))
+ self.n_mels = n_mels
+ prenet_dims = (decoder_dims * 2, decoder_dims * 2)
+ self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
+ dropout=dropout)
+ self.attn_net = LSA(decoder_dims)
+ self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
+ self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
+ self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
+ self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
+ self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
+ self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
+
+ def zoneout(self, prev, current, p=0.1):
+ device = next(self.parameters()).device # Use same device as parameters
+ mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
+ return prev * mask + current * (1 - mask)
+
+ def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
+ hidden_states, cell_states, context_vec, t, chars):
+
+ # Need this for reshaping mels
+ batch_size = encoder_seq.size(0)
+
+ # Unpack the hidden and cell states
+ attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
+ rnn1_cell, rnn2_cell = cell_states
+
+ # PreNet for the Attention RNN
+ prenet_out = self.prenet(prenet_in)
+
+ # Compute the Attention RNN hidden state
+ attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
+ attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
+
+ # Compute the attention scores
+ scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
+
+ # Dot product to create the context vector
+ context_vec = scores @ encoder_seq
+ context_vec = context_vec.squeeze(1)
+
+ # Concat Attention RNN output w. Context Vector & project
+ x = torch.cat([context_vec, attn_hidden], dim=1)
+ x = self.rnn_input(x)
+
+ # Compute first Residual RNN
+ rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
+ if self.training:
+ rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
+ else:
+ rnn1_hidden = rnn1_hidden_next
+ x = x + rnn1_hidden
+
+ # Compute second Residual RNN
+ rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
+ if self.training:
+ rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
+ else:
+ rnn2_hidden = rnn2_hidden_next
+ x = x + rnn2_hidden
+
+ # Project Mels
+ mels = self.mel_proj(x)
+ mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
+ cell_states = (rnn1_cell, rnn2_cell)
+
+ # Stop token prediction
+ s = torch.cat((x, context_vec), dim=1)
+ s = self.stop_proj(s)
+ stop_tokens = torch.sigmoid(s)
+
+ return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
+
+
+class Tacotron(nn.Module):
+ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
+ fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
+ dropout, stop_threshold, speaker_embedding_size):
+ super().__init__()
+ self.n_mels = n_mels
+ self.lstm_dims = lstm_dims
+ self.encoder_dims = encoder_dims
+ self.decoder_dims = decoder_dims
+ self.speaker_embedding_size = speaker_embedding_size
+ self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
+ encoder_K, num_highways, dropout)
+ self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
+ self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
+ dropout, speaker_embedding_size)
+ self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
+ [postnet_dims, fft_bins], num_highways)
+ self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
+
+ self.init_model()
+ self.num_params()
+
+ self.register_buffer("step", torch.zeros(1, dtype=torch.long))
+ self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
+
+ @property
+ def r(self):
+ return self.decoder.r.item()
+
+ @r.setter
+ def r(self, value):
+ self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
+
+ def forward(self, x, m, speaker_embedding):
+ device = next(self.parameters()).device # use same device as parameters
+
+ self.step += 1
+ batch_size, _, steps = m.size()
+
+ # Initialise all hidden states and pack into tuple
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
+
+ # Initialise all lstm cell states and pack into tuple
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
+ cell_states = (rnn1_cell, rnn2_cell)
+
+ # Frame for start of decoder loop
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
+
+ # Need an initial context vector
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
+
+ # SV2TTS: Run the encoder with the speaker embedding
+ # The projection avoids unnecessary matmuls in the decoder loop
+ encoder_seq = self.encoder(x, speaker_embedding)
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
+
+ # Need a couple of lists for outputs
+ mel_outputs, attn_scores, stop_outputs = [], [], []
+
+ # Run the decoder loop
+ for t in range(0, steps, self.r):
+ prenet_in = m[:, :, t - 1] if t > 0 else go_frame
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
+ hidden_states, cell_states, context_vec, t, x)
+ mel_outputs.append(mel_frames)
+ attn_scores.append(scores)
+ stop_outputs.extend([stop_tokens] * self.r)
+
+ # Concat the mel outputs into sequence
+ mel_outputs = torch.cat(mel_outputs, dim=2)
+
+ # Post-Process for Linear Spectrograms
+ postnet_out = self.postnet(mel_outputs)
+ linear = self.post_proj(postnet_out)
+ linear = linear.transpose(1, 2)
+
+ # For easy visualisation
+ attn_scores = torch.cat(attn_scores, 1)
+ # attn_scores = attn_scores.cpu().data.numpy()
+ stop_outputs = torch.cat(stop_outputs, 1)
+
+ return mel_outputs, linear, attn_scores, stop_outputs
+
+ def generate(self, x, speaker_embedding=None, steps=2000):
+ self.eval()
+ device = next(self.parameters()).device # use same device as parameters
+
+ batch_size, _ = x.size()
+
+ # Need to initialise all hidden states and pack into tuple for tidyness
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
+
+ # Need to initialise all lstm cell states and pack into tuple for tidyness
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
+ cell_states = (rnn1_cell, rnn2_cell)
+
+ # Need a Frame for start of decoder loop
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
+
+ # Need an initial context vector
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
+
+ # SV2TTS: Run the encoder with the speaker embedding
+ # The projection avoids unnecessary matmuls in the decoder loop
+ encoder_seq = self.encoder(x, speaker_embedding)
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
+
+ # Need a couple of lists for outputs
+ mel_outputs, attn_scores, stop_outputs = [], [], []
+
+ # Run the decoder loop
+ for t in range(0, steps, self.r):
+ prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
+ hidden_states, cell_states, context_vec, t, x)
+ mel_outputs.append(mel_frames)
+ attn_scores.append(scores)
+ stop_outputs.extend([stop_tokens] * self.r)
+ # Stop the loop when all stop tokens in batch exceed threshold
+ if (stop_tokens > 0.5).all() and t > 10: break
+
+ # Concat the mel outputs into sequence
+ mel_outputs = torch.cat(mel_outputs, dim=2)
+
+ # Post-Process for Linear Spectrograms
+ postnet_out = self.postnet(mel_outputs)
+ linear = self.post_proj(postnet_out)
+
+
+ linear = linear.transpose(1, 2)
+
+ # For easy visualisation
+ attn_scores = torch.cat(attn_scores, 1)
+ stop_outputs = torch.cat(stop_outputs, 1)
+
+ self.train()
+
+ return mel_outputs, linear, attn_scores
+
+ def init_model(self):
+ for p in self.parameters():
+ if p.dim() > 1: nn.init.xavier_uniform_(p)
+
+ def get_step(self):
+ return self.step.data.item()
+
+ def reset_step(self):
+ # assignment to parameters or buffers is overloaded, updates internal dict entry
+ self.step = self.step.data.new_tensor(1)
+
+ def log(self, path, msg):
+ with open(path, "a") as f:
+ print(msg, file=f)
+
+ def load(self, path, optimizer=None):
+ # Use device of model params as location for loaded state
+ device = next(self.parameters()).device
+ checkpoint = torch.load(str(path), map_location=device)
+ self.load_state_dict(checkpoint["model_state"])
+
+ if "optimizer_state" in checkpoint and optimizer is not None:
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
+
+ def save(self, path, optimizer=None):
+ if optimizer is not None:
+ torch.save({
+ "model_state": self.state_dict(),
+ "optimizer_state": optimizer.state_dict(),
+ }, str(path))
+ else:
+ torch.save({
+ "model_state": self.state_dict(),
+ }, str(path))
+
+
+ def num_params(self, print_out=True):
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+ if print_out:
+ print("Trainable Parameters: %.3fM" % parameters)
+ return parameters
diff --git a/synthesizer/preprocess.py b/synthesizer/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..c05bff555b46eb41150d4787ab199dc6c0739324
--- /dev/null
+++ b/synthesizer/preprocess.py
@@ -0,0 +1,258 @@
+from multiprocessing.pool import Pool
+from synthesizer import audio
+from functools import partial
+from itertools import chain
+from encoder import inference as encoder
+from pathlib import Path
+from utils import logmmse
+from tqdm import tqdm
+import numpy as np
+import librosa
+
+
+def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int, skip_existing: bool, hparams,
+ no_alignments: bool, datasets_name: str, subfolders: str):
+ # Gather the input directories
+ dataset_root = datasets_root.joinpath(datasets_name)
+ input_dirs = [dataset_root.joinpath(subfolder.strip()) for subfolder in subfolders.split(",")]
+ print("\n ".join(map(str, ["Using data from:"] + input_dirs)))
+ assert all(input_dir.exists() for input_dir in input_dirs)
+
+ # Create the output directories for each output file type
+ out_dir.joinpath("mels").mkdir(exist_ok=True)
+ out_dir.joinpath("audio").mkdir(exist_ok=True)
+
+ # Create a metadata file
+ metadata_fpath = out_dir.joinpath("train.txt")
+ metadata_file = metadata_fpath.open("a" if skip_existing else "w", encoding="utf-8")
+
+ # Preprocess the dataset
+ speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs))
+ func = partial(preprocess_speaker, out_dir=out_dir, skip_existing=skip_existing,
+ hparams=hparams, no_alignments=no_alignments)
+ job = Pool(n_processes).imap(func, speaker_dirs)
+ for speaker_metadata in tqdm(job, datasets_name, len(speaker_dirs), unit="speakers"):
+ for metadatum in speaker_metadata:
+ metadata_file.write("|".join(str(x) for x in metadatum) + "\n")
+ metadata_file.close()
+
+ # Verify the contents of the metadata file
+ with metadata_fpath.open("r", encoding="utf-8") as metadata_file:
+ metadata = [line.split("|") for line in metadata_file]
+ mel_frames = sum([int(m[4]) for m in metadata])
+ timesteps = sum([int(m[3]) for m in metadata])
+ sample_rate = hparams.sample_rate
+ hours = (timesteps / sample_rate) / 3600
+ print("The dataset consists of %d utterances, %d mel frames, %d audio timesteps (%.2f hours)." %
+ (len(metadata), mel_frames, timesteps, hours))
+ print("Max input length (text chars): %d" % max(len(m[5]) for m in metadata))
+ print("Max mel frames length: %d" % max(int(m[4]) for m in metadata))
+ print("Max audio timesteps length: %d" % max(int(m[3]) for m in metadata))
+
+
+def preprocess_speaker(speaker_dir, out_dir: Path, skip_existing: bool, hparams, no_alignments: bool):
+ metadata = []
+ for book_dir in speaker_dir.glob("*"):
+ if no_alignments:
+ # Gather the utterance audios and texts
+ # LibriTTS uses .wav but we will include extensions for compatibility with other datasets
+ extensions = ["*.wav", "*.flac", "*.mp3"]
+ for extension in extensions:
+ wav_fpaths = book_dir.glob(extension)
+
+ for wav_fpath in wav_fpaths:
+ # Load the audio waveform
+ wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
+ if hparams.rescale:
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
+
+ # Get the corresponding text
+ # Check for .txt (for compatibility with other datasets)
+ text_fpath = wav_fpath.with_suffix(".txt")
+ if not text_fpath.exists():
+ # Check for .normalized.txt (LibriTTS)
+ text_fpath = wav_fpath.with_suffix(".normalized.txt")
+ assert text_fpath.exists()
+ with text_fpath.open("r") as text_file:
+ text = "".join([line for line in text_file])
+ text = text.replace("\"", "")
+ text = text.strip()
+
+ # Process the utterance
+ metadata.append(process_utterance(wav, text, out_dir, str(wav_fpath.with_suffix("").name),
+ skip_existing, hparams))
+ else:
+ # Process alignment file (LibriSpeech support)
+ # Gather the utterance audios and texts
+ try:
+ alignments_fpath = next(book_dir.glob("*.alignment.txt"))
+ with alignments_fpath.open("r") as alignments_file:
+ alignments = [line.rstrip().split(" ") for line in alignments_file]
+ except StopIteration:
+ # A few alignment files will be missing
+ continue
+
+ # Iterate over each entry in the alignments file
+ for wav_fname, words, end_times in alignments:
+ wav_fpath = book_dir.joinpath(wav_fname + ".flac")
+ assert wav_fpath.exists()
+ words = words.replace("\"", "").split(",")
+ end_times = list(map(float, end_times.replace("\"", "").split(",")))
+
+ # Process each sub-utterance
+ wavs, texts = split_on_silences(wav_fpath, words, end_times, hparams)
+ for i, (wav, text) in enumerate(zip(wavs, texts)):
+ sub_basename = "%s_%02d" % (wav_fname, i)
+ metadata.append(process_utterance(wav, text, out_dir, sub_basename,
+ skip_existing, hparams))
+
+ return [m for m in metadata if m is not None]
+
+
+def split_on_silences(wav_fpath, words, end_times, hparams):
+ # Load the audio waveform
+ wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
+ if hparams.rescale:
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
+
+ words = np.array(words)
+ start_times = np.array([0.0] + end_times[:-1])
+ end_times = np.array(end_times)
+ assert len(words) == len(end_times) == len(start_times)
+ assert words[0] == "" and words[-1] == ""
+
+ # Find pauses that are too long
+ mask = (words == "") & (end_times - start_times >= hparams.silence_min_duration_split)
+ mask[0] = mask[-1] = True
+ breaks = np.where(mask)[0]
+
+ # Profile the noise from the silences and perform noise reduction on the waveform
+ silence_times = [[start_times[i], end_times[i]] for i in breaks]
+ silence_times = (np.array(silence_times) * hparams.sample_rate).astype(np.int)
+ noisy_wav = np.concatenate([wav[stime[0]:stime[1]] for stime in silence_times])
+ if len(noisy_wav) > hparams.sample_rate * 0.02:
+ profile = logmmse.profile_noise(noisy_wav, hparams.sample_rate)
+ wav = logmmse.denoise(wav, profile, eta=0)
+
+ # Re-attach segments that are too short
+ segments = list(zip(breaks[:-1], breaks[1:]))
+ segment_durations = [start_times[end] - end_times[start] for start, end in segments]
+ i = 0
+ while i < len(segments) and len(segments) > 1:
+ if segment_durations[i] < hparams.utterance_min_duration:
+ # See if the segment can be re-attached with the right or the left segment
+ left_duration = float("inf") if i == 0 else segment_durations[i - 1]
+ right_duration = float("inf") if i == len(segments) - 1 else segment_durations[i + 1]
+ joined_duration = segment_durations[i] + min(left_duration, right_duration)
+
+ # Do not re-attach if it causes the joined utterance to be too long
+ if joined_duration > hparams.hop_size * hparams.max_mel_frames / hparams.sample_rate:
+ i += 1
+ continue
+
+ # Re-attach the segment with the neighbour of shortest duration
+ j = i - 1 if left_duration <= right_duration else i
+ segments[j] = (segments[j][0], segments[j + 1][1])
+ segment_durations[j] = joined_duration
+ del segments[j + 1], segment_durations[j + 1]
+ else:
+ i += 1
+
+ # Split the utterance
+ segment_times = [[end_times[start], start_times[end]] for start, end in segments]
+ segment_times = (np.array(segment_times) * hparams.sample_rate).astype(np.int)
+ wavs = [wav[segment_time[0]:segment_time[1]] for segment_time in segment_times]
+ texts = [" ".join(words[start + 1:end]).replace(" ", " ") for start, end in segments]
+
+ # # DEBUG: play the audio segments (run with -n=1)
+ # import sounddevice as sd
+ # if len(wavs) > 1:
+ # print("This sentence was split in %d segments:" % len(wavs))
+ # else:
+ # print("There are no silences long enough for this sentence to be split:")
+ # for wav, text in zip(wavs, texts):
+ # # Pad the waveform with 1 second of silence because sounddevice tends to cut them early
+ # # when playing them. You shouldn't need to do that in your parsers.
+ # wav = np.concatenate((wav, [0] * 16000))
+ # print("\t%s" % text)
+ # sd.play(wav, 16000, blocking=True)
+ # print("")
+
+ return wavs, texts
+
+
+def process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
+ skip_existing: bool, hparams):
+ ## FOR REFERENCE:
+ # For you not to lose your head if you ever wish to change things here or implement your own
+ # synthesizer.
+ # - Both the audios and the mel spectrograms are saved as numpy arrays
+ # - There is no processing done to the audios that will be saved to disk beyond volume
+ # normalization (in split_on_silences)
+ # - However, pre-emphasis is applied to the audios before computing the mel spectrogram. This
+ # is why we re-apply it on the audio on the side of the vocoder.
+ # - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved
+ # without extra padding. This means that you won't have an exact relation between the length
+ # of the wav and of the mel spectrogram. See the vocoder data loader.
+
+
+ # Skip existing utterances if needed
+ mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
+ wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
+ if skip_existing and mel_fpath.exists() and wav_fpath.exists():
+ return None
+
+ # Trim silence
+ if hparams.trim_silence:
+ wav = encoder.preprocess_wav(wav, normalize=False, trim_silence=True)
+
+ # Skip utterances that are too short
+ if len(wav) < hparams.utterance_min_duration * hparams.sample_rate:
+ return None
+
+ # Compute the mel spectrogram
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
+ mel_frames = mel_spectrogram.shape[1]
+
+ # Skip utterances that are too long
+ if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length:
+ return None
+
+ # Write the spectrogram, embed and audio to disk
+ np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
+ np.save(wav_fpath, wav, allow_pickle=False)
+
+ # Return a tuple describing this training example
+ return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
+
+
+def embed_utterance(fpaths, encoder_model_fpath):
+ if not encoder.is_loaded():
+ encoder.load_model(encoder_model_fpath)
+
+ # Compute the speaker embedding of the utterance
+ wav_fpath, embed_fpath = fpaths
+ wav = np.load(wav_fpath)
+ wav = encoder.preprocess_wav(wav)
+ embed = encoder.embed_utterance(wav)
+ np.save(embed_fpath, embed, allow_pickle=False)
+
+
+def create_embeddings(synthesizer_root: Path, encoder_model_fpath: Path, n_processes: int):
+ wav_dir = synthesizer_root.joinpath("audio")
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
+ assert wav_dir.exists() and metadata_fpath.exists()
+ embed_dir = synthesizer_root.joinpath("embeds")
+ embed_dir.mkdir(exist_ok=True)
+
+ # Gather the input wave filepath and the target output embed filepath
+ with metadata_fpath.open("r") as metadata_file:
+ metadata = [line.split("|") for line in metadata_file]
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
+
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
+ # Embed the utterances in separate threads
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
+ job = Pool(n_processes).imap(func, fpaths)
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
+
diff --git a/synthesizer/synthesize.py b/synthesizer/synthesize.py
new file mode 100644
index 0000000000000000000000000000000000000000..f341c3dba10638c1f5174796e11f26deaf151bf1
--- /dev/null
+++ b/synthesizer/synthesize.py
@@ -0,0 +1,92 @@
+import platform
+from functools import partial
+from pathlib import Path
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from synthesizer.hparams import hparams_debug_string
+from synthesizer.models.tacotron import Tacotron
+from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
+from synthesizer.utils import data_parallel_workaround
+from synthesizer.utils.symbols import symbols
+
+
+def run_synthesis(in_dir: Path, out_dir: Path, syn_model_fpath: Path, hparams):
+ # This generates ground truth-aligned mels for vocoder training
+ synth_dir = out_dir / "mels_gta"
+ synth_dir.mkdir(exist_ok=True, parents=True)
+ print(hparams_debug_string())
+
+ # Check for GPU
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ if hparams.synthesis_batch_size % torch.cuda.device_count() != 0:
+ raise ValueError("`hparams.synthesis_batch_size` must be evenly divisible by n_gpus!")
+ else:
+ device = torch.device("cpu")
+ print("Synthesizer using device:", device)
+
+ # Instantiate Tacotron model
+ model = Tacotron(embed_dims=hparams.tts_embed_dims,
+ num_chars=len(symbols),
+ encoder_dims=hparams.tts_encoder_dims,
+ decoder_dims=hparams.tts_decoder_dims,
+ n_mels=hparams.num_mels,
+ fft_bins=hparams.num_mels,
+ postnet_dims=hparams.tts_postnet_dims,
+ encoder_K=hparams.tts_encoder_K,
+ lstm_dims=hparams.tts_lstm_dims,
+ postnet_K=hparams.tts_postnet_K,
+ num_highways=hparams.tts_num_highways,
+ dropout=0., # Use zero dropout for gta mels
+ stop_threshold=hparams.tts_stop_threshold,
+ speaker_embedding_size=hparams.speaker_embedding_size).to(device)
+
+ # Load the weights
+ print("\nLoading weights at %s" % syn_model_fpath)
+ model.load(syn_model_fpath)
+ print("Tacotron weights loaded from step %d" % model.step)
+
+ # Synthesize using same reduction factor as the model is currently trained
+ r = np.int32(model.r)
+
+ # Set model to eval mode (disable gradient and zoneout)
+ model.eval()
+
+ # Initialize the dataset
+ metadata_fpath = in_dir.joinpath("train.txt")
+ mel_dir = in_dir.joinpath("mels")
+ embed_dir = in_dir.joinpath("embeds")
+
+ dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
+ collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
+ data_loader = DataLoader(dataset, hparams.synthesis_batch_size, collate_fn=collate_fn, num_workers=2)
+
+ # Generate GTA mels
+ meta_out_fpath = out_dir / "synthesized.txt"
+ with meta_out_fpath.open("w") as file:
+ for i, (texts, mels, embeds, idx) in tqdm(enumerate(data_loader), total=len(data_loader)):
+ texts, mels, embeds = texts.to(device), mels.to(device), embeds.to(device)
+
+ # Parallelize model onto GPUS using workaround due to python bug
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
+ _, mels_out, _ = data_parallel_workaround(model, texts, mels, embeds)
+ else:
+ _, mels_out, _, _ = model(texts, mels, embeds)
+
+ for j, k in enumerate(idx):
+ # Note: outputs mel-spectrogram files and target ones have same names, just different folders
+ mel_filename = Path(synth_dir).joinpath(dataset.metadata[k][1])
+ mel_out = mels_out[j].detach().cpu().numpy().T
+
+ # Use the length of the ground truth mel to remove padding from the generated mels
+ mel_out = mel_out[:int(dataset.metadata[k][4])]
+
+ # Write the spectrogram to disk
+ np.save(mel_filename, mel_out, allow_pickle=False)
+
+ # Write metadata into the synthesized file
+ file.write("|".join(dataset.metadata[k]))
diff --git a/synthesizer/synthesizer_dataset.py b/synthesizer/synthesizer_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6771a5aa61c38b790096a52addf983824e3be698
--- /dev/null
+++ b/synthesizer/synthesizer_dataset.py
@@ -0,0 +1,92 @@
+import torch
+from torch.utils.data import Dataset
+import numpy as np
+from pathlib import Path
+from synthesizer.utils.text import text_to_sequence
+
+
+class SynthesizerDataset(Dataset):
+ def __init__(self, metadata_fpath: Path, mel_dir: Path, embed_dir: Path, hparams):
+ print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, embed_dir))
+
+ with metadata_fpath.open("r") as metadata_file:
+ metadata = [line.split("|") for line in metadata_file]
+
+ mel_fnames = [x[1] for x in metadata if int(x[4])]
+ mel_fpaths = [mel_dir.joinpath(fname) for fname in mel_fnames]
+ embed_fnames = [x[2] for x in metadata if int(x[4])]
+ embed_fpaths = [embed_dir.joinpath(fname) for fname in embed_fnames]
+ self.samples_fpaths = list(zip(mel_fpaths, embed_fpaths))
+ self.samples_texts = [x[5].strip() for x in metadata if int(x[4])]
+ self.metadata = metadata
+ self.hparams = hparams
+
+ print("Found %d samples" % len(self.samples_fpaths))
+
+ def __getitem__(self, index):
+ # Sometimes index may be a list of 2 (not sure why this happens)
+ # If that is the case, return a single item corresponding to first element in index
+ if index is list:
+ index = index[0]
+
+ mel_path, embed_path = self.samples_fpaths[index]
+ mel = np.load(mel_path).T.astype(np.float32)
+
+ # Load the embed
+ embed = np.load(embed_path)
+
+ # Get the text and clean it
+ text = text_to_sequence(self.samples_texts[index], self.hparams.tts_cleaner_names)
+
+ # Convert the list returned by text_to_sequence to a numpy array
+ text = np.asarray(text).astype(np.int32)
+
+ return text, mel.astype(np.float32), embed.astype(np.float32), index
+
+ def __len__(self):
+ return len(self.samples_fpaths)
+
+
+def collate_synthesizer(batch, r, hparams):
+ # Text
+ x_lens = [len(x[0]) for x in batch]
+ max_x_len = max(x_lens)
+
+ chars = [pad1d(x[0], max_x_len) for x in batch]
+ chars = np.stack(chars)
+
+ # Mel spectrogram
+ spec_lens = [x[1].shape[-1] for x in batch]
+ max_spec_len = max(spec_lens) + 1
+ if max_spec_len % r != 0:
+ max_spec_len += r - max_spec_len % r
+
+ # WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence
+ # By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence.
+ if hparams.symmetric_mels:
+ mel_pad_value = -1 * hparams.max_abs_value
+ else:
+ mel_pad_value = 0
+
+ mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch]
+ mel = np.stack(mel)
+
+ # Speaker embedding (SV2TTS)
+ embeds = np.array([x[2] for x in batch])
+
+ # Index (for vocoder preprocessing)
+ indices = [x[3] for x in batch]
+
+
+ # Convert all to tensor
+ chars = torch.tensor(chars).long()
+ mel = torch.tensor(mel)
+ embeds = torch.tensor(embeds)
+
+ return chars, mel, embeds, indices
+
+def pad1d(x, max_len, pad_value=0):
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
+
+def pad2d(x, max_len, pad_value=0):
+ return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value)
diff --git a/synthesizer/train.py b/synthesizer/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a72939ab3e0be5c21b20ba8900f584f508dec99
--- /dev/null
+++ b/synthesizer/train.py
@@ -0,0 +1,258 @@
+from datetime import datetime
+from functools import partial
+from pathlib import Path
+
+import torch
+import torch.nn.functional as F
+from torch import optim
+from torch.utils.data import DataLoader
+
+from synthesizer import audio
+from synthesizer.models.tacotron import Tacotron
+from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
+from synthesizer.utils import ValueWindow, data_parallel_workaround
+from synthesizer.utils.plot import plot_spectrogram
+from synthesizer.utils.symbols import symbols
+from synthesizer.utils.text import sequence_to_text
+from vocoder.display import *
+
+
+def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
+
+
+def time_string():
+ return datetime.now().strftime("%Y-%m-%d %H:%M")
+
+
+def train(run_id: str, syn_dir: Path, models_dir: Path, save_every: int, backup_every: int, force_restart: bool,
+ hparams):
+ models_dir.mkdir(exist_ok=True)
+
+ model_dir = models_dir.joinpath(run_id)
+ plot_dir = model_dir.joinpath("plots")
+ wav_dir = model_dir.joinpath("wavs")
+ mel_output_dir = model_dir.joinpath("mel-spectrograms")
+ meta_folder = model_dir.joinpath("metas")
+ model_dir.mkdir(exist_ok=True)
+ plot_dir.mkdir(exist_ok=True)
+ wav_dir.mkdir(exist_ok=True)
+ mel_output_dir.mkdir(exist_ok=True)
+ meta_folder.mkdir(exist_ok=True)
+
+ weights_fpath = model_dir / f"synthesizer.pt"
+ metadata_fpath = syn_dir.joinpath("train.txt")
+
+ print("Checkpoint path: {}".format(weights_fpath))
+ print("Loading training data from: {}".format(metadata_fpath))
+ print("Using model: Tacotron")
+
+ # Bookkeeping
+ time_window = ValueWindow(100)
+ loss_window = ValueWindow(100)
+
+ # From WaveRNN/train_tacotron.py
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+
+ for session in hparams.tts_schedule:
+ _, _, _, batch_size = session
+ if batch_size % torch.cuda.device_count() != 0:
+ raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
+ else:
+ device = torch.device("cpu")
+ print("Using device:", device)
+
+ # Instantiate Tacotron Model
+ print("\nInitialising Tacotron Model...\n")
+ model = Tacotron(embed_dims=hparams.tts_embed_dims,
+ num_chars=len(symbols),
+ encoder_dims=hparams.tts_encoder_dims,
+ decoder_dims=hparams.tts_decoder_dims,
+ n_mels=hparams.num_mels,
+ fft_bins=hparams.num_mels,
+ postnet_dims=hparams.tts_postnet_dims,
+ encoder_K=hparams.tts_encoder_K,
+ lstm_dims=hparams.tts_lstm_dims,
+ postnet_K=hparams.tts_postnet_K,
+ num_highways=hparams.tts_num_highways,
+ dropout=hparams.tts_dropout,
+ stop_threshold=hparams.tts_stop_threshold,
+ speaker_embedding_size=hparams.speaker_embedding_size).to(device)
+
+ # Initialize the optimizer
+ optimizer = optim.Adam(model.parameters())
+
+ # Load the weights
+ if force_restart or not weights_fpath.exists():
+ print("\nStarting the training of Tacotron from scratch\n")
+ model.save(weights_fpath)
+
+ # Embeddings metadata
+ char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
+ with open(char_embedding_fpath, "w", encoding="utf-8") as f:
+ for symbol in symbols:
+ if symbol == " ":
+ symbol = "\\s" # For visual purposes, swap space with \s
+
+ f.write("{}\n".format(symbol))
+
+ else:
+ print("\nLoading weights at %s" % weights_fpath)
+ model.load(weights_fpath, optimizer)
+ print("Tacotron weights loaded from step %d" % model.step)
+
+ # Initialize the dataset
+ metadata_fpath = syn_dir.joinpath("train.txt")
+ mel_dir = syn_dir.joinpath("mels")
+ embed_dir = syn_dir.joinpath("embeds")
+ dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
+
+ for i, session in enumerate(hparams.tts_schedule):
+ current_step = model.get_step()
+
+ r, lr, max_step, batch_size = session
+
+ training_steps = max_step - current_step
+
+ # Do we need to change to the next session?
+ if current_step >= max_step:
+ # Are there no further sessions than the current one?
+ if i == len(hparams.tts_schedule) - 1:
+ # We have completed training. Save the model and exit
+ model.save(weights_fpath, optimizer)
+ break
+ else:
+ # There is a following session, go to it
+ continue
+
+ model.r = r
+
+ # Begin the training
+ simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
+ ("Batch Size", batch_size),
+ ("Learning Rate", lr),
+ ("Outputs/Step (r)", model.r)])
+
+ for p in optimizer.param_groups:
+ p["lr"] = lr
+
+ collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
+ data_loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn)
+
+ total_iters = len(dataset)
+ steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
+ epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
+
+ for epoch in range(1, epochs+1):
+ for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
+ start_time = time.time()
+
+ # Generate stop tokens for training
+ stop = torch.ones(mels.shape[0], mels.shape[2])
+ for j, k in enumerate(idx):
+ stop[j, :int(dataset.metadata[k][4])-1] = 0
+
+ texts = texts.to(device)
+ mels = mels.to(device)
+ embeds = embeds.to(device)
+ stop = stop.to(device)
+
+ # Forward pass
+ # Parallelize model onto GPUS using workaround due to python bug
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
+ m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts, mels, embeds)
+ else:
+ m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
+
+ # Backward pass
+ m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
+ m2_loss = F.mse_loss(m2_hat, mels)
+ stop_loss = F.binary_cross_entropy(stop_pred, stop)
+
+ loss = m1_loss + m2_loss + stop_loss
+
+ optimizer.zero_grad()
+ loss.backward()
+
+ if hparams.tts_clip_grad_norm is not None:
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
+ if np.isnan(grad_norm.cpu()):
+ print("grad_norm was NaN!")
+
+ optimizer.step()
+
+ time_window.append(time.time() - start_time)
+ loss_window.append(loss.item())
+
+ step = model.get_step()
+ k = step // 1000
+
+ msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | " \
+ f"{1./time_window.average:#.2} steps/s | Step: {k}k | "
+ stream(msg)
+
+ # Backup or save model as appropriate
+ if backup_every != 0 and step % backup_every == 0 :
+ backup_fpath = weights_fpath.parent / f"synthesizer_{k:06d}.pt"
+ model.save(backup_fpath, optimizer)
+
+ if save_every != 0 and step % save_every == 0 :
+ # Must save latest optimizer state to ensure that resuming training
+ # doesn't produce artifacts
+ model.save(weights_fpath, optimizer)
+
+ # Evaluate model to generate samples
+ epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
+ step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps
+ if epoch_eval or step_eval:
+ for sample_idx in range(hparams.tts_eval_num_samples):
+ # At most, generate samples equal to number in the batch
+ if sample_idx + 1 <= len(texts):
+ # Remove padding from mels using frame length in metadata
+ mel_length = int(dataset.metadata[idx[sample_idx]][4])
+ mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
+ target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
+ attention_len = mel_length // model.r
+
+ eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
+ mel_prediction=mel_prediction,
+ target_spectrogram=target_spectrogram,
+ input_seq=np_now(texts[sample_idx]),
+ step=step,
+ plot_dir=plot_dir,
+ mel_output_dir=mel_output_dir,
+ wav_dir=wav_dir,
+ sample_num=sample_idx + 1,
+ loss=loss,
+ hparams=hparams)
+
+ # Break out of loop to update training schedule
+ if step >= max_step:
+ break
+
+ # Add line break after every epoch
+ print("")
+
+
+def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
+ plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
+ # Save some results for evaluation
+ attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
+ save_attention(attention, attention_path)
+
+ # save predicted mel spectrogram to disk (debug)
+ mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
+ np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
+
+ # save griffin lim inverted wav for debug (mel -> wav)
+ wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
+ wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num))
+ audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
+
+ # save real and predicted mel-spectrogram plot to disk (control purposes)
+ spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
+ title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
+ plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
+ target_spectrogram=target_spectrogram,
+ max_len=target_spectrogram.size // hparams.num_mels)
+ print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
diff --git a/synthesizer/utils/__init__.py b/synthesizer/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d3024302c084c8a3a446b74a64cec5f27bc19cb
--- /dev/null
+++ b/synthesizer/utils/__init__.py
@@ -0,0 +1,45 @@
+import torch
+
+
+_output_ref = None
+_replicas_ref = None
+
+def data_parallel_workaround(model, *input):
+ global _output_ref
+ global _replicas_ref
+ device_ids = list(range(torch.cuda.device_count()))
+ output_device = device_ids[0]
+ replicas = torch.nn.parallel.replicate(model, device_ids)
+ # input.shape = (num_args, batch, ...)
+ inputs = torch.nn.parallel.scatter(input, device_ids)
+ # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
+ replicas = replicas[:len(inputs)]
+ outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
+ y_hat = torch.nn.parallel.gather(outputs, output_device)
+ _output_ref = outputs
+ _replicas_ref = replicas
+ return y_hat
+
+
+class ValueWindow():
+ def __init__(self, window_size=100):
+ self._window_size = window_size
+ self._values = []
+
+ def append(self, x):
+ self._values = self._values[-(self._window_size - 1):] + [x]
+
+ @property
+ def sum(self):
+ return sum(self._values)
+
+ @property
+ def count(self):
+ return len(self._values)
+
+ @property
+ def average(self):
+ return self.sum / max(1, self.count)
+
+ def reset(self):
+ self._values = []
diff --git a/synthesizer/utils/_cmudict.py b/synthesizer/utils/_cmudict.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfddda34744516b9407f847bdbcc9f7be30df062
--- /dev/null
+++ b/synthesizer/utils/_cmudict.py
@@ -0,0 +1,62 @@
+import re
+
+valid_symbols = [
+ "AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2",
+ "AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2",
+ "B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY",
+ "EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1",
+ "IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0",
+ "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW",
+ "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH"
+]
+
+_valid_symbol_set = set(valid_symbols)
+
+
+class CMUDict:
+ """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
+ def __init__(self, file_or_path, keep_ambiguous=True):
+ if isinstance(file_or_path, str):
+ with open(file_or_path, encoding="latin-1") as f:
+ entries = _parse_cmudict(f)
+ else:
+ entries = _parse_cmudict(file_or_path)
+ if not keep_ambiguous:
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
+ self._entries = entries
+
+
+ def __len__(self):
+ return len(self._entries)
+
+
+ def lookup(self, word):
+ """Returns list of ARPAbet pronunciations of the given word."""
+ return self._entries.get(word.upper())
+
+
+
+_alt_re = re.compile(r"\([0-9]+\)")
+
+
+def _parse_cmudict(file):
+ cmudict = {}
+ for line in file:
+ if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
+ parts = line.split(" ")
+ word = re.sub(_alt_re, "", parts[0])
+ pronunciation = _get_pronunciation(parts[1])
+ if pronunciation:
+ if word in cmudict:
+ cmudict[word].append(pronunciation)
+ else:
+ cmudict[word] = [pronunciation]
+ return cmudict
+
+
+def _get_pronunciation(s):
+ parts = s.strip().split(" ")
+ for part in parts:
+ if part not in _valid_symbol_set:
+ return None
+ return " ".join(parts)
diff --git a/synthesizer/utils/cleaners.py b/synthesizer/utils/cleaners.py
new file mode 100644
index 0000000000000000000000000000000000000000..8df9a096eb9836c196fabc474b4ae59c4270a282
--- /dev/null
+++ b/synthesizer/utils/cleaners.py
@@ -0,0 +1,88 @@
+"""
+Cleaners are transformations that run over the input text at both training and eval time.
+
+Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
+hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
+ 1. "english_cleaners" for English text
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
+ the symbols in symbols.py to match your data).
+"""
+import re
+from unidecode import unidecode
+from synthesizer.utils.numbers import normalize_numbers
+
+
+# Regular expression matching whitespace:
+_whitespace_re = re.compile(r"\s+")
+
+# List of (regular expression, replacement) pairs for abbreviations:
+_abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
+ ("mrs", "misess"),
+ ("mr", "mister"),
+ ("dr", "doctor"),
+ ("st", "saint"),
+ ("co", "company"),
+ ("jr", "junior"),
+ ("maj", "major"),
+ ("gen", "general"),
+ ("drs", "doctors"),
+ ("rev", "reverend"),
+ ("lt", "lieutenant"),
+ ("hon", "honorable"),
+ ("sgt", "sergeant"),
+ ("capt", "captain"),
+ ("esq", "esquire"),
+ ("ltd", "limited"),
+ ("col", "colonel"),
+ ("ft", "fort"),
+]]
+
+
+def expand_abbreviations(text):
+ for regex, replacement in _abbreviations:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def expand_numbers(text):
+ return normalize_numbers(text)
+
+
+def lowercase(text):
+ """lowercase input tokens."""
+ return text.lower()
+
+
+def collapse_whitespace(text):
+ return re.sub(_whitespace_re, " ", text)
+
+
+def convert_to_ascii(text):
+ return unidecode(text)
+
+
+def basic_cleaners(text):
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
+ text = lowercase(text)
+ text = collapse_whitespace(text)
+ return text
+
+
+def transliteration_cleaners(text):
+ """Pipeline for non-English text that transliterates to ASCII."""
+ text = convert_to_ascii(text)
+ text = lowercase(text)
+ text = collapse_whitespace(text)
+ return text
+
+
+def english_cleaners(text):
+ """Pipeline for English text, including number and abbreviation expansion."""
+ text = convert_to_ascii(text)
+ text = lowercase(text)
+ text = expand_numbers(text)
+ text = expand_abbreviations(text)
+ text = collapse_whitespace(text)
+ return text
diff --git a/synthesizer/utils/numbers.py b/synthesizer/utils/numbers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8265fcf97e327d192a0726c6a5370183d0880ef
--- /dev/null
+++ b/synthesizer/utils/numbers.py
@@ -0,0 +1,69 @@
+import re
+import inflect
+
+
+_inflect = inflect.engine()
+_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
+_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
+_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
+_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
+_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
+_number_re = re.compile(r"[0-9]+")
+
+
+def _remove_commas(m):
+ return m.group(1).replace(",", "")
+
+
+def _expand_decimal_point(m):
+ return m.group(1).replace(".", " point ")
+
+
+def _expand_dollars(m):
+ match = m.group(1)
+ parts = match.split(".")
+ if len(parts) > 2:
+ return match + " dollars" # Unexpected format
+ dollars = int(parts[0]) if parts[0] else 0
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
+ if dollars and cents:
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
+ cent_unit = "cent" if cents == 1 else "cents"
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
+ elif dollars:
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
+ return "%s %s" % (dollars, dollar_unit)
+ elif cents:
+ cent_unit = "cent" if cents == 1 else "cents"
+ return "%s %s" % (cents, cent_unit)
+ else:
+ return "zero dollars"
+
+
+def _expand_ordinal(m):
+ return _inflect.number_to_words(m.group(0))
+
+
+def _expand_number(m):
+ num = int(m.group(0))
+ if num > 1000 and num < 3000:
+ if num == 2000:
+ return "two thousand"
+ elif num > 2000 and num < 2010:
+ return "two thousand " + _inflect.number_to_words(num % 100)
+ elif num % 100 == 0:
+ return _inflect.number_to_words(num // 100) + " hundred"
+ else:
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
+ else:
+ return _inflect.number_to_words(num, andword="")
+
+
+def normalize_numbers(text):
+ text = re.sub(_comma_number_re, _remove_commas, text)
+ text = re.sub(_pounds_re, r"\1 pounds", text)
+ text = re.sub(_dollars_re, _expand_dollars, text)
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
+ text = re.sub(_number_re, _expand_number, text)
+ return text
diff --git a/synthesizer/utils/plot.py b/synthesizer/utils/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..5499e9c7f93d41229c13a767e9e55f6a389d24ff
--- /dev/null
+++ b/synthesizer/utils/plot.py
@@ -0,0 +1,82 @@
+import numpy as np
+
+
+def split_title_line(title_text, max_words=5):
+ """
+ A function that splits any string based on specific character
+ (returning it with the string), with maximum number of words on it
+ """
+ seq = title_text.split()
+ return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)])
+
+
+def plot_alignment(alignment, path, title=None, split_title=False, max_len=None):
+ import matplotlib
+ matplotlib.use("Agg")
+ import matplotlib.pyplot as plt
+
+ if max_len is not None:
+ alignment = alignment[:, :max_len]
+
+ fig = plt.figure(figsize=(8, 6))
+ ax = fig.add_subplot(111)
+
+ im = ax.imshow(
+ alignment,
+ aspect="auto",
+ origin="lower",
+ interpolation="none")
+ fig.colorbar(im, ax=ax)
+ xlabel = "Decoder timestep"
+
+ if split_title:
+ title = split_title_line(title)
+
+ plt.xlabel(xlabel)
+ plt.title(title)
+ plt.ylabel("Encoder timestep")
+ plt.tight_layout()
+ plt.savefig(path, format="png")
+ plt.close()
+
+
+def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False):
+ import matplotlib
+ matplotlib.use("Agg")
+ import matplotlib.pyplot as plt
+
+ if max_len is not None:
+ target_spectrogram = target_spectrogram[:max_len]
+ pred_spectrogram = pred_spectrogram[:max_len]
+
+ if split_title:
+ title = split_title_line(title)
+
+ fig = plt.figure(figsize=(10, 8))
+ # Set common labels
+ fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
+
+ #target spectrogram subplot
+ if target_spectrogram is not None:
+ ax1 = fig.add_subplot(311)
+ ax2 = fig.add_subplot(312)
+
+ if auto_aspect:
+ im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
+ else:
+ im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
+ ax1.set_title("Target Mel-Spectrogram")
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
+ ax2.set_title("Predicted Mel-Spectrogram")
+ else:
+ ax2 = fig.add_subplot(211)
+
+ if auto_aspect:
+ im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
+ else:
+ im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
+
+ plt.tight_layout()
+ plt.savefig(path, format="png")
+ plt.close()
diff --git a/synthesizer/utils/symbols.py b/synthesizer/utils/symbols.py
new file mode 100644
index 0000000000000000000000000000000000000000..14f629b4a047173b82e6f9e87f10bec9e067ff31
--- /dev/null
+++ b/synthesizer/utils/symbols.py
@@ -0,0 +1,17 @@
+"""
+Defines the set of symbols used in text input to the model.
+
+The default is a set of ASCII characters that works well for English or text that has been run
+through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
+"""
+# from . import cmudict
+
+_pad = "_"
+_eos = "~"
+_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'\"(),-.:;? "
+
+# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
+#_arpabet = ["@' + s for s in cmudict.valid_symbols]
+
+# Export all symbols:
+symbols = [_pad, _eos] + list(_characters) #+ _arpabet
diff --git a/synthesizer/utils/text.py b/synthesizer/utils/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..a12d211c060a895a47d7eae6890020557f88183e
--- /dev/null
+++ b/synthesizer/utils/text.py
@@ -0,0 +1,75 @@
+from synthesizer.utils.symbols import symbols
+from synthesizer.utils import cleaners
+import re
+
+
+# Mappings from symbol to numeric ID and vice versa:
+_symbol_to_id = {s: i for i, s in enumerate(symbols)}
+_id_to_symbol = {i: s for i, s in enumerate(symbols)}
+
+# Regular expression matching text enclosed in curly braces:
+_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
+
+
+def text_to_sequence(text, cleaner_names):
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
+
+ The text can optionally have ARPAbet sequences enclosed in curly braces embedded
+ in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
+
+ Args:
+ text: string to convert to a sequence
+ cleaner_names: names of the cleaner functions to run the text through
+
+ Returns:
+ List of integers corresponding to the symbols in the text
+ """
+ sequence = []
+
+ # Check for curly braces and treat their contents as ARPAbet:
+ while len(text):
+ m = _curly_re.match(text)
+ if not m:
+ sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
+ break
+ sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
+ sequence += _arpabet_to_sequence(m.group(2))
+ text = m.group(3)
+
+ # Append EOS token
+ sequence.append(_symbol_to_id["~"])
+ return sequence
+
+
+def sequence_to_text(sequence):
+ """Converts a sequence of IDs back to a string"""
+ result = ""
+ for symbol_id in sequence:
+ if symbol_id in _id_to_symbol:
+ s = _id_to_symbol[symbol_id]
+ # Enclose ARPAbet back in curly braces:
+ if len(s) > 1 and s[0] == "@":
+ s = "{%s}" % s[1:]
+ result += s
+ return result.replace("}{", " ")
+
+
+def _clean_text(text, cleaner_names):
+ for name in cleaner_names:
+ cleaner = getattr(cleaners, name)
+ if not cleaner:
+ raise Exception("Unknown cleaner: %s" % name)
+ text = cleaner(text)
+ return text
+
+
+def _symbols_to_sequence(symbols):
+ return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
+
+
+def _arpabet_to_sequence(text):
+ return _symbols_to_sequence(["@" + s for s in text.split()])
+
+
+def _should_keep_symbol(s):
+ return s in _symbol_to_id and s not in ("_", "~")
diff --git a/trump.mp3 b/trump.mp3
new file mode 100644
index 0000000000000000000000000000000000000000..eee45db641638e3faffd318fa07f9863f9fa54d4
Binary files /dev/null and b/trump.mp3 differ
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/argutils.py b/utils/argutils.py
new file mode 100644
index 0000000000000000000000000000000000000000..376f5abff314ecb1366060e62b553b236efef8c2
--- /dev/null
+++ b/utils/argutils.py
@@ -0,0 +1,40 @@
+from pathlib import Path
+import numpy as np
+import argparse
+
+_type_priorities = [ # In decreasing order
+ Path,
+ str,
+ int,
+ float,
+ bool,
+]
+
+def _priority(o):
+ p = next((i for i, t in enumerate(_type_priorities) if type(o) is t), None)
+ if p is not None:
+ return p
+ p = next((i for i, t in enumerate(_type_priorities) if isinstance(o, t)), None)
+ if p is not None:
+ return p
+ return len(_type_priorities)
+
+def print_args(args: argparse.Namespace, parser=None):
+ args = vars(args)
+ if parser is None:
+ priorities = list(map(_priority, args.values()))
+ else:
+ all_params = [a.dest for g in parser._action_groups for a in g._group_actions ]
+ priority = lambda p: all_params.index(p) if p in all_params else len(all_params)
+ priorities = list(map(priority, args.keys()))
+
+ pad = max(map(len, args.keys())) + 3
+ indices = np.lexsort((list(args.keys()), priorities))
+ items = list(args.items())
+
+ print("Arguments:")
+ for i in indices:
+ param, value = items[i]
+ print(" {0}:{1}{2}".format(param, ' ' * (pad - len(param)), value))
+ print("")
+
\ No newline at end of file
diff --git a/utils/default_models.py b/utils/default_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..1200028a302cae1e2f05162ea111f5ec545c3db4
--- /dev/null
+++ b/utils/default_models.py
@@ -0,0 +1,56 @@
+import urllib.request
+from pathlib import Path
+from threading import Thread
+from urllib.error import HTTPError
+
+from tqdm import tqdm
+
+
+default_models = {
+ "encoder": ("https://drive.google.com/uc?export=download&id=1q8mEGwCkFy23KZsinbuvdKAQLqNKbYf1", 17090379),
+ "synthesizer": ("https://drive.google.com/u/0/uc?id=1EqFMIbvxffxtjiVrtykroF6_mUh-5Z3s&export=download&confirm=t", 370554559),
+ "vocoder": ("https://drive.google.com/uc?export=download&id=1cf2NO6FtI0jDuy8AV3Xgn6leO6dHjIgu", 53845290),
+}
+
+
+class DownloadProgressBar(tqdm):
+ def update_to(self, b=1, bsize=1, tsize=None):
+ if tsize is not None:
+ self.total = tsize
+ self.update(b * bsize - self.n)
+
+
+def download(url: str, target: Path, bar_pos=0):
+ # Ensure the directory exists
+ target.parent.mkdir(exist_ok=True, parents=True)
+
+ desc = f"Downloading {target.name}"
+ with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=desc, position=bar_pos, leave=False) as t:
+ try:
+ urllib.request.urlretrieve(url, filename=target, reporthook=t.update_to)
+ except HTTPError:
+ return
+
+
+def ensure_default_models(models_dir: Path):
+ # Define download tasks
+ jobs = []
+ for model_name, (url, size) in default_models.items():
+ target_path = models_dir / "default" / f"{model_name}.pt"
+ if target_path.exists():
+ if target_path.stat().st_size != size:
+ print(f"File {target_path} is not of expected size, redownloading...")
+ else:
+ continue
+
+ thread = Thread(target=download, args=(url, target_path, len(jobs)))
+ thread.start()
+ jobs.append((thread, target_path, size))
+
+ # Run and join threads
+ for thread, target_path, size in jobs:
+ thread.join()
+
+ assert target_path.exists() and target_path.stat().st_size == size, \
+ f"Download for {target_path.name} failed. You may download models manually instead.\n" \
+ f"https://drive.google.com/drive/folders/1fU6umc5uQAVR2udZdHX-lDgXYzTyqG_j"
diff --git a/utils/logmmse.py b/utils/logmmse.py
new file mode 100644
index 0000000000000000000000000000000000000000..db82e400e9a1c9a435d87db0ea8b101a01c03a32
--- /dev/null
+++ b/utils/logmmse.py
@@ -0,0 +1,247 @@
+# The MIT License (MIT)
+#
+# Copyright (c) 2015 braindead
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+#
+# This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I
+# simply modified the interface to meet my needs.
+
+
+import numpy as np
+import math
+from scipy.special import expn
+from collections import namedtuple
+
+NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2")
+
+
+def profile_noise(noise, sampling_rate, window_size=0):
+ """
+ Creates a profile of the noise in a given waveform.
+
+ :param noise: a waveform containing noise ONLY, as a numpy array of floats or ints.
+ :param sampling_rate: the sampling rate of the audio
+ :param window_size: the size of the window the logmmse algorithm operates on. A default value
+ will be picked if left as 0.
+ :return: a NoiseProfile object
+ """
+ noise, dtype = to_float(noise)
+ noise += np.finfo(np.float64).eps
+
+ if window_size == 0:
+ window_size = int(math.floor(0.02 * sampling_rate))
+
+ if window_size % 2 == 1:
+ window_size = window_size + 1
+
+ perc = 50
+ len1 = int(math.floor(window_size * perc / 100))
+ len2 = int(window_size - len1)
+
+ win = np.hanning(window_size)
+ win = win * len2 / np.sum(win)
+ n_fft = 2 * window_size
+
+ noise_mean = np.zeros(n_fft)
+ n_frames = len(noise) // window_size
+ for j in range(0, window_size * n_frames, window_size):
+ noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0))
+ noise_mu2 = (noise_mean / n_frames) ** 2
+
+ return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2)
+
+
+def denoise(wav, noise_profile: NoiseProfile, eta=0.15):
+ """
+ Cleans the noise from a speech waveform given a noise profile. The waveform must have the
+ same sampling rate as the one used to create the noise profile.
+
+ :param wav: a speech waveform as a numpy array of floats or ints.
+ :param noise_profile: a NoiseProfile object that was created from a similar (or a segment of
+ the same) waveform.
+ :param eta: voice threshold for noise update. While the voice activation detection value is
+ below this threshold, the noise profile will be continuously updated throughout the audio.
+ Set to 0 to disable updating the noise profile.
+ :return: the clean wav as a numpy array of floats or ints of the same length.
+ """
+ wav, dtype = to_float(wav)
+ wav += np.finfo(np.float64).eps
+ p = noise_profile
+
+ nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2))
+ x_final = np.zeros(nframes * p.len2)
+
+ aa = 0.98
+ mu = 0.98
+ ksi_min = 10 ** (-25 / 10)
+
+ x_old = np.zeros(p.len1)
+ xk_prev = np.zeros(p.len1)
+ noise_mu2 = p.noise_mu2
+ for k in range(0, nframes * p.len2, p.len2):
+ insign = p.win * wav[k:k + p.window_size]
+
+ spec = np.fft.fft(insign, p.n_fft, axis=0)
+ sig = np.absolute(spec)
+ sig2 = sig ** 2
+
+ gammak = np.minimum(sig2 / noise_mu2, 40)
+
+ if xk_prev.all() == 0:
+ ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
+ else:
+ ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
+ ksi = np.maximum(ksi_min, ksi)
+
+ log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi)
+ vad_decision = np.sum(log_sigma_k) / p.window_size
+ if vad_decision < eta:
+ noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
+
+ a = ksi / (1 + ksi)
+ vk = a * gammak
+ ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
+ hw = a * np.exp(ei_vk)
+ sig = sig * hw
+ xk_prev = sig ** 2
+ xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0)
+ xi_w = np.real(xi_w)
+
+ x_final[k:k + p.len2] = x_old + xi_w[0:p.len1]
+ x_old = xi_w[p.len1:p.window_size]
+
+ output = from_float(x_final, dtype)
+ output = np.pad(output, (0, len(wav) - len(output)), mode="constant")
+ return output
+
+
+## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that
+## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of
+## webrctvad
+# def vad(wav, sampling_rate, eta=0.15, window_size=0):
+# """
+# TODO: fix doc
+# Creates a profile of the noise in a given waveform.
+#
+# :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints.
+# :param sampling_rate: the sampling rate of the audio
+# :param window_size: the size of the window the logmmse algorithm operates on. A default value
+# will be picked if left as 0.
+# :param eta: voice threshold for noise update. While the voice activation detection value is
+# below this threshold, the noise profile will be continuously updated throughout the audio.
+# Set to 0 to disable updating the noise profile.
+# """
+# wav, dtype = to_float(wav)
+# wav += np.finfo(np.float64).eps
+#
+# if window_size == 0:
+# window_size = int(math.floor(0.02 * sampling_rate))
+#
+# if window_size % 2 == 1:
+# window_size = window_size + 1
+#
+# perc = 50
+# len1 = int(math.floor(window_size * perc / 100))
+# len2 = int(window_size - len1)
+#
+# win = np.hanning(window_size)
+# win = win * len2 / np.sum(win)
+# n_fft = 2 * window_size
+#
+# wav_mean = np.zeros(n_fft)
+# n_frames = len(wav) // window_size
+# for j in range(0, window_size * n_frames, window_size):
+# wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0))
+# noise_mu2 = (wav_mean / n_frames) ** 2
+#
+# wav, dtype = to_float(wav)
+# wav += np.finfo(np.float64).eps
+#
+# nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2))
+# vad = np.zeros(nframes * len2, dtype=np.bool)
+#
+# aa = 0.98
+# mu = 0.98
+# ksi_min = 10 ** (-25 / 10)
+#
+# xk_prev = np.zeros(len1)
+# noise_mu2 = noise_mu2
+# for k in range(0, nframes * len2, len2):
+# insign = win * wav[k:k + window_size]
+#
+# spec = np.fft.fft(insign, n_fft, axis=0)
+# sig = np.absolute(spec)
+# sig2 = sig ** 2
+#
+# gammak = np.minimum(sig2 / noise_mu2, 40)
+#
+# if xk_prev.all() == 0:
+# ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
+# else:
+# ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
+# ksi = np.maximum(ksi_min, ksi)
+#
+# log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi)
+# vad_decision = np.sum(log_sigma_k) / window_size
+# if vad_decision < eta:
+# noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
+# print(vad_decision)
+#
+# a = ksi / (1 + ksi)
+# vk = a * gammak
+# ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
+# hw = a * np.exp(ei_vk)
+# sig = sig * hw
+# xk_prev = sig ** 2
+#
+# vad[k:k + len2] = vad_decision >= eta
+#
+# vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant")
+# return vad
+
+
+def to_float(_input):
+ if _input.dtype == np.float64:
+ return _input, _input.dtype
+ elif _input.dtype == np.float32:
+ return _input.astype(np.float64), _input.dtype
+ elif _input.dtype == np.uint8:
+ return (_input - 128) / 128., _input.dtype
+ elif _input.dtype == np.int16:
+ return _input / 32768., _input.dtype
+ elif _input.dtype == np.int32:
+ return _input / 2147483648., _input.dtype
+ raise ValueError('Unsupported wave file format')
+
+
+def from_float(_input, dtype):
+ if dtype == np.float64:
+ return _input, np.float64
+ elif dtype == np.float32:
+ return _input.astype(np.float32)
+ elif dtype == np.uint8:
+ return ((_input * 128) + 128).astype(np.uint8)
+ elif dtype == np.int16:
+ return (_input * 32768).astype(np.int16)
+ elif dtype == np.int32:
+ print(_input)
+ return (_input * 2147483648).astype(np.int32)
+ raise ValueError('Unsupported wave file format')
diff --git a/utils/profiler.py b/utils/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..d47de0b326de68358bb9bbdd356f18dfb4f6a8f0
--- /dev/null
+++ b/utils/profiler.py
@@ -0,0 +1,45 @@
+from time import perf_counter as timer
+from collections import OrderedDict
+import numpy as np
+
+
+class Profiler:
+ def __init__(self, summarize_every=5, disabled=False):
+ self.last_tick = timer()
+ self.logs = OrderedDict()
+ self.summarize_every = summarize_every
+ self.disabled = disabled
+
+ def tick(self, name):
+ if self.disabled:
+ return
+
+ # Log the time needed to execute that function
+ if not name in self.logs:
+ self.logs[name] = []
+ if len(self.logs[name]) >= self.summarize_every:
+ self.summarize()
+ self.purge_logs()
+ self.logs[name].append(timer() - self.last_tick)
+
+ self.reset_timer()
+
+ def purge_logs(self):
+ for name in self.logs:
+ self.logs[name].clear()
+
+ def reset_timer(self):
+ self.last_tick = timer()
+
+ def summarize(self):
+ n = max(map(len, self.logs.values()))
+ assert n == self.summarize_every
+ print("\nAverage execution time over %d steps:" % n)
+
+ name_msgs = ["%s (%d/%d):" % (name, len(deltas), n) for name, deltas in self.logs.items()]
+ pad = max(map(len, name_msgs))
+ for name_msg, deltas in zip(name_msgs, self.logs.values()):
+ print(" %s mean: %4.0fms std: %4.0fms" %
+ (name_msg.ljust(pad), np.mean(deltas) * 1000, np.std(deltas) * 1000))
+ print("", flush=True)
+
\ No newline at end of file
diff --git a/vocoder/LICENSE.txt b/vocoder/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4510117373ebe19b60a78c908381a6db51532672
--- /dev/null
+++ b/vocoder/LICENSE.txt
@@ -0,0 +1,22 @@
+MIT License
+
+Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
+Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/vocoder/audio.py b/vocoder/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..993544e909c5df731ef7b14bd031024c07f4e638
--- /dev/null
+++ b/vocoder/audio.py
@@ -0,0 +1,108 @@
+import math
+import numpy as np
+import librosa
+import vocoder.hparams as hp
+from scipy.signal import lfilter
+import soundfile as sf
+
+
+def label_2_float(x, bits) :
+ return 2 * x / (2**bits - 1.) - 1.
+
+
+def float_2_label(x, bits) :
+ assert abs(x).max() <= 1.0
+ x = (x + 1.) * (2**bits - 1) / 2
+ return x.clip(0, 2**bits - 1)
+
+
+def load_wav(path) :
+ return librosa.load(str(path), sr=hp.sample_rate)[0]
+
+
+def save_wav(x, path) :
+ sf.write(path, x.astype(np.float32), hp.sample_rate)
+
+
+def split_signal(x) :
+ unsigned = x + 2**15
+ coarse = unsigned // 256
+ fine = unsigned % 256
+ return coarse, fine
+
+
+def combine_signal(coarse, fine) :
+ return coarse * 256 + fine - 2**15
+
+
+def encode_16bits(x) :
+ return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
+
+
+mel_basis = None
+
+
+def linear_to_mel(spectrogram):
+ global mel_basis
+ if mel_basis is None:
+ mel_basis = build_mel_basis()
+ return np.dot(mel_basis, spectrogram)
+
+
+def build_mel_basis():
+ return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin)
+
+
+def normalize(S):
+ return np.clip((S - hp.min_level_db) / -hp.min_level_db, 0, 1)
+
+
+def denormalize(S):
+ return (np.clip(S, 0, 1) * -hp.min_level_db) + hp.min_level_db
+
+
+def amp_to_db(x):
+ return 20 * np.log10(np.maximum(1e-5, x))
+
+
+def db_to_amp(x):
+ return np.power(10.0, x * 0.05)
+
+
+def spectrogram(y):
+ D = stft(y)
+ S = amp_to_db(np.abs(D)) - hp.ref_level_db
+ return normalize(S)
+
+
+def melspectrogram(y):
+ D = stft(y)
+ S = amp_to_db(linear_to_mel(np.abs(D)))
+ return normalize(S)
+
+
+def stft(y):
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
+
+
+def pre_emphasis(x):
+ return lfilter([1, -hp.preemphasis], [1], x)
+
+
+def de_emphasis(x):
+ return lfilter([1], [1, -hp.preemphasis], x)
+
+
+def encode_mu_law(x, mu) :
+ mu = mu - 1
+ fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu)
+ return np.floor((fx + 1) / 2 * mu + 0.5)
+
+
+def decode_mu_law(y, mu, from_labels=True) :
+ if from_labels:
+ y = label_2_float(y, math.log2(mu))
+ mu = mu - 1
+ x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1)
+ return x
+
diff --git a/vocoder/display.py b/vocoder/display.py
new file mode 100644
index 0000000000000000000000000000000000000000..5044714333542fea8d8dea06ca72574226925358
--- /dev/null
+++ b/vocoder/display.py
@@ -0,0 +1,127 @@
+import time
+import numpy as np
+import sys
+
+
+def progbar(i, n, size=16):
+ done = (i * size) // n
+ bar = ''
+ for i in range(size):
+ bar += 'â–ˆ' if i <= done else 'â–‘'
+ return bar
+
+
+def stream(message) :
+ try:
+ sys.stdout.write("\r{%s}" % message)
+ except:
+ #Remove non-ASCII characters from message
+ message = ''.join(i for i in message if ord(i)<128)
+ sys.stdout.write("\r{%s}" % message)
+
+
+def simple_table(item_tuples) :
+
+ border_pattern = '+---------------------------------------'
+ whitespace = ' '
+
+ headings, cells, = [], []
+
+ for item in item_tuples :
+
+ heading, cell = str(item[0]), str(item[1])
+
+ pad_head = True if len(heading) < len(cell) else False
+
+ pad = abs(len(heading) - len(cell))
+ pad = whitespace[:pad]
+
+ pad_left = pad[:len(pad)//2]
+ pad_right = pad[len(pad)//2:]
+
+ if pad_head :
+ heading = pad_left + heading + pad_right
+ else :
+ cell = pad_left + cell + pad_right
+
+ headings += [heading]
+ cells += [cell]
+
+ border, head, body = '', '', ''
+
+ for i in range(len(item_tuples)) :
+
+ temp_head = f'| {headings[i]} '
+ temp_body = f'| {cells[i]} '
+
+ border += border_pattern[:len(temp_head)]
+ head += temp_head
+ body += temp_body
+
+ if i == len(item_tuples) - 1 :
+ head += '|'
+ body += '|'
+ border += '+'
+
+ print(border)
+ print(head)
+ print(border)
+ print(body)
+ print(border)
+ print(' ')
+
+
+def time_since(started) :
+ elapsed = time.time() - started
+ m = int(elapsed // 60)
+ s = int(elapsed % 60)
+ if m >= 60 :
+ h = int(m // 60)
+ m = m % 60
+ return f'{h}h {m}m {s}s'
+ else :
+ return f'{m}m {s}s'
+
+
+def save_attention(attn, path):
+ import matplotlib.pyplot as plt
+
+ fig = plt.figure(figsize=(12, 6))
+ plt.imshow(attn.T, interpolation='nearest', aspect='auto')
+ fig.savefig(f'{path}.png', bbox_inches='tight')
+ plt.close(fig)
+
+
+def save_spectrogram(M, path, length=None):
+ import matplotlib.pyplot as plt
+
+ M = np.flip(M, axis=0)
+ if length : M = M[:, :length]
+ fig = plt.figure(figsize=(12, 6))
+ plt.imshow(M, interpolation='nearest', aspect='auto')
+ fig.savefig(f'{path}.png', bbox_inches='tight')
+ plt.close(fig)
+
+
+def plot(array):
+ import matplotlib.pyplot as plt
+
+ fig = plt.figure(figsize=(30, 5))
+ ax = fig.add_subplot(111)
+ ax.xaxis.label.set_color('grey')
+ ax.yaxis.label.set_color('grey')
+ ax.xaxis.label.set_fontsize(23)
+ ax.yaxis.label.set_fontsize(23)
+ ax.tick_params(axis='x', colors='grey', labelsize=23)
+ ax.tick_params(axis='y', colors='grey', labelsize=23)
+ plt.plot(array)
+
+
+def plot_spec(M):
+ import matplotlib.pyplot as plt
+
+ M = np.flip(M, axis=0)
+ plt.figure(figsize=(18,4))
+ plt.imshow(M, interpolation='nearest', aspect='auto')
+ plt.show()
+
diff --git a/vocoder/distribution.py b/vocoder/distribution.py
new file mode 100644
index 0000000000000000000000000000000000000000..4863035672b323cc49536c065a3b32fbc6c8aeac
--- /dev/null
+++ b/vocoder/distribution.py
@@ -0,0 +1,132 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def log_sum_exp(x):
+ """ numerically stable log_sum_exp implementation that prevents overflow """
+ # TF ordering
+ axis = len(x.size()) - 1
+ m, _ = torch.max(x, dim=axis)
+ m2, _ = torch.max(x, dim=axis, keepdim=True)
+ return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
+
+
+# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
+def discretized_mix_logistic_loss(y_hat, y, num_classes=65536,
+ log_scale_min=None, reduce=True):
+ if log_scale_min is None:
+ log_scale_min = float(np.log(1e-14))
+ y_hat = y_hat.permute(0,2,1)
+ assert y_hat.dim() == 3
+ assert y_hat.size(1) % 3 == 0
+ nr_mix = y_hat.size(1) // 3
+
+ # (B x T x C)
+ y_hat = y_hat.transpose(1, 2)
+
+ # unpack parameters. (B, T, num_mixtures) x 3
+ logit_probs = y_hat[:, :, :nr_mix]
+ means = y_hat[:, :, nr_mix:2 * nr_mix]
+ log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
+
+ # B x T x 1 -> B x T x num_mixtures
+ y = y.expand_as(means)
+
+ centered_y = y - means
+ inv_stdv = torch.exp(-log_scales)
+ plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
+ cdf_plus = torch.sigmoid(plus_in)
+ min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
+ cdf_min = torch.sigmoid(min_in)
+
+ # log probability for edge case of 0 (before scaling)
+ # equivalent: torch.log(F.sigmoid(plus_in))
+ log_cdf_plus = plus_in - F.softplus(plus_in)
+
+ # log probability for edge case of 255 (before scaling)
+ # equivalent: (1 - F.sigmoid(min_in)).log()
+ log_one_minus_cdf_min = -F.softplus(min_in)
+
+ # probability for all other cases
+ cdf_delta = cdf_plus - cdf_min
+
+ mid_in = inv_stdv * centered_y
+ # log probability in the center of the bin, to be used in extreme cases
+ # (not actually used in our code)
+ log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
+
+ # tf equivalent
+ """
+ log_probs = tf.where(x < -0.999, log_cdf_plus,
+ tf.where(x > 0.999, log_one_minus_cdf_min,
+ tf.where(cdf_delta > 1e-5,
+ tf.log(tf.maximum(cdf_delta, 1e-12)),
+ log_pdf_mid - np.log(127.5))))
+ """
+ # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
+ # for num_classes=65536 case? 1e-7? not sure..
+ inner_inner_cond = (cdf_delta > 1e-5).float()
+
+ inner_inner_out = inner_inner_cond * \
+ torch.log(torch.clamp(cdf_delta, min=1e-12)) + \
+ (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
+ inner_cond = (y > 0.999).float()
+ inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
+ cond = (y < -0.999).float()
+ log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
+
+ log_probs = log_probs + F.log_softmax(logit_probs, -1)
+
+ if reduce:
+ return -torch.mean(log_sum_exp(log_probs))
+ else:
+ return -log_sum_exp(log_probs).unsqueeze(-1)
+
+
+def sample_from_discretized_mix_logistic(y, log_scale_min=None):
+ """
+ Sample from discretized mixture of logistic distributions
+ Args:
+ y (Tensor): B x C x T
+ log_scale_min (float): Log scale minimum value
+ Returns:
+ Tensor: sample in range of [-1, 1].
+ """
+ if log_scale_min is None:
+ log_scale_min = float(np.log(1e-14))
+ assert y.size(1) % 3 == 0
+ nr_mix = y.size(1) // 3
+
+ # B x T x C
+ y = y.transpose(1, 2)
+ logit_probs = y[:, :, :nr_mix]
+
+ # sample mixture indicator from softmax
+ temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
+ temp = logit_probs.data - torch.log(- torch.log(temp))
+ _, argmax = temp.max(dim=-1)
+
+ # (B, T) -> (B, T, nr_mix)
+ one_hot = to_one_hot(argmax, nr_mix)
+ # select logistic parameters
+ means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
+ log_scales = torch.clamp(torch.sum(
+ y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
+ # sample from logistic & clip to interval
+ # we don't actually round to the nearest 8bit value when sampling
+ u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
+ x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
+
+ x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
+
+ return x
+
+
+def to_one_hot(tensor, n, fill_with=1.):
+ # we perform one hot encore with respect to the last axis
+ one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
+ if tensor.is_cuda:
+ one_hot = one_hot.cuda()
+ one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
+ return one_hot
diff --git a/vocoder/gen_wavernn.py b/vocoder/gen_wavernn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5358160d3e90311438bc0aae5689a4aa1f6d887a
--- /dev/null
+++ b/vocoder/gen_wavernn.py
@@ -0,0 +1,31 @@
+from vocoder.models.fatchord_version import WaveRNN
+from vocoder.audio import *
+
+
+def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path):
+ k = model.get_step() // 1000
+
+ for i, (m, x) in enumerate(test_set, 1):
+ if i > samples:
+ break
+
+ print('\n| Generating: %i/%i' % (i, samples))
+
+ x = x[0].numpy()
+
+ bits = 16 if hp.voc_mode == 'MOL' else hp.bits
+
+ if hp.mu_law and hp.voc_mode != 'MOL' :
+ x = decode_mu_law(x, 2**bits, from_labels=True)
+ else :
+ x = label_2_float(x, bits)
+
+ save_wav(x, save_path.joinpath("%dk_steps_%d_target.wav" % (k, i)))
+
+ batch_str = "gen_batched_target%d_overlap%d" % (target, overlap) if batched else \
+ "gen_not_batched"
+ save_str = save_path.joinpath("%dk_steps_%d_%s.wav" % (k, i, batch_str))
+
+ wav = model.generate(m, batched, target, overlap, hp.mu_law)
+ save_wav(wav, save_str)
+
diff --git a/vocoder/hparams.py b/vocoder/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb2806c1605e1c297c67d08c5c39098d3de6310e
--- /dev/null
+++ b/vocoder/hparams.py
@@ -0,0 +1,44 @@
+from synthesizer.hparams import hparams as _syn_hp
+
+
+# Audio settings------------------------------------------------------------------------
+# Match the values of the synthesizer
+sample_rate = _syn_hp.sample_rate
+n_fft = _syn_hp.n_fft
+num_mels = _syn_hp.num_mels
+hop_length = _syn_hp.hop_size
+win_length = _syn_hp.win_size
+fmin = _syn_hp.fmin
+min_level_db = _syn_hp.min_level_db
+ref_level_db = _syn_hp.ref_level_db
+mel_max_abs_value = _syn_hp.max_abs_value
+preemphasis = _syn_hp.preemphasis
+apply_preemphasis = _syn_hp.preemphasize
+
+bits = 9 # bit depth of signal
+mu_law = True # Recommended to suppress noise if using raw bits in hp.voc_mode
+ # below
+
+
+# WAVERNN / VOCODER --------------------------------------------------------------------------------
+voc_mode = 'RAW' # either 'RAW' (softmax on raw bits) or 'MOL' (sample from
+# mixture of logistics)
+voc_upsample_factors = (5, 5, 8) # NB - this needs to correctly factorise hop_length
+voc_rnn_dims = 512
+voc_fc_dims = 512
+voc_compute_dims = 128
+voc_res_out_dims = 128
+voc_res_blocks = 10
+
+# Training
+voc_batch_size = 100
+voc_lr = 1e-4
+voc_gen_at_checkpoint = 5 # number of samples to generate at each checkpoint
+voc_pad = 2 # this will pad the input so that the resnet can 'see' wider
+ # than input length
+voc_seq_len = hop_length * 5 # must be a multiple of hop_length
+
+# Generating / Synthesizing
+voc_gen_batched = True # very fast (realtime+) single utterance batched generation
+voc_target = 8000 # target number of samples to be generated in each batch entry
+voc_overlap = 400 # number of samples for crossfading between batches
diff --git a/vocoder/inference.py b/vocoder/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..76d06ad45774d3fab635f9460b58df9476cb72aa
--- /dev/null
+++ b/vocoder/inference.py
@@ -0,0 +1,64 @@
+from vocoder.models.fatchord_version import WaveRNN
+from vocoder import hparams as hp
+import torch
+
+
+_model = None # type: WaveRNN
+
+def load_model(weights_fpath, verbose=True):
+ global _model, _device
+
+ if verbose:
+ print("Building Wave-RNN")
+ _model = WaveRNN(
+ rnn_dims=hp.voc_rnn_dims,
+ fc_dims=hp.voc_fc_dims,
+ bits=hp.bits,
+ pad=hp.voc_pad,
+ upsample_factors=hp.voc_upsample_factors,
+ feat_dims=hp.num_mels,
+ compute_dims=hp.voc_compute_dims,
+ res_out_dims=hp.voc_res_out_dims,
+ res_blocks=hp.voc_res_blocks,
+ hop_length=hp.hop_length,
+ sample_rate=hp.sample_rate,
+ mode=hp.voc_mode
+ )
+
+ if torch.cuda.is_available():
+ _model = _model.cuda()
+ _device = torch.device('cuda')
+ else:
+ _device = torch.device('cpu')
+
+ if verbose:
+ print("Loading model weights at %s" % weights_fpath)
+ checkpoint = torch.load(weights_fpath, _device)
+ _model.load_state_dict(checkpoint['model_state'])
+ _model.eval()
+
+
+def is_loaded():
+ return _model is not None
+
+
+def infer_waveform(mel, normalize=True, batched=True, target=8000, overlap=800,
+ progress_callback=None):
+ """
+ Infers the waveform of a mel spectrogram output by the synthesizer (the format must match
+ that of the synthesizer!)
+
+ :param normalize:
+ :param batched:
+ :param target:
+ :param overlap:
+ :return:
+ """
+ if _model is None:
+ raise Exception("Please load Wave-RNN in memory before using it")
+
+ if normalize:
+ mel = mel / hp.mel_max_abs_value
+ mel = torch.from_numpy(mel[None, ...])
+ wav = _model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback)
+ return wav
diff --git a/vocoder/models/__pycache__/fatchord_version.cpython-38.pyc b/vocoder/models/__pycache__/fatchord_version.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..497fa58e42605c720561396bca42960738fc371b
Binary files /dev/null and b/vocoder/models/__pycache__/fatchord_version.cpython-38.pyc differ
diff --git a/vocoder/models/deepmind_version.py b/vocoder/models/deepmind_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..b98157f1039a5f6ba9c7b5f3d9b988173f0da85e
--- /dev/null
+++ b/vocoder/models/deepmind_version.py
@@ -0,0 +1,170 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from utils.display import *
+from utils.dsp import *
+
+
+class WaveRNN(nn.Module) :
+ def __init__(self, hidden_size=896, quantisation=256) :
+ super(WaveRNN, self).__init__()
+
+ self.hidden_size = hidden_size
+ self.split_size = hidden_size // 2
+
+ # The main matmul
+ self.R = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
+
+ # Output fc layers
+ self.O1 = nn.Linear(self.split_size, self.split_size)
+ self.O2 = nn.Linear(self.split_size, quantisation)
+ self.O3 = nn.Linear(self.split_size, self.split_size)
+ self.O4 = nn.Linear(self.split_size, quantisation)
+
+ # Input fc layers
+ self.I_coarse = nn.Linear(2, 3 * self.split_size, bias=False)
+ self.I_fine = nn.Linear(3, 3 * self.split_size, bias=False)
+
+ # biases for the gates
+ self.bias_u = nn.Parameter(torch.zeros(self.hidden_size))
+ self.bias_r = nn.Parameter(torch.zeros(self.hidden_size))
+ self.bias_e = nn.Parameter(torch.zeros(self.hidden_size))
+
+ # display num params
+ self.num_params()
+
+
+ def forward(self, prev_y, prev_hidden, current_coarse) :
+
+ # Main matmul - the projection is split 3 ways
+ R_hidden = self.R(prev_hidden)
+ R_u, R_r, R_e, = torch.split(R_hidden, self.hidden_size, dim=1)
+
+ # Project the prev input
+ coarse_input_proj = self.I_coarse(prev_y)
+ I_coarse_u, I_coarse_r, I_coarse_e = \
+ torch.split(coarse_input_proj, self.split_size, dim=1)
+
+ # Project the prev input and current coarse sample
+ fine_input = torch.cat([prev_y, current_coarse], dim=1)
+ fine_input_proj = self.I_fine(fine_input)
+ I_fine_u, I_fine_r, I_fine_e = \
+ torch.split(fine_input_proj, self.split_size, dim=1)
+
+ # concatenate for the gates
+ I_u = torch.cat([I_coarse_u, I_fine_u], dim=1)
+ I_r = torch.cat([I_coarse_r, I_fine_r], dim=1)
+ I_e = torch.cat([I_coarse_e, I_fine_e], dim=1)
+
+ # Compute all gates for coarse and fine
+ u = F.sigmoid(R_u + I_u + self.bias_u)
+ r = F.sigmoid(R_r + I_r + self.bias_r)
+ e = F.tanh(r * R_e + I_e + self.bias_e)
+ hidden = u * prev_hidden + (1. - u) * e
+
+ # Split the hidden state
+ hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1)
+
+ # Compute outputs
+ out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
+ out_fine = self.O4(F.relu(self.O3(hidden_fine)))
+
+ return out_coarse, out_fine, hidden
+
+
+ def generate(self, seq_len):
+ with torch.no_grad():
+ # First split up the biases for the gates
+ b_coarse_u, b_fine_u = torch.split(self.bias_u, self.split_size)
+ b_coarse_r, b_fine_r = torch.split(self.bias_r, self.split_size)
+ b_coarse_e, b_fine_e = torch.split(self.bias_e, self.split_size)
+
+ # Lists for the two output seqs
+ c_outputs, f_outputs = [], []
+
+ # Some initial inputs
+ out_coarse = torch.LongTensor([0]).cuda()
+ out_fine = torch.LongTensor([0]).cuda()
+
+ # We'll meed a hidden state
+ hidden = self.init_hidden()
+
+ # Need a clock for display
+ start = time.time()
+
+ # Loop for generation
+ for i in range(seq_len) :
+
+ # Split into two hidden states
+ hidden_coarse, hidden_fine = \
+ torch.split(hidden, self.split_size, dim=1)
+
+ # Scale and concat previous predictions
+ out_coarse = out_coarse.unsqueeze(0).float() / 127.5 - 1.
+ out_fine = out_fine.unsqueeze(0).float() / 127.5 - 1.
+ prev_outputs = torch.cat([out_coarse, out_fine], dim=1)
+
+ # Project input
+ coarse_input_proj = self.I_coarse(prev_outputs)
+ I_coarse_u, I_coarse_r, I_coarse_e = \
+ torch.split(coarse_input_proj, self.split_size, dim=1)
+
+ # Project hidden state and split 6 ways
+ R_hidden = self.R(hidden)
+ R_coarse_u , R_fine_u, \
+ R_coarse_r, R_fine_r, \
+ R_coarse_e, R_fine_e = torch.split(R_hidden, self.split_size, dim=1)
+
+ # Compute the coarse gates
+ u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u)
+ r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r)
+ e = F.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
+ hidden_coarse = u * hidden_coarse + (1. - u) * e
+
+ # Compute the coarse output
+ out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
+ posterior = F.softmax(out_coarse, dim=1)
+ distrib = torch.distributions.Categorical(posterior)
+ out_coarse = distrib.sample()
+ c_outputs.append(out_coarse)
+
+ # Project the [prev outputs and predicted coarse sample]
+ coarse_pred = out_coarse.float() / 127.5 - 1.
+ fine_input = torch.cat([prev_outputs, coarse_pred.unsqueeze(0)], dim=1)
+ fine_input_proj = self.I_fine(fine_input)
+ I_fine_u, I_fine_r, I_fine_e = \
+ torch.split(fine_input_proj, self.split_size, dim=1)
+
+ # Compute the fine gates
+ u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u)
+ r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r)
+ e = F.tanh(r * R_fine_e + I_fine_e + b_fine_e)
+ hidden_fine = u * hidden_fine + (1. - u) * e
+
+ # Compute the fine output
+ out_fine = self.O4(F.relu(self.O3(hidden_fine)))
+ posterior = F.softmax(out_fine, dim=1)
+ distrib = torch.distributions.Categorical(posterior)
+ out_fine = distrib.sample()
+ f_outputs.append(out_fine)
+
+ # Put the hidden state back together
+ hidden = torch.cat([hidden_coarse, hidden_fine], dim=1)
+
+ # Display progress
+ speed = (i + 1) / (time.time() - start)
+ stream('Gen: %i/%i -- Speed: %i', (i + 1, seq_len, speed))
+
+ coarse = torch.stack(c_outputs).squeeze(1).cpu().data.numpy()
+ fine = torch.stack(f_outputs).squeeze(1).cpu().data.numpy()
+ output = combine_signal(coarse, fine)
+
+ return output, coarse, fine
+
+ def init_hidden(self, batch_size=1) :
+ return torch.zeros(batch_size, self.hidden_size).cuda()
+
+ def num_params(self) :
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+ print('Trainable Parameters: %.3f million' % parameters)
\ No newline at end of file
diff --git a/vocoder/models/fatchord_version.py b/vocoder/models/fatchord_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dd95645b73de12194c1d00e0c82ddb4c29ecc86
--- /dev/null
+++ b/vocoder/models/fatchord_version.py
@@ -0,0 +1,434 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from vocoder.distribution import sample_from_discretized_mix_logistic
+from vocoder.display import *
+from vocoder.audio import *
+
+
+class ResBlock(nn.Module):
+ def __init__(self, dims):
+ super().__init__()
+ self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
+ self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
+ self.batch_norm1 = nn.BatchNorm1d(dims)
+ self.batch_norm2 = nn.BatchNorm1d(dims)
+
+ def forward(self, x):
+ residual = x
+ x = self.conv1(x)
+ x = self.batch_norm1(x)
+ x = F.relu(x)
+ x = self.conv2(x)
+ x = self.batch_norm2(x)
+ return x + residual
+
+
+class MelResNet(nn.Module):
+ def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
+ super().__init__()
+ k_size = pad * 2 + 1
+ self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
+ self.batch_norm = nn.BatchNorm1d(compute_dims)
+ self.layers = nn.ModuleList()
+ for i in range(res_blocks):
+ self.layers.append(ResBlock(compute_dims))
+ self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ x = self.batch_norm(x)
+ x = F.relu(x)
+ for f in self.layers: x = f(x)
+ x = self.conv_out(x)
+ return x
+
+
+class Stretch2d(nn.Module):
+ def __init__(self, x_scale, y_scale):
+ super().__init__()
+ self.x_scale = x_scale
+ self.y_scale = y_scale
+
+ def forward(self, x):
+ b, c, h, w = x.size()
+ x = x.unsqueeze(-1).unsqueeze(3)
+ x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
+ return x.view(b, c, h * self.y_scale, w * self.x_scale)
+
+
+class UpsampleNetwork(nn.Module):
+ def __init__(self, feat_dims, upsample_scales, compute_dims,
+ res_blocks, res_out_dims, pad):
+ super().__init__()
+ total_scale = np.cumproduct(upsample_scales)[-1]
+ self.indent = pad * total_scale
+ self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad)
+ self.resnet_stretch = Stretch2d(total_scale, 1)
+ self.up_layers = nn.ModuleList()
+ for scale in upsample_scales:
+ k_size = (1, scale * 2 + 1)
+ padding = (0, scale)
+ stretch = Stretch2d(scale, 1)
+ conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
+ conv.weight.data.fill_(1. / k_size[1])
+ self.up_layers.append(stretch)
+ self.up_layers.append(conv)
+
+ def forward(self, m):
+ aux = self.resnet(m).unsqueeze(1)
+ aux = self.resnet_stretch(aux)
+ aux = aux.squeeze(1)
+ m = m.unsqueeze(1)
+ for f in self.up_layers: m = f(m)
+ m = m.squeeze(1)[:, :, self.indent:-self.indent]
+ return m.transpose(1, 2), aux.transpose(1, 2)
+
+
+class WaveRNN(nn.Module):
+ def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors,
+ feat_dims, compute_dims, res_out_dims, res_blocks,
+ hop_length, sample_rate, mode='RAW'):
+ super().__init__()
+ self.mode = mode
+ self.pad = pad
+ if self.mode == 'RAW' :
+ self.n_classes = 2 ** bits
+ elif self.mode == 'MOL' :
+ self.n_classes = 30
+ else :
+ RuntimeError("Unknown model mode value - ", self.mode)
+
+ self.rnn_dims = rnn_dims
+ self.aux_dims = res_out_dims // 4
+ self.hop_length = hop_length
+ self.sample_rate = sample_rate
+
+ self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad)
+ self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
+ self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
+ self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
+ self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
+ self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
+ self.fc3 = nn.Linear(fc_dims, self.n_classes)
+
+ self.step = nn.Parameter(torch.zeros(1).long(), requires_grad=False)
+ self.num_params()
+
+ def forward(self, x, mels):
+ self.step += 1
+ bsize = x.size(0)
+ if torch.cuda.is_available():
+ h1 = torch.zeros(1, bsize, self.rnn_dims).cuda()
+ h2 = torch.zeros(1, bsize, self.rnn_dims).cuda()
+ else:
+ h1 = torch.zeros(1, bsize, self.rnn_dims).cpu()
+ h2 = torch.zeros(1, bsize, self.rnn_dims).cpu()
+ mels, aux = self.upsample(mels)
+
+ aux_idx = [self.aux_dims * i for i in range(5)]
+ a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
+ a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
+ a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
+ a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
+
+ x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
+ x = self.I(x)
+ res = x
+ x, _ = self.rnn1(x, h1)
+
+ x = x + res
+ res = x
+ x = torch.cat([x, a2], dim=2)
+ x, _ = self.rnn2(x, h2)
+
+ x = x + res
+ x = torch.cat([x, a3], dim=2)
+ x = F.relu(self.fc1(x))
+
+ x = torch.cat([x, a4], dim=2)
+ x = F.relu(self.fc2(x))
+ return self.fc3(x)
+
+ def generate(self, mels, batched, target, overlap, mu_law, progress_callback=None):
+ mu_law = mu_law if self.mode == 'RAW' else False
+ progress_callback = progress_callback or self.gen_display
+
+ self.eval()
+ output = []
+ start = time.time()
+ rnn1 = self.get_gru_cell(self.rnn1)
+ rnn2 = self.get_gru_cell(self.rnn2)
+
+ with torch.no_grad():
+ if torch.cuda.is_available():
+ mels = mels.cuda()
+ else:
+ mels = mels.cpu()
+ wave_len = (mels.size(-1) - 1) * self.hop_length
+ mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side='both')
+ mels, aux = self.upsample(mels.transpose(1, 2))
+
+ if batched:
+ mels = self.fold_with_overlap(mels, target, overlap)
+ aux = self.fold_with_overlap(aux, target, overlap)
+
+ b_size, seq_len, _ = mels.size()
+
+ if torch.cuda.is_available():
+ h1 = torch.zeros(b_size, self.rnn_dims).cuda()
+ h2 = torch.zeros(b_size, self.rnn_dims).cuda()
+ x = torch.zeros(b_size, 1).cuda()
+ else:
+ h1 = torch.zeros(b_size, self.rnn_dims).cpu()
+ h2 = torch.zeros(b_size, self.rnn_dims).cpu()
+ x = torch.zeros(b_size, 1).cpu()
+
+ d = self.aux_dims
+ aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)]
+
+ for i in range(seq_len):
+
+ m_t = mels[:, i, :]
+
+ a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
+
+ x = torch.cat([x, m_t, a1_t], dim=1)
+ x = self.I(x)
+ h1 = rnn1(x, h1)
+
+ x = x + h1
+ inp = torch.cat([x, a2_t], dim=1)
+ h2 = rnn2(inp, h2)
+
+ x = x + h2
+ x = torch.cat([x, a3_t], dim=1)
+ x = F.relu(self.fc1(x))
+
+ x = torch.cat([x, a4_t], dim=1)
+ x = F.relu(self.fc2(x))
+
+ logits = self.fc3(x)
+
+ if self.mode == 'MOL':
+ sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
+ output.append(sample.view(-1))
+ if torch.cuda.is_available():
+ # x = torch.FloatTensor([[sample]]).cuda()
+ x = sample.transpose(0, 1).cuda()
+ else:
+ x = sample.transpose(0, 1)
+
+ elif self.mode == 'RAW' :
+ posterior = F.softmax(logits, dim=1)
+ distrib = torch.distributions.Categorical(posterior)
+
+ sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1.
+ output.append(sample)
+ x = sample.unsqueeze(-1)
+ else:
+ raise RuntimeError("Unknown model mode value - ", self.mode)
+
+ if i % 100 == 0:
+ gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
+ progress_callback(i, seq_len, b_size, gen_rate)
+
+ output = torch.stack(output).transpose(0, 1)
+ output = output.cpu().numpy()
+ output = output.astype(np.float64)
+
+ if batched:
+ output = self.xfade_and_unfold(output, target, overlap)
+ else:
+ output = output[0]
+
+ if mu_law:
+ output = decode_mu_law(output, self.n_classes, False)
+ if hp.apply_preemphasis:
+ output = de_emphasis(output)
+
+ # Fade-out at the end to avoid signal cutting out suddenly
+ fade_out = np.linspace(1, 0, 20 * self.hop_length)
+ output = output[:wave_len]
+ output[-20 * self.hop_length:] *= fade_out
+
+ self.train()
+
+ return output
+
+
+ def gen_display(self, i, seq_len, b_size, gen_rate):
+ pbar = progbar(i, seq_len)
+ msg = f'| {pbar} {i*b_size}/{seq_len*b_size} | Batch Size: {b_size} | Gen Rate: {gen_rate:.1f}kHz | '
+ stream(msg)
+
+ def get_gru_cell(self, gru):
+ gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
+ gru_cell.weight_hh.data = gru.weight_hh_l0.data
+ gru_cell.weight_ih.data = gru.weight_ih_l0.data
+ gru_cell.bias_hh.data = gru.bias_hh_l0.data
+ gru_cell.bias_ih.data = gru.bias_ih_l0.data
+ return gru_cell
+
+ def pad_tensor(self, x, pad, side='both'):
+ # NB - this is just a quick method i need right now
+ # i.e., it won't generalise to other shapes/dims
+ b, t, c = x.size()
+ total = t + 2 * pad if side == 'both' else t + pad
+ if torch.cuda.is_available():
+ padded = torch.zeros(b, total, c).cuda()
+ else:
+ padded = torch.zeros(b, total, c).cpu()
+ if side == 'before' or side == 'both':
+ padded[:, pad:pad + t, :] = x
+ elif side == 'after':
+ padded[:, :t, :] = x
+ return padded
+
+ def fold_with_overlap(self, x, target, overlap):
+
+ ''' Fold the tensor with overlap for quick batched inference.
+ Overlap will be used for crossfading in xfade_and_unfold()
+
+ Args:
+ x (tensor) : Upsampled conditioning features.
+ shape=(1, timesteps, features)
+ target (int) : Target timesteps for each index of batch
+ overlap (int) : Timesteps for both xfade and rnn warmup
+
+ Return:
+ (tensor) : shape=(num_folds, target + 2 * overlap, features)
+
+ Details:
+ x = [[h1, h2, ... hn]]
+
+ Where each h is a vector of conditioning features
+
+ Eg: target=2, overlap=1 with x.size(1)=10
+
+ folded = [[h1, h2, h3, h4],
+ [h4, h5, h6, h7],
+ [h7, h8, h9, h10]]
+ '''
+
+ _, total_len, features = x.size()
+
+ # Calculate variables needed
+ num_folds = (total_len - overlap) // (target + overlap)
+ extended_len = num_folds * (overlap + target) + overlap
+ remaining = total_len - extended_len
+
+ # Pad if some time steps poking out
+ if remaining != 0:
+ num_folds += 1
+ padding = target + 2 * overlap - remaining
+ x = self.pad_tensor(x, padding, side='after')
+
+ if torch.cuda.is_available():
+ folded = torch.zeros(num_folds, target + 2 * overlap, features).cuda()
+ else:
+ folded = torch.zeros(num_folds, target + 2 * overlap, features).cpu()
+
+ # Get the values for the folded tensor
+ for i in range(num_folds):
+ start = i * (target + overlap)
+ end = start + target + 2 * overlap
+ folded[i] = x[:, start:end, :]
+
+ return folded
+
+ def xfade_and_unfold(self, y, target, overlap):
+
+ ''' Applies a crossfade and unfolds into a 1d array.
+
+ Args:
+ y (ndarry) : Batched sequences of audio samples
+ shape=(num_folds, target + 2 * overlap)
+ dtype=np.float64
+ overlap (int) : Timesteps for both xfade and rnn warmup
+
+ Return:
+ (ndarry) : audio samples in a 1d array
+ shape=(total_len)
+ dtype=np.float64
+
+ Details:
+ y = [[seq1],
+ [seq2],
+ [seq3]]
+
+ Apply a gain envelope at both ends of the sequences
+
+ y = [[seq1_in, seq1_target, seq1_out],
+ [seq2_in, seq2_target, seq2_out],
+ [seq3_in, seq3_target, seq3_out]]
+
+ Stagger and add up the groups of samples:
+
+ [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
+
+ '''
+
+ num_folds, length = y.shape
+ target = length - 2 * overlap
+ total_len = num_folds * (target + overlap) + overlap
+
+ # Need some silence for the rnn warmup
+ silence_len = overlap // 2
+ fade_len = overlap - silence_len
+ silence = np.zeros((silence_len), dtype=np.float64)
+
+ # Equal power crossfade
+ t = np.linspace(-1, 1, fade_len, dtype=np.float64)
+ fade_in = np.sqrt(0.5 * (1 + t))
+ fade_out = np.sqrt(0.5 * (1 - t))
+
+ # Concat the silence to the fades
+ fade_in = np.concatenate([silence, fade_in])
+ fade_out = np.concatenate([fade_out, silence])
+
+ # Apply the gain to the overlap samples
+ y[:, :overlap] *= fade_in
+ y[:, -overlap:] *= fade_out
+
+ unfolded = np.zeros((total_len), dtype=np.float64)
+
+ # Loop to add up all the samples
+ for i in range(num_folds):
+ start = i * (target + overlap)
+ end = start + target + 2 * overlap
+ unfolded[start:end] += y[i]
+
+ return unfolded
+
+ def get_step(self) :
+ return self.step.data.item()
+
+ def checkpoint(self, model_dir, optimizer) :
+ k_steps = self.get_step() // 1000
+ self.save(model_dir.joinpath("checkpoint_%dk_steps.pt" % k_steps), optimizer)
+
+ def log(self, path, msg) :
+ with open(path, 'a') as f:
+ print(msg, file=f)
+
+ def load(self, path, optimizer) :
+ checkpoint = torch.load(path)
+ if "optimizer_state" in checkpoint:
+ self.load_state_dict(checkpoint["model_state"])
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
+ else:
+ # Backwards compatibility
+ self.load_state_dict(checkpoint)
+
+ def save(self, path, optimizer) :
+ torch.save({
+ "model_state": self.state_dict(),
+ "optimizer_state": optimizer.state_dict(),
+ }, path)
+
+ def num_params(self, print_out=True):
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+ if print_out :
+ print('Trainable Parameters: %.3fM' % parameters)
diff --git a/vocoder/train.py b/vocoder/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c517ecc36c0dd4b8af257d3bda0e77da155686a
--- /dev/null
+++ b/vocoder/train.py
@@ -0,0 +1,118 @@
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import optim
+from torch.utils.data import DataLoader
+
+import vocoder.hparams as hp
+from vocoder.display import stream, simple_table
+from vocoder.distribution import discretized_mix_logistic_loss
+from vocoder.gen_wavernn import gen_testset
+from vocoder.models.fatchord_version import WaveRNN
+from vocoder.vocoder_dataset import VocoderDataset, collate_vocoder
+
+
+def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool, save_every: int,
+ backup_every: int, force_restart: bool):
+ # Check to make sure the hop length is correctly factorised
+ assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length
+
+ # Instantiate the model
+ print("Initializing the model...")
+ model = WaveRNN(
+ rnn_dims=hp.voc_rnn_dims,
+ fc_dims=hp.voc_fc_dims,
+ bits=hp.bits,
+ pad=hp.voc_pad,
+ upsample_factors=hp.voc_upsample_factors,
+ feat_dims=hp.num_mels,
+ compute_dims=hp.voc_compute_dims,
+ res_out_dims=hp.voc_res_out_dims,
+ res_blocks=hp.voc_res_blocks,
+ hop_length=hp.hop_length,
+ sample_rate=hp.sample_rate,
+ mode=hp.voc_mode
+ )
+
+ if torch.cuda.is_available():
+ model = model.cuda()
+
+ # Initialize the optimizer
+ optimizer = optim.Adam(model.parameters())
+ for p in optimizer.param_groups:
+ p["lr"] = hp.voc_lr
+ loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss
+
+ # Load the weights
+ model_dir = models_dir / run_id
+ model_dir.mkdir(exist_ok=True)
+ weights_fpath = model_dir / "vocoder.pt"
+ if force_restart or not weights_fpath.exists():
+ print("\nStarting the training of WaveRNN from scratch\n")
+ model.save(weights_fpath, optimizer)
+ else:
+ print("\nLoading weights at %s" % weights_fpath)
+ model.load(weights_fpath, optimizer)
+ print("WaveRNN weights loaded from step %d" % model.step)
+
+ # Initialize the dataset
+ metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \
+ voc_dir.joinpath("synthesized.txt")
+ mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath("mels_gta")
+ wav_dir = syn_dir.joinpath("audio")
+ dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir)
+ test_loader = DataLoader(dataset, batch_size=1, shuffle=True)
+
+ # Begin the training
+ simple_table([('Batch size', hp.voc_batch_size),
+ ('LR', hp.voc_lr),
+ ('Sequence Len', hp.voc_seq_len)])
+
+ for epoch in range(1, 350):
+ data_loader = DataLoader(dataset, hp.voc_batch_size, shuffle=True, num_workers=2, collate_fn=collate_vocoder)
+ start = time.time()
+ running_loss = 0.
+
+ for i, (x, y, m) in enumerate(data_loader, 1):
+ if torch.cuda.is_available():
+ x, m, y = x.cuda(), m.cuda(), y.cuda()
+
+ # Forward pass
+ y_hat = model(x, m)
+ if model.mode == 'RAW':
+ y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
+ elif model.mode == 'MOL':
+ y = y.float()
+ y = y.unsqueeze(-1)
+
+ # Backward pass
+ loss = loss_func(y_hat, y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ running_loss += loss.item()
+ speed = i / (time.time() - start)
+ avg_loss = running_loss / i
+
+ step = model.get_step()
+ k = step // 1000
+
+ if backup_every != 0 and step % backup_every == 0 :
+ model.checkpoint(model_dir, optimizer)
+
+ if save_every != 0 and step % save_every == 0 :
+ model.save(weights_fpath, optimizer)
+
+ msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \
+ f"Loss: {avg_loss:.4f} | {speed:.1f} " \
+ f"steps/s | Step: {k}k | "
+ stream(msg)
+
+
+ gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
+ hp.voc_target, hp.voc_overlap, model_dir)
+ print("")
diff --git a/vocoder/vocoder_dataset.py b/vocoder/vocoder_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f5561a9f7f9a592ab7b4e05d194cfc22e5d6e3b
--- /dev/null
+++ b/vocoder/vocoder_dataset.py
@@ -0,0 +1,84 @@
+from torch.utils.data import Dataset
+from pathlib import Path
+from vocoder import audio
+import vocoder.hparams as hp
+import numpy as np
+import torch
+
+
+class VocoderDataset(Dataset):
+ def __init__(self, metadata_fpath: Path, mel_dir: Path, wav_dir: Path):
+ print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, wav_dir))
+
+ with metadata_fpath.open("r") as metadata_file:
+ metadata = [line.split("|") for line in metadata_file]
+
+ gta_fnames = [x[1] for x in metadata if int(x[4])]
+ gta_fpaths = [mel_dir.joinpath(fname) for fname in gta_fnames]
+ wav_fnames = [x[0] for x in metadata if int(x[4])]
+ wav_fpaths = [wav_dir.joinpath(fname) for fname in wav_fnames]
+ self.samples_fpaths = list(zip(gta_fpaths, wav_fpaths))
+
+ print("Found %d samples" % len(self.samples_fpaths))
+
+ def __getitem__(self, index):
+ mel_path, wav_path = self.samples_fpaths[index]
+
+ # Load the mel spectrogram and adjust its range to [-1, 1]
+ mel = np.load(mel_path).T.astype(np.float32) / hp.mel_max_abs_value
+
+ # Load the wav
+ wav = np.load(wav_path)
+ if hp.apply_preemphasis:
+ wav = audio.pre_emphasis(wav)
+ wav = np.clip(wav, -1, 1)
+
+ # Fix for missing padding # TODO: settle on whether this is any useful
+ r_pad = (len(wav) // hp.hop_length + 1) * hp.hop_length - len(wav)
+ wav = np.pad(wav, (0, r_pad), mode='constant')
+ assert len(wav) >= mel.shape[1] * hp.hop_length
+ wav = wav[:mel.shape[1] * hp.hop_length]
+ assert len(wav) % hp.hop_length == 0
+
+ # Quantize the wav
+ if hp.voc_mode == 'RAW':
+ if hp.mu_law:
+ quant = audio.encode_mu_law(wav, mu=2 ** hp.bits)
+ else:
+ quant = audio.float_2_label(wav, bits=hp.bits)
+ elif hp.voc_mode == 'MOL':
+ quant = audio.float_2_label(wav, bits=16)
+
+ return mel.astype(np.float32), quant.astype(np.int64)
+
+ def __len__(self):
+ return len(self.samples_fpaths)
+
+
+def collate_vocoder(batch):
+ mel_win = hp.voc_seq_len // hp.hop_length + 2 * hp.voc_pad
+ max_offsets = [x[0].shape[-1] -2 - (mel_win + 2 * hp.voc_pad) for x in batch]
+ mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
+ sig_offsets = [(offset + hp.voc_pad) * hp.hop_length for offset in mel_offsets]
+
+ mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] for i, x in enumerate(batch)]
+
+ labels = [x[1][sig_offsets[i]:sig_offsets[i] + hp.voc_seq_len + 1] for i, x in enumerate(batch)]
+
+ mels = np.stack(mels).astype(np.float32)
+ labels = np.stack(labels).astype(np.int64)
+
+ mels = torch.tensor(mels)
+ labels = torch.tensor(labels).long()
+
+ x = labels[:, :hp.voc_seq_len]
+ y = labels[:, 1:]
+
+ bits = 16 if hp.voc_mode == 'MOL' else hp.bits
+
+ x = audio.label_2_float(x.float(), bits)
+
+ if hp.voc_mode == 'MOL' :
+ y = audio.label_2_float(y.float(), bits)
+
+ return x, y, mels
\ No newline at end of file