diff --git a/8230_00000.mp3 b/8230_00000.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..b7c5620095db916a11b06864b3bcfc3df3611f21 Binary files /dev/null and b/8230_00000.mp3 differ diff --git a/demo_cli.py b/demo_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..d43f04d729a2bae22ecd1b1878cae59a2fe5ffad --- /dev/null +++ b/demo_cli.py @@ -0,0 +1,225 @@ +from encoder.params_model import model_embedding_size as speaker_embedding_size +from utils.argutils import print_args +from utils.modelutils import check_model_paths +from synthesizer.inference import Synthesizer +from encoder import inference as encoder +from vocoder import inference as vocoder +from pathlib import Path +import numpy as np +import soundfile as sf +import librosa +import argparse +import torch +import sys +import os +from audioread.exceptions import NoBackendError + +if __name__ == '__main__': + ## Info & args + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("-e", "--enc_model_fpath", type=Path, + default="encoder/saved_models/pretrained.pt", + help="Path to a saved encoder") + parser.add_argument("-s", "--syn_model_fpath", type=Path, + default="synthesizer/saved_models/pretrained/pretrained.pt", + help="Path to a saved synthesizer") + parser.add_argument("-v", "--voc_model_fpath", type=Path, + default="vocoder/saved_models/pretrained/pretrained.pt", + help="Path to a saved vocoder") + parser.add_argument("--cpu", action="store_true", help=\ + "If True, processing is done on CPU, even when a GPU is available.") + parser.add_argument("--no_sound", action="store_true", help=\ + "If True, audio won't be played.") + parser.add_argument("--seed", type=int, default=None, help=\ + "Optional random number seed value to make toolbox deterministic.") + parser.add_argument("--no_mp3_support", action="store_true", help=\ + "If True, disallows loading mp3 files to prevent audioread errors when ffmpeg is not installed.") + args = parser.parse_args() + print_args(args, parser) + if not args.no_sound: + import sounddevice as sd + + if args.cpu: + # Hide GPUs from Pytorch to force CPU processing + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + if not args.no_mp3_support: + try: + librosa.load("samples/1320_00000.mp3") + except NoBackendError: + print("Librosa will be unable to open mp3 files if additional software is not installed.\n" + "Please install ffmpeg or add the '--no_mp3_support' option to proceed without support for mp3 files.") + exit(-1) + + print("Running a test of your configuration...\n") + + if torch.cuda.is_available(): + device_id = torch.cuda.current_device() + gpu_properties = torch.cuda.get_device_properties(device_id) + ## Print some environment information (for debugging purposes) + print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with " + "%.1fGb total memory.\n" % + (torch.cuda.device_count(), + device_id, + gpu_properties.name, + gpu_properties.major, + gpu_properties.minor, + gpu_properties.total_memory / 1e9)) + else: + print("Using CPU for inference.\n") + + ## Remind the user to download pretrained models if needed + check_model_paths(encoder_path=args.enc_model_fpath, + synthesizer_path=args.syn_model_fpath, + vocoder_path=args.voc_model_fpath) + + ## Load the models one by one. + print("Preparing the encoder, the synthesizer and the vocoder...") + encoder.load_model(args.enc_model_fpath) + synthesizer = Synthesizer(args.syn_model_fpath) + vocoder.load_model(args.voc_model_fpath) + + + ## Run a test + print("Testing your configuration with small inputs.") + # Forward an audio waveform of zeroes that lasts 1 second. Notice how we can get the encoder's + # sampling rate, which may differ. + # If you're unfamiliar with digital audio, know that it is encoded as an array of floats + # (or sometimes integers, but mostly floats in this projects) ranging from -1 to 1. + # The sampling rate is the number of values (samples) recorded per second, it is set to + # 16000 for the encoder. Creating an array of length will always correspond + # to an audio of 1 second. + print("\tTesting the encoder...") + encoder.embed_utterance(np.zeros(encoder.sampling_rate)) + + # Create a dummy embedding. You would normally use the embedding that encoder.embed_utterance + # returns, but here we're going to make one ourselves just for the sake of showing that it's + # possible. + embed = np.random.rand(speaker_embedding_size) + # Embeddings are L2-normalized (this isn't important here, but if you want to make your own + # embeddings it will be). + embed /= np.linalg.norm(embed) + # The synthesizer can handle multiple inputs with batching. Let's create another embedding to + # illustrate that + embeds = [embed, np.zeros(speaker_embedding_size)] + texts = ["test 1", "test 2"] + print("\tTesting the synthesizer... (loading the model will output a lot of text)") + mels = synthesizer.synthesize_spectrograms(texts, embeds) + + # The vocoder synthesizes one waveform at a time, but it's more efficient for long ones. We + # can concatenate the mel spectrograms to a single one. + mel = np.concatenate(mels, axis=1) + # The vocoder can take a callback function to display the generation. More on that later. For + # now we'll simply hide it like this: + no_action = lambda *args: None + print("\tTesting the vocoder...") + # For the sake of making this test short, we'll pass a short target length. The target length + # is the length of the wav segments that are processed in parallel. E.g. for audio sampled + # at 16000 Hertz, a target length of 8000 means that the target audio will be cut in chunks of + # 0.5 seconds which will all be generated together. The parameters here are absurdly short, and + # that has a detrimental effect on the quality of the audio. The default parameters are + # recommended in general. + vocoder.infer_waveform(mel, target=200, overlap=50, progress_callback=no_action) + + print("All test passed! You can now synthesize speech.\n\n") + + + ## Interactive speech generation + print("This is a GUI-less example of interface to SV2TTS. The purpose of this script is to " + "show how you can interface this project easily with your own. See the source code for " + "an explanation of what is happening.\n") + + print("Interactive generation loop") + num_generated = 0 + while True: + try: + # Get the reference audio filepath + message = "Reference voice: enter an audio filepath of a voice to be cloned (mp3, " \ + "wav, m4a, flac, ...):\n" + in_fpath = Path(input(message).replace("\"", "").replace("\'", "")) + + if in_fpath.suffix.lower() == ".mp3" and args.no_mp3_support: + print("Can't Use mp3 files please try again:") + continue + ## Computing the embedding + # First, we load the wav using the function that the speaker encoder provides. This is + # important: there is preprocessing that must be applied. + + # The following two methods are equivalent: + # - Directly load from the filepath: + preprocessed_wav = encoder.preprocess_wav(in_fpath) + # - If the wav is already loaded: + original_wav, sampling_rate = librosa.load(str(in_fpath)) + preprocessed_wav = encoder.preprocess_wav(original_wav, sampling_rate) + print("Loaded file succesfully") + + # Then we derive the embedding. There are many functions and parameters that the + # speaker encoder interfaces. These are mostly for in-depth research. You will typically + # only use this function (with its default parameters): + embed = encoder.embed_utterance(preprocessed_wav) + print("Created the embedding") + + + ## Generating the spectrogram + text = input("Write a sentence (+-20 words) to be synthesized:\n") + + # If seed is specified, reset torch seed and force synthesizer reload + if args.seed is not None: + torch.manual_seed(args.seed) + synthesizer = Synthesizer(args.syn_model_fpath) + + # The synthesizer works in batch, so you need to put your data in a list or numpy array + texts = [text] + embeds = [embed] + # If you know what the attention layer alignments are, you can retrieve them here by + # passing return_alignments=True + specs = synthesizer.synthesize_spectrograms(texts, embeds) + spec = specs[0] + print("Created the mel spectrogram") + + + ## Generating the waveform + print("Synthesizing the waveform:") + + # If seed is specified, reset torch seed and reload vocoder + if args.seed is not None: + torch.manual_seed(args.seed) + vocoder.load_model(args.voc_model_fpath) + + # Synthesizing the waveform is fairly straightforward. Remember that the longer the + # spectrogram, the more time-efficient the vocoder. + generated_wav = vocoder.infer_waveform(spec) + + + ## Post-generation + # There's a bug with sounddevice that makes the audio cut one second earlier, so we + # pad it. + generated_wav = np.pad(generated_wav, (0, synthesizer.sample_rate), mode="constant") + + # Trim excess silences to compensate for gaps in spectrograms (issue #53) + generated_wav = encoder.preprocess_wav(generated_wav) + + # Play the audio (non-blocking) + if not args.no_sound: + try: + sd.stop() + sd.play(generated_wav, synthesizer.sample_rate) + except sd.PortAudioError as e: + print("\nCaught exception: %s" % repr(e)) + print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n") + except: + raise + + # Save it on the disk + filename = "demo_output_%02d.wav" % num_generated + print(generated_wav.dtype) + sf.write(filename, generated_wav.astype(np.float32), synthesizer.sample_rate) + num_generated += 1 + print("\nSaved output as %s\n\n" % filename) + + + except Exception as e: + print("Caught exception: %s" % repr(e)) + print("Restarting\n") diff --git a/demo_toolbox.py b/demo_toolbox.py new file mode 100644 index 0000000000000000000000000000000000000000..ea30a29275965c7e2b815cd703e891a5ca53e97b --- /dev/null +++ b/demo_toolbox.py @@ -0,0 +1,43 @@ +from pathlib import Path +from toolbox import Toolbox +from utils.argutils import print_args +from utils.modelutils import check_model_paths +import argparse +import os + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="Runs the toolbox", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("-d", "--datasets_root", type=Path, help= \ + "Path to the directory containing your datasets. See toolbox/__init__.py for a list of " + "supported datasets.", default=None) + parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models", + help="Directory containing saved encoder models") + parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models", + help="Directory containing saved synthesizer models") + parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models", + help="Directory containing saved vocoder models") + parser.add_argument("--cpu", action="store_true", help=\ + "If True, processing is done on CPU, even when a GPU is available.") + parser.add_argument("--seed", type=int, default=None, help=\ + "Optional random number seed value to make toolbox deterministic.") + parser.add_argument("--no_mp3_support", action="store_true", help=\ + "If True, no mp3 files are allowed.") + args = parser.parse_args() + print_args(args, parser) + + if args.cpu: + # Hide GPUs from Pytorch to force CPU processing + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + del args.cpu + + ## Remind the user to download pretrained models if needed + check_model_paths(encoder_path=args.enc_models_dir, synthesizer_path=args.syn_models_dir, + vocoder_path=args.voc_models_dir) + + # Launch the toolbox + Toolbox(**vars(args)) diff --git a/encoder/__init__.py b/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/encoder/audio.py b/encoder/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..799aa835499ce8b839290f28b2c8ffb629f37565 --- /dev/null +++ b/encoder/audio.py @@ -0,0 +1,117 @@ +from scipy.ndimage.morphology import binary_dilation +from encoder.params_data import * +from pathlib import Path +from typing import Optional, Union +from warnings import warn +import numpy as np +import librosa +import struct + +try: + import webrtcvad +except: + warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.") + webrtcvad=None + +int16_max = (2 ** 15) - 1 + + +def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], + source_sr: Optional[int] = None, + normalize: Optional[bool] = True, + trim_silence: Optional[bool] = True): + """ + Applies the preprocessing operations used in training the Speaker Encoder to a waveform + either on disk or in memory. The waveform will be resampled to match the data hyperparameters. + + :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not + just .wav), either the waveform as a numpy array of floats. + :param source_sr: if passing an audio waveform, the sampling rate of the waveform before + preprocessing. After preprocessing, the waveform's sampling rate will match the data + hyperparameters. If passing a filepath, the sampling rate will be automatically detected and + this argument will be ignored. + """ + # Load the wav from disk if needed + if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): + wav, source_sr = librosa.load(str(fpath_or_wav), sr=None) + else: + wav = fpath_or_wav + + # Resample the wav if needed + if source_sr is not None and source_sr != sampling_rate: + wav = librosa.resample(wav, source_sr, sampling_rate) + + # Apply the preprocessing: normalize volume and shorten long silences + if normalize: + wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) + if webrtcvad and trim_silence: + wav = trim_long_silences(wav) + + return wav + + +def wav_to_mel_spectrogram(wav): + """ + Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. + Note: this not a log-mel spectrogram. + """ + frames = librosa.feature.melspectrogram( + wav, + sampling_rate, + n_fft=int(sampling_rate * mel_window_length / 1000), + hop_length=int(sampling_rate * mel_window_step / 1000), + n_mels=mel_n_channels + ) + return frames.astype(np.float32).T + + +def trim_long_silences(wav): + """ + Ensures that segments without voice in the waveform remain no longer than a + threshold determined by the VAD parameters in params.py. + + :param wav: the raw waveform as a numpy array of floats + :return: the same waveform with silences trimmed away (length <= original wav length) + """ + # Compute the voice detection window size + samples_per_window = (vad_window_length * sampling_rate) // 1000 + + # Trim the end of the audio to have a multiple of the window size + wav = wav[:len(wav) - (len(wav) % samples_per_window)] + + # Convert the float waveform to 16-bit mono PCM + pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) + + # Perform voice activation detection + voice_flags = [] + vad = webrtcvad.Vad(mode=3) + for window_start in range(0, len(wav), samples_per_window): + window_end = window_start + samples_per_window + voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], + sample_rate=sampling_rate)) + voice_flags = np.array(voice_flags) + + # Smooth the voice detection with a moving average + def moving_average(array, width): + array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) + ret = np.cumsum(array_padded, dtype=float) + ret[width:] = ret[width:] - ret[:-width] + return ret[width - 1:] / width + + audio_mask = moving_average(voice_flags, vad_moving_average_width) + audio_mask = np.round(audio_mask).astype(np.bool) + + # Dilate the voiced regions + audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) + audio_mask = np.repeat(audio_mask, samples_per_window) + + return wav[audio_mask == True] + + +def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): + if increase_only and decrease_only: + raise ValueError("Both increase only and decrease only are set") + dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2)) + if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): + return wav + return wav * (10 ** (dBFS_change / 20)) diff --git a/encoder/config.py b/encoder/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1c21312f3de971bfa008254c6035cebc09f05e4c --- /dev/null +++ b/encoder/config.py @@ -0,0 +1,45 @@ +librispeech_datasets = { + "train": { + "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"], + "other": ["LibriSpeech/train-other-500"] + }, + "test": { + "clean": ["LibriSpeech/test-clean"], + "other": ["LibriSpeech/test-other"] + }, + "dev": { + "clean": ["LibriSpeech/dev-clean"], + "other": ["LibriSpeech/dev-other"] + }, +} +libritts_datasets = { + "train": { + "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"], + "other": ["LibriTTS/train-other-500"] + }, + "test": { + "clean": ["LibriTTS/test-clean"], + "other": ["LibriTTS/test-other"] + }, + "dev": { + "clean": ["LibriTTS/dev-clean"], + "other": ["LibriTTS/dev-other"] + }, +} +voxceleb_datasets = { + "voxceleb1" : { + "train": ["VoxCeleb1/wav"], + "test": ["VoxCeleb1/test_wav"] + }, + "voxceleb2" : { + "train": ["VoxCeleb2/dev/aac"], + "test": ["VoxCeleb2/test_wav"] + } +} + +other_datasets = [ + "LJSpeech-1.1", + "VCTK-Corpus/wav48", +] + +anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"] diff --git a/encoder/data_objects/__init__.py b/encoder/data_objects/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef04ade68544d0477a7f6deb4e7d51e97f592910 --- /dev/null +++ b/encoder/data_objects/__init__.py @@ -0,0 +1,2 @@ +from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset +from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader diff --git a/encoder/data_objects/random_cycler.py b/encoder/data_objects/random_cycler.py new file mode 100644 index 0000000000000000000000000000000000000000..c405db6b27f46d874d8feb37e3f9c1e12c251109 --- /dev/null +++ b/encoder/data_objects/random_cycler.py @@ -0,0 +1,37 @@ +import random + +class RandomCycler: + """ + Creates an internal copy of a sequence and allows access to its items in a constrained random + order. For a source sequence of n items and one or several consecutive queries of a total + of m items, the following guarantees hold (one implies the other): + - Each item will be returned between m // n and ((m - 1) // n) + 1 times. + - Between two appearances of the same item, there may be at most 2 * (n - 1) other items. + """ + + def __init__(self, source): + if len(source) == 0: + raise Exception("Can't create RandomCycler from an empty collection") + self.all_items = list(source) + self.next_items = [] + + def sample(self, count: int): + shuffle = lambda l: random.sample(l, len(l)) + + out = [] + while count > 0: + if count >= len(self.all_items): + out.extend(shuffle(list(self.all_items))) + count -= len(self.all_items) + continue + n = min(count, len(self.next_items)) + out.extend(self.next_items[:n]) + count -= n + self.next_items = self.next_items[n:] + if len(self.next_items) == 0: + self.next_items = shuffle(list(self.all_items)) + return out + + def __next__(self): + return self.sample(1)[0] + diff --git a/encoder/data_objects/speaker.py b/encoder/data_objects/speaker.py new file mode 100644 index 0000000000000000000000000000000000000000..494e882fe34fc38dcc793ab8c74a6cc2376bb7b5 --- /dev/null +++ b/encoder/data_objects/speaker.py @@ -0,0 +1,40 @@ +from encoder.data_objects.random_cycler import RandomCycler +from encoder.data_objects.utterance import Utterance +from pathlib import Path + +# Contains the set of utterances of a single speaker +class Speaker: + def __init__(self, root: Path): + self.root = root + self.name = root.name + self.utterances = None + self.utterance_cycler = None + + def _load_utterances(self): + with self.root.joinpath("_sources.txt").open("r") as sources_file: + sources = [l.split(",") for l in sources_file] + sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources} + self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()] + self.utterance_cycler = RandomCycler(self.utterances) + + def random_partial(self, count, n_frames): + """ + Samples a batch of unique partial utterances from the disk in a way that all + utterances come up at least once every two cycles and in a random order every time. + + :param count: The number of partial utterances to sample from the set of utterances from + that speaker. Utterances are guaranteed not to be repeated if is not larger than + the number of utterances available. + :param n_frames: The number of frames in the partial utterance. + :return: A list of tuples (utterance, frames, range) where utterance is an Utterance, + frames are the frames of the partial utterances and range is the range of the partial + utterance with regard to the complete utterance. + """ + if self.utterances is None: + self._load_utterances() + + utterances = self.utterance_cycler.sample(count) + + a = [(u,) + u.random_partial(n_frames) for u in utterances] + + return a diff --git a/encoder/data_objects/speaker_batch.py b/encoder/data_objects/speaker_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..56651dba5804a0c59c334e49ac18f8f5a4bfa444 --- /dev/null +++ b/encoder/data_objects/speaker_batch.py @@ -0,0 +1,12 @@ +import numpy as np +from typing import List +from encoder.data_objects.speaker import Speaker + +class SpeakerBatch: + def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int): + self.speakers = speakers + self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers} + + # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with + # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40) + self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]]) diff --git a/encoder/data_objects/speaker_verification_dataset.py b/encoder/data_objects/speaker_verification_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..77a6e05eae6a939ae7575ae70b7173644141fffe --- /dev/null +++ b/encoder/data_objects/speaker_verification_dataset.py @@ -0,0 +1,56 @@ +from encoder.data_objects.random_cycler import RandomCycler +from encoder.data_objects.speaker_batch import SpeakerBatch +from encoder.data_objects.speaker import Speaker +from encoder.params_data import partials_n_frames +from torch.utils.data import Dataset, DataLoader +from pathlib import Path + +# TODO: improve with a pool of speakers for data efficiency + +class SpeakerVerificationDataset(Dataset): + def __init__(self, datasets_root: Path): + self.root = datasets_root + speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] + if len(speaker_dirs) == 0: + raise Exception("No speakers found. Make sure you are pointing to the directory " + "containing all preprocessed speaker directories.") + self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs] + self.speaker_cycler = RandomCycler(self.speakers) + + def __len__(self): + return int(1e10) + + def __getitem__(self, index): + return next(self.speaker_cycler) + + def get_logs(self): + log_string = "" + for log_fpath in self.root.glob("*.txt"): + with log_fpath.open("r") as log_file: + log_string += "".join(log_file.readlines()) + return log_string + + +class SpeakerVerificationDataLoader(DataLoader): + def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None, + batch_sampler=None, num_workers=0, pin_memory=False, timeout=0, + worker_init_fn=None): + self.utterances_per_speaker = utterances_per_speaker + + super().__init__( + dataset=dataset, + batch_size=speakers_per_batch, + shuffle=False, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=num_workers, + collate_fn=self.collate, + pin_memory=pin_memory, + drop_last=False, + timeout=timeout, + worker_init_fn=worker_init_fn + ) + + def collate(self, speakers): + return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames) + \ No newline at end of file diff --git a/encoder/data_objects/utterance.py b/encoder/data_objects/utterance.py new file mode 100644 index 0000000000000000000000000000000000000000..0768c3420f422a7464f305b4c1fb6752c57ceda7 --- /dev/null +++ b/encoder/data_objects/utterance.py @@ -0,0 +1,26 @@ +import numpy as np + + +class Utterance: + def __init__(self, frames_fpath, wave_fpath): + self.frames_fpath = frames_fpath + self.wave_fpath = wave_fpath + + def get_frames(self): + return np.load(self.frames_fpath) + + def random_partial(self, n_frames): + """ + Crops the frames into a partial utterance of n_frames + + :param n_frames: The number of frames of the partial utterance + :return: the partial utterance frames and a tuple indicating the start and end of the + partial utterance in the complete utterance. + """ + frames = self.get_frames() + if frames.shape[0] == n_frames: + start = 0 + else: + start = np.random.randint(0, frames.shape[0] - n_frames) + end = start + n_frames + return frames[start:end], (start, end) \ No newline at end of file diff --git a/encoder/inference.py b/encoder/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..4ca417b63406b9280418069d28d2877308be9bc2 --- /dev/null +++ b/encoder/inference.py @@ -0,0 +1,178 @@ +from encoder.params_data import * +from encoder.model import SpeakerEncoder +from encoder.audio import preprocess_wav # We want to expose this function from here +from matplotlib import cm +from encoder import audio +from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np +import torch + +_model = None # type: SpeakerEncoder +_device = None # type: torch.device + + +def load_model(weights_fpath: Path, device=None): + """ + Loads the model in memory. If this function is not explicitely called, it will be run on the + first call to embed_frames() with the default weights file. + + :param weights_fpath: the path to saved model weights. + :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The + model will be loaded and will run on this device. Outputs will however always be on the cpu. + If None, will default to your GPU if it"s available, otherwise your CPU. + """ + # TODO: I think the slow loading of the encoder might have something to do with the device it + # was saved on. Worth investigating. + global _model, _device + if device is None: + _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, str): + _device = torch.device(device) + _model = SpeakerEncoder(_device, torch.device("cpu")) + checkpoint = torch.load(weights_fpath, _device) + _model.load_state_dict(checkpoint["model_state"]) + _model.eval() + print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"])) + + +def is_loaded(): + return _model is not None + + +def embed_frames_batch(frames_batch): + """ + Computes embeddings for a batch of mel spectrogram. + + :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape + (batch_size, n_frames, n_channels) + :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size) + """ + if _model is None: + raise Exception("Model was not loaded. Call load_model() before inference.") + + frames = torch.from_numpy(frames_batch).to(_device) + embed = _model.forward(frames).detach().cpu().numpy() + return embed + + +def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames, + min_pad_coverage=0.75, overlap=0.5): + """ + Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain + partial utterances of each. Both the waveform and the mel + spectrogram slices are returned, so as to make each partial utterance waveform correspond to + its spectrogram. This function assumes that the mel spectrogram parameters used are those + defined in params_data.py. + + The returned ranges may be indexing further than the length of the waveform. It is + recommended that you pad the waveform with zeros up to wave_slices[-1].stop. + + :param n_samples: the number of samples in the waveform + :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial + utterance + :param min_pad_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered, as if we padded the audio. Otherwise, + it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial + utterance, this parameter is ignored so that the function always returns at least 1 slice. + :param overlap: by how much the partial utterance should overlap. If set to 0, the partial + utterances are entirely disjoint. + :return: the waveform slices and mel spectrogram slices as lists of array slices. Index + respectively the waveform and the mel spectrogram with these slices to obtain the partial + utterances. + """ + assert 0 <= overlap < 1 + assert 0 < min_pad_coverage <= 1 + + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) + n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) + frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) + + # Compute the slices + wav_slices, mel_slices = [], [] + steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1) + for i in range(0, steps, frame_step): + mel_range = np.array([i, i + partial_utterance_n_frames]) + wav_range = mel_range * samples_per_frame + mel_slices.append(slice(*mel_range)) + wav_slices.append(slice(*wav_range)) + + # Evaluate whether extra padding is warranted or not + last_wav_range = wav_slices[-1] + coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) + if coverage < min_pad_coverage and len(mel_slices) > 1: + mel_slices = mel_slices[:-1] + wav_slices = wav_slices[:-1] + + return wav_slices, mel_slices + + +def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs): + """ + Computes an embedding for a single utterance. + + # TODO: handle multiple wavs to benefit from batching on GPU + :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32 + :param using_partials: if True, then the utterance is split in partial utterances of + frames and the utterance embedding is computed from their + normalized average. If False, the utterance is instead computed from feeding the entire + spectogram to the network. + :param return_partials: if True, the partial embeddings will also be returned along with the + wav slices that correspond to the partial embeddings. + :param kwargs: additional arguments to compute_partial_splits() + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If + is True, the partial utterances as a numpy array of float32 of shape + (n_partials, model_embedding_size) and the wav partials as a list of slices will also be + returned. If is simultaneously set to False, both these values will be None + instead. + """ + # Process the entire utterance if not using partials + if not using_partials: + frames = audio.wav_to_mel_spectrogram(wav) + embed = embed_frames_batch(frames[None, ...])[0] + if return_partials: + return embed, None, None + return embed + + # Compute where to split the utterance into partials and pad if necessary + wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs) + max_wave_length = wave_slices[-1].stop + if max_wave_length >= len(wav): + wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") + + # Split the utterance into partials + frames = audio.wav_to_mel_spectrogram(wav) + frames_batch = np.array([frames[s] for s in mel_slices]) + partial_embeds = embed_frames_batch(frames_batch) + + # Compute the utterance embedding from the partial embeddings + raw_embed = np.mean(partial_embeds, axis=0) + embed = raw_embed / np.linalg.norm(raw_embed, 2) + + if return_partials: + return embed, partial_embeds, wave_slices + return embed + + +def embed_speaker(wavs, **kwargs): + raise NotImplemented() + + +def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)): + if ax is None: + ax = plt.gca() + + if shape is None: + height = int(np.sqrt(len(embed))) + shape = (height, -1) + embed = embed.reshape(shape) + + cmap = cm.get_cmap() + mappable = ax.imshow(embed, cmap=cmap) + cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04) + sm = cm.ScalarMappable(cmap=cmap) + sm.set_clim(*color_range) + + ax.set_xticks([]), ax.set_yticks([]) + ax.set_title(title) diff --git a/encoder/model.py b/encoder/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e050d3204d8f1becdf0f8b3133470708e5420cea --- /dev/null +++ b/encoder/model.py @@ -0,0 +1,135 @@ +from encoder.params_model import * +from encoder.params_data import * +from scipy.interpolate import interp1d +from sklearn.metrics import roc_curve +from torch.nn.utils import clip_grad_norm_ +from scipy.optimize import brentq +from torch import nn +import numpy as np +import torch + + +class SpeakerEncoder(nn.Module): + def __init__(self, device, loss_device): + super().__init__() + self.loss_device = loss_device + + # Network defition + self.lstm = nn.LSTM(input_size=mel_n_channels, + hidden_size=model_hidden_size, + num_layers=model_num_layers, + batch_first=True).to(device) + self.linear = nn.Linear(in_features=model_hidden_size, + out_features=model_embedding_size).to(device) + self.relu = torch.nn.ReLU().to(device) + + # Cosine similarity scaling (with fixed initial parameter values) + self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device) + self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device) + + # Loss + self.loss_fn = nn.CrossEntropyLoss().to(loss_device) + + def do_gradient_ops(self): + # Gradient scale + self.similarity_weight.grad *= 0.01 + self.similarity_bias.grad *= 0.01 + + # Gradient clipping + clip_grad_norm_(self.parameters(), 3, norm_type=2) + + def forward(self, utterances, hidden_init=None): + """ + Computes the embeddings of a batch of utterance spectrograms. + + :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape + (batch_size, n_frames, n_channels) + :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, + batch_size, hidden_size). Will default to a tensor of zeros if None. + :return: the embeddings as a tensor of shape (batch_size, embedding_size) + """ + # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state + # and the final cell state. + out, (hidden, cell) = self.lstm(utterances, hidden_init) + + # We take only the hidden state of the last layer + embeds_raw = self.relu(self.linear(hidden[-1])) + + # L2-normalize it + embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5) + + return embeds + + def similarity_matrix(self, embeds): + """ + Computes the similarity matrix according the section 2.1 of GE2E. + + :param embeds: the embeddings as a tensor of shape (speakers_per_batch, + utterances_per_speaker, embedding_size) + :return: the similarity matrix as a tensor of shape (speakers_per_batch, + utterances_per_speaker, speakers_per_batch) + """ + speakers_per_batch, utterances_per_speaker = embeds.shape[:2] + + # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation + centroids_incl = torch.mean(embeds, dim=1, keepdim=True) + centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5) + + # Exclusive centroids (1 per utterance) + centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds) + centroids_excl /= (utterances_per_speaker - 1) + centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5) + + # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot + # product of these vectors (which is just an element-wise multiplication reduced by a sum). + # We vectorize the computation for efficiency. + sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker, + speakers_per_batch).to(self.loss_device) + mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int) + for j in range(speakers_per_batch): + mask = np.where(mask_matrix[j])[0] + sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2) + sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1) + + ## Even more vectorized version (slower maybe because of transpose) + # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker + # ).to(self.loss_device) + # eye = np.eye(speakers_per_batch, dtype=np.int) + # mask = np.where(1 - eye) + # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2) + # mask = np.where(eye) + # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2) + # sim_matrix2 = sim_matrix2.transpose(1, 2) + + sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias + return sim_matrix + + def loss(self, embeds): + """ + Computes the softmax loss according the section 2.1 of GE2E. + + :param embeds: the embeddings as a tensor of shape (speakers_per_batch, + utterances_per_speaker, embedding_size) + :return: the loss and the EER for this batch of embeddings. + """ + speakers_per_batch, utterances_per_speaker = embeds.shape[:2] + + # Loss + sim_matrix = self.similarity_matrix(embeds) + sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker, + speakers_per_batch)) + ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker) + target = torch.from_numpy(ground_truth).long().to(self.loss_device) + loss = self.loss_fn(sim_matrix, target) + + # EER (not backpropagated) + with torch.no_grad(): + inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0] + labels = np.array([inv_argmax(i) for i in ground_truth]) + preds = sim_matrix.detach().cpu().numpy() + + # Snippet from https://yangcha.github.io/EER-ROC/ + fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) + eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) + + return loss, eer diff --git a/encoder/params_data.py b/encoder/params_data.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb1716ed45617f2b127a7fb8885afe6cc74fb71 --- /dev/null +++ b/encoder/params_data.py @@ -0,0 +1,29 @@ + +## Mel-filterbank +mel_window_length = 25 # In milliseconds +mel_window_step = 10 # In milliseconds +mel_n_channels = 40 + + +## Audio +sampling_rate = 16000 +# Number of spectrogram frames in a partial utterance +partials_n_frames = 160 # 1600 ms +# Number of spectrogram frames at inference +inference_n_frames = 80 # 800 ms + + +## Voice Activation Detection +# Window size of the VAD. Must be either 10, 20 or 30 milliseconds. +# This sets the granularity of the VAD. Should not need to be changed. +vad_window_length = 30 # In milliseconds +# Number of frames to average together when performing the moving average smoothing. +# The larger this value, the larger the VAD variations must be to not get smoothed out. +vad_moving_average_width = 8 +# Maximum number of consecutive silent frames a segment can have. +vad_max_silence_length = 6 + + +## Audio volume normalization +audio_norm_target_dBFS = -30 + diff --git a/encoder/params_model.py b/encoder/params_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3e356472fb5a27f370cb3920976a11d12a76c1b7 --- /dev/null +++ b/encoder/params_model.py @@ -0,0 +1,11 @@ + +## Model parameters +model_hidden_size = 256 +model_embedding_size = 256 +model_num_layers = 3 + + +## Training parameters +learning_rate_init = 1e-4 +speakers_per_batch = 64 +utterances_per_speaker = 10 diff --git a/encoder/preprocess.py b/encoder/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..551a8b29c4d84c0e1430f285a1c8b5e10c98ee5f --- /dev/null +++ b/encoder/preprocess.py @@ -0,0 +1,175 @@ +from multiprocess.pool import ThreadPool +from encoder.params_data import * +from encoder.config import librispeech_datasets, anglophone_nationalites +from datetime import datetime +from encoder import audio +from pathlib import Path +from tqdm import tqdm +import numpy as np + + +class DatasetLog: + """ + Registers metadata about the dataset in a text file. + """ + def __init__(self, root, name): + self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w") + self.sample_data = dict() + + start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) + self.write_line("Creating dataset %s on %s" % (name, start_time)) + self.write_line("-----") + self._log_params() + + def _log_params(self): + from encoder import params_data + self.write_line("Parameter values:") + for param_name in (p for p in dir(params_data) if not p.startswith("__")): + value = getattr(params_data, param_name) + self.write_line("\t%s: %s" % (param_name, value)) + self.write_line("-----") + + def write_line(self, line): + self.text_file.write("%s\n" % line) + + def add_sample(self, **kwargs): + for param_name, value in kwargs.items(): + if not param_name in self.sample_data: + self.sample_data[param_name] = [] + self.sample_data[param_name].append(value) + + def finalize(self): + self.write_line("Statistics:") + for param_name, values in self.sample_data.items(): + self.write_line("\t%s:" % param_name) + self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values))) + self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values))) + self.write_line("-----") + end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) + self.write_line("Finished on %s" % end_time) + self.text_file.close() + + +def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog): + dataset_root = datasets_root.joinpath(dataset_name) + if not dataset_root.exists(): + print("Couldn\'t find %s, skipping this dataset." % dataset_root) + return None, None + return dataset_root, DatasetLog(out_dir, dataset_name) + + +def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension, + skip_existing, logger): + print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs))) + + # Function to preprocess utterances for one speaker + def preprocess_speaker(speaker_dir: Path): + # Give a name to the speaker that includes its dataset + speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) + + # Create an output directory with that name, as well as a txt file containing a + # reference to each source file. + speaker_out_dir = out_dir.joinpath(speaker_name) + speaker_out_dir.mkdir(exist_ok=True) + sources_fpath = speaker_out_dir.joinpath("_sources.txt") + + # There's a possibility that the preprocessing was interrupted earlier, check if + # there already is a sources file. + if sources_fpath.exists(): + try: + with sources_fpath.open("r") as sources_file: + existing_fnames = {line.split(",")[0] for line in sources_file} + except: + existing_fnames = {} + else: + existing_fnames = {} + + # Gather all audio files for that speaker recursively + sources_file = sources_fpath.open("a" if skip_existing else "w") + for in_fpath in speaker_dir.glob("**/*.%s" % extension): + # Check if the target output file already exists + out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts) + out_fname = out_fname.replace(".%s" % extension, ".npy") + if skip_existing and out_fname in existing_fnames: + continue + + # Load and preprocess the waveform + wav = audio.preprocess_wav(in_fpath) + if len(wav) == 0: + continue + + # Create the mel spectrogram, discard those that are too short + frames = audio.wav_to_mel_spectrogram(wav) + if len(frames) < partials_n_frames: + continue + + out_fpath = speaker_out_dir.joinpath(out_fname) + np.save(out_fpath, frames) + logger.add_sample(duration=len(wav) / sampling_rate) + sources_file.write("%s,%s\n" % (out_fname, in_fpath)) + + sources_file.close() + + # Process the utterances for each speaker + with ThreadPool(8) as pool: + list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs), + unit="speakers")) + logger.finalize() + print("Done preprocessing %s.\n" % dataset_name) + + +def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False): + for dataset_name in librispeech_datasets["train"]["other"]: + # Initialize the preprocessing + dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) + if not dataset_root: + return + + # Preprocess all speakers + speaker_dirs = list(dataset_root.glob("*")) + _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac", + skip_existing, logger) + + +def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False): + # Initialize the preprocessing + dataset_name = "VoxCeleb1" + dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) + if not dataset_root: + return + + # Get the contents of the meta file + with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile: + metadata = [line.split("\t") for line in metafile][1:] + + # Select the ID and the nationality, filter out non-anglophone speakers + nationalities = {line[0]: line[3] for line in metadata} + keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if + nationality.lower() in anglophone_nationalites] + print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." % + (len(keep_speaker_ids), len(nationalities))) + + # Get the speaker directories for anglophone speakers only + speaker_dirs = dataset_root.joinpath("wav").glob("*") + speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if + speaker_dir.name in keep_speaker_ids] + print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." % + (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs))) + + # Preprocess all speakers + _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav", + skip_existing, logger) + + +def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False): + # Initialize the preprocessing + dataset_name = "VoxCeleb2" + dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) + if not dataset_root: + return + + # Get the speaker directories + # Preprocess all speakers + speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*")) + _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a", + skip_existing, logger) diff --git a/encoder/train.py b/encoder/train.py new file mode 100644 index 0000000000000000000000000000000000000000..619952e8de6c390912fe341403a39169592e585d --- /dev/null +++ b/encoder/train.py @@ -0,0 +1,123 @@ +from encoder.visualizations import Visualizations +from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset +from encoder.params_model import * +from encoder.model import SpeakerEncoder +from utils.profiler import Profiler +from pathlib import Path +import torch + +def sync(device: torch.device): + # For correct profiling (cuda operations are async) + if device.type == "cuda": + torch.cuda.synchronize(device) + + +def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int, + backup_every: int, vis_every: int, force_restart: bool, visdom_server: str, + no_visdom: bool): + # Create a dataset and a dataloader + dataset = SpeakerVerificationDataset(clean_data_root) + loader = SpeakerVerificationDataLoader( + dataset, + speakers_per_batch, + utterances_per_speaker, + num_workers=8, + ) + + # Setup the device on which to run the forward pass and the loss. These can be different, + # because the forward pass is faster on the GPU whereas the loss is often (depending on your + # hyperparameters) faster on the CPU. + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # FIXME: currently, the gradient is None if loss_device is cuda + loss_device = torch.device("cpu") + + # Create the model and the optimizer + model = SpeakerEncoder(device, loss_device) + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init) + init_step = 1 + + # Configure file path for the model + state_fpath = models_dir.joinpath(run_id + ".pt") + backup_dir = models_dir.joinpath(run_id + "_backups") + + # Load any existing model + if not force_restart: + if state_fpath.exists(): + print("Found existing model \"%s\", loading it and resuming training." % run_id) + checkpoint = torch.load(state_fpath) + init_step = checkpoint["step"] + model.load_state_dict(checkpoint["model_state"]) + optimizer.load_state_dict(checkpoint["optimizer_state"]) + optimizer.param_groups[0]["lr"] = learning_rate_init + else: + print("No model \"%s\" found, starting training from scratch." % run_id) + else: + print("Starting the training from scratch.") + model.train() + + # Initialize the visualization environment + vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom) + vis.log_dataset(dataset) + vis.log_params() + device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU") + vis.log_implementation({"Device": device_name}) + + # Training loop + profiler = Profiler(summarize_every=10, disabled=False) + for step, speaker_batch in enumerate(loader, init_step): + profiler.tick("Blocking, waiting for batch (threaded)") + + # Forward pass + inputs = torch.from_numpy(speaker_batch.data).to(device) + sync(device) + profiler.tick("Data to %s" % device) + embeds = model(inputs) + sync(device) + profiler.tick("Forward pass") + embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device) + loss, eer = model.loss(embeds_loss) + sync(loss_device) + profiler.tick("Loss") + + # Backward pass + model.zero_grad() + loss.backward() + profiler.tick("Backward pass") + model.do_gradient_ops() + optimizer.step() + profiler.tick("Parameter update") + + # Update visualizations + # learning_rate = optimizer.param_groups[0]["lr"] + vis.update(loss.item(), eer, step) + + # Draw projections and save them to the backup folder + if umap_every != 0 and step % umap_every == 0: + print("Drawing and saving projections (step %d)" % step) + backup_dir.mkdir(exist_ok=True) + projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step)) + embeds = embeds.detach().cpu().numpy() + vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath) + vis.save() + + # Overwrite the latest version of the model + if save_every != 0 and step % save_every == 0: + print("Saving the model (step %d)" % step) + torch.save({ + "step": step + 1, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, state_fpath) + + # Make a backup + if backup_every != 0 and step % backup_every == 0: + print("Making a backup (step %d)" % step) + backup_dir.mkdir(exist_ok=True) + backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step)) + torch.save({ + "step": step + 1, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, backup_fpath) + + profiler.tick("Extras (visualizations, saving)") diff --git a/encoder/visualizations.py b/encoder/visualizations.py new file mode 100644 index 0000000000000000000000000000000000000000..980c74f95f1f7df41ebccc983600b2713c0b0502 --- /dev/null +++ b/encoder/visualizations.py @@ -0,0 +1,178 @@ +from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset +from datetime import datetime +from time import perf_counter as timer +import matplotlib.pyplot as plt +import numpy as np +# import webbrowser +import visdom +import umap + +colormap = np.array([ + [76, 255, 0], + [0, 127, 70], + [255, 0, 0], + [255, 217, 38], + [0, 135, 255], + [165, 0, 165], + [255, 167, 255], + [0, 255, 255], + [255, 96, 38], + [142, 76, 0], + [33, 0, 127], + [0, 0, 0], + [183, 183, 183], +], dtype=np.float) / 255 + + +class Visualizations: + def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False): + # Tracking data + self.last_update_timestamp = timer() + self.update_every = update_every + self.step_times = [] + self.losses = [] + self.eers = [] + print("Updating the visualizations every %d steps." % update_every) + + # If visdom is disabled TODO: use a better paradigm for that + self.disabled = disabled + if self.disabled: + return + + # Set the environment name + now = str(datetime.now().strftime("%d-%m %Hh%M")) + if env_name is None: + self.env_name = now + else: + self.env_name = "%s (%s)" % (env_name, now) + + # Connect to visdom and open the corresponding window in the browser + try: + self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True) + except ConnectionError: + raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to " + "start it.") + # webbrowser.open("http://localhost:8097/env/" + self.env_name) + + # Create the windows + self.loss_win = None + self.eer_win = None + # self.lr_win = None + self.implementation_win = None + self.projection_win = None + self.implementation_string = "" + + def log_params(self): + if self.disabled: + return + from encoder import params_data + from encoder import params_model + param_string = "Model parameters:
" + for param_name in (p for p in dir(params_model) if not p.startswith("__")): + value = getattr(params_model, param_name) + param_string += "\t%s: %s
" % (param_name, value) + param_string += "Data parameters:
" + for param_name in (p for p in dir(params_data) if not p.startswith("__")): + value = getattr(params_data, param_name) + param_string += "\t%s: %s
" % (param_name, value) + self.vis.text(param_string, opts={"title": "Parameters"}) + + def log_dataset(self, dataset: SpeakerVerificationDataset): + if self.disabled: + return + dataset_string = "" + dataset_string += "Speakers: %s\n" % len(dataset.speakers) + dataset_string += "\n" + dataset.get_logs() + dataset_string = dataset_string.replace("\n", "
") + self.vis.text(dataset_string, opts={"title": "Dataset"}) + + def log_implementation(self, params): + if self.disabled: + return + implementation_string = "" + for param, value in params.items(): + implementation_string += "%s: %s\n" % (param, value) + implementation_string = implementation_string.replace("\n", "
") + self.implementation_string = implementation_string + self.implementation_win = self.vis.text( + implementation_string, + opts={"title": "Training implementation"} + ) + + def update(self, loss, eer, step): + # Update the tracking data + now = timer() + self.step_times.append(1000 * (now - self.last_update_timestamp)) + self.last_update_timestamp = now + self.losses.append(loss) + self.eers.append(eer) + print(".", end="") + + # Update the plots every steps + if step % self.update_every != 0: + return + time_string = "Step time: mean: %5dms std: %5dms" % \ + (int(np.mean(self.step_times)), int(np.std(self.step_times))) + print("\nStep %6d Loss: %.4f EER: %.4f %s" % + (step, np.mean(self.losses), np.mean(self.eers), time_string)) + if not self.disabled: + self.loss_win = self.vis.line( + [np.mean(self.losses)], + [step], + win=self.loss_win, + update="append" if self.loss_win else None, + opts=dict( + legend=["Avg. loss"], + xlabel="Step", + ylabel="Loss", + title="Loss", + ) + ) + self.eer_win = self.vis.line( + [np.mean(self.eers)], + [step], + win=self.eer_win, + update="append" if self.eer_win else None, + opts=dict( + legend=["Avg. EER"], + xlabel="Step", + ylabel="EER", + title="Equal error rate" + ) + ) + if self.implementation_win is not None: + self.vis.text( + self.implementation_string + ("%s" % time_string), + win=self.implementation_win, + opts={"title": "Training implementation"}, + ) + + # Reset the tracking + self.losses.clear() + self.eers.clear() + self.step_times.clear() + + def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, + max_speakers=10): + max_speakers = min(max_speakers, len(colormap)) + embeds = embeds[:max_speakers * utterances_per_speaker] + + n_speakers = len(embeds) // utterances_per_speaker + ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker) + colors = [colormap[i] for i in ground_truth] + + reducer = umap.UMAP() + projected = reducer.fit_transform(embeds) + plt.scatter(projected[:, 0], projected[:, 1], c=colors) + plt.gca().set_aspect("equal", "datalim") + plt.title("UMAP projection (step %d)" % step) + if not self.disabled: + self.projection_win = self.vis.matplot(plt, win=self.projection_win) + if out_fpath is not None: + plt.savefig(out_fpath) + plt.clf() + + def save(self): + if not self.disabled: + self.vis.save([self.env_name]) + \ No newline at end of file diff --git a/encoder_preprocess.py b/encoder_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..11502013c8d75d4652fb0ffdcdc49d55e8fb8bc9 --- /dev/null +++ b/encoder_preprocess.py @@ -0,0 +1,70 @@ +from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2 +from utils.argutils import print_args +from pathlib import Path +import argparse + +if __name__ == "__main__": + class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): + pass + + parser = argparse.ArgumentParser( + description="Preprocesses audio files from datasets, encodes them as mel spectrograms and " + "writes them to the disk. This will allow you to train the encoder. The " + "datasets required are at least one of VoxCeleb1, VoxCeleb2 and LibriSpeech. " + "Ideally, you should have all three. You should extract them as they are " + "after having downloaded them and put them in a same directory, e.g.:\n" + "-[datasets_root]\n" + " -LibriSpeech\n" + " -train-other-500\n" + " -VoxCeleb1\n" + " -wav\n" + " -vox1_meta.csv\n" + " -VoxCeleb2\n" + " -dev", + formatter_class=MyFormatter + ) + parser.add_argument("datasets_root", type=Path, help=\ + "Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.") + parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\ + "Path to the output directory that will contain the mel spectrograms. If left out, " + "defaults to /SV2TTS/encoder/") + parser.add_argument("-d", "--datasets", type=str, + default="librispeech_other,voxceleb1,voxceleb2", help=\ + "Comma-separated list of the name of the datasets you want to preprocess. Only the train " + "set of these datasets will be used. Possible names: librispeech_other, voxceleb1, " + "voxceleb2.") + parser.add_argument("-s", "--skip_existing", action="store_true", help=\ + "Whether to skip existing output files with the same name. Useful if this script was " + "interrupted.") + parser.add_argument("--no_trim", action="store_true", help=\ + "Preprocess audio without trimming silences (not recommended).") + args = parser.parse_args() + + # Verify webrtcvad is available + if not args.no_trim: + try: + import webrtcvad + except: + raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables " + "noise removal and is recommended. Please install and try again. If installation fails, " + "use --no_trim to disable this error message.") + del args.no_trim + + # Process the arguments + args.datasets = args.datasets.split(",") + if not hasattr(args, "out_dir"): + args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder") + assert args.datasets_root.exists() + args.out_dir.mkdir(exist_ok=True, parents=True) + + # Preprocess the datasets + print_args(args, parser) + preprocess_func = { + "librispeech_other": preprocess_librispeech, + "voxceleb1": preprocess_voxceleb1, + "voxceleb2": preprocess_voxceleb2, + } + args = vars(args) + for dataset in args.pop("datasets"): + print("Preprocessing %s" % dataset) + preprocess_func[dataset](**args) diff --git a/encoder_train.py b/encoder_train.py new file mode 100644 index 0000000000000000000000000000000000000000..b8740a894d615aadfe529cb36068fc8e3496125f --- /dev/null +++ b/encoder_train.py @@ -0,0 +1,47 @@ +from utils.argutils import print_args +from encoder.train import train +from pathlib import Path +import argparse + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Trains the speaker encoder. You must have run encoder_preprocess.py first.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("run_id", type=str, help= \ + "Name for this model instance. If a model state from the same run ID was previously " + "saved, the training will restart from there. Pass -f to overwrite saved states and " + "restart from scratch.") + parser.add_argument("clean_data_root", type=Path, help= \ + "Path to the output directory of encoder_preprocess.py. If you left the default " + "output directory when preprocessing, it should be /SV2TTS/encoder/.") + parser.add_argument("-m", "--models_dir", type=Path, default="encoder/saved_models/", help=\ + "Path to the output directory that will contain the saved model weights, as well as " + "backups of those weights and plots generated during training.") + parser.add_argument("-v", "--vis_every", type=int, default=10, help= \ + "Number of steps between updates of the loss and the plots.") + parser.add_argument("-u", "--umap_every", type=int, default=100, help= \ + "Number of steps between updates of the umap projection. Set to 0 to never update the " + "projections.") + parser.add_argument("-s", "--save_every", type=int, default=500, help= \ + "Number of steps between updates of the model on the disk. Set to 0 to never save the " + "model.") + parser.add_argument("-b", "--backup_every", type=int, default=7500, help= \ + "Number of steps between backups of the model. Set to 0 to never make backups of the " + "model.") + parser.add_argument("-f", "--force_restart", action="store_true", help= \ + "Do not load any saved model.") + parser.add_argument("--visdom_server", type=str, default="http://localhost") + parser.add_argument("--no_visdom", action="store_true", help= \ + "Disable visdom.") + args = parser.parse_args() + + # Process the arguments + args.models_dir.mkdir(exist_ok=True) + + # Run the training + print_args(args, parser) + train(**vars(args)) + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f4148d69216957aea00d4bd933abd11b21f9b1c6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +umap-learn +visdom +librosa>=0.8.0 +matplotlib>=3.3.0 +numpy==1.19.3; platform_system == "Windows" +numpy==1.19.4; platform_system != "Windows" +scipy>=1.0.0 +tqdm +sounddevice +SoundFile +Unidecode +inflect +PyQt5 +multiprocess +numba +webrtcvad; platform_system != "Windows" diff --git a/synthesizer/LICENSE.txt b/synthesizer/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..3337d453404ea63d5a5919d3922045374bea3da1 --- /dev/null +++ b/synthesizer/LICENSE.txt @@ -0,0 +1,24 @@ +MIT License + +Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah) +Original work Copyright (c) 2019 fatchord (https://github.com/fatchord) +Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ) +Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/synthesizer/__init__.py b/synthesizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4287ca8617970fa8fc025b75cb319c7032706910 --- /dev/null +++ b/synthesizer/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/synthesizer/audio.py b/synthesizer/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..83dc96c63c962bc8e13c446d05e27c009fb3239f --- /dev/null +++ b/synthesizer/audio.py @@ -0,0 +1,206 @@ +import librosa +import librosa.filters +import numpy as np +from scipy import signal +from scipy.io import wavfile +import soundfile as sf + + +def load_wav(path, sr): + return librosa.core.load(path, sr=sr)[0] + +def save_wav(wav, path, sr): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + #proposed by @dsmiller + wavfile.write(path, sr, wav.astype(np.int16)) + +def save_wavenet_wav(wav, path, sr): + sf.write(path, wav.astype(np.float32), sr) + +def preemphasis(wav, k, preemphasize=True): + if preemphasize: + return signal.lfilter([1, -k], [1], wav) + return wav + +def inv_preemphasis(wav, k, inv_preemphasize=True): + if inv_preemphasize: + return signal.lfilter([1], [1, -k], wav) + return wav + +#From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py +def start_and_end_indices(quantized, silence_threshold=2): + for start in range(quantized.size): + if abs(quantized[start] - 127) > silence_threshold: + break + for end in range(quantized.size - 1, 1, -1): + if abs(quantized[end] - 127) > silence_threshold: + break + + assert abs(quantized[start] - 127) > silence_threshold + assert abs(quantized[end] - 127) > silence_threshold + + return start, end + +def get_hop_size(hparams): + hop_size = hparams.hop_size + if hop_size is None: + assert hparams.frame_shift_ms is not None + hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) + return hop_size + +def linearspectrogram(wav, hparams): + D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams) + S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db + + if hparams.signal_normalization: + return _normalize(S, hparams) + return S + +def melspectrogram(wav, hparams): + D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams) + S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db + + if hparams.signal_normalization: + return _normalize(S, hparams) + return S + +def inv_linear_spectrogram(linear_spectrogram, hparams): + """Converts linear spectrogram to waveform using librosa""" + if hparams.signal_normalization: + D = _denormalize(linear_spectrogram, hparams) + else: + D = linear_spectrogram + + S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear + + if hparams.use_lws: + processor = _lws_processor(hparams) + D = processor.run_lws(S.astype(np.float64).T ** hparams.power) + y = processor.istft(D).astype(np.float32) + return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize) + else: + return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize) + +def inv_mel_spectrogram(mel_spectrogram, hparams): + """Converts mel spectrogram to waveform using librosa""" + if hparams.signal_normalization: + D = _denormalize(mel_spectrogram, hparams) + else: + D = mel_spectrogram + + S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear + + if hparams.use_lws: + processor = _lws_processor(hparams) + D = processor.run_lws(S.astype(np.float64).T ** hparams.power) + y = processor.istft(D).astype(np.float32) + return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize) + else: + return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize) + +def _lws_processor(hparams): + import lws + return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech") + +def _griffin_lim(S, hparams): + """librosa implementation of Griffin-Lim + Based on https://github.com/librosa/librosa/issues/434 + """ + angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) + S_complex = np.abs(S).astype(np.complex) + y = _istft(S_complex * angles, hparams) + for i in range(hparams.griffin_lim_iters): + angles = np.exp(1j * np.angle(_stft(y, hparams))) + y = _istft(S_complex * angles, hparams) + return y + +def _stft(y, hparams): + if hparams.use_lws: + return _lws_processor(hparams).stft(y).T + else: + return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size) + +def _istft(y, hparams): + return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size) + +########################################################## +#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) +def num_frames(length, fsize, fshift): + """Compute number of time frames of spectrogram + """ + pad = (fsize - fshift) + if length % fshift == 0: + M = (length + pad * 2 - fsize) // fshift + 1 + else: + M = (length + pad * 2 - fsize) // fshift + 2 + return M + + +def pad_lr(x, fsize, fshift): + """Compute left and right padding + """ + M = num_frames(len(x), fsize, fshift) + pad = (fsize - fshift) + T = len(x) + 2 * pad + r = (M - 1) * fshift + fsize - T + return pad, pad + r +########################################################## +#Librosa correct padding +def librosa_pad_lr(x, fsize, fshift): + return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] + +# Conversions +_mel_basis = None +_inv_mel_basis = None + +def _linear_to_mel(spectogram, hparams): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis(hparams) + return np.dot(_mel_basis, spectogram) + +def _mel_to_linear(mel_spectrogram, hparams): + global _inv_mel_basis + if _inv_mel_basis is None: + _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams)) + return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram)) + +def _build_mel_basis(hparams): + assert hparams.fmax <= hparams.sample_rate // 2 + return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels, + fmin=hparams.fmin, fmax=hparams.fmax) + +def _amp_to_db(x, hparams): + min_level = np.exp(hparams.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + +def _db_to_amp(x): + return np.power(10.0, (x) * 0.05) + +def _normalize(S, hparams): + if hparams.allow_clipping_in_normalization: + if hparams.symmetric_mels: + return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value, + -hparams.max_abs_value, hparams.max_abs_value) + else: + return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value) + + assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0 + if hparams.symmetric_mels: + return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value + else: + return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)) + +def _denormalize(D, hparams): + if hparams.allow_clipping_in_normalization: + if hparams.symmetric_mels: + return (((np.clip(D, -hparams.max_abs_value, + hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + + hparams.min_level_db) + else: + return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db) + + if hparams.symmetric_mels: + return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db) + else: + return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db) diff --git a/synthesizer/hparams.py b/synthesizer/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..f7d38f0aa4c34d11349e40dbb9861b1aec2dcb8b --- /dev/null +++ b/synthesizer/hparams.py @@ -0,0 +1,92 @@ +import ast +import pprint + +class HParams(object): + def __init__(self, **kwargs): self.__dict__.update(kwargs) + def __setitem__(self, key, value): setattr(self, key, value) + def __getitem__(self, key): return getattr(self, key) + def __repr__(self): return pprint.pformat(self.__dict__) + + def parse(self, string): + # Overrides hparams from a comma-separated string of name=value pairs + if len(string) > 0: + overrides = [s.split("=") for s in string.split(",")] + keys, values = zip(*overrides) + keys = list(map(str.strip, keys)) + values = list(map(str.strip, values)) + for k in keys: + self.__dict__[k] = ast.literal_eval(values[keys.index(k)]) + return self + +hparams = HParams( + ### Signal Processing (used in both synthesizer and vocoder) + sample_rate = 16000, + n_fft = 800, + num_mels = 80, + hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125) + win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050) + fmin = 55, + min_level_db = -100, + ref_level_db = 20, + max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small. + preemphasis = 0.97, # Filter coefficient to use if preemphasize is True + preemphasize = True, + + ### Tacotron Text-to-Speech (TTS) + tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs + tts_encoder_dims = 256, + tts_decoder_dims = 128, + tts_postnet_dims = 512, + tts_encoder_K = 5, + tts_lstm_dims = 1024, + tts_postnet_K = 5, + tts_num_highways = 4, + tts_dropout = 0.5, + tts_cleaner_names = ["english_cleaners"], + tts_stop_threshold = -3.4, # Value below which audio generation ends. + # For example, for a range of [-4, 4], this + # will terminate the sequence at the first + # frame that has all values < -3.4 + + ### Tacotron Training + tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule + (2, 5e-4, 40_000, 12), # (r, lr, step, batch_size) + (2, 2e-4, 80_000, 12), # + (2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames + (2, 3e-5, 320_000, 12), # synthesized for each decoder iteration) + (2, 1e-5, 640_000, 12)], # lr = learning rate + + tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed + tts_eval_interval = 500, # Number of steps between model evaluation (sample generation) + # Set to -1 to generate after completing epoch, or 0 to disable + + tts_eval_num_samples = 1, # Makes this number of samples + + ### Data Preprocessing + max_mel_frames = 900, + rescale = True, + rescaling_max = 0.9, + synthesis_batch_size = 16, # For vocoder preprocessing and inference. + + ### Mel Visualization and Griffin-Lim + signal_normalization = True, + power = 1.5, + griffin_lim_iters = 60, + + ### Audio processing options + fmax = 7600, # Should not exceed (sample_rate // 2) + allow_clipping_in_normalization = True, # Used when signal_normalization = True + clip_mels_length = True, # If true, discards samples exceeding max_mel_frames + use_lws = False, # "Fast spectrogram phase recovery using local weighted sums" + symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True, + # and [0, max_abs_value] if False + trim_silence = True, # Use with sample_rate of 16000 for best results + + ### SV2TTS + speaker_embedding_size = 256, # Dimension for the speaker embedding + silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split + utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded + ) + +def hparams_debug_string(): + return str(hparams) diff --git a/synthesizer/inference.py b/synthesizer/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..af7bf083ffc9bed33ea6e2c77cb7f69e6b5c0475 --- /dev/null +++ b/synthesizer/inference.py @@ -0,0 +1,171 @@ +import torch +from synthesizer import audio +from synthesizer.hparams import hparams +from synthesizer.models.tacotron import Tacotron +from synthesizer.utils.symbols import symbols +from synthesizer.utils.text import text_to_sequence +from vocoder.display import simple_table +from pathlib import Path +from typing import Union, List +import numpy as np +import librosa + + +class Synthesizer: + sample_rate = hparams.sample_rate + hparams = hparams + + def __init__(self, model_fpath: Path, verbose=True): + """ + The model isn't instantiated and loaded in memory until needed or until load() is called. + + :param model_fpath: path to the trained model file + :param verbose: if False, prints less information when using the model + """ + self.model_fpath = model_fpath + self.verbose = verbose + + # Check for GPU + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + if self.verbose: + print("Synthesizer using device:", self.device) + + # Tacotron model will be instantiated later on first use. + self._model = None + + def is_loaded(self): + """ + Whether the model is loaded in memory. + """ + return self._model is not None + + def load(self): + """ + Instantiates and loads the model given the weights file that was passed in the constructor. + """ + self._model = Tacotron(embed_dims=hparams.tts_embed_dims, + num_chars=len(symbols), + encoder_dims=hparams.tts_encoder_dims, + decoder_dims=hparams.tts_decoder_dims, + n_mels=hparams.num_mels, + fft_bins=hparams.num_mels, + postnet_dims=hparams.tts_postnet_dims, + encoder_K=hparams.tts_encoder_K, + lstm_dims=hparams.tts_lstm_dims, + postnet_K=hparams.tts_postnet_K, + num_highways=hparams.tts_num_highways, + dropout=hparams.tts_dropout, + stop_threshold=hparams.tts_stop_threshold, + speaker_embedding_size=hparams.speaker_embedding_size).to(self.device) + + self._model.load(self.model_fpath) + self._model.eval() + + if self.verbose: + print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"])) + + def synthesize_spectrograms(self, texts: List[str], + embeddings: Union[np.ndarray, List[np.ndarray]], + return_alignments=False): + """ + Synthesizes mel spectrograms from texts and speaker embeddings. + + :param texts: a list of N text prompts to be synthesized + :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256) + :param return_alignments: if True, a matrix representing the alignments between the + characters + and each decoder output step will be returned for each spectrogram + :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the + sequence length of spectrogram i, and possibly the alignments. + """ + # Load the model on the first request. + if not self.is_loaded(): + self.load() + + # Print some info about the model when it is loaded + tts_k = self._model.get_step() // 1000 + + simple_table([("Tacotron", str(tts_k) + "k"), + ("r", self._model.r)]) + + # Preprocess text inputs + inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts] + if not isinstance(embeddings, list): + embeddings = [embeddings] + + # Batch inputs + batched_inputs = [inputs[i:i+hparams.synthesis_batch_size] + for i in range(0, len(inputs), hparams.synthesis_batch_size)] + batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size] + for i in range(0, len(embeddings), hparams.synthesis_batch_size)] + + specs = [] + for i, batch in enumerate(batched_inputs, 1): + if self.verbose: + print(f"\n| Generating {i}/{len(batched_inputs)}") + + # Pad texts so they are all the same length + text_lens = [len(text) for text in batch] + max_text_len = max(text_lens) + chars = [pad1d(text, max_text_len) for text in batch] + chars = np.stack(chars) + + # Stack speaker embeddings into 2D array for batch processing + speaker_embeds = np.stack(batched_embeds[i-1]) + + # Convert to tensor + chars = torch.tensor(chars).long().to(self.device) + speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device) + + # Inference + _, mels, alignments = self._model.generate(chars, speaker_embeddings) + mels = mels.detach().cpu().numpy() + for m in mels: + # Trim silence from end of each spectrogram + while np.max(m[:, -1]) < hparams.tts_stop_threshold: + m = m[:, :-1] + specs.append(m) + + if self.verbose: + print("\n\nDone.\n") + return (specs, alignments) if return_alignments else specs + + @staticmethod + def load_preprocess_wav(fpath): + """ + Loads and preprocesses an audio file under the same conditions the audio files were used to + train the synthesizer. + """ + wav = librosa.load(str(fpath), hparams.sample_rate)[0] + if hparams.rescale: + wav = wav / np.abs(wav).max() * hparams.rescaling_max + return wav + + @staticmethod + def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]): + """ + Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that + were fed to the synthesizer when training. + """ + if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): + wav = Synthesizer.load_preprocess_wav(fpath_or_wav) + else: + wav = fpath_or_wav + + mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32) + return mel_spectrogram + + @staticmethod + def griffin_lim(mel): + """ + Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built + with the same parameters present in hparams.py. + """ + return audio.inv_mel_spectrogram(mel, hparams) + + +def pad1d(x, max_len, pad_value=0): + return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value) diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py new file mode 100644 index 0000000000000000000000000000000000000000..769f7f98b79100ff587af3609010dd55e3b2a146 --- /dev/null +++ b/synthesizer/models/tacotron.py @@ -0,0 +1,519 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from typing import Union + + +class HighwayNetwork(nn.Module): + def __init__(self, size): + super().__init__() + self.W1 = nn.Linear(size, size) + self.W2 = nn.Linear(size, size) + self.W1.bias.data.fill_(0.) + + def forward(self, x): + x1 = self.W1(x) + x2 = self.W2(x) + g = torch.sigmoid(x2) + y = g * F.relu(x1) + (1. - g) * x + return y + + +class Encoder(nn.Module): + def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout): + super().__init__() + prenet_dims = (encoder_dims, encoder_dims) + cbhg_channels = encoder_dims + self.embedding = nn.Embedding(num_chars, embed_dims) + self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1], + dropout=dropout) + self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels, + proj_channels=[cbhg_channels, cbhg_channels], + num_highways=num_highways) + + def forward(self, x, speaker_embedding=None): + x = self.embedding(x) + x = self.pre_net(x) + x.transpose_(1, 2) + x = self.cbhg(x) + if speaker_embedding is not None: + x = self.add_speaker_embedding(x, speaker_embedding) + return x + + def add_speaker_embedding(self, x, speaker_embedding): + # SV2TTS + # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims) + # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size) + # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size)) + # This concats the speaker embedding for each char in the encoder output + + # Save the dimensions as human-readable names + batch_size = x.size()[0] + num_chars = x.size()[1] + + if speaker_embedding.dim() == 1: + idx = 0 + else: + idx = 1 + + # Start by making a copy of each speaker embedding to match the input text length + # The output of this has size (batch_size, num_chars * tts_embed_dims) + speaker_embedding_size = speaker_embedding.size()[idx] + e = speaker_embedding.repeat_interleave(num_chars, dim=idx) + + # Reshape it and transpose + e = e.reshape(batch_size, speaker_embedding_size, num_chars) + e = e.transpose(1, 2) + + # Concatenate the tiled speaker embedding with the encoder output + x = torch.cat((x, e), 2) + return x + + +class BatchNormConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel, relu=True): + super().__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False) + self.bnorm = nn.BatchNorm1d(out_channels) + self.relu = relu + + def forward(self, x): + x = self.conv(x) + x = F.relu(x) if self.relu is True else x + return self.bnorm(x) + + +class CBHG(nn.Module): + def __init__(self, K, in_channels, channels, proj_channels, num_highways): + super().__init__() + + # List of all rnns to call `flatten_parameters()` on + self._to_flatten = [] + + self.bank_kernels = [i for i in range(1, K + 1)] + self.conv1d_bank = nn.ModuleList() + for k in self.bank_kernels: + conv = BatchNormConv(in_channels, channels, k) + self.conv1d_bank.append(conv) + + self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) + + self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3) + self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False) + + # Fix the highway input if necessary + if proj_channels[-1] != channels: + self.highway_mismatch = True + self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False) + else: + self.highway_mismatch = False + + self.highways = nn.ModuleList() + for i in range(num_highways): + hn = HighwayNetwork(channels) + self.highways.append(hn) + + self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True) + self._to_flatten.append(self.rnn) + + # Avoid fragmentation of RNN parameters and associated warning + self._flatten_parameters() + + def forward(self, x): + # Although we `_flatten_parameters()` on init, when using DataParallel + # the model gets replicated, making it no longer guaranteed that the + # weights are contiguous in GPU memory. Hence, we must call it again + self._flatten_parameters() + + # Save these for later + residual = x + seq_len = x.size(-1) + conv_bank = [] + + # Convolution Bank + for conv in self.conv1d_bank: + c = conv(x) # Convolution + conv_bank.append(c[:, :, :seq_len]) + + # Stack along the channel axis + conv_bank = torch.cat(conv_bank, dim=1) + + # dump the last padding to fit residual + x = self.maxpool(conv_bank)[:, :, :seq_len] + + # Conv1d projections + x = self.conv_project1(x) + x = self.conv_project2(x) + + # Residual Connect + x = x + residual + + # Through the highways + x = x.transpose(1, 2) + if self.highway_mismatch is True: + x = self.pre_highway(x) + for h in self.highways: x = h(x) + + # And then the RNN + x, _ = self.rnn(x) + return x + + def _flatten_parameters(self): + """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used + to improve efficiency and avoid PyTorch yelling at us.""" + [m.flatten_parameters() for m in self._to_flatten] + +class PreNet(nn.Module): + def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5): + super().__init__() + self.fc1 = nn.Linear(in_dims, fc1_dims) + self.fc2 = nn.Linear(fc1_dims, fc2_dims) + self.p = dropout + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = F.dropout(x, self.p, training=True) + x = self.fc2(x) + x = F.relu(x) + x = F.dropout(x, self.p, training=True) + return x + + +class Attention(nn.Module): + def __init__(self, attn_dims): + super().__init__() + self.W = nn.Linear(attn_dims, attn_dims, bias=False) + self.v = nn.Linear(attn_dims, 1, bias=False) + + def forward(self, encoder_seq_proj, query, t): + + # print(encoder_seq_proj.shape) + # Transform the query vector + query_proj = self.W(query).unsqueeze(1) + + # Compute the scores + u = self.v(torch.tanh(encoder_seq_proj + query_proj)) + scores = F.softmax(u, dim=1) + + return scores.transpose(1, 2) + + +class LSA(nn.Module): + def __init__(self, attn_dim, kernel_size=31, filters=32): + super().__init__() + self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True) + self.L = nn.Linear(filters, attn_dim, bias=False) + self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term + self.v = nn.Linear(attn_dim, 1, bias=False) + self.cumulative = None + self.attention = None + + def init_attention(self, encoder_seq_proj): + device = next(self.parameters()).device # use same device as parameters + b, t, c = encoder_seq_proj.size() + self.cumulative = torch.zeros(b, t, device=device) + self.attention = torch.zeros(b, t, device=device) + + def forward(self, encoder_seq_proj, query, t, chars): + + if t == 0: self.init_attention(encoder_seq_proj) + + processed_query = self.W(query).unsqueeze(1) + + location = self.cumulative.unsqueeze(1) + processed_loc = self.L(self.conv(location).transpose(1, 2)) + + u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc)) + u = u.squeeze(-1) + + # Mask zero padding chars + u = u * (chars != 0).float() + + # Smooth Attention + # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True) + scores = F.softmax(u, dim=1) + self.attention = scores + self.cumulative = self.cumulative + self.attention + + return scores.unsqueeze(-1).transpose(1, 2) + + +class Decoder(nn.Module): + # Class variable because its value doesn't change between classes + # yet ought to be scoped by class because its a property of a Decoder + max_r = 20 + def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims, + dropout, speaker_embedding_size): + super().__init__() + self.register_buffer("r", torch.tensor(1, dtype=torch.int)) + self.n_mels = n_mels + prenet_dims = (decoder_dims * 2, decoder_dims * 2) + self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1], + dropout=dropout) + self.attn_net = LSA(decoder_dims) + self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims) + self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims) + self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims) + self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims) + self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False) + self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1) + + def zoneout(self, prev, current, p=0.1): + device = next(self.parameters()).device # Use same device as parameters + mask = torch.zeros(prev.size(), device=device).bernoulli_(p) + return prev * mask + current * (1 - mask) + + def forward(self, encoder_seq, encoder_seq_proj, prenet_in, + hidden_states, cell_states, context_vec, t, chars): + + # Need this for reshaping mels + batch_size = encoder_seq.size(0) + + # Unpack the hidden and cell states + attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states + rnn1_cell, rnn2_cell = cell_states + + # PreNet for the Attention RNN + prenet_out = self.prenet(prenet_in) + + # Compute the Attention RNN hidden state + attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1) + attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden) + + # Compute the attention scores + scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars) + + # Dot product to create the context vector + context_vec = scores @ encoder_seq + context_vec = context_vec.squeeze(1) + + # Concat Attention RNN output w. Context Vector & project + x = torch.cat([context_vec, attn_hidden], dim=1) + x = self.rnn_input(x) + + # Compute first Residual RNN + rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) + if self.training: + rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next) + else: + rnn1_hidden = rnn1_hidden_next + x = x + rnn1_hidden + + # Compute second Residual RNN + rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell)) + if self.training: + rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next) + else: + rnn2_hidden = rnn2_hidden_next + x = x + rnn2_hidden + + # Project Mels + mels = self.mel_proj(x) + mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r] + hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) + cell_states = (rnn1_cell, rnn2_cell) + + # Stop token prediction + s = torch.cat((x, context_vec), dim=1) + s = self.stop_proj(s) + stop_tokens = torch.sigmoid(s) + + return mels, scores, hidden_states, cell_states, context_vec, stop_tokens + + +class Tacotron(nn.Module): + def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, + fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways, + dropout, stop_threshold, speaker_embedding_size): + super().__init__() + self.n_mels = n_mels + self.lstm_dims = lstm_dims + self.encoder_dims = encoder_dims + self.decoder_dims = decoder_dims + self.speaker_embedding_size = speaker_embedding_size + self.encoder = Encoder(embed_dims, num_chars, encoder_dims, + encoder_K, num_highways, dropout) + self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False) + self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims, + dropout, speaker_embedding_size) + self.postnet = CBHG(postnet_K, n_mels, postnet_dims, + [postnet_dims, fft_bins], num_highways) + self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False) + + self.init_model() + self.num_params() + + self.register_buffer("step", torch.zeros(1, dtype=torch.long)) + self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32)) + + @property + def r(self): + return self.decoder.r.item() + + @r.setter + def r(self, value): + self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False) + + def forward(self, x, m, speaker_embedding): + device = next(self.parameters()).device # use same device as parameters + + self.step += 1 + batch_size, _, steps = m.size() + + # Initialise all hidden states and pack into tuple + attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) + rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) + rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) + hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) + + # Initialise all lstm cell states and pack into tuple + rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device) + rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device) + cell_states = (rnn1_cell, rnn2_cell) + + # Frame for start of decoder loop + go_frame = torch.zeros(batch_size, self.n_mels, device=device) + + # Need an initial context vector + context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device) + + # SV2TTS: Run the encoder with the speaker embedding + # The projection avoids unnecessary matmuls in the decoder loop + encoder_seq = self.encoder(x, speaker_embedding) + encoder_seq_proj = self.encoder_proj(encoder_seq) + + # Need a couple of lists for outputs + mel_outputs, attn_scores, stop_outputs = [], [], [] + + # Run the decoder loop + for t in range(0, steps, self.r): + prenet_in = m[:, :, t - 1] if t > 0 else go_frame + mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \ + self.decoder(encoder_seq, encoder_seq_proj, prenet_in, + hidden_states, cell_states, context_vec, t, x) + mel_outputs.append(mel_frames) + attn_scores.append(scores) + stop_outputs.extend([stop_tokens] * self.r) + + # Concat the mel outputs into sequence + mel_outputs = torch.cat(mel_outputs, dim=2) + + # Post-Process for Linear Spectrograms + postnet_out = self.postnet(mel_outputs) + linear = self.post_proj(postnet_out) + linear = linear.transpose(1, 2) + + # For easy visualisation + attn_scores = torch.cat(attn_scores, 1) + # attn_scores = attn_scores.cpu().data.numpy() + stop_outputs = torch.cat(stop_outputs, 1) + + return mel_outputs, linear, attn_scores, stop_outputs + + def generate(self, x, speaker_embedding=None, steps=2000): + self.eval() + device = next(self.parameters()).device # use same device as parameters + + batch_size, _ = x.size() + + # Need to initialise all hidden states and pack into tuple for tidyness + attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) + rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) + rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) + hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) + + # Need to initialise all lstm cell states and pack into tuple for tidyness + rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device) + rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device) + cell_states = (rnn1_cell, rnn2_cell) + + # Need a Frame for start of decoder loop + go_frame = torch.zeros(batch_size, self.n_mels, device=device) + + # Need an initial context vector + context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device) + + # SV2TTS: Run the encoder with the speaker embedding + # The projection avoids unnecessary matmuls in the decoder loop + encoder_seq = self.encoder(x, speaker_embedding) + encoder_seq_proj = self.encoder_proj(encoder_seq) + + # Need a couple of lists for outputs + mel_outputs, attn_scores, stop_outputs = [], [], [] + + # Run the decoder loop + for t in range(0, steps, self.r): + prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame + mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \ + self.decoder(encoder_seq, encoder_seq_proj, prenet_in, + hidden_states, cell_states, context_vec, t, x) + mel_outputs.append(mel_frames) + attn_scores.append(scores) + stop_outputs.extend([stop_tokens] * self.r) + # Stop the loop when all stop tokens in batch exceed threshold + if (stop_tokens > 0.5).all() and t > 10: break + + # Concat the mel outputs into sequence + mel_outputs = torch.cat(mel_outputs, dim=2) + + # Post-Process for Linear Spectrograms + postnet_out = self.postnet(mel_outputs) + linear = self.post_proj(postnet_out) + + + linear = linear.transpose(1, 2) + + # For easy visualisation + attn_scores = torch.cat(attn_scores, 1) + stop_outputs = torch.cat(stop_outputs, 1) + + self.train() + + return mel_outputs, linear, attn_scores + + def init_model(self): + for p in self.parameters(): + if p.dim() > 1: nn.init.xavier_uniform_(p) + + def get_step(self): + return self.step.data.item() + + def reset_step(self): + # assignment to parameters or buffers is overloaded, updates internal dict entry + self.step = self.step.data.new_tensor(1) + + def log(self, path, msg): + with open(path, "a") as f: + print(msg, file=f) + + def load(self, path, optimizer=None): + # Use device of model params as location for loaded state + device = next(self.parameters()).device + checkpoint = torch.load(str(path), map_location=device) + self.load_state_dict(checkpoint["model_state"]) + + if "optimizer_state" in checkpoint and optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer_state"]) + + def save(self, path, optimizer=None): + if optimizer is not None: + torch.save({ + "model_state": self.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, str(path)) + else: + torch.save({ + "model_state": self.state_dict(), + }, str(path)) + + + def num_params(self, print_out=True): + parameters = filter(lambda p: p.requires_grad, self.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + if print_out: + print("Trainable Parameters: %.3fM" % parameters) + return parameters diff --git a/synthesizer/preprocess.py b/synthesizer/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..cde325c4163d6800404de214202d773addfff296 --- /dev/null +++ b/synthesizer/preprocess.py @@ -0,0 +1,259 @@ +from multiprocessing.pool import Pool +from synthesizer import audio +from functools import partial +from itertools import chain +from encoder import inference as encoder +from pathlib import Path +from utils import logmmse +from tqdm import tqdm +import numpy as np +import librosa + + +def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int, + skip_existing: bool, hparams, no_alignments: bool, + datasets_name: str, subfolders: str): + # Gather the input directories + dataset_root = datasets_root.joinpath(datasets_name) + input_dirs = [dataset_root.joinpath(subfolder.strip()) for subfolder in subfolders.split(",")] + print("\n ".join(map(str, ["Using data from:"] + input_dirs))) + assert all(input_dir.exists() for input_dir in input_dirs) + + # Create the output directories for each output file type + out_dir.joinpath("mels").mkdir(exist_ok=True) + out_dir.joinpath("audio").mkdir(exist_ok=True) + + # Create a metadata file + metadata_fpath = out_dir.joinpath("train.txt") + metadata_file = metadata_fpath.open("a" if skip_existing else "w", encoding="utf-8") + + # Preprocess the dataset + speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs)) + func = partial(preprocess_speaker, out_dir=out_dir, skip_existing=skip_existing, + hparams=hparams, no_alignments=no_alignments) + job = Pool(n_processes).imap(func, speaker_dirs) + for speaker_metadata in tqdm(job, datasets_name, len(speaker_dirs), unit="speakers"): + for metadatum in speaker_metadata: + metadata_file.write("|".join(str(x) for x in metadatum) + "\n") + metadata_file.close() + + # Verify the contents of the metadata file + with metadata_fpath.open("r", encoding="utf-8") as metadata_file: + metadata = [line.split("|") for line in metadata_file] + mel_frames = sum([int(m[4]) for m in metadata]) + timesteps = sum([int(m[3]) for m in metadata]) + sample_rate = hparams.sample_rate + hours = (timesteps / sample_rate) / 3600 + print("The dataset consists of %d utterances, %d mel frames, %d audio timesteps (%.2f hours)." % + (len(metadata), mel_frames, timesteps, hours)) + print("Max input length (text chars): %d" % max(len(m[5]) for m in metadata)) + print("Max mel frames length: %d" % max(int(m[4]) for m in metadata)) + print("Max audio timesteps length: %d" % max(int(m[3]) for m in metadata)) + + +def preprocess_speaker(speaker_dir, out_dir: Path, skip_existing: bool, hparams, no_alignments: bool): + metadata = [] + for book_dir in speaker_dir.glob("*"): + if no_alignments: + # Gather the utterance audios and texts + # LibriTTS uses .wav but we will include extensions for compatibility with other datasets + extensions = ["*.wav", "*.flac", "*.mp3"] + for extension in extensions: + wav_fpaths = book_dir.glob(extension) + + for wav_fpath in wav_fpaths: + # Load the audio waveform + wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate) + if hparams.rescale: + wav = wav / np.abs(wav).max() * hparams.rescaling_max + + # Get the corresponding text + # Check for .txt (for compatibility with other datasets) + text_fpath = wav_fpath.with_suffix(".txt") + if not text_fpath.exists(): + # Check for .normalized.txt (LibriTTS) + text_fpath = wav_fpath.with_suffix(".normalized.txt") + assert text_fpath.exists() + with text_fpath.open("r") as text_file: + text = "".join([line for line in text_file]) + text = text.replace("\"", "") + text = text.strip() + + # Process the utterance + metadata.append(process_utterance(wav, text, out_dir, str(wav_fpath.with_suffix("").name), + skip_existing, hparams)) + else: + # Process alignment file (LibriSpeech support) + # Gather the utterance audios and texts + try: + alignments_fpath = next(book_dir.glob("*.alignment.txt")) + with alignments_fpath.open("r") as alignments_file: + alignments = [line.rstrip().split(" ") for line in alignments_file] + except StopIteration: + # A few alignment files will be missing + continue + + # Iterate over each entry in the alignments file + for wav_fname, words, end_times in alignments: + wav_fpath = book_dir.joinpath(wav_fname + ".flac") + assert wav_fpath.exists() + words = words.replace("\"", "").split(",") + end_times = list(map(float, end_times.replace("\"", "").split(","))) + + # Process each sub-utterance + wavs, texts = split_on_silences(wav_fpath, words, end_times, hparams) + for i, (wav, text) in enumerate(zip(wavs, texts)): + sub_basename = "%s_%02d" % (wav_fname, i) + metadata.append(process_utterance(wav, text, out_dir, sub_basename, + skip_existing, hparams)) + + return [m for m in metadata if m is not None] + + +def split_on_silences(wav_fpath, words, end_times, hparams): + # Load the audio waveform + wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate) + if hparams.rescale: + wav = wav / np.abs(wav).max() * hparams.rescaling_max + + words = np.array(words) + start_times = np.array([0.0] + end_times[:-1]) + end_times = np.array(end_times) + assert len(words) == len(end_times) == len(start_times) + assert words[0] == "" and words[-1] == "" + + # Find pauses that are too long + mask = (words == "") & (end_times - start_times >= hparams.silence_min_duration_split) + mask[0] = mask[-1] = True + breaks = np.where(mask)[0] + + # Profile the noise from the silences and perform noise reduction on the waveform + silence_times = [[start_times[i], end_times[i]] for i in breaks] + silence_times = (np.array(silence_times) * hparams.sample_rate).astype(np.int) + noisy_wav = np.concatenate([wav[stime[0]:stime[1]] for stime in silence_times]) + if len(noisy_wav) > hparams.sample_rate * 0.02: + profile = logmmse.profile_noise(noisy_wav, hparams.sample_rate) + wav = logmmse.denoise(wav, profile, eta=0) + + # Re-attach segments that are too short + segments = list(zip(breaks[:-1], breaks[1:])) + segment_durations = [start_times[end] - end_times[start] for start, end in segments] + i = 0 + while i < len(segments) and len(segments) > 1: + if segment_durations[i] < hparams.utterance_min_duration: + # See if the segment can be re-attached with the right or the left segment + left_duration = float("inf") if i == 0 else segment_durations[i - 1] + right_duration = float("inf") if i == len(segments) - 1 else segment_durations[i + 1] + joined_duration = segment_durations[i] + min(left_duration, right_duration) + + # Do not re-attach if it causes the joined utterance to be too long + if joined_duration > hparams.hop_size * hparams.max_mel_frames / hparams.sample_rate: + i += 1 + continue + + # Re-attach the segment with the neighbour of shortest duration + j = i - 1 if left_duration <= right_duration else i + segments[j] = (segments[j][0], segments[j + 1][1]) + segment_durations[j] = joined_duration + del segments[j + 1], segment_durations[j + 1] + else: + i += 1 + + # Split the utterance + segment_times = [[end_times[start], start_times[end]] for start, end in segments] + segment_times = (np.array(segment_times) * hparams.sample_rate).astype(np.int) + wavs = [wav[segment_time[0]:segment_time[1]] for segment_time in segment_times] + texts = [" ".join(words[start + 1:end]).replace(" ", " ") for start, end in segments] + + # # DEBUG: play the audio segments (run with -n=1) + # import sounddevice as sd + # if len(wavs) > 1: + # print("This sentence was split in %d segments:" % len(wavs)) + # else: + # print("There are no silences long enough for this sentence to be split:") + # for wav, text in zip(wavs, texts): + # # Pad the waveform with 1 second of silence because sounddevice tends to cut them early + # # when playing them. You shouldn't need to do that in your parsers. + # wav = np.concatenate((wav, [0] * 16000)) + # print("\t%s" % text) + # sd.play(wav, 16000, blocking=True) + # print("") + + return wavs, texts + + +def process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str, + skip_existing: bool, hparams): + ## FOR REFERENCE: + # For you not to lose your head if you ever wish to change things here or implement your own + # synthesizer. + # - Both the audios and the mel spectrograms are saved as numpy arrays + # - There is no processing done to the audios that will be saved to disk beyond volume + # normalization (in split_on_silences) + # - However, pre-emphasis is applied to the audios before computing the mel spectrogram. This + # is why we re-apply it on the audio on the side of the vocoder. + # - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved + # without extra padding. This means that you won't have an exact relation between the length + # of the wav and of the mel spectrogram. See the vocoder data loader. + + + # Skip existing utterances if needed + mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename) + wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename) + if skip_existing and mel_fpath.exists() and wav_fpath.exists(): + return None + + # Trim silence + if hparams.trim_silence: + wav = encoder.preprocess_wav(wav, normalize=False, trim_silence=True) + + # Skip utterances that are too short + if len(wav) < hparams.utterance_min_duration * hparams.sample_rate: + return None + + # Compute the mel spectrogram + mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32) + mel_frames = mel_spectrogram.shape[1] + + # Skip utterances that are too long + if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length: + return None + + # Write the spectrogram, embed and audio to disk + np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False) + np.save(wav_fpath, wav, allow_pickle=False) + + # Return a tuple describing this training example + return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text + + +def embed_utterance(fpaths, encoder_model_fpath): + if not encoder.is_loaded(): + encoder.load_model(encoder_model_fpath) + + # Compute the speaker embedding of the utterance + wav_fpath, embed_fpath = fpaths + wav = np.load(wav_fpath) + wav = encoder.preprocess_wav(wav) + embed = encoder.embed_utterance(wav) + np.save(embed_fpath, embed, allow_pickle=False) + + +def create_embeddings(synthesizer_root: Path, encoder_model_fpath: Path, n_processes: int): + wav_dir = synthesizer_root.joinpath("audio") + metadata_fpath = synthesizer_root.joinpath("train.txt") + assert wav_dir.exists() and metadata_fpath.exists() + embed_dir = synthesizer_root.joinpath("embeds") + embed_dir.mkdir(exist_ok=True) + + # Gather the input wave filepath and the target output embed filepath + with metadata_fpath.open("r") as metadata_file: + metadata = [line.split("|") for line in metadata_file] + fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata] + + # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here. + # Embed the utterances in separate threads + func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath) + job = Pool(n_processes).imap(func, fpaths) + list(tqdm(job, "Embedding", len(fpaths), unit="utterances")) + diff --git a/synthesizer/synthesize.py b/synthesizer/synthesize.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc7dc2678e85006b9f66d910fcae3e307c521a8 --- /dev/null +++ b/synthesizer/synthesize.py @@ -0,0 +1,97 @@ +import torch +from torch.utils.data import DataLoader +from synthesizer.hparams import hparams_debug_string +from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer +from synthesizer.models.tacotron import Tacotron +from synthesizer.utils.text import text_to_sequence +from synthesizer.utils.symbols import symbols +import numpy as np +from pathlib import Path +from tqdm import tqdm +import platform + +def run_synthesis(in_dir, out_dir, model_dir, hparams): + # This generates ground truth-aligned mels for vocoder training + synth_dir = Path(out_dir).joinpath("mels_gta") + synth_dir.mkdir(exist_ok=True) + print(hparams_debug_string()) + + # Check for GPU + if torch.cuda.is_available(): + device = torch.device("cuda") + if hparams.synthesis_batch_size % torch.cuda.device_count() != 0: + raise ValueError("`hparams.synthesis_batch_size` must be evenly divisible by n_gpus!") + else: + device = torch.device("cpu") + print("Synthesizer using device:", device) + + # Instantiate Tacotron model + model = Tacotron(embed_dims=hparams.tts_embed_dims, + num_chars=len(symbols), + encoder_dims=hparams.tts_encoder_dims, + decoder_dims=hparams.tts_decoder_dims, + n_mels=hparams.num_mels, + fft_bins=hparams.num_mels, + postnet_dims=hparams.tts_postnet_dims, + encoder_K=hparams.tts_encoder_K, + lstm_dims=hparams.tts_lstm_dims, + postnet_K=hparams.tts_postnet_K, + num_highways=hparams.tts_num_highways, + dropout=0., # Use zero dropout for gta mels + stop_threshold=hparams.tts_stop_threshold, + speaker_embedding_size=hparams.speaker_embedding_size).to(device) + + # Load the weights + model_dir = Path(model_dir) + model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt") + print("\nLoading weights at %s" % model_fpath) + model.load(model_fpath) + print("Tacotron weights loaded from step %d" % model.step) + + # Synthesize using same reduction factor as the model is currently trained + r = np.int32(model.r) + + # Set model to eval mode (disable gradient and zoneout) + model.eval() + + # Initialize the dataset + in_dir = Path(in_dir) + metadata_fpath = in_dir.joinpath("train.txt") + mel_dir = in_dir.joinpath("mels") + embed_dir = in_dir.joinpath("embeds") + + dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams) + data_loader = DataLoader(dataset, + collate_fn=lambda batch: collate_synthesizer(batch, r, hparams), + batch_size=hparams.synthesis_batch_size, + num_workers=2 if platform.system() != "Windows" else 0, + shuffle=False, + pin_memory=True) + + # Generate GTA mels + meta_out_fpath = Path(out_dir).joinpath("synthesized.txt") + with open(meta_out_fpath, "w") as file: + for i, (texts, mels, embeds, idx) in tqdm(enumerate(data_loader), total=len(data_loader)): + texts = texts.to(device) + mels = mels.to(device) + embeds = embeds.to(device) + + # Parallelize model onto GPUS using workaround due to python bug + if device.type == "cuda" and torch.cuda.device_count() > 1: + _, mels_out, _ = data_parallel_workaround(model, texts, mels, embeds) + else: + _, mels_out, _, _ = model(texts, mels, embeds) + + for j, k in enumerate(idx): + # Note: outputs mel-spectrogram files and target ones have same names, just different folders + mel_filename = Path(synth_dir).joinpath(dataset.metadata[k][1]) + mel_out = mels_out[j].detach().cpu().numpy().T + + # Use the length of the ground truth mel to remove padding from the generated mels + mel_out = mel_out[:int(dataset.metadata[k][4])] + + # Write the spectrogram to disk + np.save(mel_filename, mel_out, allow_pickle=False) + + # Write metadata into the synthesized file + file.write("|".join(dataset.metadata[k])) diff --git a/synthesizer/synthesizer_dataset.py b/synthesizer/synthesizer_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9d552d16d0b6757871189037bf0b981c8dfebbaf --- /dev/null +++ b/synthesizer/synthesizer_dataset.py @@ -0,0 +1,92 @@ +import torch +from torch.utils.data import Dataset +import numpy as np +from pathlib import Path +from synthesizer.utils.text import text_to_sequence + + +class SynthesizerDataset(Dataset): + def __init__(self, metadata_fpath: Path, mel_dir: Path, embed_dir: Path, hparams): + print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, embed_dir)) + + with metadata_fpath.open("r") as metadata_file: + metadata = [line.split("|") for line in metadata_file] + + mel_fnames = [x[1] for x in metadata if int(x[4])] + mel_fpaths = [mel_dir.joinpath(fname) for fname in mel_fnames] + embed_fnames = [x[2] for x in metadata if int(x[4])] + embed_fpaths = [embed_dir.joinpath(fname) for fname in embed_fnames] + self.samples_fpaths = list(zip(mel_fpaths, embed_fpaths)) + self.samples_texts = [x[5].strip() for x in metadata if int(x[4])] + self.metadata = metadata + self.hparams = hparams + + print("Found %d samples" % len(self.samples_fpaths)) + + def __getitem__(self, index): + # Sometimes index may be a list of 2 (not sure why this happens) + # If that is the case, return a single item corresponding to first element in index + if index is list: + index = index[0] + + mel_path, embed_path = self.samples_fpaths[index] + mel = np.load(mel_path).T.astype(np.float32) + + # Load the embed + embed = np.load(embed_path) + + # Get the text and clean it + text = text_to_sequence(self.samples_texts[index], self.hparams.tts_cleaner_names) + + # Convert the list returned by text_to_sequence to a numpy array + text = np.asarray(text).astype(np.int32) + + return text, mel.astype(np.float32), embed.astype(np.float32), index + + def __len__(self): + return len(self.samples_fpaths) + + +def collate_synthesizer(batch, r, hparams): + # Text + x_lens = [len(x[0]) for x in batch] + max_x_len = max(x_lens) + + chars = [pad1d(x[0], max_x_len) for x in batch] + chars = np.stack(chars) + + # Mel spectrogram + spec_lens = [x[1].shape[-1] for x in batch] + max_spec_len = max(spec_lens) + 1 + if max_spec_len % r != 0: + max_spec_len += r - max_spec_len % r + + # WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence + # By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence. + if hparams.symmetric_mels: + mel_pad_value = -1 * hparams.max_abs_value + else: + mel_pad_value = 0 + + mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch] + mel = np.stack(mel) + + # Speaker embedding (SV2TTS) + embeds = [x[2] for x in batch] + + # Index (for vocoder preprocessing) + indices = [x[3] for x in batch] + + + # Convert all to tensor + chars = torch.tensor(chars).long() + mel = torch.tensor(mel) + embeds = torch.tensor(embeds) + + return chars, mel, embeds, indices + +def pad1d(x, max_len, pad_value=0): + return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value) + +def pad2d(x, max_len, pad_value=0): + return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value) diff --git a/synthesizer/train.py b/synthesizer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..a136cf9b38538ca7dc428adf209c0cbb40e890d7 --- /dev/null +++ b/synthesizer/train.py @@ -0,0 +1,269 @@ +import torch +import torch.nn.functional as F +from torch import optim +from torch.utils.data import DataLoader +from synthesizer import audio +from synthesizer.models.tacotron import Tacotron +from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer +from synthesizer.utils import ValueWindow, data_parallel_workaround +from synthesizer.utils.plot import plot_spectrogram +from synthesizer.utils.symbols import symbols +from synthesizer.utils.text import sequence_to_text +from vocoder.display import * +from datetime import datetime +import numpy as np +from pathlib import Path +import sys +import time +import platform + + +def np_now(x: torch.Tensor): return x.detach().cpu().numpy() + +def time_string(): + return datetime.now().strftime("%Y-%m-%d %H:%M") + +def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, + backup_every: int, force_restart:bool, hparams): + + syn_dir = Path(syn_dir) + models_dir = Path(models_dir) + models_dir.mkdir(exist_ok=True) + + model_dir = models_dir.joinpath(run_id) + plot_dir = model_dir.joinpath("plots") + wav_dir = model_dir.joinpath("wavs") + mel_output_dir = model_dir.joinpath("mel-spectrograms") + meta_folder = model_dir.joinpath("metas") + model_dir.mkdir(exist_ok=True) + plot_dir.mkdir(exist_ok=True) + wav_dir.mkdir(exist_ok=True) + mel_output_dir.mkdir(exist_ok=True) + meta_folder.mkdir(exist_ok=True) + + weights_fpath = model_dir.joinpath(run_id).with_suffix(".pt") + metadata_fpath = syn_dir.joinpath("train.txt") + + print("Checkpoint path: {}".format(weights_fpath)) + print("Loading training data from: {}".format(metadata_fpath)) + print("Using model: Tacotron") + + # Book keeping + step = 0 + time_window = ValueWindow(100) + loss_window = ValueWindow(100) + + + # From WaveRNN/train_tacotron.py + if torch.cuda.is_available(): + device = torch.device("cuda") + + for session in hparams.tts_schedule: + _, _, _, batch_size = session + if batch_size % torch.cuda.device_count() != 0: + raise ValueError("`batch_size` must be evenly divisible by n_gpus!") + else: + device = torch.device("cpu") + print("Using device:", device) + + # Instantiate Tacotron Model + print("\nInitialising Tacotron Model...\n") + model = Tacotron(embed_dims=hparams.tts_embed_dims, + num_chars=len(symbols), + encoder_dims=hparams.tts_encoder_dims, + decoder_dims=hparams.tts_decoder_dims, + n_mels=hparams.num_mels, + fft_bins=hparams.num_mels, + postnet_dims=hparams.tts_postnet_dims, + encoder_K=hparams.tts_encoder_K, + lstm_dims=hparams.tts_lstm_dims, + postnet_K=hparams.tts_postnet_K, + num_highways=hparams.tts_num_highways, + dropout=hparams.tts_dropout, + stop_threshold=hparams.tts_stop_threshold, + speaker_embedding_size=hparams.speaker_embedding_size).to(device) + + # Initialize the optimizer + optimizer = optim.Adam(model.parameters()) + + # Load the weights + if force_restart or not weights_fpath.exists(): + print("\nStarting the training of Tacotron from scratch\n") + model.save(weights_fpath) + + # Embeddings metadata + char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv") + with open(char_embedding_fpath, "w", encoding="utf-8") as f: + for symbol in symbols: + if symbol == " ": + symbol = "\\s" # For visual purposes, swap space with \s + + f.write("{}\n".format(symbol)) + + else: + print("\nLoading weights at %s" % weights_fpath) + model.load(weights_fpath, optimizer) + print("Tacotron weights loaded from step %d" % model.step) + + # Initialize the dataset + metadata_fpath = syn_dir.joinpath("train.txt") + mel_dir = syn_dir.joinpath("mels") + embed_dir = syn_dir.joinpath("embeds") + dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams) + test_loader = DataLoader(dataset, + batch_size=1, + shuffle=True, + pin_memory=True) + + for i, session in enumerate(hparams.tts_schedule): + current_step = model.get_step() + + r, lr, max_step, batch_size = session + + training_steps = max_step - current_step + + # Do we need to change to the next session? + if current_step >= max_step: + # Are there no further sessions than the current one? + if i == len(hparams.tts_schedule) - 1: + # We have completed training. Save the model and exit + model.save(weights_fpath, optimizer) + break + else: + # There is a following session, go to it + continue + + model.r = r + + # Begin the training + simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"), + ("Batch Size", batch_size), + ("Learning Rate", lr), + ("Outputs/Step (r)", model.r)]) + + for p in optimizer.param_groups: + p["lr"] = lr + + data_loader = DataLoader(dataset, + collate_fn=lambda batch: collate_synthesizer(batch, r, hparams), + batch_size=batch_size, + num_workers=2 if platform.system() != "Windows" else 0, + shuffle=True, + pin_memory=True) + + total_iters = len(dataset) + steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32) + epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32) + + for epoch in range(1, epochs+1): + for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1): + start_time = time.time() + + # Generate stop tokens for training + stop = torch.ones(mels.shape[0], mels.shape[2]) + for j, k in enumerate(idx): + stop[j, :int(dataset.metadata[k][4])-1] = 0 + + texts = texts.to(device) + mels = mels.to(device) + embeds = embeds.to(device) + stop = stop.to(device) + + # Forward pass + # Parallelize model onto GPUS using workaround due to python bug + if device.type == "cuda" and torch.cuda.device_count() > 1: + m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts, + mels, embeds) + else: + m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds) + + # Backward pass + m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels) + m2_loss = F.mse_loss(m2_hat, mels) + stop_loss = F.binary_cross_entropy(stop_pred, stop) + + loss = m1_loss + m2_loss + stop_loss + + optimizer.zero_grad() + loss.backward() + + if hparams.tts_clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm) + if np.isnan(grad_norm.cpu()): + print("grad_norm was NaN!") + + optimizer.step() + + time_window.append(time.time() - start_time) + loss_window.append(loss.item()) + + step = model.get_step() + k = step // 1000 + + msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | " + stream(msg) + + # Backup or save model as appropriate + if backup_every != 0 and step % backup_every == 0 : + backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k)) + model.save(backup_fpath, optimizer) + + if save_every != 0 and step % save_every == 0 : + # Must save latest optimizer state to ensure that resuming training + # doesn't produce artifacts + model.save(weights_fpath, optimizer) + + # Evaluate model to generate samples + epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done + step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps + if epoch_eval or step_eval: + for sample_idx in range(hparams.tts_eval_num_samples): + # At most, generate samples equal to number in the batch + if sample_idx + 1 <= len(texts): + # Remove padding from mels using frame length in metadata + mel_length = int(dataset.metadata[idx[sample_idx]][4]) + mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length] + target_spectrogram = np_now(mels[sample_idx]).T[:mel_length] + attention_len = mel_length // model.r + + eval_model(attention=np_now(attention[sample_idx][:, :attention_len]), + mel_prediction=mel_prediction, + target_spectrogram=target_spectrogram, + input_seq=np_now(texts[sample_idx]), + step=step, + plot_dir=plot_dir, + mel_output_dir=mel_output_dir, + wav_dir=wav_dir, + sample_num=sample_idx + 1, + loss=loss, + hparams=hparams) + + # Break out of loop to update training schedule + if step >= max_step: + break + + # Add line break after every epoch + print("") + +def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step, + plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams): + # Save some results for evaluation + attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num))) + save_attention(attention, attention_path) + + # save predicted mel spectrogram to disk (debug) + mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num)) + np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False) + + # save griffin lim inverted wav for debug (mel -> wav) + wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams) + wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num)) + audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate) + + # save real and predicted mel-spectrogram plot to disk (control purposes) + spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num)) + title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss) + plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str, + target_spectrogram=target_spectrogram, + max_len=target_spectrogram.size // hparams.num_mels) + print("Input at step {}: {}".format(step, sequence_to_text(input_seq))) diff --git a/synthesizer/utils/__init__.py b/synthesizer/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae3e48110e61231acf1e666e5fa76af5e4ebdcd --- /dev/null +++ b/synthesizer/utils/__init__.py @@ -0,0 +1,45 @@ +import torch + + +_output_ref = None +_replicas_ref = None + +def data_parallel_workaround(model, *input): + global _output_ref + global _replicas_ref + device_ids = list(range(torch.cuda.device_count())) + output_device = device_ids[0] + replicas = torch.nn.parallel.replicate(model, device_ids) + # input.shape = (num_args, batch, ...) + inputs = torch.nn.parallel.scatter(input, device_ids) + # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...) + replicas = replicas[:len(inputs)] + outputs = torch.nn.parallel.parallel_apply(replicas, inputs) + y_hat = torch.nn.parallel.gather(outputs, output_device) + _output_ref = outputs + _replicas_ref = replicas + return y_hat + + +class ValueWindow(): + def __init__(self, window_size=100): + self._window_size = window_size + self._values = [] + + def append(self, x): + self._values = self._values[-(self._window_size - 1):] + [x] + + @property + def sum(self): + return sum(self._values) + + @property + def count(self): + return len(self._values) + + @property + def average(self): + return self.sum / max(1, self.count) + + def reset(self): + self._values = [] diff --git a/synthesizer/utils/_cmudict.py b/synthesizer/utils/_cmudict.py new file mode 100644 index 0000000000000000000000000000000000000000..2cef1f896d4fb78478884fe8e810956998d5e3b3 --- /dev/null +++ b/synthesizer/utils/_cmudict.py @@ -0,0 +1,62 @@ +import re + +valid_symbols = [ + "AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2", + "AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2", + "B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY", + "EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1", + "IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0", + "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW", + "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH" +] + +_valid_symbol_set = set(valid_symbols) + + +class CMUDict: + """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" + def __init__(self, file_or_path, keep_ambiguous=True): + if isinstance(file_or_path, str): + with open(file_or_path, encoding="latin-1") as f: + entries = _parse_cmudict(f) + else: + entries = _parse_cmudict(file_or_path) + if not keep_ambiguous: + entries = {word: pron for word, pron in entries.items() if len(pron) == 1} + self._entries = entries + + + def __len__(self): + return len(self._entries) + + + def lookup(self, word): + """Returns list of ARPAbet pronunciations of the given word.""" + return self._entries.get(word.upper()) + + + +_alt_re = re.compile(r"\([0-9]+\)") + + +def _parse_cmudict(file): + cmudict = {} + for line in file: + if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): + parts = line.split(" ") + word = re.sub(_alt_re, "", parts[0]) + pronunciation = _get_pronunciation(parts[1]) + if pronunciation: + if word in cmudict: + cmudict[word].append(pronunciation) + else: + cmudict[word] = [pronunciation] + return cmudict + + +def _get_pronunciation(s): + parts = s.strip().split(" ") + for part in parts: + if part not in _valid_symbol_set: + return None + return " ".join(parts) diff --git a/synthesizer/utils/cleaners.py b/synthesizer/utils/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..eab63f05c9cc7cc0b583992eac94058097f3c191 --- /dev/null +++ b/synthesizer/utils/cleaners.py @@ -0,0 +1,88 @@ +""" +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You"ll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +""" + +import re +from unidecode import unidecode +from .numbers import normalize_numbers + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r"\s+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + """lowercase input tokens.""" + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + """Pipeline for non-English text that transliterates to ASCII.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + """Pipeline for English text, including number and abbreviation expansion.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + return text diff --git a/synthesizer/utils/numbers.py b/synthesizer/utils/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..75020a0bd732830f603d7c7d250c9e087033cc24 --- /dev/null +++ b/synthesizer/utils/numbers.py @@ -0,0 +1,68 @@ +import re +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) + else: + return "zero dollars" + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + else: + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/synthesizer/utils/plot.py b/synthesizer/utils/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..f47d2713d4daa6cf387b37970fd879548abc8d88 --- /dev/null +++ b/synthesizer/utils/plot.py @@ -0,0 +1,76 @@ +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + + +def split_title_line(title_text, max_words=5): + """ + A function that splits any string based on specific character + (returning it with the string), with maximum number of words on it + """ + seq = title_text.split() + return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)]) + +def plot_alignment(alignment, path, title=None, split_title=False, max_len=None): + if max_len is not None: + alignment = alignment[:, :max_len] + + fig = plt.figure(figsize=(8, 6)) + ax = fig.add_subplot(111) + + im = ax.imshow( + alignment, + aspect="auto", + origin="lower", + interpolation="none") + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + + if split_title: + title = split_title_line(title) + + plt.xlabel(xlabel) + plt.title(title) + plt.ylabel("Encoder timestep") + plt.tight_layout() + plt.savefig(path, format="png") + plt.close() + + +def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False): + if max_len is not None: + target_spectrogram = target_spectrogram[:max_len] + pred_spectrogram = pred_spectrogram[:max_len] + + if split_title: + title = split_title_line(title) + + fig = plt.figure(figsize=(10, 8)) + # Set common labels + fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16) + + #target spectrogram subplot + if target_spectrogram is not None: + ax1 = fig.add_subplot(311) + ax2 = fig.add_subplot(312) + + if auto_aspect: + im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none") + else: + im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none") + ax1.set_title("Target Mel-Spectrogram") + fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1) + ax2.set_title("Predicted Mel-Spectrogram") + else: + ax2 = fig.add_subplot(211) + + if auto_aspect: + im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none") + else: + im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none") + fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2) + + plt.tight_layout() + plt.savefig(path, format="png") + plt.close() diff --git a/synthesizer/utils/symbols.py b/synthesizer/utils/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..132d3a612c3b13e2ada905a706001cff29a4f63a --- /dev/null +++ b/synthesizer/utils/symbols.py @@ -0,0 +1,17 @@ +""" +Defines the set of symbols used in text input to the model. + +The default is a set of ASCII characters that works well for English or text that has been run +through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. +""" +# from . import cmudict + +_pad = "_" +_eos = "~" +_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'\"(),-.:;? " + +# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): +#_arpabet = ["@' + s for s in cmudict.valid_symbols] + +# Export all symbols: +symbols = [_pad, _eos] + list(_characters) #+ _arpabet diff --git a/synthesizer/utils/text.py b/synthesizer/utils/text.py new file mode 100644 index 0000000000000000000000000000000000000000..29372174aec95cd2eac1ea40096fcc148f532b07 --- /dev/null +++ b/synthesizer/utils/text.py @@ -0,0 +1,74 @@ +from .symbols import symbols +from . import cleaners +import re + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} + +# Regular expression matching text enclosed in curly braces: +_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") + + +def text_to_sequence(text, cleaner_names): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + + The text can optionally have ARPAbet sequences enclosed in curly braces embedded + in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." + + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [] + + # Check for curly braces and treat their contents as ARPAbet: + while len(text): + m = _curly_re.match(text) + if not m: + sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) + break + sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) + sequence += _arpabet_to_sequence(m.group(2)) + text = m.group(3) + + # Append EOS token + sequence.append(_symbol_to_id["~"]) + return sequence + + +def sequence_to_text(sequence): + """Converts a sequence of IDs back to a string""" + result = "" + for symbol_id in sequence: + if symbol_id in _id_to_symbol: + s = _id_to_symbol[symbol_id] + # Enclose ARPAbet back in curly braces: + if len(s) > 1 and s[0] == "@": + s = "{%s}" % s[1:] + result += s + return result.replace("}{", " ") + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception("Unknown cleaner: %s" % name) + text = cleaner(text) + return text + + +def _symbols_to_sequence(symbols): + return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] + + +def _arpabet_to_sequence(text): + return _symbols_to_sequence(["@" + s for s in text.split()]) + + +def _should_keep_symbol(s): + return s in _symbol_to_id and s not in ("_", "~") diff --git a/synthesizer_preprocess_audio.py b/synthesizer_preprocess_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..fd4d01d476d77391322aef9d9d5a005adb1f5c15 --- /dev/null +++ b/synthesizer_preprocess_audio.py @@ -0,0 +1,59 @@ +from synthesizer.preprocess import preprocess_dataset +from synthesizer.hparams import hparams +from utils.argutils import print_args +from pathlib import Path +import argparse + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Preprocesses audio files from datasets, encodes them as mel spectrograms " + "and writes them to the disk. Audio files are also saved, to be used by the " + "vocoder for training.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("datasets_root", type=Path, help=\ + "Path to the directory containing your LibriSpeech/TTS datasets.") + parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\ + "Path to the output directory that will contain the mel spectrograms, the audios and the " + "embeds. Defaults to /SV2TTS/synthesizer/") + parser.add_argument("-n", "--n_processes", type=int, default=None, help=\ + "Number of processes in parallel.") + parser.add_argument("-s", "--skip_existing", action="store_true", help=\ + "Whether to overwrite existing files with the same name. Useful if the preprocessing was " + "interrupted.") + parser.add_argument("--hparams", type=str, default="", help=\ + "Hyperparameter overrides as a comma-separated list of name-value pairs") + parser.add_argument("--no_trim", action="store_true", help=\ + "Preprocess audio without trimming silences (not recommended).") + parser.add_argument("--no_alignments", action="store_true", help=\ + "Use this option when dataset does not include alignments\ + (these are used to split long audio files into sub-utterances.)") + parser.add_argument("--datasets_name", type=str, default="LibriSpeech", help=\ + "Name of the dataset directory to process.") + parser.add_argument("--subfolders", type=str, default="train-clean-100, train-clean-360", help=\ + "Comma-separated list of subfolders to process inside your dataset directory") + args = parser.parse_args() + + # Process the arguments + if not hasattr(args, "out_dir"): + args.out_dir = args.datasets_root.joinpath("SV2TTS", "synthesizer") + + # Create directories + assert args.datasets_root.exists() + args.out_dir.mkdir(exist_ok=True, parents=True) + + # Verify webrtcvad is available + if not args.no_trim: + try: + import webrtcvad + except: + raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables " + "noise removal and is recommended. Please install and try again. If installation fails, " + "use --no_trim to disable this error message.") + del args.no_trim + + # Preprocess the dataset + print_args(args, parser) + args.hparams = hparams.parse(args.hparams) + preprocess_dataset(**vars(args)) diff --git a/synthesizer_preprocess_embeds.py b/synthesizer_preprocess_embeds.py new file mode 100644 index 0000000000000000000000000000000000000000..94f864d5d3c36c6177b211f5818e7c920a41cd8c --- /dev/null +++ b/synthesizer_preprocess_embeds.py @@ -0,0 +1,25 @@ +from synthesizer.preprocess import create_embeddings +from utils.argutils import print_args +from pathlib import Path +import argparse + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Creates embeddings for the synthesizer from the LibriSpeech utterances.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("synthesizer_root", type=Path, help=\ + "Path to the synthesizer training data that contains the audios and the train.txt file. " + "If you let everything as default, it should be /SV2TTS/synthesizer/.") + parser.add_argument("-e", "--encoder_model_fpath", type=Path, + default="encoder/saved_models/pretrained.pt", help=\ + "Path your trained encoder model.") + parser.add_argument("-n", "--n_processes", type=int, default=4, help= \ + "Number of parallel processes. An encoder is created for each, so you may need to lower " + "this value on GPUs with low memory. Set it to 1 if CUDA is unhappy.") + args = parser.parse_args() + + # Preprocess the dataset + print_args(args, parser) + create_embeddings(**vars(args)) diff --git a/synthesizer_train.py b/synthesizer_train.py new file mode 100644 index 0000000000000000000000000000000000000000..2743d590d882f209734b68921b84a9d23492942c --- /dev/null +++ b/synthesizer_train.py @@ -0,0 +1,35 @@ +from synthesizer.hparams import hparams +from synthesizer.train import train +from utils.argutils import print_args +import argparse + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("run_id", type=str, help= \ + "Name for this model instance. If a model state from the same run ID was previously " + "saved, the training will restart from there. Pass -f to overwrite saved states and " + "restart from scratch.") + parser.add_argument("syn_dir", type=str, default=argparse.SUPPRESS, help= \ + "Path to the synthesizer directory that contains the ground truth mel spectrograms, " + "the wavs and the embeds.") + parser.add_argument("-m", "--models_dir", type=str, default="synthesizer/saved_models/", help=\ + "Path to the output directory that will contain the saved model weights and the logs.") + parser.add_argument("-s", "--save_every", type=int, default=1000, help= \ + "Number of steps between updates of the model on the disk. Set to 0 to never save the " + "model.") + parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \ + "Number of steps between backups of the model. Set to 0 to never make backups of the " + "model.") + parser.add_argument("-f", "--force_restart", action="store_true", help= \ + "Do not load any saved model and restart from scratch.") + parser.add_argument("--hparams", default="", + help="Hyperparameter overrides as a comma-separated list of name=value " + "pairs") + args = parser.parse_args() + print_args(args, parser) + + args.hparams = hparams.parse(args.hparams) + + # Run the training + train(**vars(args)) diff --git a/toolbox/__init__.py b/toolbox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..531d6adef076007afd6116eb6472485f540e80de --- /dev/null +++ b/toolbox/__init__.py @@ -0,0 +1,357 @@ +from toolbox.ui import UI +from encoder import inference as encoder +from synthesizer.inference import Synthesizer +from vocoder import inference as vocoder +from pathlib import Path +from time import perf_counter as timer +from toolbox.utterance import Utterance +import numpy as np +import traceback +import sys +import torch +import librosa +from audioread.exceptions import NoBackendError + +# Use this directory structure for your datasets, or modify it to fit your needs +recognized_datasets = [ + "LibriSpeech/dev-clean", + "LibriSpeech/dev-other", + "LibriSpeech/test-clean", + "LibriSpeech/test-other", + "LibriSpeech/train-clean-100", + "LibriSpeech/train-clean-360", + "LibriSpeech/train-other-500", + "LibriTTS/dev-clean", + "LibriTTS/dev-other", + "LibriTTS/test-clean", + "LibriTTS/test-other", + "LibriTTS/train-clean-100", + "LibriTTS/train-clean-360", + "LibriTTS/train-other-500", + "LJSpeech-1.1", + "VoxCeleb1/wav", + "VoxCeleb1/test_wav", + "VoxCeleb2/dev/aac", + "VoxCeleb2/test/aac", + "VCTK-Corpus/wav48", +] + +#Maximum of generated wavs to keep on memory +MAX_WAVES = 15 + +class Toolbox: + def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, seed, no_mp3_support): + if not no_mp3_support: + try: + librosa.load("samples/6829_00000.mp3") + except NoBackendError: + print("Librosa will be unable to open mp3 files if additional software is not installed.\n" + "Please install ffmpeg or add the '--no_mp3_support' option to proceed without support for mp3 files.") + exit(-1) + self.no_mp3_support = no_mp3_support + sys.excepthook = self.excepthook + self.datasets_root = datasets_root + self.utterances = set() + self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav + + self.synthesizer = None # type: Synthesizer + self.current_wav = None + self.waves_list = [] + self.waves_count = 0 + self.waves_namelist = [] + + # Check for webrtcvad (enables removal of silences in vocoder output) + try: + import webrtcvad + self.trim_silences = True + except: + self.trim_silences = False + + # Initialize the events and the interface + self.ui = UI() + self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed) + self.setup_events() + self.ui.start() + + def excepthook(self, exc_type, exc_value, exc_tb): + traceback.print_exception(exc_type, exc_value, exc_tb) + self.ui.log("Exception: %s" % exc_value) + + def setup_events(self): + # Dataset, speaker and utterance selection + self.ui.browser_load_button.clicked.connect(lambda: self.load_from_browser()) + random_func = lambda level: lambda: self.ui.populate_browser(self.datasets_root, + recognized_datasets, + level) + self.ui.random_dataset_button.clicked.connect(random_func(0)) + self.ui.random_speaker_button.clicked.connect(random_func(1)) + self.ui.random_utterance_button.clicked.connect(random_func(2)) + self.ui.dataset_box.currentIndexChanged.connect(random_func(1)) + self.ui.speaker_box.currentIndexChanged.connect(random_func(2)) + + # Model selection + self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder) + def func(): + self.synthesizer = None + self.ui.synthesizer_box.currentIndexChanged.connect(func) + self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder) + + # Utterance selection + func = lambda: self.load_from_browser(self.ui.browse_file()) + self.ui.browser_browse_button.clicked.connect(func) + func = lambda: self.ui.draw_utterance(self.ui.selected_utterance, "current") + self.ui.utterance_history.currentIndexChanged.connect(func) + func = lambda: self.ui.play(self.ui.selected_utterance.wav, Synthesizer.sample_rate) + self.ui.play_button.clicked.connect(func) + self.ui.stop_button.clicked.connect(self.ui.stop) + self.ui.record_button.clicked.connect(self.record) + + #Audio + self.ui.setup_audio_devices(Synthesizer.sample_rate) + + #Wav playback & save + func = lambda: self.replay_last_wav() + self.ui.replay_wav_button.clicked.connect(func) + func = lambda: self.export_current_wave() + self.ui.export_wav_button.clicked.connect(func) + self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav) + + # Generation + func = lambda: self.synthesize() or self.vocode() + self.ui.generate_button.clicked.connect(func) + self.ui.synthesize_button.clicked.connect(self.synthesize) + self.ui.vocode_button.clicked.connect(self.vocode) + self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox) + + # UMAP legend + self.ui.clear_button.clicked.connect(self.clear_utterances) + + def set_current_wav(self, index): + self.current_wav = self.waves_list[index] + + def export_current_wave(self): + self.ui.save_audio_file(self.current_wav, Synthesizer.sample_rate) + + def replay_last_wav(self): + self.ui.play(self.current_wav, Synthesizer.sample_rate) + + def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, seed): + self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True) + self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir) + self.ui.populate_gen_options(seed, self.trim_silences) + + def load_from_browser(self, fpath=None): + if fpath is None: + fpath = Path(self.datasets_root, + self.ui.current_dataset_name, + self.ui.current_speaker_name, + self.ui.current_utterance_name) + name = str(fpath.relative_to(self.datasets_root)) + speaker_name = self.ui.current_dataset_name + '_' + self.ui.current_speaker_name + + # Select the next utterance + if self.ui.auto_next_checkbox.isChecked(): + self.ui.browser_select_next() + elif fpath == "": + return + else: + name = fpath.name + speaker_name = fpath.parent.name + + if fpath.suffix.lower() == ".mp3" and self.no_mp3_support: + self.ui.log("Error: No mp3 file argument was passed but an mp3 file was used") + return + + # Get the wav from the disk. We take the wav with the vocoder/synthesizer format for + # playback, so as to have a fair comparison with the generated audio + wav = Synthesizer.load_preprocess_wav(fpath) + self.ui.log("Loaded %s" % name) + + self.add_real_utterance(wav, name, speaker_name) + + def record(self): + wav = self.ui.record_one(encoder.sampling_rate, 5) + if wav is None: + return + self.ui.play(wav, encoder.sampling_rate) + + speaker_name = "user01" + name = speaker_name + "_rec_%05d" % np.random.randint(100000) + self.add_real_utterance(wav, name, speaker_name) + + def add_real_utterance(self, wav, name, speaker_name): + # Compute the mel spectrogram + spec = Synthesizer.make_spectrogram(wav) + self.ui.draw_spec(spec, "current") + + # Compute the embedding + if not encoder.is_loaded(): + self.init_encoder() + encoder_wav = encoder.preprocess_wav(wav) + embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True) + + # Add the utterance + utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False) + self.utterances.add(utterance) + self.ui.register_utterance(utterance) + + # Plot it + self.ui.draw_embed(embed, name, "current") + self.ui.draw_umap_projections(self.utterances) + + def clear_utterances(self): + self.utterances.clear() + self.ui.draw_umap_projections(self.utterances) + + def synthesize(self): + self.ui.log("Generating the mel spectrogram...") + self.ui.set_loading(1) + + # Update the synthesizer random seed + if self.ui.random_seed_checkbox.isChecked(): + seed = int(self.ui.seed_textbox.text()) + self.ui.populate_gen_options(seed, self.trim_silences) + else: + seed = None + + if seed is not None: + torch.manual_seed(seed) + + # Synthesize the spectrogram + if self.synthesizer is None or seed is not None: + self.init_synthesizer() + + texts = self.ui.text_prompt.toPlainText().split("\n") + embed = self.ui.selected_utterance.embed + embeds = [embed] * len(texts) + specs = self.synthesizer.synthesize_spectrograms(texts, embeds) + breaks = [spec.shape[1] for spec in specs] + spec = np.concatenate(specs, axis=1) + + self.ui.draw_spec(spec, "generated") + self.current_generated = (self.ui.selected_utterance.speaker_name, spec, breaks, None) + self.ui.set_loading(0) + + def vocode(self): + speaker_name, spec, breaks, _ = self.current_generated + assert spec is not None + + # Initialize the vocoder model and make it determinstic, if user provides a seed + if self.ui.random_seed_checkbox.isChecked(): + seed = int(self.ui.seed_textbox.text()) + self.ui.populate_gen_options(seed, self.trim_silences) + else: + seed = None + + if seed is not None: + torch.manual_seed(seed) + + # Synthesize the waveform + if not vocoder.is_loaded() or seed is not None: + self.init_vocoder() + + def vocoder_progress(i, seq_len, b_size, gen_rate): + real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000 + line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \ + % (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor) + self.ui.log(line, "overwrite") + self.ui.set_loading(i, seq_len) + if self.ui.current_vocoder_fpath is not None: + self.ui.log("") + wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress) + else: + self.ui.log("Waveform generation with Griffin-Lim... ") + wav = Synthesizer.griffin_lim(spec) + self.ui.set_loading(0) + self.ui.log(" Done!", "append") + + # Add breaks + b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size) + b_starts = np.concatenate(([0], b_ends[:-1])) + wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)] + breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks) + wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)]) + + # Trim excessive silences + if self.ui.trim_silences_checkbox.isChecked(): + wav = encoder.preprocess_wav(wav) + + # Play it + wav = wav / np.abs(wav).max() * 0.97 + self.ui.play(wav, Synthesizer.sample_rate) + + # Name it (history displayed in combobox) + # TODO better naming for the combobox items? + wav_name = str(self.waves_count + 1) + + #Update waves combobox + self.waves_count += 1 + if self.waves_count > MAX_WAVES: + self.waves_list.pop() + self.waves_namelist.pop() + self.waves_list.insert(0, wav) + self.waves_namelist.insert(0, wav_name) + + self.ui.waves_cb.disconnect() + self.ui.waves_cb_model.setStringList(self.waves_namelist) + self.ui.waves_cb.setCurrentIndex(0) + self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav) + + # Update current wav + self.set_current_wav(0) + + #Enable replay and save buttons: + self.ui.replay_wav_button.setDisabled(False) + self.ui.export_wav_button.setDisabled(False) + + # Compute the embedding + # TODO: this is problematic with different sampling rates, gotta fix it + if not encoder.is_loaded(): + self.init_encoder() + encoder_wav = encoder.preprocess_wav(wav) + embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True) + + # Add the utterance + name = speaker_name + "_gen_%05d" % np.random.randint(100000) + utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, True) + self.utterances.add(utterance) + + # Plot it + self.ui.draw_embed(embed, name, "generated") + self.ui.draw_umap_projections(self.utterances) + + def init_encoder(self): + model_fpath = self.ui.current_encoder_fpath + + self.ui.log("Loading the encoder %s... " % model_fpath) + self.ui.set_loading(1) + start = timer() + encoder.load_model(model_fpath) + self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") + self.ui.set_loading(0) + + def init_synthesizer(self): + model_fpath = self.ui.current_synthesizer_fpath + + self.ui.log("Loading the synthesizer %s... " % model_fpath) + self.ui.set_loading(1) + start = timer() + self.synthesizer = Synthesizer(model_fpath) + self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") + self.ui.set_loading(0) + + def init_vocoder(self): + model_fpath = self.ui.current_vocoder_fpath + # Case of Griffin-lim + if model_fpath is None: + return + + self.ui.log("Loading the vocoder %s... " % model_fpath) + self.ui.set_loading(1) + start = timer() + vocoder.load_model(model_fpath) + self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") + self.ui.set_loading(0) + + def update_seed_textbox(self): + self.ui.update_seed_textbox() diff --git a/toolbox/ui.py b/toolbox/ui.py new file mode 100644 index 0000000000000000000000000000000000000000..d56b5740e276751f954aae1ca17e5ed485b48937 --- /dev/null +++ b/toolbox/ui.py @@ -0,0 +1,611 @@ +import matplotlib.pyplot as plt +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.figure import Figure +from PyQt5.QtCore import Qt, QStringListModel +from PyQt5.QtWidgets import * +from encoder.inference import plot_embedding_as_heatmap +from toolbox.utterance import Utterance +from pathlib import Path +from typing import List, Set +import sounddevice as sd +import soundfile as sf +import numpy as np +# from sklearn.manifold import TSNE # You can try with TSNE if you like, I prefer UMAP +from time import sleep +import umap +import sys +from warnings import filterwarnings, warn +filterwarnings("ignore") + + +colormap = np.array([ + [0, 127, 70], + [255, 0, 0], + [255, 217, 38], + [0, 135, 255], + [165, 0, 165], + [255, 167, 255], + [97, 142, 151], + [0, 255, 255], + [255, 96, 38], + [142, 76, 0], + [33, 0, 127], + [0, 0, 0], + [183, 183, 183], + [76, 255, 0], +], dtype=np.float) / 255 + +default_text = \ + "Welcome to the toolbox! To begin, load an utterance from your datasets or record one " \ + "yourself.\nOnce its embedding has been created, you can synthesize any text written here.\n" \ + "The synthesizer expects to generate " \ + "outputs that are somewhere between 5 and 12 seconds.\nTo mark breaks, write a new line. " \ + "Each line will be treated separately.\nThen, they are joined together to make the final " \ + "spectrogram. Use the vocoder to generate audio.\nThe vocoder generates almost in constant " \ + "time, so it will be more time efficient for longer inputs like this one.\nOn the left you " \ + "have the embedding projections. Load or record more utterances to see them.\nIf you have " \ + "at least 2 or 3 utterances from a same speaker, a cluster should form.\nSynthesized " \ + "utterances are of the same color as the speaker whose voice was used, but they're " \ + "represented with a cross." + + +class UI(QDialog): + min_umap_points = 4 + max_log_lines = 5 + max_saved_utterances = 20 + + def draw_utterance(self, utterance: Utterance, which): + self.draw_spec(utterance.spec, which) + self.draw_embed(utterance.embed, utterance.name, which) + + def draw_embed(self, embed, name, which): + embed_ax, _ = self.current_ax if which == "current" else self.gen_ax + embed_ax.figure.suptitle("" if embed is None else name) + + ## Embedding + # Clear the plot + if len(embed_ax.images) > 0: + embed_ax.images[0].colorbar.remove() + embed_ax.clear() + + # Draw the embed + if embed is not None: + plot_embedding_as_heatmap(embed, embed_ax) + embed_ax.set_title("embedding") + embed_ax.set_aspect("equal", "datalim") + embed_ax.set_xticks([]) + embed_ax.set_yticks([]) + embed_ax.figure.canvas.draw() + + def draw_spec(self, spec, which): + _, spec_ax = self.current_ax if which == "current" else self.gen_ax + + ## Spectrogram + # Draw the spectrogram + spec_ax.clear() + if spec is not None: + im = spec_ax.imshow(spec, aspect="auto", interpolation="none") + # spec_ax.figure.colorbar(mappable=im, shrink=0.65, orientation="horizontal", + # spec_ax=spec_ax) + spec_ax.set_title("mel spectrogram") + + spec_ax.set_xticks([]) + spec_ax.set_yticks([]) + spec_ax.figure.canvas.draw() + if which != "current": + self.vocode_button.setDisabled(spec is None) + + def draw_umap_projections(self, utterances: Set[Utterance]): + self.umap_ax.clear() + + speakers = np.unique([u.speaker_name for u in utterances]) + colors = {speaker_name: colormap[i] for i, speaker_name in enumerate(speakers)} + embeds = [u.embed for u in utterances] + + # Display a message if there aren't enough points + if len(utterances) < self.min_umap_points: + self.umap_ax.text(.5, .5, "Add %d more points to\ngenerate the projections" % + (self.min_umap_points - len(utterances)), + horizontalalignment='center', fontsize=15) + self.umap_ax.set_title("") + + # Compute the projections + else: + if not self.umap_hot: + self.log( + "Drawing UMAP projections for the first time, this will take a few seconds.") + self.umap_hot = True + + reducer = umap.UMAP(int(np.ceil(np.sqrt(len(embeds)))), metric="cosine") + # reducer = TSNE() + projections = reducer.fit_transform(embeds) + + speakers_done = set() + for projection, utterance in zip(projections, utterances): + color = colors[utterance.speaker_name] + mark = "x" if "_gen_" in utterance.name else "o" + label = None if utterance.speaker_name in speakers_done else utterance.speaker_name + speakers_done.add(utterance.speaker_name) + self.umap_ax.scatter(projection[0], projection[1], c=[color], marker=mark, + label=label) + # self.umap_ax.set_title("UMAP projections") + self.umap_ax.legend(prop={'size': 10}) + + # Draw the plot + self.umap_ax.set_aspect("equal", "datalim") + self.umap_ax.set_xticks([]) + self.umap_ax.set_yticks([]) + self.umap_ax.figure.canvas.draw() + + def save_audio_file(self, wav, sample_rate): + dialog = QFileDialog() + dialog.setDefaultSuffix(".wav") + fpath, _ = dialog.getSaveFileName( + parent=self, + caption="Select a path to save the audio file", + filter="Audio Files (*.flac *.wav)" + ) + if fpath: + #Default format is wav + if Path(fpath).suffix == "": + fpath += ".wav" + sf.write(fpath, wav, sample_rate) + + def setup_audio_devices(self, sample_rate): + input_devices = [] + output_devices = [] + for device in sd.query_devices(): + # Check if valid input + try: + sd.check_input_settings(device=device["name"], samplerate=sample_rate) + input_devices.append(device["name"]) + except: + pass + + # Check if valid output + try: + sd.check_output_settings(device=device["name"], samplerate=sample_rate) + output_devices.append(device["name"]) + except Exception as e: + # Log a warning only if the device is not an input + if not device["name"] in input_devices: + warn("Unsupported output device %s for the sample rate: %d \nError: %s" % (device["name"], sample_rate, str(e))) + + if len(input_devices) == 0: + self.log("No audio input device detected. Recording may not work.") + self.audio_in_device = None + else: + self.audio_in_device = input_devices[0] + + if len(output_devices) == 0: + self.log("No supported output audio devices were found! Audio output may not work.") + self.audio_out_devices_cb.addItems(["None"]) + self.audio_out_devices_cb.setDisabled(True) + else: + self.audio_out_devices_cb.clear() + self.audio_out_devices_cb.addItems(output_devices) + self.audio_out_devices_cb.currentTextChanged.connect(self.set_audio_device) + + self.set_audio_device() + + def set_audio_device(self): + + output_device = self.audio_out_devices_cb.currentText() + if output_device == "None": + output_device = None + + # If None, sounddevice queries portaudio + sd.default.device = (self.audio_in_device, output_device) + + def play(self, wav, sample_rate): + try: + sd.stop() + sd.play(wav, sample_rate) + except Exception as e: + print(e) + self.log("Error in audio playback. Try selecting a different audio output device.") + self.log("Your device must be connected before you start the toolbox.") + + def stop(self): + sd.stop() + + def record_one(self, sample_rate, duration): + self.record_button.setText("Recording...") + self.record_button.setDisabled(True) + + self.log("Recording %d seconds of audio" % duration) + sd.stop() + try: + wav = sd.rec(duration * sample_rate, sample_rate, 1) + except Exception as e: + print(e) + self.log("Could not record anything. Is your recording device enabled?") + self.log("Your device must be connected before you start the toolbox.") + return None + + for i in np.arange(0, duration, 0.1): + self.set_loading(i, duration) + sleep(0.1) + self.set_loading(duration, duration) + sd.wait() + + self.log("Done recording.") + self.record_button.setText("Record") + self.record_button.setDisabled(False) + + return wav.squeeze() + + @property + def current_dataset_name(self): + return self.dataset_box.currentText() + + @property + def current_speaker_name(self): + return self.speaker_box.currentText() + + @property + def current_utterance_name(self): + return self.utterance_box.currentText() + + def browse_file(self): + fpath = QFileDialog().getOpenFileName( + parent=self, + caption="Select an audio file", + filter="Audio Files (*.mp3 *.flac *.wav *.m4a)" + ) + return Path(fpath[0]) if fpath[0] != "" else "" + + @staticmethod + def repopulate_box(box, items, random=False): + """ + Resets a box and adds a list of items. Pass a list of (item, data) pairs instead to join + data to the items + """ + box.blockSignals(True) + box.clear() + for item in items: + item = list(item) if isinstance(item, tuple) else [item] + box.addItem(str(item[0]), *item[1:]) + if len(items) > 0: + box.setCurrentIndex(np.random.randint(len(items)) if random else 0) + box.setDisabled(len(items) == 0) + box.blockSignals(False) + + def populate_browser(self, datasets_root: Path, recognized_datasets: List, level: int, + random=True): + # Select a random dataset + if level <= 0: + if datasets_root is not None: + datasets = [datasets_root.joinpath(d) for d in recognized_datasets] + datasets = [d.relative_to(datasets_root) for d in datasets if d.exists()] + self.browser_load_button.setDisabled(len(datasets) == 0) + if datasets_root is None or len(datasets) == 0: + msg = "Warning: you d" + ("id not pass a root directory for datasets as argument" \ + if datasets_root is None else "o not have any of the recognized datasets" \ + " in %s" % datasets_root) + self.log(msg) + msg += ".\nThe recognized datasets are:\n\t%s\nFeel free to add your own. You " \ + "can still use the toolbox by recording samples yourself." % \ + ("\n\t".join(recognized_datasets)) + print(msg, file=sys.stderr) + + self.random_utterance_button.setDisabled(True) + self.random_speaker_button.setDisabled(True) + self.random_dataset_button.setDisabled(True) + self.utterance_box.setDisabled(True) + self.speaker_box.setDisabled(True) + self.dataset_box.setDisabled(True) + self.browser_load_button.setDisabled(True) + self.auto_next_checkbox.setDisabled(True) + return + self.repopulate_box(self.dataset_box, datasets, random) + + # Select a random speaker + if level <= 1: + speakers_root = datasets_root.joinpath(self.current_dataset_name) + speaker_names = [d.stem for d in speakers_root.glob("*") if d.is_dir()] + self.repopulate_box(self.speaker_box, speaker_names, random) + + # Select a random utterance + if level <= 2: + utterances_root = datasets_root.joinpath( + self.current_dataset_name, + self.current_speaker_name + ) + utterances = [] + for extension in ['mp3', 'flac', 'wav', 'm4a']: + utterances.extend(Path(utterances_root).glob("**/*.%s" % extension)) + utterances = [fpath.relative_to(utterances_root) for fpath in utterances] + self.repopulate_box(self.utterance_box, utterances, random) + + def browser_select_next(self): + index = (self.utterance_box.currentIndex() + 1) % len(self.utterance_box) + self.utterance_box.setCurrentIndex(index) + + @property + def current_encoder_fpath(self): + return self.encoder_box.itemData(self.encoder_box.currentIndex()) + + @property + def current_synthesizer_fpath(self): + return self.synthesizer_box.itemData(self.synthesizer_box.currentIndex()) + + @property + def current_vocoder_fpath(self): + return self.vocoder_box.itemData(self.vocoder_box.currentIndex()) + + def populate_models(self, encoder_models_dir: Path, synthesizer_models_dir: Path, + vocoder_models_dir: Path): + # Encoder + encoder_fpaths = list(encoder_models_dir.glob("*.pt")) + if len(encoder_fpaths) == 0: + raise Exception("No encoder models found in %s" % encoder_models_dir) + self.repopulate_box(self.encoder_box, [(f.stem, f) for f in encoder_fpaths]) + + # Synthesizer + synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt")) + if len(synthesizer_fpaths) == 0: + raise Exception("No synthesizer models found in %s" % synthesizer_models_dir) + self.repopulate_box(self.synthesizer_box, [(f.stem, f) for f in synthesizer_fpaths]) + + # Vocoder + vocoder_fpaths = list(vocoder_models_dir.glob("**/*.pt")) + vocoder_items = [(f.stem, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)] + self.repopulate_box(self.vocoder_box, vocoder_items) + + @property + def selected_utterance(self): + return self.utterance_history.itemData(self.utterance_history.currentIndex()) + + def register_utterance(self, utterance: Utterance): + self.utterance_history.blockSignals(True) + self.utterance_history.insertItem(0, utterance.name, utterance) + self.utterance_history.setCurrentIndex(0) + self.utterance_history.blockSignals(False) + + if len(self.utterance_history) > self.max_saved_utterances: + self.utterance_history.removeItem(self.max_saved_utterances) + + self.play_button.setDisabled(False) + self.generate_button.setDisabled(False) + self.synthesize_button.setDisabled(False) + + def log(self, line, mode="newline"): + if mode == "newline": + self.logs.append(line) + if len(self.logs) > self.max_log_lines: + del self.logs[0] + elif mode == "append": + self.logs[-1] += line + elif mode == "overwrite": + self.logs[-1] = line + log_text = '\n'.join(self.logs) + + self.log_window.setText(log_text) + self.app.processEvents() + + def set_loading(self, value, maximum=1): + self.loading_bar.setValue(value * 100) + self.loading_bar.setMaximum(maximum * 100) + self.loading_bar.setTextVisible(value != 0) + self.app.processEvents() + + def populate_gen_options(self, seed, trim_silences): + if seed is not None: + self.random_seed_checkbox.setChecked(True) + self.seed_textbox.setText(str(seed)) + self.seed_textbox.setEnabled(True) + else: + self.random_seed_checkbox.setChecked(False) + self.seed_textbox.setText(str(0)) + self.seed_textbox.setEnabled(False) + + if not trim_silences: + self.trim_silences_checkbox.setChecked(False) + self.trim_silences_checkbox.setDisabled(True) + + def update_seed_textbox(self): + if self.random_seed_checkbox.isChecked(): + self.seed_textbox.setEnabled(True) + else: + self.seed_textbox.setEnabled(False) + + def reset_interface(self): + self.draw_embed(None, None, "current") + self.draw_embed(None, None, "generated") + self.draw_spec(None, "current") + self.draw_spec(None, "generated") + self.draw_umap_projections(set()) + self.set_loading(0) + self.play_button.setDisabled(True) + self.generate_button.setDisabled(True) + self.synthesize_button.setDisabled(True) + self.vocode_button.setDisabled(True) + self.replay_wav_button.setDisabled(True) + self.export_wav_button.setDisabled(True) + [self.log("") for _ in range(self.max_log_lines)] + + def __init__(self): + ## Initialize the application + self.app = QApplication(sys.argv) + super().__init__(None) + self.setWindowTitle("SV2TTS toolbox") + + + ## Main layouts + # Root + root_layout = QGridLayout() + self.setLayout(root_layout) + + # Browser + browser_layout = QGridLayout() + root_layout.addLayout(browser_layout, 0, 0, 1, 2) + + # Generation + gen_layout = QVBoxLayout() + root_layout.addLayout(gen_layout, 0, 2, 1, 2) + + # Projections + self.projections_layout = QVBoxLayout() + root_layout.addLayout(self.projections_layout, 1, 0, 1, 1) + + # Visualizations + vis_layout = QVBoxLayout() + root_layout.addLayout(vis_layout, 1, 1, 1, 3) + + + ## Projections + # UMap + fig, self.umap_ax = plt.subplots(figsize=(3, 3), facecolor="#F0F0F0") + fig.subplots_adjust(left=0.02, bottom=0.02, right=0.98, top=0.98) + self.projections_layout.addWidget(FigureCanvas(fig)) + self.umap_hot = False + self.clear_button = QPushButton("Clear") + self.projections_layout.addWidget(self.clear_button) + + + ## Browser + # Dataset, speaker and utterance selection + i = 0 + self.dataset_box = QComboBox() + browser_layout.addWidget(QLabel("Dataset"), i, 0) + browser_layout.addWidget(self.dataset_box, i + 1, 0) + self.speaker_box = QComboBox() + browser_layout.addWidget(QLabel("Speaker"), i, 1) + browser_layout.addWidget(self.speaker_box, i + 1, 1) + self.utterance_box = QComboBox() + browser_layout.addWidget(QLabel("Utterance"), i, 2) + browser_layout.addWidget(self.utterance_box, i + 1, 2) + self.browser_load_button = QPushButton("Load") + browser_layout.addWidget(self.browser_load_button, i + 1, 3) + i += 2 + + # Random buttons + self.random_dataset_button = QPushButton("Random") + browser_layout.addWidget(self.random_dataset_button, i, 0) + self.random_speaker_button = QPushButton("Random") + browser_layout.addWidget(self.random_speaker_button, i, 1) + self.random_utterance_button = QPushButton("Random") + browser_layout.addWidget(self.random_utterance_button, i, 2) + self.auto_next_checkbox = QCheckBox("Auto select next") + self.auto_next_checkbox.setChecked(True) + browser_layout.addWidget(self.auto_next_checkbox, i, 3) + i += 1 + + # Utterance box + browser_layout.addWidget(QLabel("Use embedding from:"), i, 0) + self.utterance_history = QComboBox() + browser_layout.addWidget(self.utterance_history, i, 1, 1, 3) + i += 1 + + # Random & next utterance buttons + self.browser_browse_button = QPushButton("Browse") + browser_layout.addWidget(self.browser_browse_button, i, 0) + self.record_button = QPushButton("Record") + browser_layout.addWidget(self.record_button, i, 1) + self.play_button = QPushButton("Play") + browser_layout.addWidget(self.play_button, i, 2) + self.stop_button = QPushButton("Stop") + browser_layout.addWidget(self.stop_button, i, 3) + i += 1 + + + # Model and audio output selection + self.encoder_box = QComboBox() + browser_layout.addWidget(QLabel("Encoder"), i, 0) + browser_layout.addWidget(self.encoder_box, i + 1, 0) + self.synthesizer_box = QComboBox() + browser_layout.addWidget(QLabel("Synthesizer"), i, 1) + browser_layout.addWidget(self.synthesizer_box, i + 1, 1) + self.vocoder_box = QComboBox() + browser_layout.addWidget(QLabel("Vocoder"), i, 2) + browser_layout.addWidget(self.vocoder_box, i + 1, 2) + + self.audio_out_devices_cb=QComboBox() + browser_layout.addWidget(QLabel("Audio Output"), i, 3) + browser_layout.addWidget(self.audio_out_devices_cb, i + 1, 3) + i += 2 + + #Replay & Save Audio + browser_layout.addWidget(QLabel("Toolbox Output:"), i, 0) + self.waves_cb = QComboBox() + self.waves_cb_model = QStringListModel() + self.waves_cb.setModel(self.waves_cb_model) + self.waves_cb.setToolTip("Select one of the last generated waves in this section for replaying or exporting") + browser_layout.addWidget(self.waves_cb, i, 1) + self.replay_wav_button = QPushButton("Replay") + self.replay_wav_button.setToolTip("Replay last generated vocoder") + browser_layout.addWidget(self.replay_wav_button, i, 2) + self.export_wav_button = QPushButton("Export") + self.export_wav_button.setToolTip("Save last generated vocoder audio in filesystem as a wav file") + browser_layout.addWidget(self.export_wav_button, i, 3) + i += 1 + + + ## Embed & spectrograms + vis_layout.addStretch() + + gridspec_kw = {"width_ratios": [1, 4]} + fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0", + gridspec_kw=gridspec_kw) + fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8) + vis_layout.addWidget(FigureCanvas(fig)) + + fig, self.gen_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0", + gridspec_kw=gridspec_kw) + fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8) + vis_layout.addWidget(FigureCanvas(fig)) + + for ax in self.current_ax.tolist() + self.gen_ax.tolist(): + ax.set_facecolor("#F0F0F0") + for side in ["top", "right", "bottom", "left"]: + ax.spines[side].set_visible(False) + + + ## Generation + self.text_prompt = QPlainTextEdit(default_text) + gen_layout.addWidget(self.text_prompt, stretch=1) + + self.generate_button = QPushButton("Synthesize and vocode") + gen_layout.addWidget(self.generate_button) + + layout = QHBoxLayout() + self.synthesize_button = QPushButton("Synthesize only") + layout.addWidget(self.synthesize_button) + self.vocode_button = QPushButton("Vocode only") + layout.addWidget(self.vocode_button) + gen_layout.addLayout(layout) + + layout_seed = QGridLayout() + self.random_seed_checkbox = QCheckBox("Random seed:") + self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.") + layout_seed.addWidget(self.random_seed_checkbox, 0, 0) + self.seed_textbox = QLineEdit() + self.seed_textbox.setMaximumWidth(80) + layout_seed.addWidget(self.seed_textbox, 0, 1) + self.trim_silences_checkbox = QCheckBox("Enhance vocoder output") + self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output." + " This feature requires `webrtcvad` to be installed.") + layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2) + gen_layout.addLayout(layout_seed) + + self.loading_bar = QProgressBar() + gen_layout.addWidget(self.loading_bar) + + self.log_window = QLabel() + self.log_window.setAlignment(Qt.AlignBottom | Qt.AlignLeft) + gen_layout.addWidget(self.log_window) + self.logs = [] + gen_layout.addStretch() + + + ## Set the size of the window and of the elements + max_size = QDesktopWidget().availableGeometry(self).size() * 0.8 + self.resize(max_size) + + ## Finalize the display + self.reset_interface() + self.show() + + def start(self): + self.app.exec_() diff --git a/toolbox/utterance.py b/toolbox/utterance.py new file mode 100644 index 0000000000000000000000000000000000000000..844c8a2adb0c8eba2992eaf5ea357d7add3c1896 --- /dev/null +++ b/toolbox/utterance.py @@ -0,0 +1,5 @@ +from collections import namedtuple + +Utterance = namedtuple("Utterance", "name speaker_name wav spec embed partial_embeds synth") +Utterance.__eq__ = lambda x, y: x.name == y.name +Utterance.__hash__ = lambda x: hash(x.name) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/argutils.py b/utils/argutils.py new file mode 100644 index 0000000000000000000000000000000000000000..db41683027173517c910e3b259f8da48207dcb38 --- /dev/null +++ b/utils/argutils.py @@ -0,0 +1,40 @@ +from pathlib import Path +import numpy as np +import argparse + +_type_priorities = [ # In decreasing order + Path, + str, + int, + float, + bool, +] + +def _priority(o): + p = next((i for i, t in enumerate(_type_priorities) if type(o) is t), None) + if p is not None: + return p + p = next((i for i, t in enumerate(_type_priorities) if isinstance(o, t)), None) + if p is not None: + return p + return len(_type_priorities) + +def print_args(args: argparse.Namespace, parser=None): + args = vars(args) + if parser is None: + priorities = list(map(_priority, args.values())) + else: + all_params = [a.dest for g in parser._action_groups for a in g._group_actions ] + priority = lambda p: all_params.index(p) if p in all_params else len(all_params) + priorities = list(map(priority, args.keys())) + + pad = max(map(len, args.keys())) + 3 + indices = np.lexsort((list(args.keys()), priorities)) + items = list(args.items()) + + print("Arguments:") + for i in indices: + param, value = items[i] + print(" {0}:{1}{2}".format(param, ' ' * (pad - len(param)), value)) + print("") + \ No newline at end of file diff --git a/utils/logmmse.py b/utils/logmmse.py new file mode 100644 index 0000000000000000000000000000000000000000..58cc4502fa5ba0670678c3edaf5ba1587b8b58ea --- /dev/null +++ b/utils/logmmse.py @@ -0,0 +1,247 @@ +# The MIT License (MIT) +# +# Copyright (c) 2015 braindead +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# +# This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I +# simply modified the interface to meet my needs. + + +import numpy as np +import math +from scipy.special import expn +from collections import namedtuple + +NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2") + + +def profile_noise(noise, sampling_rate, window_size=0): + """ + Creates a profile of the noise in a given waveform. + + :param noise: a waveform containing noise ONLY, as a numpy array of floats or ints. + :param sampling_rate: the sampling rate of the audio + :param window_size: the size of the window the logmmse algorithm operates on. A default value + will be picked if left as 0. + :return: a NoiseProfile object + """ + noise, dtype = to_float(noise) + noise += np.finfo(np.float64).eps + + if window_size == 0: + window_size = int(math.floor(0.02 * sampling_rate)) + + if window_size % 2 == 1: + window_size = window_size + 1 + + perc = 50 + len1 = int(math.floor(window_size * perc / 100)) + len2 = int(window_size - len1) + + win = np.hanning(window_size) + win = win * len2 / np.sum(win) + n_fft = 2 * window_size + + noise_mean = np.zeros(n_fft) + n_frames = len(noise) // window_size + for j in range(0, window_size * n_frames, window_size): + noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0)) + noise_mu2 = (noise_mean / n_frames) ** 2 + + return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2) + + +def denoise(wav, noise_profile: NoiseProfile, eta=0.15): + """ + Cleans the noise from a speech waveform given a noise profile. The waveform must have the + same sampling rate as the one used to create the noise profile. + + :param wav: a speech waveform as a numpy array of floats or ints. + :param noise_profile: a NoiseProfile object that was created from a similar (or a segment of + the same) waveform. + :param eta: voice threshold for noise update. While the voice activation detection value is + below this threshold, the noise profile will be continuously updated throughout the audio. + Set to 0 to disable updating the noise profile. + :return: the clean wav as a numpy array of floats or ints of the same length. + """ + wav, dtype = to_float(wav) + wav += np.finfo(np.float64).eps + p = noise_profile + + nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2)) + x_final = np.zeros(nframes * p.len2) + + aa = 0.98 + mu = 0.98 + ksi_min = 10 ** (-25 / 10) + + x_old = np.zeros(p.len1) + xk_prev = np.zeros(p.len1) + noise_mu2 = p.noise_mu2 + for k in range(0, nframes * p.len2, p.len2): + insign = p.win * wav[k:k + p.window_size] + + spec = np.fft.fft(insign, p.n_fft, axis=0) + sig = np.absolute(spec) + sig2 = sig ** 2 + + gammak = np.minimum(sig2 / noise_mu2, 40) + + if xk_prev.all() == 0: + ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0) + else: + ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0) + ksi = np.maximum(ksi_min, ksi) + + log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi) + vad_decision = np.sum(log_sigma_k) / p.window_size + if vad_decision < eta: + noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2 + + a = ksi / (1 + ksi) + vk = a * gammak + ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8)) + hw = a * np.exp(ei_vk) + sig = sig * hw + xk_prev = sig ** 2 + xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0) + xi_w = np.real(xi_w) + + x_final[k:k + p.len2] = x_old + xi_w[0:p.len1] + x_old = xi_w[p.len1:p.window_size] + + output = from_float(x_final, dtype) + output = np.pad(output, (0, len(wav) - len(output)), mode="constant") + return output + + +## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that +## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of +## webrctvad +# def vad(wav, sampling_rate, eta=0.15, window_size=0): +# """ +# TODO: fix doc +# Creates a profile of the noise in a given waveform. +# +# :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints. +# :param sampling_rate: the sampling rate of the audio +# :param window_size: the size of the window the logmmse algorithm operates on. A default value +# will be picked if left as 0. +# :param eta: voice threshold for noise update. While the voice activation detection value is +# below this threshold, the noise profile will be continuously updated throughout the audio. +# Set to 0 to disable updating the noise profile. +# """ +# wav, dtype = to_float(wav) +# wav += np.finfo(np.float64).eps +# +# if window_size == 0: +# window_size = int(math.floor(0.02 * sampling_rate)) +# +# if window_size % 2 == 1: +# window_size = window_size + 1 +# +# perc = 50 +# len1 = int(math.floor(window_size * perc / 100)) +# len2 = int(window_size - len1) +# +# win = np.hanning(window_size) +# win = win * len2 / np.sum(win) +# n_fft = 2 * window_size +# +# wav_mean = np.zeros(n_fft) +# n_frames = len(wav) // window_size +# for j in range(0, window_size * n_frames, window_size): +# wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0)) +# noise_mu2 = (wav_mean / n_frames) ** 2 +# +# wav, dtype = to_float(wav) +# wav += np.finfo(np.float64).eps +# +# nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2)) +# vad = np.zeros(nframes * len2, dtype=np.bool) +# +# aa = 0.98 +# mu = 0.98 +# ksi_min = 10 ** (-25 / 10) +# +# xk_prev = np.zeros(len1) +# noise_mu2 = noise_mu2 +# for k in range(0, nframes * len2, len2): +# insign = win * wav[k:k + window_size] +# +# spec = np.fft.fft(insign, n_fft, axis=0) +# sig = np.absolute(spec) +# sig2 = sig ** 2 +# +# gammak = np.minimum(sig2 / noise_mu2, 40) +# +# if xk_prev.all() == 0: +# ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0) +# else: +# ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0) +# ksi = np.maximum(ksi_min, ksi) +# +# log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi) +# vad_decision = np.sum(log_sigma_k) / window_size +# if vad_decision < eta: +# noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2 +# print(vad_decision) +# +# a = ksi / (1 + ksi) +# vk = a * gammak +# ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8)) +# hw = a * np.exp(ei_vk) +# sig = sig * hw +# xk_prev = sig ** 2 +# +# vad[k:k + len2] = vad_decision >= eta +# +# vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant") +# return vad + + +def to_float(_input): + if _input.dtype == np.float64: + return _input, _input.dtype + elif _input.dtype == np.float32: + return _input.astype(np.float64), _input.dtype + elif _input.dtype == np.uint8: + return (_input - 128) / 128., _input.dtype + elif _input.dtype == np.int16: + return _input / 32768., _input.dtype + elif _input.dtype == np.int32: + return _input / 2147483648., _input.dtype + raise ValueError('Unsupported wave file format') + + +def from_float(_input, dtype): + if dtype == np.float64: + return _input, np.float64 + elif dtype == np.float32: + return _input.astype(np.float32) + elif dtype == np.uint8: + return ((_input * 128) + 128).astype(np.uint8) + elif dtype == np.int16: + return (_input * 32768).astype(np.int16) + elif dtype == np.int32: + print(_input) + return (_input * 2147483648).astype(np.int32) + raise ValueError('Unsupported wave file format') diff --git a/utils/modelutils.py b/utils/modelutils.py new file mode 100644 index 0000000000000000000000000000000000000000..6acaa984e0c7876f9149fc1ff99001b7761dc80b --- /dev/null +++ b/utils/modelutils.py @@ -0,0 +1,17 @@ +from pathlib import Path + +def check_model_paths(encoder_path: Path, synthesizer_path: Path, vocoder_path: Path): + # This function tests the model paths and makes sure at least one is valid. + if encoder_path.is_file() or encoder_path.is_dir(): + return + if synthesizer_path.is_file() or synthesizer_path.is_dir(): + return + if vocoder_path.is_file() or vocoder_path.is_dir(): + return + + # If none of the paths exist, remind the user to download models if needed + print("********************************************************************************") + print("Error: Model files not found. Follow these instructions to get and install the models:") + print("https://github.com/CorentinJ/Real-Time-Voice-Cloning/wiki/Pretrained-models") + print("********************************************************************************\n") + quit(-1) diff --git a/utils/profiler.py b/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..17175b9e1b0eb17fdc015199e5194a5c1afb8a28 --- /dev/null +++ b/utils/profiler.py @@ -0,0 +1,45 @@ +from time import perf_counter as timer +from collections import OrderedDict +import numpy as np + + +class Profiler: + def __init__(self, summarize_every=5, disabled=False): + self.last_tick = timer() + self.logs = OrderedDict() + self.summarize_every = summarize_every + self.disabled = disabled + + def tick(self, name): + if self.disabled: + return + + # Log the time needed to execute that function + if not name in self.logs: + self.logs[name] = [] + if len(self.logs[name]) >= self.summarize_every: + self.summarize() + self.purge_logs() + self.logs[name].append(timer() - self.last_tick) + + self.reset_timer() + + def purge_logs(self): + for name in self.logs: + self.logs[name].clear() + + def reset_timer(self): + self.last_tick = timer() + + def summarize(self): + n = max(map(len, self.logs.values())) + assert n == self.summarize_every + print("\nAverage execution time over %d steps:" % n) + + name_msgs = ["%s (%d/%d):" % (name, len(deltas), n) for name, deltas in self.logs.items()] + pad = max(map(len, name_msgs)) + for name_msg, deltas in zip(name_msgs, self.logs.values()): + print(" %s mean: %4.0fms std: %4.0fms" % + (name_msg.ljust(pad), np.mean(deltas) * 1000, np.std(deltas) * 1000)) + print("", flush=True) + \ No newline at end of file diff --git a/vocoder/LICENSE.txt b/vocoder/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..8d6716174d0d0058f3fc3ae6a8e595119605acbf --- /dev/null +++ b/vocoder/LICENSE.txt @@ -0,0 +1,22 @@ +MIT License + +Original work Copyright (c) 2019 fatchord (https://github.com/fatchord) +Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vocoder/audio.py b/vocoder/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..116396261e184b9968971bd06fabc6f525e0c2fe --- /dev/null +++ b/vocoder/audio.py @@ -0,0 +1,108 @@ +import math +import numpy as np +import librosa +import vocoder.hparams as hp +from scipy.signal import lfilter +import soundfile as sf + + +def label_2_float(x, bits) : + return 2 * x / (2**bits - 1.) - 1. + + +def float_2_label(x, bits) : + assert abs(x).max() <= 1.0 + x = (x + 1.) * (2**bits - 1) / 2 + return x.clip(0, 2**bits - 1) + + +def load_wav(path) : + return librosa.load(str(path), sr=hp.sample_rate)[0] + + +def save_wav(x, path) : + sf.write(path, x.astype(np.float32), hp.sample_rate) + + +def split_signal(x) : + unsigned = x + 2**15 + coarse = unsigned // 256 + fine = unsigned % 256 + return coarse, fine + + +def combine_signal(coarse, fine) : + return coarse * 256 + fine - 2**15 + + +def encode_16bits(x) : + return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16) + + +mel_basis = None + + +def linear_to_mel(spectrogram): + global mel_basis + if mel_basis is None: + mel_basis = build_mel_basis() + return np.dot(mel_basis, spectrogram) + + +def build_mel_basis(): + return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin) + + +def normalize(S): + return np.clip((S - hp.min_level_db) / -hp.min_level_db, 0, 1) + + +def denormalize(S): + return (np.clip(S, 0, 1) * -hp.min_level_db) + hp.min_level_db + + +def amp_to_db(x): + return 20 * np.log10(np.maximum(1e-5, x)) + + +def db_to_amp(x): + return np.power(10.0, x * 0.05) + + +def spectrogram(y): + D = stft(y) + S = amp_to_db(np.abs(D)) - hp.ref_level_db + return normalize(S) + + +def melspectrogram(y): + D = stft(y) + S = amp_to_db(linear_to_mel(np.abs(D))) + return normalize(S) + + +def stft(y): + return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) + + +def pre_emphasis(x): + return lfilter([1, -hp.preemphasis], [1], x) + + +def de_emphasis(x): + return lfilter([1], [1, -hp.preemphasis], x) + + +def encode_mu_law(x, mu) : + mu = mu - 1 + fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) + return np.floor((fx + 1) / 2 * mu + 0.5) + + +def decode_mu_law(y, mu, from_labels=True) : + if from_labels: + y = label_2_float(y, math.log2(mu)) + mu = mu - 1 + x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1) + return x + diff --git a/vocoder/display.py b/vocoder/display.py new file mode 100644 index 0000000000000000000000000000000000000000..956880722a3f05613ebd06f5686b3d8a59642e92 --- /dev/null +++ b/vocoder/display.py @@ -0,0 +1,120 @@ +import matplotlib.pyplot as plt +import time +import numpy as np +import sys + + +def progbar(i, n, size=16): + done = (i * size) // n + bar = '' + for i in range(size): + bar += '█' if i <= done else '░' + return bar + + +def stream(message) : + try: + sys.stdout.write("\r{%s}" % message) + except: + #Remove non-ASCII characters from message + message = ''.join(i for i in message if ord(i)<128) + sys.stdout.write("\r{%s}" % message) + + +def simple_table(item_tuples) : + + border_pattern = '+---------------------------------------' + whitespace = ' ' + + headings, cells, = [], [] + + for item in item_tuples : + + heading, cell = str(item[0]), str(item[1]) + + pad_head = True if len(heading) < len(cell) else False + + pad = abs(len(heading) - len(cell)) + pad = whitespace[:pad] + + pad_left = pad[:len(pad)//2] + pad_right = pad[len(pad)//2:] + + if pad_head : + heading = pad_left + heading + pad_right + else : + cell = pad_left + cell + pad_right + + headings += [heading] + cells += [cell] + + border, head, body = '', '', '' + + for i in range(len(item_tuples)) : + + temp_head = f'| {headings[i]} ' + temp_body = f'| {cells[i]} ' + + border += border_pattern[:len(temp_head)] + head += temp_head + body += temp_body + + if i == len(item_tuples) - 1 : + head += '|' + body += '|' + border += '+' + + print(border) + print(head) + print(border) + print(body) + print(border) + print(' ') + + +def time_since(started) : + elapsed = time.time() - started + m = int(elapsed // 60) + s = int(elapsed % 60) + if m >= 60 : + h = int(m // 60) + m = m % 60 + return f'{h}h {m}m {s}s' + else : + return f'{m}m {s}s' + + +def save_attention(attn, path) : + fig = plt.figure(figsize=(12, 6)) + plt.imshow(attn.T, interpolation='nearest', aspect='auto') + fig.savefig(f'{path}.png', bbox_inches='tight') + plt.close(fig) + + +def save_spectrogram(M, path, length=None) : + M = np.flip(M, axis=0) + if length : M = M[:, :length] + fig = plt.figure(figsize=(12, 6)) + plt.imshow(M, interpolation='nearest', aspect='auto') + fig.savefig(f'{path}.png', bbox_inches='tight') + plt.close(fig) + + +def plot(array) : + fig = plt.figure(figsize=(30, 5)) + ax = fig.add_subplot(111) + ax.xaxis.label.set_color('grey') + ax.yaxis.label.set_color('grey') + ax.xaxis.label.set_fontsize(23) + ax.yaxis.label.set_fontsize(23) + ax.tick_params(axis='x', colors='grey', labelsize=23) + ax.tick_params(axis='y', colors='grey', labelsize=23) + plt.plot(array) + + +def plot_spec(M) : + M = np.flip(M, axis=0) + plt.figure(figsize=(18,4)) + plt.imshow(M, interpolation='nearest', aspect='auto') + plt.show() + diff --git a/vocoder/distribution.py b/vocoder/distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..d3119a5ba1e77bc25a92d2664f83d366f12399c0 --- /dev/null +++ b/vocoder/distribution.py @@ -0,0 +1,132 @@ +import numpy as np +import torch +import torch.nn.functional as F + + +def log_sum_exp(x): + """ numerically stable log_sum_exp implementation that prevents overflow """ + # TF ordering + axis = len(x.size()) - 1 + m, _ = torch.max(x, dim=axis) + m2, _ = torch.max(x, dim=axis, keepdim=True) + return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) + + +# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py +def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, + log_scale_min=None, reduce=True): + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + y_hat = y_hat.permute(0,2,1) + assert y_hat.dim() == 3 + assert y_hat.size(1) % 3 == 0 + nr_mix = y_hat.size(1) // 3 + + # (B x T x C) + y_hat = y_hat.transpose(1, 2) + + # unpack parameters. (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix:2 * nr_mix] + log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min) + + # B x T x 1 -> B x T x num_mixtures + y = y.expand_as(means) + + centered_y = y - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1)) + cdf_plus = torch.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1. / (num_classes - 1)) + cdf_min = torch.sigmoid(min_in) + + # log probability for edge case of 0 (before scaling) + # equivalent: torch.log(F.sigmoid(plus_in)) + log_cdf_plus = plus_in - F.softplus(plus_in) + + # log probability for edge case of 255 (before scaling) + # equivalent: (1 - F.sigmoid(min_in)).log() + log_one_minus_cdf_min = -F.softplus(min_in) + + # probability for all other cases + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + # log probability in the center of the bin, to be used in extreme cases + # (not actually used in our code) + log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) + + # tf equivalent + """ + log_probs = tf.where(x < -0.999, log_cdf_plus, + tf.where(x > 0.999, log_one_minus_cdf_min, + tf.where(cdf_delta > 1e-5, + tf.log(tf.maximum(cdf_delta, 1e-12)), + log_pdf_mid - np.log(127.5)))) + """ + # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value + # for num_classes=65536 case? 1e-7? not sure.. + inner_inner_cond = (cdf_delta > 1e-5).float() + + inner_inner_out = inner_inner_cond * \ + torch.log(torch.clamp(cdf_delta, min=1e-12)) + \ + (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) + inner_cond = (y > 0.999).float() + inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out + cond = (y < -0.999).float() + log_probs = cond * log_cdf_plus + (1. - cond) * inner_out + + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if reduce: + return -torch.mean(log_sum_exp(log_probs)) + else: + return -log_sum_exp(log_probs).unsqueeze(-1) + + +def sample_from_discretized_mix_logistic(y, log_scale_min=None): + """ + Sample from discretized mixture of logistic distributions + Args: + y (Tensor): B x C x T + log_scale_min (float): Log scale minimum value + Returns: + Tensor: sample in range of [-1, 1]. + """ + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + assert y.size(1) % 3 == 0 + nr_mix = y.size(1) // 3 + + # B x T x C + y = y.transpose(1, 2) + logit_probs = y[:, :, :nr_mix] + + # sample mixture indicator from softmax + temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) + temp = logit_probs.data - torch.log(- torch.log(temp)) + _, argmax = temp.max(dim=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = to_one_hot(argmax, nr_mix) + # select logistic parameters + means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.clamp(torch.sum( + y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) + # sample from logistic & clip to interval + # we don't actually round to the nearest 8bit value when sampling + u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) + x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) + + x = torch.clamp(torch.clamp(x, min=-1.), max=1.) + + return x + + +def to_one_hot(tensor, n, fill_with=1.): + # we perform one hot encore with respect to the last axis + one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() + if tensor.is_cuda: + one_hot = one_hot.cuda() + one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) + return one_hot diff --git a/vocoder/gen_wavernn.py b/vocoder/gen_wavernn.py new file mode 100644 index 0000000000000000000000000000000000000000..2036737f805f6055893812e48f99d524624aab07 --- /dev/null +++ b/vocoder/gen_wavernn.py @@ -0,0 +1,31 @@ +from vocoder.models.fatchord_version import WaveRNN +from vocoder.audio import * + + +def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path): + k = model.get_step() // 1000 + + for i, (m, x) in enumerate(test_set, 1): + if i > samples: + break + + print('\n| Generating: %i/%i' % (i, samples)) + + x = x[0].numpy() + + bits = 16 if hp.voc_mode == 'MOL' else hp.bits + + if hp.mu_law and hp.voc_mode != 'MOL' : + x = decode_mu_law(x, 2**bits, from_labels=True) + else : + x = label_2_float(x, bits) + + save_wav(x, save_path.joinpath("%dk_steps_%d_target.wav" % (k, i))) + + batch_str = "gen_batched_target%d_overlap%d" % (target, overlap) if batched else \ + "gen_not_batched" + save_str = save_path.joinpath("%dk_steps_%d_%s.wav" % (k, i, batch_str)) + + wav = model.generate(m, batched, target, overlap, hp.mu_law) + save_wav(wav, save_str) + diff --git a/vocoder/hparams.py b/vocoder/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..c1de9f7dcc2926735b80a28ed1226ff1b5824753 --- /dev/null +++ b/vocoder/hparams.py @@ -0,0 +1,44 @@ +from synthesizer.hparams import hparams as _syn_hp + + +# Audio settings------------------------------------------------------------------------ +# Match the values of the synthesizer +sample_rate = _syn_hp.sample_rate +n_fft = _syn_hp.n_fft +num_mels = _syn_hp.num_mels +hop_length = _syn_hp.hop_size +win_length = _syn_hp.win_size +fmin = _syn_hp.fmin +min_level_db = _syn_hp.min_level_db +ref_level_db = _syn_hp.ref_level_db +mel_max_abs_value = _syn_hp.max_abs_value +preemphasis = _syn_hp.preemphasis +apply_preemphasis = _syn_hp.preemphasize + +bits = 9 # bit depth of signal +mu_law = True # Recommended to suppress noise if using raw bits in hp.voc_mode + # below + + +# WAVERNN / VOCODER -------------------------------------------------------------------------------- +voc_mode = 'RAW' # either 'RAW' (softmax on raw bits) or 'MOL' (sample from +# mixture of logistics) +voc_upsample_factors = (5, 5, 8) # NB - this needs to correctly factorise hop_length +voc_rnn_dims = 512 +voc_fc_dims = 512 +voc_compute_dims = 128 +voc_res_out_dims = 128 +voc_res_blocks = 10 + +# Training +voc_batch_size = 100 +voc_lr = 1e-4 +voc_gen_at_checkpoint = 5 # number of samples to generate at each checkpoint +voc_pad = 2 # this will pad the input so that the resnet can 'see' wider + # than input length +voc_seq_len = hop_length * 5 # must be a multiple of hop_length + +# Generating / Synthesizing +voc_gen_batched = True # very fast (realtime+) single utterance batched generation +voc_target = 8000 # target number of samples to be generated in each batch entry +voc_overlap = 400 # number of samples for crossfading between batches diff --git a/vocoder/inference.py b/vocoder/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..7e546845da0b8cdb18b34fbd332b9aaa39cea55c --- /dev/null +++ b/vocoder/inference.py @@ -0,0 +1,64 @@ +from vocoder.models.fatchord_version import WaveRNN +from vocoder import hparams as hp +import torch + + +_model = None # type: WaveRNN + +def load_model(weights_fpath, verbose=True): + global _model, _device + + if verbose: + print("Building Wave-RNN") + _model = WaveRNN( + rnn_dims=hp.voc_rnn_dims, + fc_dims=hp.voc_fc_dims, + bits=hp.bits, + pad=hp.voc_pad, + upsample_factors=hp.voc_upsample_factors, + feat_dims=hp.num_mels, + compute_dims=hp.voc_compute_dims, + res_out_dims=hp.voc_res_out_dims, + res_blocks=hp.voc_res_blocks, + hop_length=hp.hop_length, + sample_rate=hp.sample_rate, + mode=hp.voc_mode + ) + + if torch.cuda.is_available(): + _model = _model.cuda() + _device = torch.device('cuda') + else: + _device = torch.device('cpu') + + if verbose: + print("Loading model weights at %s" % weights_fpath) + checkpoint = torch.load(weights_fpath, _device) + _model.load_state_dict(checkpoint['model_state']) + _model.eval() + + +def is_loaded(): + return _model is not None + + +def infer_waveform(mel, normalize=True, batched=True, target=8000, overlap=800, + progress_callback=None): + """ + Infers the waveform of a mel spectrogram output by the synthesizer (the format must match + that of the synthesizer!) + + :param normalize: + :param batched: + :param target: + :param overlap: + :return: + """ + if _model is None: + raise Exception("Please load Wave-RNN in memory before using it") + + if normalize: + mel = mel / hp.mel_max_abs_value + mel = torch.from_numpy(mel[None, ...]) + wav = _model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback) + return wav diff --git a/vocoder/models/deepmind_version.py b/vocoder/models/deepmind_version.py new file mode 100644 index 0000000000000000000000000000000000000000..1d973d9b8b9ab547571abc5a3f5ea86226a25924 --- /dev/null +++ b/vocoder/models/deepmind_version.py @@ -0,0 +1,170 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from utils.display import * +from utils.dsp import * + + +class WaveRNN(nn.Module) : + def __init__(self, hidden_size=896, quantisation=256) : + super(WaveRNN, self).__init__() + + self.hidden_size = hidden_size + self.split_size = hidden_size // 2 + + # The main matmul + self.R = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) + + # Output fc layers + self.O1 = nn.Linear(self.split_size, self.split_size) + self.O2 = nn.Linear(self.split_size, quantisation) + self.O3 = nn.Linear(self.split_size, self.split_size) + self.O4 = nn.Linear(self.split_size, quantisation) + + # Input fc layers + self.I_coarse = nn.Linear(2, 3 * self.split_size, bias=False) + self.I_fine = nn.Linear(3, 3 * self.split_size, bias=False) + + # biases for the gates + self.bias_u = nn.Parameter(torch.zeros(self.hidden_size)) + self.bias_r = nn.Parameter(torch.zeros(self.hidden_size)) + self.bias_e = nn.Parameter(torch.zeros(self.hidden_size)) + + # display num params + self.num_params() + + + def forward(self, prev_y, prev_hidden, current_coarse) : + + # Main matmul - the projection is split 3 ways + R_hidden = self.R(prev_hidden) + R_u, R_r, R_e, = torch.split(R_hidden, self.hidden_size, dim=1) + + # Project the prev input + coarse_input_proj = self.I_coarse(prev_y) + I_coarse_u, I_coarse_r, I_coarse_e = \ + torch.split(coarse_input_proj, self.split_size, dim=1) + + # Project the prev input and current coarse sample + fine_input = torch.cat([prev_y, current_coarse], dim=1) + fine_input_proj = self.I_fine(fine_input) + I_fine_u, I_fine_r, I_fine_e = \ + torch.split(fine_input_proj, self.split_size, dim=1) + + # concatenate for the gates + I_u = torch.cat([I_coarse_u, I_fine_u], dim=1) + I_r = torch.cat([I_coarse_r, I_fine_r], dim=1) + I_e = torch.cat([I_coarse_e, I_fine_e], dim=1) + + # Compute all gates for coarse and fine + u = F.sigmoid(R_u + I_u + self.bias_u) + r = F.sigmoid(R_r + I_r + self.bias_r) + e = F.tanh(r * R_e + I_e + self.bias_e) + hidden = u * prev_hidden + (1. - u) * e + + # Split the hidden state + hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1) + + # Compute outputs + out_coarse = self.O2(F.relu(self.O1(hidden_coarse))) + out_fine = self.O4(F.relu(self.O3(hidden_fine))) + + return out_coarse, out_fine, hidden + + + def generate(self, seq_len): + with torch.no_grad(): + # First split up the biases for the gates + b_coarse_u, b_fine_u = torch.split(self.bias_u, self.split_size) + b_coarse_r, b_fine_r = torch.split(self.bias_r, self.split_size) + b_coarse_e, b_fine_e = torch.split(self.bias_e, self.split_size) + + # Lists for the two output seqs + c_outputs, f_outputs = [], [] + + # Some initial inputs + out_coarse = torch.LongTensor([0]).cuda() + out_fine = torch.LongTensor([0]).cuda() + + # We'll meed a hidden state + hidden = self.init_hidden() + + # Need a clock for display + start = time.time() + + # Loop for generation + for i in range(seq_len) : + + # Split into two hidden states + hidden_coarse, hidden_fine = \ + torch.split(hidden, self.split_size, dim=1) + + # Scale and concat previous predictions + out_coarse = out_coarse.unsqueeze(0).float() / 127.5 - 1. + out_fine = out_fine.unsqueeze(0).float() / 127.5 - 1. + prev_outputs = torch.cat([out_coarse, out_fine], dim=1) + + # Project input + coarse_input_proj = self.I_coarse(prev_outputs) + I_coarse_u, I_coarse_r, I_coarse_e = \ + torch.split(coarse_input_proj, self.split_size, dim=1) + + # Project hidden state and split 6 ways + R_hidden = self.R(hidden) + R_coarse_u , R_fine_u, \ + R_coarse_r, R_fine_r, \ + R_coarse_e, R_fine_e = torch.split(R_hidden, self.split_size, dim=1) + + # Compute the coarse gates + u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u) + r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r) + e = F.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e) + hidden_coarse = u * hidden_coarse + (1. - u) * e + + # Compute the coarse output + out_coarse = self.O2(F.relu(self.O1(hidden_coarse))) + posterior = F.softmax(out_coarse, dim=1) + distrib = torch.distributions.Categorical(posterior) + out_coarse = distrib.sample() + c_outputs.append(out_coarse) + + # Project the [prev outputs and predicted coarse sample] + coarse_pred = out_coarse.float() / 127.5 - 1. + fine_input = torch.cat([prev_outputs, coarse_pred.unsqueeze(0)], dim=1) + fine_input_proj = self.I_fine(fine_input) + I_fine_u, I_fine_r, I_fine_e = \ + torch.split(fine_input_proj, self.split_size, dim=1) + + # Compute the fine gates + u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u) + r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r) + e = F.tanh(r * R_fine_e + I_fine_e + b_fine_e) + hidden_fine = u * hidden_fine + (1. - u) * e + + # Compute the fine output + out_fine = self.O4(F.relu(self.O3(hidden_fine))) + posterior = F.softmax(out_fine, dim=1) + distrib = torch.distributions.Categorical(posterior) + out_fine = distrib.sample() + f_outputs.append(out_fine) + + # Put the hidden state back together + hidden = torch.cat([hidden_coarse, hidden_fine], dim=1) + + # Display progress + speed = (i + 1) / (time.time() - start) + stream('Gen: %i/%i -- Speed: %i', (i + 1, seq_len, speed)) + + coarse = torch.stack(c_outputs).squeeze(1).cpu().data.numpy() + fine = torch.stack(f_outputs).squeeze(1).cpu().data.numpy() + output = combine_signal(coarse, fine) + + return output, coarse, fine + + def init_hidden(self, batch_size=1) : + return torch.zeros(batch_size, self.hidden_size).cuda() + + def num_params(self) : + parameters = filter(lambda p: p.requires_grad, self.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + print('Trainable Parameters: %.3f million' % parameters) \ No newline at end of file diff --git a/vocoder/models/fatchord_version.py b/vocoder/models/fatchord_version.py new file mode 100644 index 0000000000000000000000000000000000000000..70ef1e3f6b99f32cc4fa95f64acfa58268d71ad7 --- /dev/null +++ b/vocoder/models/fatchord_version.py @@ -0,0 +1,434 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from vocoder.distribution import sample_from_discretized_mix_logistic +from vocoder.display import * +from vocoder.audio import * + + +class ResBlock(nn.Module): + def __init__(self, dims): + super().__init__() + self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.batch_norm1 = nn.BatchNorm1d(dims) + self.batch_norm2 = nn.BatchNorm1d(dims) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.batch_norm1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.batch_norm2(x) + return x + residual + + +class MelResNet(nn.Module): + def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad): + super().__init__() + k_size = pad * 2 + 1 + self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False) + self.batch_norm = nn.BatchNorm1d(compute_dims) + self.layers = nn.ModuleList() + for i in range(res_blocks): + self.layers.append(ResBlock(compute_dims)) + self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) + + def forward(self, x): + x = self.conv_in(x) + x = self.batch_norm(x) + x = F.relu(x) + for f in self.layers: x = f(x) + x = self.conv_out(x) + return x + + +class Stretch2d(nn.Module): + def __init__(self, x_scale, y_scale): + super().__init__() + self.x_scale = x_scale + self.y_scale = y_scale + + def forward(self, x): + b, c, h, w = x.size() + x = x.unsqueeze(-1).unsqueeze(3) + x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) + return x.view(b, c, h * self.y_scale, w * self.x_scale) + + +class UpsampleNetwork(nn.Module): + def __init__(self, feat_dims, upsample_scales, compute_dims, + res_blocks, res_out_dims, pad): + super().__init__() + total_scale = np.cumproduct(upsample_scales)[-1] + self.indent = pad * total_scale + self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad) + self.resnet_stretch = Stretch2d(total_scale, 1) + self.up_layers = nn.ModuleList() + for scale in upsample_scales: + k_size = (1, scale * 2 + 1) + padding = (0, scale) + stretch = Stretch2d(scale, 1) + conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) + conv.weight.data.fill_(1. / k_size[1]) + self.up_layers.append(stretch) + self.up_layers.append(conv) + + def forward(self, m): + aux = self.resnet(m).unsqueeze(1) + aux = self.resnet_stretch(aux) + aux = aux.squeeze(1) + m = m.unsqueeze(1) + for f in self.up_layers: m = f(m) + m = m.squeeze(1)[:, :, self.indent:-self.indent] + return m.transpose(1, 2), aux.transpose(1, 2) + + +class WaveRNN(nn.Module): + def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors, + feat_dims, compute_dims, res_out_dims, res_blocks, + hop_length, sample_rate, mode='RAW'): + super().__init__() + self.mode = mode + self.pad = pad + if self.mode == 'RAW' : + self.n_classes = 2 ** bits + elif self.mode == 'MOL' : + self.n_classes = 30 + else : + RuntimeError("Unknown model mode value - ", self.mode) + + self.rnn_dims = rnn_dims + self.aux_dims = res_out_dims // 4 + self.hop_length = hop_length + self.sample_rate = sample_rate + + self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad) + self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims) + self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True) + self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) + self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims) + self.fc3 = nn.Linear(fc_dims, self.n_classes) + + self.step = nn.Parameter(torch.zeros(1).long(), requires_grad=False) + self.num_params() + + def forward(self, x, mels): + self.step += 1 + bsize = x.size(0) + if torch.cuda.is_available(): + h1 = torch.zeros(1, bsize, self.rnn_dims).cuda() + h2 = torch.zeros(1, bsize, self.rnn_dims).cuda() + else: + h1 = torch.zeros(1, bsize, self.rnn_dims).cpu() + h2 = torch.zeros(1, bsize, self.rnn_dims).cpu() + mels, aux = self.upsample(mels) + + aux_idx = [self.aux_dims * i for i in range(5)] + a1 = aux[:, :, aux_idx[0]:aux_idx[1]] + a2 = aux[:, :, aux_idx[1]:aux_idx[2]] + a3 = aux[:, :, aux_idx[2]:aux_idx[3]] + a4 = aux[:, :, aux_idx[3]:aux_idx[4]] + + x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2) + x = self.I(x) + res = x + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = torch.cat([x, a2], dim=2) + x, _ = self.rnn2(x, h2) + + x = x + res + x = torch.cat([x, a3], dim=2) + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4], dim=2) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + def generate(self, mels, batched, target, overlap, mu_law, progress_callback=None): + mu_law = mu_law if self.mode == 'RAW' else False + progress_callback = progress_callback or self.gen_display + + self.eval() + output = [] + start = time.time() + rnn1 = self.get_gru_cell(self.rnn1) + rnn2 = self.get_gru_cell(self.rnn2) + + with torch.no_grad(): + if torch.cuda.is_available(): + mels = mels.cuda() + else: + mels = mels.cpu() + wave_len = (mels.size(-1) - 1) * self.hop_length + mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side='both') + mels, aux = self.upsample(mels.transpose(1, 2)) + + if batched: + mels = self.fold_with_overlap(mels, target, overlap) + aux = self.fold_with_overlap(aux, target, overlap) + + b_size, seq_len, _ = mels.size() + + if torch.cuda.is_available(): + h1 = torch.zeros(b_size, self.rnn_dims).cuda() + h2 = torch.zeros(b_size, self.rnn_dims).cuda() + x = torch.zeros(b_size, 1).cuda() + else: + h1 = torch.zeros(b_size, self.rnn_dims).cpu() + h2 = torch.zeros(b_size, self.rnn_dims).cpu() + x = torch.zeros(b_size, 1).cpu() + + d = self.aux_dims + aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)] + + for i in range(seq_len): + + m_t = mels[:, i, :] + + a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) + + x = torch.cat([x, m_t, a1_t], dim=1) + x = self.I(x) + h1 = rnn1(x, h1) + + x = x + h1 + inp = torch.cat([x, a2_t], dim=1) + h2 = rnn2(inp, h2) + + x = x + h2 + x = torch.cat([x, a3_t], dim=1) + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4_t], dim=1) + x = F.relu(self.fc2(x)) + + logits = self.fc3(x) + + if self.mode == 'MOL': + sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2)) + output.append(sample.view(-1)) + if torch.cuda.is_available(): + # x = torch.FloatTensor([[sample]]).cuda() + x = sample.transpose(0, 1).cuda() + else: + x = sample.transpose(0, 1) + + elif self.mode == 'RAW' : + posterior = F.softmax(logits, dim=1) + distrib = torch.distributions.Categorical(posterior) + + sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1. + output.append(sample) + x = sample.unsqueeze(-1) + else: + raise RuntimeError("Unknown model mode value - ", self.mode) + + if i % 100 == 0: + gen_rate = (i + 1) / (time.time() - start) * b_size / 1000 + progress_callback(i, seq_len, b_size, gen_rate) + + output = torch.stack(output).transpose(0, 1) + output = output.cpu().numpy() + output = output.astype(np.float64) + + if batched: + output = self.xfade_and_unfold(output, target, overlap) + else: + output = output[0] + + if mu_law: + output = decode_mu_law(output, self.n_classes, False) + if hp.apply_preemphasis: + output = de_emphasis(output) + + # Fade-out at the end to avoid signal cutting out suddenly + fade_out = np.linspace(1, 0, 20 * self.hop_length) + output = output[:wave_len] + output[-20 * self.hop_length:] *= fade_out + + self.train() + + return output + + + def gen_display(self, i, seq_len, b_size, gen_rate): + pbar = progbar(i, seq_len) + msg = f'| {pbar} {i*b_size}/{seq_len*b_size} | Batch Size: {b_size} | Gen Rate: {gen_rate:.1f}kHz | ' + stream(msg) + + def get_gru_cell(self, gru): + gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) + gru_cell.weight_hh.data = gru.weight_hh_l0.data + gru_cell.weight_ih.data = gru.weight_ih_l0.data + gru_cell.bias_hh.data = gru.bias_hh_l0.data + gru_cell.bias_ih.data = gru.bias_ih_l0.data + return gru_cell + + def pad_tensor(self, x, pad, side='both'): + # NB - this is just a quick method i need right now + # i.e., it won't generalise to other shapes/dims + b, t, c = x.size() + total = t + 2 * pad if side == 'both' else t + pad + if torch.cuda.is_available(): + padded = torch.zeros(b, total, c).cuda() + else: + padded = torch.zeros(b, total, c).cpu() + if side == 'before' or side == 'both': + padded[:, pad:pad + t, :] = x + elif side == 'after': + padded[:, :t, :] = x + return padded + + def fold_with_overlap(self, x, target, overlap): + + ''' Fold the tensor with overlap for quick batched inference. + Overlap will be used for crossfading in xfade_and_unfold() + + Args: + x (tensor) : Upsampled conditioning features. + shape=(1, timesteps, features) + target (int) : Target timesteps for each index of batch + overlap (int) : Timesteps for both xfade and rnn warmup + + Return: + (tensor) : shape=(num_folds, target + 2 * overlap, features) + + Details: + x = [[h1, h2, ... hn]] + + Where each h is a vector of conditioning features + + Eg: target=2, overlap=1 with x.size(1)=10 + + folded = [[h1, h2, h3, h4], + [h4, h5, h6, h7], + [h7, h8, h9, h10]] + ''' + + _, total_len, features = x.size() + + # Calculate variables needed + num_folds = (total_len - overlap) // (target + overlap) + extended_len = num_folds * (overlap + target) + overlap + remaining = total_len - extended_len + + # Pad if some time steps poking out + if remaining != 0: + num_folds += 1 + padding = target + 2 * overlap - remaining + x = self.pad_tensor(x, padding, side='after') + + if torch.cuda.is_available(): + folded = torch.zeros(num_folds, target + 2 * overlap, features).cuda() + else: + folded = torch.zeros(num_folds, target + 2 * overlap, features).cpu() + + # Get the values for the folded tensor + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + folded[i] = x[:, start:end, :] + + return folded + + def xfade_and_unfold(self, y, target, overlap): + + ''' Applies a crossfade and unfolds into a 1d array. + + Args: + y (ndarry) : Batched sequences of audio samples + shape=(num_folds, target + 2 * overlap) + dtype=np.float64 + overlap (int) : Timesteps for both xfade and rnn warmup + + Return: + (ndarry) : audio samples in a 1d array + shape=(total_len) + dtype=np.float64 + + Details: + y = [[seq1], + [seq2], + [seq3]] + + Apply a gain envelope at both ends of the sequences + + y = [[seq1_in, seq1_target, seq1_out], + [seq2_in, seq2_target, seq2_out], + [seq3_in, seq3_target, seq3_out]] + + Stagger and add up the groups of samples: + + [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...] + + ''' + + num_folds, length = y.shape + target = length - 2 * overlap + total_len = num_folds * (target + overlap) + overlap + + # Need some silence for the rnn warmup + silence_len = overlap // 2 + fade_len = overlap - silence_len + silence = np.zeros((silence_len), dtype=np.float64) + + # Equal power crossfade + t = np.linspace(-1, 1, fade_len, dtype=np.float64) + fade_in = np.sqrt(0.5 * (1 + t)) + fade_out = np.sqrt(0.5 * (1 - t)) + + # Concat the silence to the fades + fade_in = np.concatenate([silence, fade_in]) + fade_out = np.concatenate([fade_out, silence]) + + # Apply the gain to the overlap samples + y[:, :overlap] *= fade_in + y[:, -overlap:] *= fade_out + + unfolded = np.zeros((total_len), dtype=np.float64) + + # Loop to add up all the samples + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + unfolded[start:end] += y[i] + + return unfolded + + def get_step(self) : + return self.step.data.item() + + def checkpoint(self, model_dir, optimizer) : + k_steps = self.get_step() // 1000 + self.save(model_dir.joinpath("checkpoint_%dk_steps.pt" % k_steps), optimizer) + + def log(self, path, msg) : + with open(path, 'a') as f: + print(msg, file=f) + + def load(self, path, optimizer) : + checkpoint = torch.load(path) + if "optimizer_state" in checkpoint: + self.load_state_dict(checkpoint["model_state"]) + optimizer.load_state_dict(checkpoint["optimizer_state"]) + else: + # Backwards compatibility + self.load_state_dict(checkpoint) + + def save(self, path, optimizer) : + torch.save({ + "model_state": self.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, path) + + def num_params(self, print_out=True): + parameters = filter(lambda p: p.requires_grad, self.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + if print_out : + print('Trainable Parameters: %.3fM' % parameters) diff --git a/vocoder/train.py b/vocoder/train.py new file mode 100644 index 0000000000000000000000000000000000000000..6dc2f892e1fc134b311e2c9ee42250a2d3713547 --- /dev/null +++ b/vocoder/train.py @@ -0,0 +1,127 @@ +from vocoder.models.fatchord_version import WaveRNN +from vocoder.vocoder_dataset import VocoderDataset, collate_vocoder +from vocoder.distribution import discretized_mix_logistic_loss +from vocoder.display import stream, simple_table +from vocoder.gen_wavernn import gen_testset +from torch.utils.data import DataLoader +from pathlib import Path +from torch import optim +import torch.nn.functional as F +import vocoder.hparams as hp +import numpy as np +import time +import torch +import platform + +def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool, + save_every: int, backup_every: int, force_restart: bool): + # Check to make sure the hop length is correctly factorised + assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length + + # Instantiate the model + print("Initializing the model...") + model = WaveRNN( + rnn_dims=hp.voc_rnn_dims, + fc_dims=hp.voc_fc_dims, + bits=hp.bits, + pad=hp.voc_pad, + upsample_factors=hp.voc_upsample_factors, + feat_dims=hp.num_mels, + compute_dims=hp.voc_compute_dims, + res_out_dims=hp.voc_res_out_dims, + res_blocks=hp.voc_res_blocks, + hop_length=hp.hop_length, + sample_rate=hp.sample_rate, + mode=hp.voc_mode + ) + + if torch.cuda.is_available(): + model = model.cuda() + device = torch.device('cuda') + else: + device = torch.device('cpu') + + # Initialize the optimizer + optimizer = optim.Adam(model.parameters()) + for p in optimizer.param_groups: + p["lr"] = hp.voc_lr + loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss + + # Load the weights + model_dir = models_dir.joinpath(run_id) + model_dir.mkdir(exist_ok=True) + weights_fpath = model_dir.joinpath(run_id + ".pt") + if force_restart or not weights_fpath.exists(): + print("\nStarting the training of WaveRNN from scratch\n") + model.save(weights_fpath, optimizer) + else: + print("\nLoading weights at %s" % weights_fpath) + model.load(weights_fpath, optimizer) + print("WaveRNN weights loaded from step %d" % model.step) + + # Initialize the dataset + metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \ + voc_dir.joinpath("synthesized.txt") + mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath("mels_gta") + wav_dir = syn_dir.joinpath("audio") + dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir) + test_loader = DataLoader(dataset, + batch_size=1, + shuffle=True, + pin_memory=True) + + # Begin the training + simple_table([('Batch size', hp.voc_batch_size), + ('LR', hp.voc_lr), + ('Sequence Len', hp.voc_seq_len)]) + + for epoch in range(1, 350): + data_loader = DataLoader(dataset, + collate_fn=collate_vocoder, + batch_size=hp.voc_batch_size, + num_workers=2 if platform.system() != "Windows" else 0, + shuffle=True, + pin_memory=True) + start = time.time() + running_loss = 0. + + for i, (x, y, m) in enumerate(data_loader, 1): + if torch.cuda.is_available(): + x, m, y = x.cuda(), m.cuda(), y.cuda() + + # Forward pass + y_hat = model(x, m) + if model.mode == 'RAW': + y_hat = y_hat.transpose(1, 2).unsqueeze(-1) + elif model.mode == 'MOL': + y = y.float() + y = y.unsqueeze(-1) + + # Backward pass + loss = loss_func(y_hat, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + running_loss += loss.item() + speed = i / (time.time() - start) + avg_loss = running_loss / i + + step = model.get_step() + k = step // 1000 + + if backup_every != 0 and step % backup_every == 0 : + model.checkpoint(model_dir, optimizer) + + if save_every != 0 and step % save_every == 0 : + model.save(weights_fpath, optimizer) + + msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \ + f"Loss: {avg_loss:.4f} | {speed:.1f} " \ + f"steps/s | Step: {k}k | " + stream(msg) + + + gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched, + hp.voc_target, hp.voc_overlap, model_dir) + print("") diff --git a/vocoder/vocoder_dataset.py b/vocoder/vocoder_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9eae1b5f20117feef0a06e264a99b3c0c6143bac --- /dev/null +++ b/vocoder/vocoder_dataset.py @@ -0,0 +1,84 @@ +from torch.utils.data import Dataset +from pathlib import Path +from vocoder import audio +import vocoder.hparams as hp +import numpy as np +import torch + + +class VocoderDataset(Dataset): + def __init__(self, metadata_fpath: Path, mel_dir: Path, wav_dir: Path): + print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, wav_dir)) + + with metadata_fpath.open("r") as metadata_file: + metadata = [line.split("|") for line in metadata_file] + + gta_fnames = [x[1] for x in metadata if int(x[4])] + gta_fpaths = [mel_dir.joinpath(fname) for fname in gta_fnames] + wav_fnames = [x[0] for x in metadata if int(x[4])] + wav_fpaths = [wav_dir.joinpath(fname) for fname in wav_fnames] + self.samples_fpaths = list(zip(gta_fpaths, wav_fpaths)) + + print("Found %d samples" % len(self.samples_fpaths)) + + def __getitem__(self, index): + mel_path, wav_path = self.samples_fpaths[index] + + # Load the mel spectrogram and adjust its range to [-1, 1] + mel = np.load(mel_path).T.astype(np.float32) / hp.mel_max_abs_value + + # Load the wav + wav = np.load(wav_path) + if hp.apply_preemphasis: + wav = audio.pre_emphasis(wav) + wav = np.clip(wav, -1, 1) + + # Fix for missing padding # TODO: settle on whether this is any useful + r_pad = (len(wav) // hp.hop_length + 1) * hp.hop_length - len(wav) + wav = np.pad(wav, (0, r_pad), mode='constant') + assert len(wav) >= mel.shape[1] * hp.hop_length + wav = wav[:mel.shape[1] * hp.hop_length] + assert len(wav) % hp.hop_length == 0 + + # Quantize the wav + if hp.voc_mode == 'RAW': + if hp.mu_law: + quant = audio.encode_mu_law(wav, mu=2 ** hp.bits) + else: + quant = audio.float_2_label(wav, bits=hp.bits) + elif hp.voc_mode == 'MOL': + quant = audio.float_2_label(wav, bits=16) + + return mel.astype(np.float32), quant.astype(np.int64) + + def __len__(self): + return len(self.samples_fpaths) + + +def collate_vocoder(batch): + mel_win = hp.voc_seq_len // hp.hop_length + 2 * hp.voc_pad + max_offsets = [x[0].shape[-1] -2 - (mel_win + 2 * hp.voc_pad) for x in batch] + mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] + sig_offsets = [(offset + hp.voc_pad) * hp.hop_length for offset in mel_offsets] + + mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] for i, x in enumerate(batch)] + + labels = [x[1][sig_offsets[i]:sig_offsets[i] + hp.voc_seq_len + 1] for i, x in enumerate(batch)] + + mels = np.stack(mels).astype(np.float32) + labels = np.stack(labels).astype(np.int64) + + mels = torch.tensor(mels) + labels = torch.tensor(labels).long() + + x = labels[:, :hp.voc_seq_len] + y = labels[:, 1:] + + bits = 16 if hp.voc_mode == 'MOL' else hp.bits + + x = audio.label_2_float(x.float(), bits) + + if hp.voc_mode == 'MOL' : + y = audio.label_2_float(y.float(), bits) + + return x, y, mels \ No newline at end of file diff --git a/vocoder_preprocess.py b/vocoder_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..7ede3dfb95972e2de575de35b9d4a9c6d642885e --- /dev/null +++ b/vocoder_preprocess.py @@ -0,0 +1,59 @@ +from synthesizer.synthesize import run_synthesis +from synthesizer.hparams import hparams +from utils.argutils import print_args +import argparse +import os + + +if __name__ == "__main__": + class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): + pass + + parser = argparse.ArgumentParser( + description="Creates ground-truth aligned (GTA) spectrograms from the vocoder.", + formatter_class=MyFormatter + ) + parser.add_argument("datasets_root", type=str, help=\ + "Path to the directory containing your SV2TTS directory. If you specify both --in_dir and " + "--out_dir, this argument won't be used.") + parser.add_argument("--model_dir", type=str, + default="synthesizer/saved_models/pretrained/", help=\ + "Path to the pretrained model directory.") + parser.add_argument("-i", "--in_dir", type=str, default=argparse.SUPPRESS, help= \ + "Path to the synthesizer directory that contains the mel spectrograms, the wavs and the " + "embeds. Defaults to /SV2TTS/synthesizer/.") + parser.add_argument("-o", "--out_dir", type=str, default=argparse.SUPPRESS, help= \ + "Path to the output vocoder directory that will contain the ground truth aligned mel " + "spectrograms. Defaults to /SV2TTS/vocoder/.") + parser.add_argument("--hparams", default="", + help="Hyperparameter overrides as a comma-separated list of name=value " + "pairs") + parser.add_argument("--no_trim", action="store_true", help=\ + "Preprocess audio without trimming silences (not recommended).") + parser.add_argument("--cpu", action="store_true", help=\ + "If True, processing is done on CPU, even when a GPU is available.") + args = parser.parse_args() + print_args(args, parser) + modified_hp = hparams.parse(args.hparams) + + if not hasattr(args, "in_dir"): + args.in_dir = os.path.join(args.datasets_root, "SV2TTS", "synthesizer") + if not hasattr(args, "out_dir"): + args.out_dir = os.path.join(args.datasets_root, "SV2TTS", "vocoder") + + if args.cpu: + # Hide GPUs from Pytorch to force CPU processing + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + # Verify webrtcvad is available + if not args.no_trim: + try: + import webrtcvad + except: + raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables " + "noise removal and is recommended. Please install and try again. If installation fails, " + "use --no_trim to disable this error message.") + del args.no_trim + + run_synthesis(args.in_dir, args.out_dir, args.model_dir, modified_hp) + diff --git a/vocoder_train.py b/vocoder_train.py new file mode 100644 index 0000000000000000000000000000000000000000..d712ffa3e6c92a091aa18dc90f0027f46940e400 --- /dev/null +++ b/vocoder_train.py @@ -0,0 +1,56 @@ +from utils.argutils import print_args +from vocoder.train import train +from pathlib import Path +import argparse + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Trains the vocoder from the synthesizer audios and the GTA synthesized mels, " + "or ground truth mels.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("run_id", type=str, help= \ + "Name for this model instance. If a model state from the same run ID was previously " + "saved, the training will restart from there. Pass -f to overwrite saved states and " + "restart from scratch.") + parser.add_argument("datasets_root", type=str, help= \ + "Path to the directory containing your SV2TTS directory. Specifying --syn_dir or --voc_dir " + "will take priority over this argument.") + parser.add_argument("--syn_dir", type=str, default=argparse.SUPPRESS, help= \ + "Path to the synthesizer directory that contains the ground truth mel spectrograms, " + "the wavs and the embeds. Defaults to /SV2TTS/synthesizer/.") + parser.add_argument("--voc_dir", type=str, default=argparse.SUPPRESS, help= \ + "Path to the vocoder directory that contains the GTA synthesized mel spectrograms. " + "Defaults to /SV2TTS/vocoder/. Unused if --ground_truth is passed.") + parser.add_argument("-m", "--models_dir", type=str, default="vocoder/saved_models/", help=\ + "Path to the directory that will contain the saved model weights, as well as backups " + "of those weights and wavs generated during training.") + parser.add_argument("-g", "--ground_truth", action="store_true", help= \ + "Train on ground truth spectrograms (/SV2TTS/synthesizer/mels).") + parser.add_argument("-s", "--save_every", type=int, default=1000, help= \ + "Number of steps between updates of the model on the disk. Set to 0 to never save the " + "model.") + parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \ + "Number of steps between backups of the model. Set to 0 to never make backups of the " + "model.") + parser.add_argument("-f", "--force_restart", action="store_true", help= \ + "Do not load any saved model and restart from scratch.") + args = parser.parse_args() + + # Process the arguments + if not hasattr(args, "syn_dir"): + args.syn_dir = Path(args.datasets_root, "SV2TTS", "synthesizer") + args.syn_dir = Path(args.syn_dir) + if not hasattr(args, "voc_dir"): + args.voc_dir = Path(args.datasets_root, "SV2TTS", "vocoder") + args.voc_dir = Path(args.voc_dir) + del args.datasets_root + args.models_dir = Path(args.models_dir) + args.models_dir.mkdir(exist_ok=True) + + # Run the training + print_args(args, parser) + train(**vars(args)) + \ No newline at end of file