diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..23334b4fdeebfe2f0b8cb6d1c63b488ba56958d2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +g_00120000 filter=lfs diff=lfs merge=lfs -text +g_05000000 filter=lfs diff=lfs merge=lfs -text diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..329e9e8a1ecab118fee15061c9ac30d3919bf213 --- /dev/null +++ b/app.py @@ -0,0 +1,217 @@ + +from pathlib import Path +import torchaudio +import gradio as gr + +import numpy as np + +import torch + + +from hifigan.config import v1 +from hifigan.denoiser import Denoiser +from hifigan.env import AttrDict +from hifigan.models import Generator as HiFiGAN + + +#from BigVGAN.models import BigVGAN +#from BigVGAN.env import AttrDict as BigVGANAttrDict + + +from pflow.models.pflow_tts import pflowTTS +from pflow.text import text_to_sequence, sequence_to_text +from pflow.utils.utils import intersperse +from pflow.data.text_mel_datamodule import mel_spectrogram +from pflow.utils.model import normalize + + + +BIGVGAN_CONFIG = { + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "activation": "snakebeta", + "snake_logscale": True, + + "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": False, + "discriminator_channel_mult": 1, + + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": None, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} + +PFLOW_MODEL_PATH = 'checkpoint_epoch=499.ckpt' +VOCODER_MODEL_PATH = 'g_00120000' +VOCODER_BIGVGAN_MODEL_PATH = 'g_05000000' + +wav, sr = torchaudio.load('prompt.wav') + +prompt = mel_spectrogram( + wav, + 1024, + 80, + 22050, + 256, + 1024, + 0, + 8000, + center=False, + )[:,:,:264] + + + +def process_text(text: str, device: torch.device): + x = torch.tensor( + intersperse(text_to_sequence(text, ["ukr_cleaners"]), 0), + dtype=torch.long, + device=device, + )[None] + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) + x_phones = sequence_to_text(x.squeeze(0).tolist()) + return {"x_orig": text, "x": x, "x_lengths": x_lengths, 'x_phones':x_phones} + + + + +def load_hifigan(checkpoint_path, device): + h = AttrDict(v1) + hifigan = HiFiGAN(h).to(device) + hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"]) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def load_bigvgan(checkpoint_path, device): + print("Loading '{}'".format(checkpoint_path)) + checkpoint_dict = torch.load(checkpoint_path, map_location=device) + + + h = BigVGANAttrDict(BIGVGAN_CONFIG) + torch.manual_seed(h.seed) + + generator = BigVGAN(h).to(device) + generator.load_state_dict(checkpoint_dict['generator']) + generator.eval() + generator.remove_weight_norm() + return generator + + +def to_waveform(mel, vocoder, denoiser=None): + audio = vocoder(mel).clamp(-1, 1) + if denoiser is not None: + audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze() + + return audio.cpu().squeeze() + + + + + + +def get_device(): + if torch.cuda.is_available(): + print("[+] GPU Available! Using GPU") + device = torch.device("cuda") + else: + print("[-] GPU not available or forced CPU run! Using CPU") + device = torch.device("cpu") + return device + + +device = get_device() +model = pflowTTS.load_from_checkpoint(PFLOW_MODEL_PATH, map_location=device) +_ = model.eval() +#vocoder = load_bigvgan(VOCODER_BIGVGAN_MODEL_PATH, device) +vocoder = load_hifigan(VOCODER_MODEL_PATH, device) +denoiser = Denoiser(vocoder, mode="zeros") + +@torch.inference_mode() +def synthesise(text, temperature, speed): + if len(text) > 1000: + raise gr.Error("Текст повинен бути коротшим за 1000 символів.") + + text_processed = process_text(text.strip(), device) + + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=40, + temperature=temperature, + length_scale=1/speed, + prompt= normalize(prompt, model.mel_mean, model.mel_std) + ) + waveform = to_waveform(output["mel"], vocoder, denoiser) + + return text_processed['x_phones'][1::2], (22050, waveform.numpy()) + + +description = f''' +# Експериментальна апка для генерації аудіо з тексту. + + pflow checkpoint {PFLOW_MODEL_PATH} + vocoder: HIFIGAN(трейнутий на датасеті, з нуля) - {VOCODER_MODEL_PATH} +''' + + +if __name__ == "__main__": + i = gr.Interface( + fn=synthesise, + description=description, + inputs=[ + gr.Text(label='Текст для синтезу:', lines=5, max_lines=10), + gr.Slider(minimum=0.0, maximum=1.0, label="Температура", value=0.2), + gr.Slider(minimum=0.6, maximum=2.0, label="Швидкість", value=1.0) + ], + outputs=[ + gr.Text(label='Фонемізований текст:', lines=5), + gr.Audio( + label="Згенероване аудіо:", + autoplay=False, + streaming=False, + type="numpy", + ) + + ], + allow_flagging ='manual', + flagging_options=[("Якщо дуже погоне аудіо, тисни цю кнопку.", "negative")], + cache_examples=True, + title='', + # description=description, + # article=article, + # examples=examples, + ) + i.queue(max_size=20, default_concurrency_limit=4) + i.launch(share=False, server_name="0.0.0.0") diff --git a/checkpoint_epoch=499.ckpt b/checkpoint_epoch=499.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..b956ae1d9042cdcb4f0b9c5a2fd590f8b261d9c8 --- /dev/null +++ b/checkpoint_epoch=499.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39051170c6c0d9abce47d0073f796912d5ce3854ade8f707cb30333f50160d99 +size 279562867 diff --git a/g_00120000 b/g_00120000 new file mode 100644 index 0000000000000000000000000000000000000000..9bbb40f605a2d4549368c6540106862e72b5a2f2 --- /dev/null +++ b/g_00120000 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f25c6dbc515ed387edd5d2e5683a50510aa33986e8a79273efe1216084f0f078 +size 55824433 diff --git a/hifigan/LICENSE b/hifigan/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..91751daed806f63ac594cf077a3065f719a41662 --- /dev/null +++ b/hifigan/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Jungil Kong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/hifigan/README.md b/hifigan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5db25850451a794b1db1b15b08e82c1d802edbb3 --- /dev/null +++ b/hifigan/README.md @@ -0,0 +1,101 @@ +# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis + +### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae + +In our [paper](https://arxiv.org/abs/2010.05646), +we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
+We provide our implementation and pretrained models as open source in this repository. + +**Abstract :** +Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. +Although such methods improve the sampling efficiency and memory usage, +their sample quality has not yet reached that of autoregressive and flow-based generative models. +In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. +As speech audio consists of sinusoidal signals with various periods, +we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. +A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method +demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than +real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen +speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times +faster than real-time on CPU with comparable quality to an autoregressive counterpart. + +Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. + +## Pre-requisites + +1. Python >= 3.6 +2. Clone this repository. +3. Install python requirements. Please refer [requirements.txt](requirements.txt) +4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). + And move all wav files to `LJSpeech-1.1/wavs` + +## Training + +``` +python train.py --config config_v1.json +``` + +To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
+Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
+You can change the path by adding `--checkpoint_path` option. + +Validation loss during training with V1 generator.
+![validation loss](./validation_loss.png) + +## Pretrained Model + +You can also use pretrained models we provide.
+[Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
+Details of each folder are as in follows: + +| Folder Name | Generator | Dataset | Fine-Tuned | +| ------------ | --------- | --------- | ------------------------------------------------------ | +| LJ_V1 | V1 | LJSpeech | No | +| LJ_V2 | V2 | LJSpeech | No | +| LJ_V3 | V3 | LJSpeech | No | +| LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| VCTK_V1 | V1 | VCTK | No | +| VCTK_V2 | V2 | VCTK | No | +| VCTK_V3 | V3 | VCTK | No | +| UNIVERSAL_V1 | V1 | Universal | No | + +We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. + +## Fine-Tuning + +1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
+ The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
+ Example: + ` Audio File : LJ001-0001.wav +Mel-Spectrogram File : LJ001-0001.npy` +2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
+3. Run the following command. + ``` + python train.py --fine_tuning True --config config_v1.json + ``` + For other command line options, please refer to the training section. + +## Inference from wav file + +1. Make `test_files` directory and copy wav files into the directory. +2. Run the following command. + ` python inference.py --checkpoint_file [generator checkpoint file path]` + Generated wav files are saved in `generated_files` by default.
+ You can change the path by adding `--output_dir` option. + +## Inference for end-to-end speech synthesis + +1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
+ You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), + [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. +2. Run the following command. + ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` + Generated wav files are saved in `generated_files_from_mel` by default.
+ You can change the path by adding `--output_dir` option. + +## Acknowledgements + +We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) +and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. diff --git a/hifigan/__init__.py b/hifigan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hifigan/config.py b/hifigan/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b3abea9e151a08864353d32066bd4935e24b82e7 --- /dev/null +++ b/hifigan/config.py @@ -0,0 +1,28 @@ +v1 = { + "resblock": "1", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0004, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_initial_channel": 256, + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + "sampling_rate": 22050, + "fmin": 0, + "fmax": 8000, + "fmax_loss": None, + "num_workers": 4, + "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, +} diff --git a/hifigan/denoiser.py b/hifigan/denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..9fd33312a09b1940374a0e29a97fe3a1a1dac7d2 --- /dev/null +++ b/hifigan/denoiser.py @@ -0,0 +1,64 @@ +# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py + +"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" +import torch + + +class Denoiser(torch.nn.Module): + """Removes model bias from audio produced with waveglow""" + + def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): + super().__init__() + self.filter_length = filter_length + self.hop_length = int(filter_length / n_overlap) + self.win_length = win_length + + dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device + self.device = device + if mode == "zeros": + mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) + elif mode == "normal": + mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) + else: + raise Exception(f"Mode {mode} if not supported") + + def stft_fn(audio, n_fft, hop_length, win_length, window): + spec = torch.stft( + audio, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + return_complex=True, + ) + spec = torch.view_as_real(spec) + return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) + + self.stft = lambda x: stft_fn( + audio=x, + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=torch.hann_window(self.win_length, device=device), + ) + self.istft = lambda x, y: torch.istft( + torch.complex(x * torch.cos(y), x * torch.sin(y)), + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=torch.hann_window(self.win_length, device=device), + ) + + with torch.no_grad(): + bias_audio = vocoder(mel_input).float().squeeze(0) + bias_spec, _ = self.stft(bias_audio) + + self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) + + @torch.inference_mode() + def forward(self, audio, strength=0.0005): + audio_spec, audio_angles = self.stft(audio) + audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength + audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) + audio_denoised = self.istft(audio_spec_denoised, audio_angles) + return audio_denoised diff --git a/hifigan/env.py b/hifigan/env.py new file mode 100644 index 0000000000000000000000000000000000000000..9ea4f948a3f002921bf9bc24f52cbc1c0b1fc2ec --- /dev/null +++ b/hifigan/env.py @@ -0,0 +1,17 @@ +""" from https://github.com/jik876/hifi-gan """ + +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/hifigan/meldataset.py b/hifigan/meldataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8b43ea7965e04a52d5427a485ee911b743057c4a --- /dev/null +++ b/hifigan/meldataset.py @@ -0,0 +1,217 @@ +""" from https://github.com/jik876/hifi-gan """ + +import math +import os +import random + +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from librosa.util import normalize +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def get_dataset_filelist(a): + with open(a.input_training_file, encoding="utf-8") as fi: + training_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + + with open(a.input_validation_file, encoding="utf-8") as fi: + validation_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + return training_files, validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__( + self, + training_files, + segment_size, + n_fft, + num_mels, + hop_size, + win_size, + sampling_rate, + fmin, + fmax, + split=True, + shuffle=True, + n_cache_reuse=1, + device=None, + fmax_loss=None, + fine_tuning=False, + base_mels_path=None, + ): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.cached_wav = None + self.n_cache_reuse = n_cache_reuse + self._cache_ref_count = 0 + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + def __getitem__(self, index): + filename = self.audio_files[index] + if self._cache_ref_count == 0: + audio, sampling_rate = load_wav(filename) + audio = audio / MAX_WAV_VALUE + if not self.fine_tuning: + audio = normalize(audio) * 0.95 + self.cached_wav = audio + if sampling_rate != self.sampling_rate: + raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") + self._cache_ref_count = self.n_cache_reuse + else: + audio = self.cached_wav + self._cache_ref_count -= 1 + + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) + + if not self.fine_tuning: + if self.split: + if audio.size(1) >= self.segment_size: + max_audio_start = audio.size(1) - self.segment_size + audio_start = random.randint(0, max_audio_start) + audio = audio[:, audio_start : audio_start + self.segment_size] + else: + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + + mel = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax, + center=False, + ) + else: + mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start : mel_start + frames_per_seg] + audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] + else: + mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + + mel_loss = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax_loss, + center=False, + ) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + def __len__(self): + return len(self.audio_files) diff --git a/hifigan/models.py b/hifigan/models.py new file mode 100644 index 0000000000000000000000000000000000000000..d209d9a4e99ec29e4167a5a2eaa62d72b3eff694 --- /dev/null +++ b/hifigan/models.py @@ -0,0 +1,368 @@ +""" from https://github.com/jik876/hifi-gan """ + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .xutils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.h = h + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super().__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super().__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for _, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/hifigan/xutils.py b/hifigan/xutils.py new file mode 100644 index 0000000000000000000000000000000000000000..eefadcb7a1d0bf9015e636b88fee3e22c9771bc5 --- /dev/null +++ b/hifigan/xutils.py @@ -0,0 +1,60 @@ +""" from https://github.com/jik876/hifi-gan """ + +import glob +import os + +import matplotlib +import torch +from torch.nn.utils import weight_norm + +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print(f"Loading '{filepath}'") + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print(f"Saving checkpoint to {filepath}") + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + "????????") + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] diff --git a/pflow/__init__.py b/pflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pflow/data/__init__.py b/pflow/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pflow/data/components/__init__.py b/pflow/data/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pflow/data/text_mel_datamodule.py b/pflow/data/text_mel_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..b87f5dec6d949d4defb5244fa01c0eb9ff27ef82 --- /dev/null +++ b/pflow/data/text_mel_datamodule.py @@ -0,0 +1,256 @@ +import random +from typing import Any, Dict, Optional + +import torch +import torchaudio as ta +from lightning import LightningDataModule +from torch.utils.data.dataloader import DataLoader + +from pflow.text import text_to_sequence +from pflow.utils.audio import mel_spectrogram +from pflow.utils.model import fix_len_compatibility, normalize +from pflow.utils.utils import intersperse + + +def parse_filelist(filelist_path, split_char="|"): + with open(filelist_path, encoding="utf-8") as f: + filepaths_and_text = [line.strip().split(split_char) for line in f] + return filepaths_and_text + + +class TextMelDataModule(LightningDataModule): + def __init__( # pylint: disable=unused-argument + self, + name, + train_filelist_path, + valid_filelist_path, + batch_size, + num_workers, + pin_memory, + cleaners, + add_blank, + n_spks, + n_fft, + n_feats, + sample_rate, + hop_length, + win_length, + f_min, + f_max, + data_statistics, + seed, + ): + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument + """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + + This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be + careful not to execute things like random split twice! + """ + # load and split datasets only if not loaded already + + self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init + self.hparams.train_filelist_path, + self.hparams.n_spks, + self.hparams.cleaners, + self.hparams.add_blank, + self.hparams.n_fft, + self.hparams.n_feats, + self.hparams.sample_rate, + self.hparams.hop_length, + self.hparams.win_length, + self.hparams.f_min, + self.hparams.f_max, + self.hparams.data_statistics, + self.hparams.seed, + ) + self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init + self.hparams.valid_filelist_path, + self.hparams.n_spks, + self.hparams.cleaners, + self.hparams.add_blank, + self.hparams.n_fft, + self.hparams.n_feats, + self.hparams.sample_rate, + self.hparams.hop_length, + self.hparams.win_length, + self.hparams.f_min, + self.hparams.f_max, + self.hparams.data_statistics, + self.hparams.seed, + ) + + def train_dataloader(self): + return DataLoader( + dataset=self.trainset, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=True, + collate_fn=TextMelBatchCollate(self.hparams.n_spks), + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.validset, + batch_size=self.hparams.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + collate_fn=TextMelBatchCollate(self.hparams.n_spks), + ) + + def teardown(self, stage: Optional[str] = None): + """Clean up after fit or test.""" + pass # pylint: disable=unnecessary-pass + + def state_dict(self): # pylint: disable=no-self-use + """Extra things to save to checkpoint.""" + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + """Things to do when loading checkpoint.""" + pass # pylint: disable=unnecessary-pass + + +class TextMelDataset(torch.utils.data.Dataset): + def __init__( + self, + filelist_path, + n_spks, + cleaners, + add_blank=True, + n_fft=1024, + n_mels=80, + sample_rate=22050, + hop_length=256, + win_length=1024, + f_min=0.0, + f_max=8000, + data_parameters=None, + seed=None, + ): + self.filepaths_and_text = parse_filelist(filelist_path) + self.n_spks = n_spks + self.cleaners = cleaners + self.add_blank = add_blank + self.n_fft = n_fft + self.n_mels = n_mels + self.sample_rate = sample_rate + self.hop_length = hop_length + self.win_length = win_length + self.f_min = f_min + self.f_max = f_max + if data_parameters is not None: + self.data_parameters = data_parameters + else: + self.data_parameters = {"mel_mean": 0, "mel_std": 1} + random.seed(seed) + random.shuffle(self.filepaths_and_text) + + def get_datapoint(self, filepath_and_text): + if self.n_spks > 1: + filepath, spk, text = ( + filepath_and_text[0], + int(filepath_and_text[1]), + filepath_and_text[2], + ) + else: + filepath, text = filepath_and_text[0], filepath_and_text[1] + spk = None + + text = self.get_text(text, add_blank=self.add_blank) + mel, audio = self.get_mel(filepath) + # TODO: make dictionary to get different spec for same speaker + # right now naively repeating target mel for testing purposes + return {"x": text, "y": mel, "spk": spk, "wav":audio} + + def get_mel(self, filepath): + audio, sr = ta.load(filepath) + assert sr == self.sample_rate + mel = mel_spectrogram( + audio, + self.n_fft, + self.n_mels, + self.sample_rate, + self.hop_length, + self.win_length, + self.f_min, + self.f_max, + center=False, + ).squeeze() + mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"]) + return mel, audio + + def get_text(self, text, add_blank=True): + text_norm = text_to_sequence(text, self.cleaners) + if self.add_blank: + text_norm = intersperse(text_norm, 0) + text_norm = torch.IntTensor(text_norm) + return text_norm + + def __getitem__(self, index): + datapoint = self.get_datapoint(self.filepaths_and_text[index]) + if datapoint["wav"].shape[1] <= 66150: + ''' + skip datapoint if too short (3s) + TODO To not waste data, we can concatenate wavs less than 3s and use them + TODO as a hyperparameter; multispeaker dataset can use another wav of same speaker + ''' + return self.__getitem__(random.randint(0, len(self.filepaths_and_text)-1)) + return datapoint + + def __len__(self): + return len(self.filepaths_and_text) + + +class TextMelBatchCollate: + def __init__(self, n_spks): + self.n_spks = n_spks + + def __call__(self, batch): + B = len(batch) + y_max_length = max([item["y"].shape[-1] for item in batch]) + y_max_length = fix_len_compatibility(y_max_length) + wav_max_length = y_max_length * 256 + x_max_length = max([item["x"].shape[-1] for item in batch]) + n_feats = batch[0]["y"].shape[-2] + + y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) + x = torch.zeros((B, x_max_length), dtype=torch.long) + wav = torch.zeros((B, 1, wav_max_length), dtype=torch.float32) + y_lengths, x_lengths = [], [] + wav_lengths = [] + spks = [] + for i, item in enumerate(batch): + y_, x_ = item["y"], item["x"] + wav_ = item["wav"][:,:wav_max_length] if item["wav"].shape[-1] > wav_max_length else item["wav"] + y_lengths.append(y_.shape[-1]) + x_lengths.append(x_.shape[-1]) + wav_lengths.append(wav_.shape[-1]) + y[i, :, : y_.shape[-1]] = y_ + x[i, : x_.shape[-1]] = x_ + wav[i, :, : wav_.shape[-1]] = wav_ + spks.append(item["spk"]) + + y_lengths = torch.tensor(y_lengths, dtype=torch.long) + x_lengths = torch.tensor(x_lengths, dtype=torch.long) + wav_lengths = torch.tensor(wav_lengths, dtype=torch.long) + spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None + + return { + "x": x, + "x_lengths": x_lengths, + "y": y, + "y_lengths": y_lengths, + "spks": spks, + "wav":wav, + "wav_lengths":wav_lengths, + "prompt_spec": y, + "prompt_lengths": y_lengths, + } \ No newline at end of file diff --git a/pflow/models/__init__.py b/pflow/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pflow/models/baselightningmodule.py b/pflow/models/baselightningmodule.py new file mode 100644 index 0000000000000000000000000000000000000000..4e91f2c6e12e9a50a5b02bee59be7ac231d0f184 --- /dev/null +++ b/pflow/models/baselightningmodule.py @@ -0,0 +1,247 @@ +""" +This is a base lightning module that can be used to train a model. +The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. +""" +import inspect +from abc import ABC +from typing import Any, Dict + +import torch +from lightning import LightningModule +from lightning.pytorch.utilities import grad_norm + +from pflow import utils +from pflow.utils.utils import plot_tensor +from pflow.models.components import commons + +log = utils.get_pylogger(__name__) + + +class BaseLightningClass(LightningModule, ABC): + def update_data_statistics(self, data_statistics): + if data_statistics is None: + data_statistics = { + "mel_mean": 0.0, + "mel_std": 1.0, + } + + self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) + self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) + + def configure_optimizers(self) -> Any: + optimizer = self.hparams.optimizer(params=self.parameters()) + if self.hparams.scheduler not in (None, {}): + scheduler_args = {} + # Manage last epoch for exponential schedulers + if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: + if hasattr(self, "ckpt_loaded_epoch"): + current_epoch = self.ckpt_loaded_epoch - 1 + else: + current_epoch = -1 + + scheduler_args.update({"optimizer": optimizer}) + scheduler = self.hparams.scheduler.scheduler(**scheduler_args) + print(self.ckpt_loaded_epoch - 1) + if hasattr(self, "ckpt_loaded_epoch"): + scheduler.last_epoch = self.ckpt_loaded_epoch - 1 + else: + scheduler.last_epoch = -1 + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + # "interval": self.hparams.scheduler.lightning_args.interval, + # "frequency": self.hparams.scheduler.lightning_args.frequency, + # "name": "learning_rate", + "monitor": "val_loss", + }, + } + + return {"optimizer": optimizer} + + def get_losses(self, batch): + x, x_lengths = batch["x"], batch["x_lengths"] + y, y_lengths = batch["y"], batch["y_lengths"] + # prompt_spec = batch["prompt_spec"] + # prompt_lengths = batch["prompt_lengths"] + # prompt_slice, ids_slice = commons.rand_slice_segments( + # prompt_spec, + # prompt_lengths, + # self.prompt_size + # ) + prompt_slice = None + dur_loss, prior_loss, diff_loss, attn = self( + x=x, + x_lengths=x_lengths, + y=y, + y_lengths=y_lengths, + prompt=prompt_slice, + ) + return ({ + "dur_loss": dur_loss, + "prior_loss": prior_loss, + "diff_loss": diff_loss, + }, + { + "attn": attn + } + ) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init + + def training_step(self, batch: Any, batch_idx: int): + loss_dict, attn_dict = self.get_losses(batch) + + self.log( + "step", + float(self.global_step), + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + self.log( + "sub_loss/train_dur_loss", + loss_dict["dur_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/train_prior_loss", + loss_dict["prior_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/train_diff_loss", + loss_dict["diff_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + total_loss = sum(loss_dict.values()) + self.log( + "loss/train", + total_loss, + on_step=True, + on_epoch=True, + logger=True, + prog_bar=True, + sync_dist=True, + ) + attn = attn_dict["attn"][0] + self.logger.experiment.add_image( + f"train/alignment", + plot_tensor(attn.cpu()), + self.current_epoch, + dataformats="HWC", + ) + return {"loss": total_loss, "log": loss_dict} + + def validation_step(self, batch: Any, batch_idx: int): + loss_dict, attn_dict = self.get_losses(batch) + self.log( + "sub_loss/val_dur_loss", + loss_dict["dur_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/val_prior_loss", + loss_dict["prior_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + self.log( + "sub_loss/val_diff_loss", + loss_dict["diff_loss"], + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) + + total_loss = sum(loss_dict.values()) + self.log( + "loss/val", + total_loss, + on_step=True, + on_epoch=True, + logger=True, + prog_bar=True, + sync_dist=True, + ) + + attn = attn_dict["attn"][0] + self.logger.experiment.add_image( + f"val/alignment", + plot_tensor(attn.cpu()), + self.current_epoch, + dataformats="HWC", + ) + return total_loss + + def on_validation_end(self) -> None: + if self.trainer.is_global_zero: + one_batch = next(iter(self.trainer.val_dataloaders)) + + if self.current_epoch == 0: + log.debug("Plotting original samples") + for i in range(2): + y = one_batch["y"][i].unsqueeze(0).to(self.device) + self.logger.experiment.add_image( + f"original/{i}", + plot_tensor(y.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + + log.debug("Synthesising...") + for i in range(2): + x = one_batch["x"][i].unsqueeze(0).to(self.device) + x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) + y = one_batch["y"][i].unsqueeze(0).to(self.device) + y_lengths = one_batch["y_lengths"][i].unsqueeze(0).to(self.device) + # prompt = one_batch["prompt_spec"][i].unsqueeze(0).to(self.device) + # prompt_lengths = one_batch["prompt_lengths"][i].unsqueeze(0).to(self.device) + prompt = y + prompt_lengths = y_lengths + prompt_slice, ids_slice = commons.rand_slice_segments( + prompt, prompt_lengths, self.prompt_size + ) + output = self.synthesise(x[:, :x_lengths], x_lengths, prompt=prompt_slice, n_timesteps=10, guidance_scale=0.0) + y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] + attn = output["attn"] + self.logger.experiment.add_image( + f"generated_enc/{i}", + plot_tensor(y_enc.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + self.logger.experiment.add_image( + f"generated_dec/{i}", + plot_tensor(y_dec.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + self.logger.experiment.add_image( + f"alignment/{i}", + plot_tensor(attn.squeeze().cpu()), + self.current_epoch, + dataformats="HWC", + ) + + def on_before_optimizer_step(self, optimizer): + self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) diff --git a/pflow/models/components/__init__.py b/pflow/models/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pflow/models/components/aligner.py b/pflow/models/components/aligner.py new file mode 100644 index 0000000000000000000000000000000000000000..d89fe1b50d9a364ef62c37a821e1f683522a7c83 --- /dev/null +++ b/pflow/models/components/aligner.py @@ -0,0 +1,235 @@ +from typing import Tuple +import numpy as np + +import torch +from torch import nn, Tensor +from torch.nn import Module +import torch.nn.functional as F + +from einops import rearrange, repeat + +from beartype import beartype +from beartype.typing import Optional + +def exists(val): + return val is not None + +class AlignerNet(Module): + """alignment model https://arxiv.org/pdf/2108.10447.pdf """ + def __init__( + self, + dim_in=80, + dim_hidden=512, + attn_channels=80, + temperature=0.0005, + ): + super().__init__() + self.temperature = temperature + + self.key_layers = nn.ModuleList([ + nn.Conv1d( + dim_hidden, + dim_hidden * 2, + kernel_size=3, + padding=1, + bias=True, + ), + nn.ReLU(inplace=True), + nn.Conv1d(dim_hidden * 2, attn_channels, kernel_size=1, padding=0, bias=True) + ]) + + self.query_layers = nn.ModuleList([ + nn.Conv1d( + dim_in, + dim_in * 2, + kernel_size=3, + padding=1, + bias=True, + ), + nn.ReLU(inplace=True), + nn.Conv1d(dim_in * 2, dim_in, kernel_size=1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv1d(dim_in, attn_channels, kernel_size=1, padding=0, bias=True) + ]) + + @beartype + def forward( + self, + queries: Tensor, + keys: Tensor, + mask: Optional[Tensor] = None + ): + key_out = keys + for layer in self.key_layers: + key_out = layer(key_out) + + query_out = queries + for layer in self.query_layers: + query_out = layer(query_out) + + key_out = rearrange(key_out, 'b c t -> b t c') + query_out = rearrange(query_out, 'b c t -> b t c') + + attn_logp = torch.cdist(query_out, key_out) + attn_logp = rearrange(attn_logp, 'b ... -> b 1 ...') + + if exists(mask): + mask = rearrange(mask.bool(), '... c -> ... 1 c') + attn_logp.data.masked_fill_(~mask, -torch.finfo(attn_logp.dtype).max) + + attn = attn_logp.softmax(dim = -1) + return attn, attn_logp + +def pad_tensor(input, pad, value=0): + pad = [item for sublist in reversed(pad) for item in sublist] # Flatten the tuple + assert len(pad) // 2 == len(input.shape), 'Padding dimensions do not match input dimensions' + return F.pad(input, pad, mode='constant', value=value) + +def maximum_path(value, mask, const=None): + device = value.device + dtype = value.dtype + if not exists(const): + const = torch.tensor(float('-inf')).to(device) # Patch for Sphinx complaint + value = value * mask + + b, t_x, t_y = value.shape + direction = torch.zeros(value.shape, dtype=torch.int64, device=device) + v = torch.zeros((b, t_x), dtype=torch.float32, device=device) + x_range = torch.arange(t_x, dtype=torch.float32, device=device).view(1, -1) + + for j in range(t_y): + v0 = pad_tensor(v, ((0, 0), (1, 0)), value = const)[:, :-1] + v1 = v + max_mask = v1 >= v0 + v_max = torch.where(max_mask, v1, v0) + direction[:, :, j] = max_mask + + index_mask = x_range <= j + v = torch.where(index_mask.view(1,-1), v_max + value[:, :, j], const) + + direction = torch.where(mask.bool(), direction, 1) + + path = torch.zeros(value.shape, dtype=torch.float32, device=device) + index = mask[:, :, 0].sum(1).long() - 1 + index_range = torch.arange(b, device=device) + + for j in reversed(range(t_y)): + path[index_range, index, j] = 1 + index = index + direction[index_range, index, j] - 1 + + path = path * mask.float() + path = path.to(dtype=dtype) + return path + +class ForwardSumLoss(Module): + def __init__( + self, + blank_logprob = -1 + ): + super().__init__() + self.blank_logprob = blank_logprob + + self.ctc_loss = torch.nn.CTCLoss( + blank = 0, # check this value + zero_infinity = True + ) + + def forward(self, attn_logprob, key_lens, query_lens): + device, blank_logprob = attn_logprob.device, self.blank_logprob + max_key_len = attn_logprob.size(-1) + + # Reorder input to [query_len, batch_size, key_len] + attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t') + + # Add blank label + attn_logprob = F.pad(attn_logprob, (1, 0, 0, 0, 0, 0), value = blank_logprob) + + # Convert to log probabilities + # Note: Mask out probs beyond key_len + mask_value = -torch.finfo(attn_logprob.dtype).max + attn_logprob.masked_fill_(torch.arange(max_key_len + 1, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value) + + attn_logprob = attn_logprob.log_softmax(dim = -1) + + # Target sequences + target_seqs = torch.arange(1, max_key_len + 1, device=device, dtype=torch.long) + target_seqs = repeat(target_seqs, 'n -> b n', b = key_lens.numel()) + + # Evaluate CTC loss + cost = self.ctc_loss(attn_logprob, target_seqs, query_lens, key_lens) + + return cost + +class BinLoss(Module): + def forward(self, attn_hard, attn_logprob, key_lens): + batch, device = attn_logprob.shape[0], attn_logprob.device + max_key_len = attn_logprob.size(-1) + + # Reorder input to [query_len, batch_size, key_len] + attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t') + attn_hard = rearrange(attn_hard, 'b t c -> c b t') + + mask_value = -torch.finfo(attn_logprob.dtype).max + + attn_logprob.masked_fill_(torch.arange(max_key_len, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value) + attn_logprob = attn_logprob.log_softmax(dim = -1) + + return (attn_hard * attn_logprob).sum() / batch + +class Aligner(Module): + def __init__( + self, + dim_in, + dim_hidden, + attn_channels=80, + temperature=0.0005 + ): + super().__init__() + self.dim_in = dim_in + self.dim_hidden = dim_hidden + self.attn_channels = attn_channels + self.temperature = temperature + self.aligner = AlignerNet( + dim_in = self.dim_in, + dim_hidden = self.dim_hidden, + attn_channels = self.attn_channels, + temperature = self.temperature + ) + + def forward( + self, + x, + x_mask, + y, + y_mask + ): + alignment_soft, alignment_logprob = self.aligner(y, rearrange(x, 'b d t -> b t d'), x_mask) + + x_mask = rearrange(x_mask, '... i -> ... i 1') + y_mask = rearrange(y_mask, '... j -> ... 1 j') + attn_mask = x_mask * y_mask + attn_mask = rearrange(attn_mask, 'b 1 i j -> b i j') + + alignment_soft = rearrange(alignment_soft, 'b 1 c t -> b t c') + alignment_mask = maximum_path(alignment_soft, attn_mask) + + alignment_hard = torch.sum(alignment_mask, -1).int() + return alignment_hard, alignment_soft, alignment_logprob, alignment_mask + +if __name__ == '__main__': + batch_size = 10 + seq_len_y = 200 # length of sequence y + seq_len_x = 35 + feature_dim = 80 # feature dimension + + x = torch.randn(batch_size, 512, seq_len_x) + x = x.transpose(1,2) #dim-1 is the channels for conv + y = torch.randn(batch_size, seq_len_y, feature_dim) + y = y.transpose(1,2) #dim-1 is the channels for conv + + # Create masks + x_mask = torch.ones(batch_size, 1, seq_len_x) + y_mask = torch.ones(batch_size, 1, seq_len_y) + + align = Aligner(dim_in = 80, dim_hidden=512, attn_channels=80) + alignment_hard, alignment_soft, alignment_logprob, alignment_mas = align(x, x_mask, y, y_mask) \ No newline at end of file diff --git a/pflow/models/components/attentions.py b/pflow/models/components/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab1a1e653489b067ee6cd47fb21c650797bbacd --- /dev/null +++ b/pflow/models/components/attentions.py @@ -0,0 +1,491 @@ +# from https://github.com/jaywalnut310/vits +import math +import torch +from torch import nn +from torch.nn import functional as F + +import commons +from modules import LayerNorm + + +class Encoder(nn.Module): # backward compatible vits2 encoder + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + # if kwargs has spk_emb_dim, then add a linear layer to project spk_emb_dim to hidden_channels + self.cond_layer_idx = self.n_layers + if "gin_channels" in kwargs: + self.gin_channels = kwargs["gin_channels"] + if self.gin_channels != 0: + self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) + # vits2 says 3rd block, so idx is 2 by default + self.cond_layer_idx = ( + kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 + ) + assert ( + self.cond_layer_idx < self.n_layers + ), "cond_layer_idx should be less than n_layers" + + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, g=None): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + if i == self.cond_layer_idx and g is not None: + g = self.spk_emb_linear(g.transpose(1, 2)) + g = g.transpose(1, 2) + x = x + g + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + device=x.device, dtype=x.dtype + ) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + # query = query.view( + # b, + # self.n_heads, + # self.k_channels, + # t_t + # ).transpose(2, 3) #[b,h,t_t,c], d=h*c + # key = key.view( + # b, + # self.n_heads, + # self.k_channels, + # t_s + # ).transpose(2, 3) #[b,h,t_s,c] + # value = value.view( + # b, + # self.n_heads, + # self.k_channels, + # t_s + # ).transpose(2, 3) #[b,h,t_s,c] + # scores = torch.matmul( + # query / math.sqrt(self.k_channels), key.transpose(-2, -1) + # ) #[b,h,t_t,t_s] + query = query.view(b, self.n_heads, self.k_channels, t_t) # [b,h,c,t_t] + key = key.view(b, self.n_heads, self.k_channels, t_s) # [b,h,c,t_s] + value = value.view(b, self.n_heads, self.k_channels, t_s) # [b,h,c,t_s] + scores = torch.einsum( + "bhdt,bhds -> bhts", query / math.sqrt(self.k_channels), key + ) # [b,h,t_t,t_s] + # if self.window_size is not None: + # assert t_s == t_t, "Relative attention is only available for self-attention." + # key_relative_embeddings = self._get_relative_embeddings( + # self.emb_rel_k, t_s + # ) + # rel_logits = self._matmul_with_relative_keys( + # query / math.sqrt(self.k_channels), key_relative_embeddings + # ) #[b,h,t_t,d],[h or 1,e,d] ->[b,h,t_t,e] + # scores_local = self._relative_position_to_absolute_position(rel_logits) + # scores = scores + scores_local + # if self.proximal_bias: + # assert t_s == t_t, "Proximal bias is only available for self-attention." + # scores = scores + \ + # self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + # if mask is not None: + # scores = scores.masked_fill(mask == 0, -1e4) + # if self.block_length is not None: + # assert t_s == t_t, "Local attention is only available for self-attention." + # block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) + # scores = scores.masked_fill(block_mask == 0, -1e4) + # p_attn = F.softmax(scores, dim=-1) # [b, h, t_t, t_s] + # p_attn = self.drop(p_attn) + # output = torch.matmul(p_attn, value) # [b,h,t_t,t_s],[b,h,t_s,c] -> [b,h,t_t,c] + # if self.window_size is not None: + # relative_weights = self._absolute_position_to_relative_position(p_attn) #[b, h, t_t, 2*t_t-1] + # value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) #[h or 1, 2*t_t-1, c] + # output = output + \ + # self._matmul_with_relative_values( + # relative_weights, value_relative_embeddings) # [b, h, t_t, 2*t_t-1],[h or 1, 2*t_t-1, c] -> [b, h, t_t, c] + # output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, c] -> [b,h,c,t_t] -> [b, d, t_t] + if self.window_size is not None: + assert ( + t_s == t_t + ), "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = torch.einsum( + "bhdt,hed->bhte", + query / math.sqrt(self.k_channels), + key_relative_embeddings, + ) # [b,h,c,t_t],[h or 1,e,c] ->[b,h,t_t,e] + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert ( + t_s == t_t + ), "Local attention is only available for self-attention." + block_mask = ( + torch.ones_like(scores) + .triu(-self.block_length) + .tril(self.block_length) + ) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.einsum( + "bhcs,bhts->bhct", value, p_attn + ) # [b,h,c,t_s],[b,h,t_t,t_s] -> [b,h,c,t_t] + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position( + p_attn + ) # [b, h, t_t, 2*t_t-1] + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s + ) # [h or 1, 2*t_t-1, c] + output = output + torch.einsum( + "bhte,hec->bhct", relative_weights, value_relative_embeddings + ) # [b, h, t_t, 2*t_t-1],[h or 1, 2*t_t-1, c] -> [b, h, c, t_t] + output = output.view(b, d, t_t) # [b, h, c, t_t] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + # ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + ret = torch.einsum("bhld,hmd -> bhlm", x, y) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[ + :, slice_start_position:slice_end_position + ] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad( + x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) + ) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad( + x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) + ) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/pflow/models/components/commons.py b/pflow/models/components/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc627fdeed7dd36921a2aa3fd2159d2be34a9d1 --- /dev/null +++ b/pflow/models/components/commons.py @@ -0,0 +1,179 @@ +# from https://github.com/jaywalnut310/vits +import math +import torch +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ids_str = torch.max(torch.zeros(ids_str.size()).to(ids_str.device), ids_str).to( + dtype=torch.long + ) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def rand_slice_segments_for_cat(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = torch.rand([b // 2]).to(device=x.device) + ids_str = (torch.cat([ids_str, ids_str], dim=0) * ids_str_max).to(dtype=torch.long) + ids_str = torch.max(torch.zeros(ids_str.size()).to(ids_str.device), ids_str).to( + dtype=torch.long + ) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/pflow/models/components/decoder.py b/pflow/models/components/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8baf4c1b89dc740d0b6941dc0f8376f7f4099f --- /dev/null +++ b/pflow/models/components/decoder.py @@ -0,0 +1,444 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from conformer import ConformerBlock +from diffusers.models.activations import get_activation +from einops import pack, rearrange, repeat + +from pflow.models.components.transformer import BasicTransformerBlock + + +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Block1D(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv1d(dim, dim_out, 3, padding=1), + torch.nn.GroupNorm(groups, dim_out), + nn.Mish(), + ) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock1D(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super().__init__() + self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) + + self.block1 = Block1D(dim, dim_out, groups=groups) + self.block2 = Block1D(dim_out, dim_out, groups=groups) + + self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class Downsample1D(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs): + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + return outputs + + +class ConformerWrapper(ConformerBlock): + def __init__( # pylint: disable=useless-super-delegation + self, + *, + dim, + dim_head=64, + heads=8, + ff_mult=4, + conv_expansion_factor=2, + conv_kernel_size=31, + attn_dropout=0, + ff_dropout=0, + conv_dropout=0, + conv_causal=False, + ): + super().__init__( + dim=dim, + dim_head=dim_head, + heads=heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + conv_dropout=conv_dropout, + conv_causal=conv_causal, + ) + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + ): + return super().forward(x=hidden_states, mask=attention_mask.bool()) + +class Decoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + down_block_type="transformer", + mid_block_type="transformer", + up_block_type="transformer", + ): + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + down_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for i in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + self.get_block( + mid_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + + resnet = ResnetBlock1D( + dim=2 * input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + up_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + + self.initialize_weights() + # nn.init.normal_(self.final_proj.weight) + + + + @staticmethod + def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn): + if block_type == "conformer": + block = ConformerWrapper( + dim=dim, + dim_head=attention_head_dim, + heads=num_heads, + ff_mult=1, + conv_expansion_factor=2, + ff_dropout=dropout, + attn_dropout=dropout, + conv_dropout=dropout, + conv_kernel_size=31, + ) + elif block_type == "transformer": + block = BasicTransformerBlock( + dim=dim, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + else: + raise ValueError(f"Unknown block type {block_type}") + + return block + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None, training=True): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c") + mask_down = rearrange(mask_down, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_down, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_down = rearrange(mask_down, "b t -> b 1 t") + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c") + mask_mid = rearrange(mask_mid, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_mid, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_mid = rearrange(mask_mid, "b t -> b 1 t") + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t) + x = rearrange(x, "b c t -> b t c") + mask_up = rearrange(mask_up, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_up, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_up = rearrange(mask_up, "b t -> b 1 t") + x = upsample(x * mask_up) + + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + output = output * mask + + return output * mask diff --git a/pflow/models/components/flow_matching.py b/pflow/models/components/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..8720cc0684a22f5ce5622a62263446b28dae102c --- /dev/null +++ b/pflow/models/components/flow_matching.py @@ -0,0 +1,148 @@ +from abc import ABC + +import torch +import torch.nn.functional as F + +from pflow.models.components.decoder import Decoder +from pflow.models.components.wn_pflow_decoder import DiffSingerNet +from pflow.models.components.vits_wn_decoder import VitsWNDecoder + +from pflow.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +class BASECFM(torch.nn.Module, ABC): + def __init__( + self, + n_feats, + cfm_params, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.n_feats = n_feats + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.solver = cfm_params.solver + if hasattr(cfm_params, "sigma_min"): + self.sigma_min = cfm_params.sigma_min + else: + self.sigma_min = 1e-4 + + self.estimator = None + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, cond=None, training=False, guidance_scale=0.0): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = torch.randn_like(mu) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, cond=cond, training=training, guidance_scale=guidance_scale) + + def solve_euler(self, x, t_span, mu, mask, cond, training=False, guidance_scale=0.0): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + steps = 1 + while steps <= len(t_span) - 1: + dphi_dt = self.estimator(x, mask, mu, t, cond, training=training) + if guidance_scale > 0.0: + mu_avg = mu.mean(2, keepdims=True).expand_as(mu) + dphi_avg = self.estimator(x, mask, mu_avg, t, cond, training=training) + dphi_dt = dphi_dt + guidance_scale * (dphi_dt - dphi_avg) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if steps < len(t_span) - 1: + dt = t_span[steps + 1] - t + steps += 1 + + return sol[-1] + + def compute_loss(self, x1, mask, mu, cond=None, training=True, loss_mask=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + # y = u * t + z + estimator_out = self.estimator(y, mask, mu, t.squeeze(), training=training) + + if loss_mask is not None: + mask = loss_mask + loss = F.mse_loss(estimator_out*mask, u*mask, reduction="sum") / ( + torch.sum(mask) * u.shape[1] + ) + return loss, y + + +class CFM(BASECFM): + def __init__(self, in_channels, out_channel, cfm_params, decoder_params): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + ) + + # Just change the architecture of the estimator here + self.estimator = Decoder(in_channels=in_channels*2, out_channels=out_channel, **decoder_params) + # self.estimator = DiffSingerNet(in_dims=in_channels, encoder_hidden=out_channel) + # self.estimator = VitsWNDecoder( + # in_channels=in_channels, + # out_channels=out_channel, + # hidden_channels=out_channel, + # kernel_size=3, + # dilation_rate=1, + # n_layers=18, + # gin_channels=out_channel*2 + # ) + diff --git a/pflow/models/components/speech_prompt_encoder.py b/pflow/models/components/speech_prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..12f8a4eaf669763b90eb56e1bcd2c688afbaa7cc --- /dev/null +++ b/pflow/models/components/speech_prompt_encoder.py @@ -0,0 +1,636 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import math + +import torch +import torch.nn as nn +from einops import rearrange + +import pflow.utils as utils +from pflow.utils.model import sequence_mask +from pflow.models.components import commons +from pflow.models.components.vits_posterior import PosteriorEncoder +from pflow.models.components.transformer import BasicTransformerBlock + +log = utils.get_pylogger(__name__) + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.conv_layers = torch.nn.ModuleList() + self.norm_layers = torch.nn.ModuleList() + self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = torch.nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + # x = torch.relu(x) + return x * x_mask + +class DurationPredictorNS2(nn.Module): + def __init__( + self, in_channels, filter_channels, kernel_size, p_dropout=0.5 + ): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = LayerNorm(filter_channels) + + self.module_list = nn.ModuleList() + self.module_list.append(self.conv_1) + self.module_list.append(nn.ReLU()) + self.module_list.append(self.norm_1) + self.module_list.append(self.drop) + + for i in range(12): + self.module_list.append(nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)) + self.module_list.append(nn.ReLU()) + self.module_list.append(LayerNorm(filter_channels)) + self.module_list.append(nn.Dropout(p_dropout)) + + + # attention layer every 3 layers + self.attn_list = nn.ModuleList() + for i in range(4): + self.attn_list.append( + Encoder( + filter_channels, + filter_channels, + 8, + 10, + 3, + p_dropout=p_dropout, + ) + ) + + for i in range(30): + if i+1 % 3 == 0: + self.module_list.append(self.attn_list[i//3]) + + self.proj = nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = torch.detach(x) + for layer in self.module_list: + x = layer(x * x_mask) + x = self.proj(x * x_mask) + # x = torch.relu(x) + return x * x_mask + +class RotaryPositionalEmbeddings(nn.Module): + """ + ## RoPE module + + Rotary encoding transforms pairs of features by rotating in the 2D plane. + That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. + Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it + by an angle depending on the position of the token. + """ + + def __init__(self, d: int, base: int = 10_000): + r""" + * `d` is the number of features $d$ + * `base` is the constant used for calculating $\Theta$ + """ + super().__init__() + + self.base = base + self.d = int(d) + self.cos_cached = None + self.sin_cached = None + + def _build_cache(self, x: torch.Tensor): + r""" + Cache $\cos$ and $\sin$ values + """ + # Return if cache is already built + if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: + return + + # Get sequence length + seq_len = x.shape[0] + + # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.einsum("n,d->nd", seq_idx, theta) + + # Concatenate so that for row $m$ we have + # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ + idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) + + # Cache them + self.cos_cached = idx_theta2.cos()[:, None, None, :] + self.sin_cached = idx_theta2.sin()[:, None, None, :] + + def _neg_half(self, x: torch.Tensor): + # $\frac{d}{2}$ + d_2 = self.d // 2 + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) + + def forward(self, x: torch.Tensor): + """ + * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` + """ + # Cache $\cos$ and $\sin$ values + x = rearrange(x, "b h t d -> t b h d") + + self._build_cache(x) + + # Split the features, we can choose to apply rotary embeddings only to a partial set of features. + x_rope, x_pass = x[..., : self.d], x[..., self.d :] + + # Calculate + # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + neg_half_x = self._neg_half(x_rope) + + x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) + + return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + heads_share=True, + p_dropout=0.0, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + + # from https://nn.labml.ai/transformers/rope/index.html + self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads) + key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads) + value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads) + + query = self.query_rotary_pe(query) + key = self.key_rotary_pe(key) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + @staticmethod + def _attention_bias_proximal(length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + +class Decoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + device=x.device, dtype=x.dtype + ) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + +class TextEncoder(nn.Module): + def __init__( + self, + encoder_type, + encoder_params, + duration_predictor_params, + n_vocab, + speech_in_channels, + ): + super().__init__() + self.encoder_type = encoder_type + self.n_vocab = n_vocab + self.n_feats = encoder_params.n_feats + self.n_channels = encoder_params.n_channels + + self.emb = torch.nn.Embedding(n_vocab, self.n_channels) + torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5) + + self.speech_in_channels = speech_in_channels + self.speech_out_channels = self.n_channels + self.speech_prompt_proj = torch.nn.Conv1d(self.speech_in_channels, self.speech_out_channels, 1) + # self.speech_prompt_proj = PosteriorEncoder( + # self.speech_in_channels, + # self.speech_out_channels, + # self.speech_out_channels, + # 1, + # 1, + # 1, + # gin_channels=0, + # ) + + self.prenet = ConvReluNorm( + self.n_channels, + self.n_channels, + self.n_channels, + kernel_size=5, + n_layers=3, + p_dropout=0, + ) + + self.speech_prompt_encoder = Encoder( + encoder_params.n_channels, + encoder_params.filter_channels, + encoder_params.n_heads, + encoder_params.n_layers, + encoder_params.kernel_size, + encoder_params.p_dropout, + ) + + self.text_base_encoder = Encoder( + encoder_params.n_channels, + encoder_params.filter_channels, + encoder_params.n_heads, + encoder_params.n_layers, + encoder_params.kernel_size, + encoder_params.p_dropout, + ) + + self.decoder = Decoder( + encoder_params.n_channels, + encoder_params.filter_channels, + encoder_params.n_heads, + encoder_params.n_layers, + encoder_params.kernel_size, + encoder_params.p_dropout, + ) + + self.transformerblock = BasicTransformerBlock( + encoder_params.n_channels, + encoder_params.n_heads, + encoder_params.n_channels // encoder_params.n_heads, + encoder_params.p_dropout, + encoder_params.n_channels, + activation_fn="gelu", + attention_bias=False, + only_cross_attention=False, + double_self_attention=False, + upcast_attention=False, + norm_elementwise_affine=True, + norm_type="layer_norm", + final_dropout=False, + ) + self.proj_m = torch.nn.Conv1d(self.n_channels, self.n_feats, 1) + + self.proj_w = DurationPredictor( + self.n_channels, + duration_predictor_params.filter_channels_dp, + duration_predictor_params.kernel_size, + duration_predictor_params.p_dropout, + ) + # self.proj_w = DurationPredictorNS2( + # self.n_channels, + # duration_predictor_params.filter_channels_dp, + # duration_predictor_params.kernel_size, + # duration_predictor_params.p_dropout, + # ) + + # self.speech_prompt_pos_emb = RotaryPositionalEmbeddings(self.n_channels * 0.5) + # self.text_pos_emb = RotaryPositionalEmbeddings(self.n_channels * 0.5) + + def forward( + self, + x_input, + x_lengths, + speech_prompt, + ): + """Run forward pass to the transformer based encoder and duration predictor + + Args: + x (torch.Tensor): text input + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): text input lengths + shape: (batch_size,) + speech_prompt (torch.Tensor): speech prompt input + + Returns: + mu (torch.Tensor): average output of the encoder + shape: (batch_size, n_feats, max_text_length) + logw (torch.Tensor): log duration predicted by the duration predictor + shape: (batch_size, 1, max_text_length) + x_mask (torch.Tensor): mask for the text input + shape: (batch_size, 1, max_text_length) + """ + + x_emb = self.emb(x_input) * math.sqrt(self.n_channels) + x_emb = torch.transpose(x_emb, 1, -1) + x_emb_mask = torch.unsqueeze(sequence_mask(x_lengths, x_emb.size(2)), 1).to(x_emb.dtype) + x_emb = self.text_base_encoder(x_emb, x_emb_mask) + + x_speech_lengths = x_lengths + speech_prompt.size(2) + speech_lengths = x_speech_lengths - x_lengths + speech_mask = torch.unsqueeze(sequence_mask(speech_lengths, speech_prompt.size(2)), 1).to(x_emb.dtype) + + speech_prompt_proj = self.speech_prompt_proj(speech_prompt) + # speech_prompt_proj, speech_mask = self.speech_prompt_proj(speech_prompt, speech_lengths) + # speech_prompt_proj = self.speech_prompt_encoder(speech_prompt_proj, speech_mask) + + x_speech_cat = torch.cat([speech_prompt_proj, x_emb], dim=2) + x_speech_mask = torch.unsqueeze(sequence_mask(x_speech_lengths, x_speech_cat.size(2)), 1).to(x_speech_cat.dtype) + + x_prenet = self.prenet(x_speech_cat, x_speech_mask) + # split speech prompt and text input + speech_prompt_proj = x_prenet[:, :, :speech_prompt_proj.size(2)] + x_split = x_prenet[:, :, speech_prompt_proj.size(2):] + + # add positional encoding to speech prompt and x_split + # x_split = self.text_pos_emb(x_split.unsqueeze(1).transpose(-2,-1)).squeeze(1).transpose(-2,-1) + x_split_mask = torch.unsqueeze(sequence_mask(x_lengths, x_split.size(2)), 1).to(x_split.dtype) + + # speech_prompt = self.speech_prompt_pos_emb(speech_prompt_proj.unsqueeze(1).transpose(-2,-1)).squeeze(1).transpose(-2,-1) + # x_split = self.decoder(x_split, x_split_mask, speech_prompt, speech_mask) + + x_split = self.transformerblock(x_split.transpose(1,2), x_split_mask, speech_prompt_proj.transpose(1,2), speech_mask) + x_split = x_split.transpose(1,2) + + # x_split_mask = torch.unsqueeze(sequence_mask(x_lengths, x_split.size(2)), 1).to(x.dtype) + + # x_split = x_split + x_emb + + mu = self.proj_m(x_split) * x_split_mask + + x_dp = torch.detach(x_split) + logw = self.proj_w(x_dp, x_split_mask) + + return mu, logw, x_split_mask diff --git a/pflow/models/components/speech_prompt_encoder_v0.py b/pflow/models/components/speech_prompt_encoder_v0.py new file mode 100644 index 0000000000000000000000000000000000000000..660a32613613ef66622f169c605fc073353eebd8 --- /dev/null +++ b/pflow/models/components/speech_prompt_encoder_v0.py @@ -0,0 +1,618 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import math + +import torch +import torch.nn as nn +from einops import rearrange + +import pflow.utils as utils +from pflow.utils.model import sequence_mask +from pflow.models.components import commons +from pflow.models.components.vits_posterior import PosteriorEncoder +from pflow.models.components.transformer import BasicTransformerBlock + +log = utils.get_pylogger(__name__) + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.conv_layers = torch.nn.ModuleList() + self.norm_layers = torch.nn.ModuleList() + self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = torch.nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + # x = torch.relu(x) + return x * x_mask + +class DurationPredictorNS2(nn.Module): + def __init__( + self, in_channels, filter_channels, kernel_size, p_dropout=0.5 + ): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = LayerNorm(filter_channels) + + self.module_list = nn.ModuleList() + self.module_list.append(self.conv_1) + self.module_list.append(nn.ReLU()) + self.module_list.append(self.norm_1) + self.module_list.append(self.drop) + + for i in range(12): + self.module_list.append(nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)) + self.module_list.append(nn.ReLU()) + self.module_list.append(LayerNorm(filter_channels)) + self.module_list.append(nn.Dropout(p_dropout)) + + + # attention layer every 3 layers + self.attn_list = nn.ModuleList() + for i in range(4): + self.attn_list.append( + Encoder( + filter_channels, + filter_channels, + 8, + 10, + 3, + p_dropout=p_dropout, + ) + ) + + for i in range(30): + if i+1 % 3 == 0: + self.module_list.append(self.attn_list[i//3]) + + self.proj = nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = torch.detach(x) + for layer in self.module_list: + x = layer(x * x_mask) + x = self.proj(x * x_mask) + # x = torch.relu(x) + return x * x_mask + +class RotaryPositionalEmbeddings(nn.Module): + """ + ## RoPE module + + Rotary encoding transforms pairs of features by rotating in the 2D plane. + That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. + Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it + by an angle depending on the position of the token. + """ + + def __init__(self, d: int, base: int = 10_000): + r""" + * `d` is the number of features $d$ + * `base` is the constant used for calculating $\Theta$ + """ + super().__init__() + + self.base = base + self.d = int(d) + self.cos_cached = None + self.sin_cached = None + + def _build_cache(self, x: torch.Tensor): + r""" + Cache $\cos$ and $\sin$ values + """ + # Return if cache is already built + if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: + return + + # Get sequence length + seq_len = x.shape[0] + + # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.einsum("n,d->nd", seq_idx, theta) + + # Concatenate so that for row $m$ we have + # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ + idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) + + # Cache them + self.cos_cached = idx_theta2.cos()[:, None, None, :] + self.sin_cached = idx_theta2.sin()[:, None, None, :] + + def _neg_half(self, x: torch.Tensor): + # $\frac{d}{2}$ + d_2 = self.d // 2 + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) + + def forward(self, x: torch.Tensor): + """ + * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` + """ + # Cache $\cos$ and $\sin$ values + x = rearrange(x, "b h t d -> t b h d") + + self._build_cache(x) + + # Split the features, we can choose to apply rotary embeddings only to a partial set of features. + x_rope, x_pass = x[..., : self.d], x[..., self.d :] + + # Calculate + # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + neg_half_x = self._neg_half(x_rope) + + x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) + + return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + heads_share=True, + p_dropout=0.0, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + + # from https://nn.labml.ai/transformers/rope/index.html + self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads) + key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads) + value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads) + + query = self.query_rotary_pe(query) + key = self.key_rotary_pe(key) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + @staticmethod + def _attention_bias_proximal(length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + +class Decoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + device=x.device, dtype=x.dtype + ) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + +class TextEncoder(nn.Module): + def __init__( + self, + encoder_type, + encoder_params, + duration_predictor_params, + n_vocab, + speech_in_channels, + ): + super().__init__() + self.encoder_type = encoder_type + self.n_vocab = n_vocab + self.n_feats = encoder_params.n_feats + self.n_channels = encoder_params.n_channels + + self.emb = torch.nn.Embedding(n_vocab, self.n_channels) + torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5) + + self.speech_in_channels = speech_in_channels + self.speech_out_channels = self.n_channels + # self.speech_prompt_proj = torch.nn.Conv1d(self.speech_in_channels, self.speech_out_channels, 1) + self.speech_prompt_proj = PosteriorEncoder( + self.speech_in_channels, + self.speech_out_channels, + self.speech_out_channels, + 1, + 1, + 1, + gin_channels=0, + ) + + self.prenet = ConvReluNorm( + self.n_channels, + self.n_channels, + self.n_channels, + kernel_size=5, + n_layers=3, + p_dropout=0, + ) + + # self.speech_prompt_encoder = Encoder( + # encoder_params.n_channels, + # encoder_params.filter_channels, + # encoder_params.n_heads, + # encoder_params.n_layers, + # encoder_params.kernel_size, + # encoder_params.p_dropout, + # ) + + self.text_base_encoder = Encoder( + encoder_params.n_channels, + encoder_params.filter_channels, + encoder_params.n_heads, + encoder_params.n_layers, + encoder_params.kernel_size, + encoder_params.p_dropout, + ) + + # self.decoder = Decoder( + # encoder_params.n_channels, + # encoder_params.filter_channels, + # encoder_params.n_heads, + # encoder_params.n_layers, + # encoder_params.kernel_size, + # encoder_params.p_dropout, + # ) + + self.transformerblock = BasicTransformerBlock( + encoder_params.n_channels, + encoder_params.n_heads, + encoder_params.n_channels // encoder_params.n_heads, + encoder_params.p_dropout, + encoder_params.n_channels, + activation_fn="gelu", + attention_bias=False, + only_cross_attention=False, + double_self_attention=False, + upcast_attention=False, + norm_elementwise_affine=True, + norm_type="layer_norm", + final_dropout=False, + ) + self.proj_m = torch.nn.Conv1d(self.n_channels, self.n_feats, 1) + + self.proj_w = DurationPredictor( + self.n_channels, + duration_predictor_params.filter_channels_dp, + duration_predictor_params.kernel_size, + duration_predictor_params.p_dropout, + ) + # self.proj_w = DurationPredictorNS2( + # self.n_channels, + # duration_predictor_params.filter_channels_dp, + # duration_predictor_params.kernel_size, + # duration_predictor_params.p_dropout, + # ) + + def forward( + self, + x_input, + x_lengths, + speech_prompt, + ): + """Run forward pass to the transformer based encoder and duration predictor + + Args: + x (torch.Tensor): text input + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): text input lengths + shape: (batch_size,) + speech_prompt (torch.Tensor): speech prompt input + + Returns: + mu (torch.Tensor): average output of the encoder + shape: (batch_size, n_feats, max_text_length) + logw (torch.Tensor): log duration predicted by the duration predictor + shape: (batch_size, 1, max_text_length) + x_mask (torch.Tensor): mask for the text input + shape: (batch_size, 1, max_text_length) + """ + x_emb = self.emb(x_input) * math.sqrt(self.n_channels) + x_emb = torch.transpose(x_emb, 1, -1) + x_speech_lengths = x_lengths + speech_prompt.size(2) + speech_lengths = x_speech_lengths - x_lengths + # speech_prompt_proj = self.speech_prompt_proj(speech_prompt) + speech_prompt_proj, speech_mask = self.speech_prompt_proj(speech_prompt, speech_lengths) + x_speech_cat = torch.cat([speech_prompt_proj, x_emb], dim=2) + x_speech_mask = torch.unsqueeze(sequence_mask(x_speech_lengths, x_speech_cat.size(2)), 1).to(x_speech_cat.dtype) + + x_prenet = self.prenet(x_speech_cat, x_speech_mask) + # split speech prompt and text input + speech_split = x_prenet[:, :, :speech_prompt_proj.size(2)] + x_split = x_prenet[:, :, speech_prompt_proj.size(2):] + x_split_mask = torch.unsqueeze(sequence_mask(x_lengths, x_split.size(2)), 1).to(x_split.dtype) + speech_lengths = x_speech_lengths - x_lengths + speech_mask = torch.unsqueeze(sequence_mask(speech_lengths, speech_split.size(2)), 1).to(x_split.dtype) + + x_split = self.transformerblock(x_split.transpose(1,2), x_split_mask, speech_split.transpose(1,2), speech_mask) + x_split = x_split.transpose(1,2) + + # x_split_mask = torch.unsqueeze(sequence_mask(x_lengths, x_split.size(2)), 1).to(x.dtype) + + mu = self.proj_m(x_split) * x_split_mask + x_dp = torch.detach(x_split) + logw = self.proj_w(x_dp, x_split_mask) + + return mu, logw, x_split_mask diff --git a/pflow/models/components/test.py b/pflow/models/components/test.py new file mode 100644 index 0000000000000000000000000000000000000000..c223702eb99fde465cc06c9b1ca55f904081c35b --- /dev/null +++ b/pflow/models/components/test.py @@ -0,0 +1,6 @@ +from pflow.hifigan.meldataset import mel_spectrogram +import torch + +audio = torch.randn(2,1, 1000) +mels = mel_spectrogram(audio, 1024, 80, 22050, 256, 1024, 0, 8000, center=False) +print(mels.shape) \ No newline at end of file diff --git a/pflow/models/components/text_encoder.py b/pflow/models/components/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8176ea474666a5ee30485f0ac53a10b5970860fe --- /dev/null +++ b/pflow/models/components/text_encoder.py @@ -0,0 +1,425 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import math + +import torch +import torch.nn as nn +from einops import rearrange + +import pflow.utils as utils +from pflow.utils.model import sequence_mask + +log = utils.get_pylogger(__name__) + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.conv_layers = torch.nn.ModuleList() + self.norm_layers = torch.nn.ModuleList() + self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = torch.nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class RotaryPositionalEmbeddings(nn.Module): + """ + ## RoPE module + + Rotary encoding transforms pairs of features by rotating in the 2D plane. + That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. + Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it + by an angle depending on the position of the token. + """ + + def __init__(self, d: int, base: int = 10_000): + r""" + * `d` is the number of features $d$ + * `base` is the constant used for calculating $\Theta$ + """ + super().__init__() + + self.base = base + self.d = int(d) + self.cos_cached = None + self.sin_cached = None + + def _build_cache(self, x: torch.Tensor): + r""" + Cache $\cos$ and $\sin$ values + """ + # Return if cache is already built + if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: + return + + # Get sequence length + seq_len = x.shape[0] + + # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.einsum("n,d->nd", seq_idx, theta) + + # Concatenate so that for row $m$ we have + # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ + idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) + + # Cache them + self.cos_cached = idx_theta2.cos()[:, None, None, :] + self.sin_cached = idx_theta2.sin()[:, None, None, :] + + def _neg_half(self, x: torch.Tensor): + # $\frac{d}{2}$ + d_2 = self.d // 2 + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) + + def forward(self, x: torch.Tensor): + """ + * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` + """ + # Cache $\cos$ and $\sin$ values + x = rearrange(x, "b h t d -> t b h d") + + self._build_cache(x) + + # Split the features, we can choose to apply rotary embeddings only to a partial set of features. + x_rope, x_pass = x[..., : self.d], x[..., self.d :] + + # Calculate + # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + neg_half_x = self._neg_half(x_rope) + + x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) + + return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + heads_share=True, + p_dropout=0.0, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + + # from https://nn.labml.ai/transformers/rope/index.html + self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads) + key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads) + value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads) + + query = self.query_rotary_pe(query) + key = self.key_rotary_pe(key) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + @staticmethod + def _attention_bias_proximal(length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class TextEncoder(nn.Module): + def __init__( + self, + encoder_type, + encoder_params, + duration_predictor_params, + n_vocab, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.encoder_type = encoder_type + self.n_vocab = n_vocab + self.n_feats = encoder_params.n_feats + self.n_channels = encoder_params.n_channels + self.spk_emb_dim = spk_emb_dim + self.n_spks = n_spks + + self.emb = torch.nn.Embedding(n_vocab, self.n_channels) + torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5) + + if encoder_params.prenet: + self.prenet = ConvReluNorm( + self.n_channels, + self.n_channels, + self.n_channels, + kernel_size=5, + n_layers=3, + p_dropout=0.5, + ) + else: + self.prenet = lambda x, x_mask: x + + self.encoder = Encoder( + encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0), + encoder_params.filter_channels, + encoder_params.n_heads, + encoder_params.n_layers, + encoder_params.kernel_size, + encoder_params.p_dropout, + ) + + self.encoder_dp = Encoder( + encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0), + encoder_params.filter_channels, + encoder_params.n_heads, + encoder_params.n_layers, + encoder_params.kernel_size, + encoder_params.p_dropout, + ) + + self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) + # self.proj_v = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) + + self.proj_w = DurationPredictor( + self.n_channels + (spk_emb_dim if n_spks > 1 else 0), + duration_predictor_params.filter_channels_dp, + duration_predictor_params.kernel_size, + duration_predictor_params.p_dropout, + ) + + def forward(self, x, x_lengths, spks=None): + """Run forward pass to the transformer based encoder and duration predictor + + Args: + x (torch.Tensor): text input + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): text input lengths + shape: (batch_size,) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size,) + + Returns: + mu (torch.Tensor): average output of the encoder + shape: (batch_size, n_feats, max_text_length) + logw (torch.Tensor): log duration predicted by the duration predictor + shape: (batch_size, 1, max_text_length) + x_mask (torch.Tensor): mask for the text input + shape: (batch_size, 1, max_text_length) + """ + x = self.emb(x) * math.sqrt(self.n_channels) + x = torch.transpose(x, 1, -1) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x, x_mask) + if self.n_spks > 1: + x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) + x_dp = torch.detach(x) + x_dp = self.encoder_dp(x_dp, x_mask) + + x = self.encoder(x, x_mask) + mu = self.proj_m(x) * x_mask + # logs = self.proj_v(x) * x_mask + + # x_dp = torch.detach(x) + logw = self.proj_w(x_dp, x_mask) + + return mu, logw, x_mask diff --git a/pflow/models/components/transformer.py b/pflow/models/components/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd1afa3aff5383912209e508676c6885e13ef4ee --- /dev/null +++ b/pflow/models/components/transformer.py @@ -0,0 +1,316 @@ +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from diffusers.models.attention import ( + GEGLU, + GELU, + AdaLayerNorm, + AdaLayerNormZero, + ApproximateGELU, +) +from diffusers.models.attention_processor import Attention +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.utils.torch_utils import maybe_allow_in_graph + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = out_features if isinstance(out_features, list) else [out_features] + self.proj = LoRACompatibleLinear(in_features, out_features) + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) + self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + x = self.proj(x) + if self.alpha_logscale: + alpha = torch.exp(self.alpha) + beta = torch.exp(self.beta) + else: + alpha = self.alpha + beta = self.beta + + x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) + + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + elif activation_fn == "snakebeta": + act_fn = SnakeBeta(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + # scale_qk=False, # uncomment this to not to use flash attention + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states diff --git a/pflow/models/components/vits_modules.py b/pflow/models/components/vits_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..08147eb125d09dfe3fbfe414fe1ffebfb66f4ada --- /dev/null +++ b/pflow/models/components/vits_modules.py @@ -0,0 +1,194 @@ +# from https://github.com/jaywalnut310/vits +import torch +from torch import nn +from torch.nn import functional as F + +from pflow.models.components import commons + +LRELU_SLOPE = 0.1 + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append( + nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers-1): + self.conv_layers.append(nn.Conv1d( + hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """Dialted and Depth-Separable Convolution""" + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size ** i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding + ) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): + super(WN, self).__init__() + assert(kernel_size % 2 == 1) + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size, + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + + for i in range(n_layers): + dilation = dilation_rate ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + hidden_channels, + 2*hidden_channels, + kernel_size, + dilation=dilation, + padding=padding + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d( + hidden_channels, res_skip_channels, 1 + ) + res_skip_layer = torch.nn.utils.weight_norm( + res_skip_layer, name='weight' + ) + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + if g is not None: + g = g.unsqueeze(-1) + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset:cond_offset+2*self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply( + x_in, + g_l, + n_channels_tensor + ) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, :self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels:, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + diff --git a/pflow/models/components/vits_posterior.py b/pflow/models/components/vits_posterior.py new file mode 100644 index 0000000000000000000000000000000000000000..9241d179f85f4bad1d5c4ab220fc221b6532f251 --- /dev/null +++ b/pflow/models/components/vits_posterior.py @@ -0,0 +1,43 @@ +import torch.nn as nn +import torch + +import pflow.models.components.vits_modules as modules +import pflow.models.components.commons as commons + +class PosteriorEncoder(nn.Module): + + def __init__(self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), + 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + # m, logs = torch.split(stats, self.out_channels, dim=1) + # z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + # z = m * x_mask + return stats, x_mask diff --git a/pflow/models/components/vits_wn_decoder.py b/pflow/models/components/vits_wn_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2033493d09ac69591fe614bb0f0a861f58ade08c --- /dev/null +++ b/pflow/models/components/vits_wn_decoder.py @@ -0,0 +1,79 @@ +import math + +import torch.nn as nn +import torch +import torch.nn.functional as F +import pflow.models.components.vits_modules as modules +import pflow.models.components.commons as commons + +class Mish(nn.Module): + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super(SinusoidalPosEmb, self).__init__() + self.dim = dim + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +class VitsWNDecoder(nn.Module): + + def __init__(self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + pe_scale=1000 + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.pe_scale = pe_scale + self.time_pos_emb = SinusoidalPosEmb(hidden_channels * 2) + dim = hidden_channels * 2 + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), + Mish(), + nn.Linear(dim * 4, dim) + ) + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels * 2, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels * 2, out_channels, 1) + + def forward(self, x, x_mask, mu, t, *args, **kwargs): + # x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), + # 1).to(x.dtype) + t = self.time_pos_emb(t, scale=self.pe_scale) + t = self.mlp(t) + + x = self.pre(x) * x_mask + mu = self.pre(mu) + x = torch.cat((x, mu), dim=1) + x = self.enc(x, x_mask, g=t) + stats = self.proj(x) * x_mask + + return stats diff --git a/pflow/models/components/wn_pflow_decoder.py b/pflow/models/components/wn_pflow_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6f503f1dda8f809d09e70233eb8356cd644fc5b4 --- /dev/null +++ b/pflow/models/components/wn_pflow_decoder.py @@ -0,0 +1,117 @@ +''' +https://github.com/cantabile-kwok/VoiceFlow-TTS/blob/main/model/diffsinger.py#L51 +This is the original implementation of the DiffSinger model. +It is a slightly modified WV which can be used for initial tests. +Will update this into original p-flow implementation later. +''' +import math + +import torch.nn as nn +import torch +from torch.nn import Conv1d, Linear +import math +import torch.nn.functional as F + + +class Mish(nn.Module): + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super(SinusoidalPosEmb, self).__init__() + self.dim = dim + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class ResidualBlock(nn.Module): + def __init__(self, encoder_hidden, residual_channels, dilation): + super().__init__() + self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) + self.diffusion_projection = Linear(residual_channels, residual_channels) + self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1) + self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) + + def forward(self, x, conditioner, diffusion_step): + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + conditioner = self.conditioner_projection(conditioner) + y = x + diffusion_step + + y = self.dilated_conv(y) + conditioner + + gate, filter = torch.chunk(y, 2, dim=1) + y = torch.sigmoid(gate) * torch.tanh(filter) + + y = self.output_projection(y) + residual, skip = torch.chunk(y, 2, dim=1) + return (x + residual) / math.sqrt(2.0), skip + +class DiffSingerNet(nn.Module): + def __init__( + self, + in_dims=80, + residual_channels=256, + encoder_hidden=80, + dilation_cycle_length=1, + residual_layers=20, + pe_scale=1000 + ): + super().__init__() + + self.pe_scale = pe_scale + + self.input_projection = Conv1d(in_dims, residual_channels, 1) + self.time_pos_emb = SinusoidalPosEmb(residual_channels) + dim = residual_channels + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), + Mish(), + nn.Linear(dim * 4, dim) + ) + self.residual_layers = nn.ModuleList([ + ResidualBlock(encoder_hidden, residual_channels, 2 ** (i % dilation_cycle_length)) + for i in range(residual_layers) + ]) + self.skip_projection = Conv1d(residual_channels, residual_channels, 1) + self.output_projection = Conv1d(residual_channels, in_dims, 1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, spec, spec_mask, mu, t, *args, **kwargs): + """ + :param spec: [B, M, T] + :param t: [B, ] + :param mu: [B, M, T] + :return: + """ + # x = spec[:, 0] + x = spec + x = self.input_projection(x) # x [B, residual_channel, T] + + x = F.relu(x) + + t = self.time_pos_emb(t, scale=self.pe_scale) + t = self.mlp(t) + + cond = mu + + skip = [] + for layer_id, layer in enumerate(self.residual_layers): + x, skip_connection = layer(x, cond, t) + skip.append(skip_connection) + + x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) # [B, M, T] + return x * spec_mask \ No newline at end of file diff --git a/pflow/models/pflow_tts.py b/pflow/models/pflow_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..e67f889bf0985f77ea0edd440f891132f5f1fc32 --- /dev/null +++ b/pflow/models/pflow_tts.py @@ -0,0 +1,182 @@ +import datetime as dt +import math +import random + +import torch +import torch.nn.functional as F + + +from pflow.models.baselightningmodule import BaseLightningClass +from pflow.models.components.flow_matching import CFM +from pflow.models.components.speech_prompt_encoder import TextEncoder +from pflow.utils.model import ( + denormalize, + duration_loss, + fix_len_compatibility, + generate_path, + sequence_mask, +) +from pflow.models.components import commons +from pflow.models.components.aligner import Aligner, ForwardSumLoss, BinLoss + + + +class pflowTTS(BaseLightningClass): # + def __init__( + self, + n_vocab, + n_feats, + encoder, + decoder, + cfm, + data_statistics, + prompt_size=264, + optimizer=None, + scheduler=None, + **kwargs, + ): + super().__init__() + + self.save_hyperparameters(logger=False) + + self.n_vocab = n_vocab + self.n_feats = n_feats + self.prompt_size = prompt_size + speech_in_channels = n_feats + + self.encoder = TextEncoder( + encoder.encoder_type, + encoder.encoder_params, + encoder.duration_predictor_params, + n_vocab, + speech_in_channels, + ) + + # self.aligner = Aligner( + # dim_in=encoder.encoder_params.n_feats, + # dim_hidden=encoder.encoder_params.n_feats, + # attn_channels=encoder.encoder_params.n_feats, + # ) + + # self.aligner_loss = ForwardSumLoss() + # self.bin_loss = BinLoss() + # self.aligner_bin_loss_weight = 0.0 + + self.decoder = CFM( + in_channels=encoder.encoder_params.n_feats, + out_channel=encoder.encoder_params.n_feats, + cfm_params=cfm, + decoder_params=decoder, + ) + + self.proj_prompt = torch.nn.Conv1d(encoder.encoder_params.n_channels, self.n_feats, 1) + + self.update_data_statistics(data_statistics) + + @torch.inference_mode() + def synthesise(self, x, x_lengths, prompt, n_timesteps, temperature=1.0, length_scale=1.0, guidance_scale=0.0): + + # For RTF computation + t = dt.datetime.now() + assert prompt is not None, "Prompt must be provided for synthesis" + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.encoder(x, x_lengths, prompt) + w = torch.exp(logw) * x_mask + w_ceil = torch.ceil(w) * length_scale + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = y_lengths.max() + y_max_length_ = fix_len_compatibility(y_max_length) + + # Using obtained durations `w` construct alignment map `attn` + y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + + # Align encoded text and get mu_y + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + encoder_outputs = mu_y[:, :, :y_max_length] + + # Generate sample tracing the probability flow + decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, guidance_scale=guidance_scale) + decoder_outputs = decoder_outputs[:, :, :y_max_length] + + t = (dt.datetime.now() - t).total_seconds() + rtf = t * 22050 / (decoder_outputs.shape[-1] * 256) + + return { + "encoder_outputs": encoder_outputs, + "decoder_outputs": decoder_outputs, + "attn": attn[:, :, :y_max_length], + "mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std), + "mel_lengths": y_lengths, + "rtf": rtf, + } + + def forward(self, x, x_lengths, y, y_lengths, prompt=None, cond=None, **kwargs): + if prompt is None: + prompt_slice, ids_slice = commons.rand_slice_segments( + y, y_lengths, self.prompt_size + ) + else: + prompt_slice = prompt + mu_x, logw, x_mask = self.encoder(x, x_lengths, prompt_slice) + + y_max_length = y.shape[-1] + + y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.ones_like(mu_x) # [b, d, t] + # s_p_sq_r = torch.exp(-2 * logx) + neg_cent1 = torch.sum( + -0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True + ) + # neg_cent1 = torch.sum( + # -0.5 * math.log(2 * math.pi) - logx, [1], keepdim=True + # ) # [b, 1, t_s] + neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r) + neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r)) + neg_cent4 = torch.sum( + -0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True + ) + neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 + + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + from pflow.utils.monotonic_align import maximum_path + attn = ( + maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() + ) + + logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask + dur_loss = duration_loss(logw, logw_, x_lengths) + + # aln_hard, aln_soft, aln_log, aln_mask = self.aligner( + # mu_x.transpose(1,2), x_mask, y, y_mask + # ) + # attn = aln_mask.transpose(1,2).unsqueeze(1) + # align_loss = self.aligner_loss(aln_log, x_lengths, y_lengths) + # if self.aligner_bin_loss_weight > 0.: + # align_bin_loss = self.bin_loss(aln_mask, aln_log, x_lengths) * self.aligner_bin_loss_weight + # align_loss = align_loss + align_bin_loss + # dur_loss = F.l1_loss(logw, attn.sum(2)) + # dur_loss = dur_loss + align_loss + + # Align encoded text with mel-spectrogram and get mu_y segment + attn = attn.squeeze(1).transpose(1,2) + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + + y_loss_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) + if prompt is None: + for i in range(y.size(0)): + y_loss_mask[i,:,ids_slice[i]:ids_slice[i] + self.prompt_size] = False + # Compute loss of the decoder + diff_loss, _ = self.decoder.compute_loss(x1=y.detach(), mask=y_mask, mu=mu_y, cond=cond, loss_mask=y_loss_mask) + + prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_loss_mask) + prior_loss = prior_loss / (torch.sum(y_loss_mask) * self.n_feats) + + return dur_loss, prior_loss, diff_loss, attn \ No newline at end of file diff --git a/pflow/text/__init__.py b/pflow/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd1879ef22528064b07b8acf118e10f2047952a6 --- /dev/null +++ b/pflow/text/__init__.py @@ -0,0 +1,53 @@ +""" from https://github.com/keithito/tacotron """ +from pflow.text import cleaners +from pflow.text.symbols import symbols + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension + + +def text_to_sequence(text, cleaner_names): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [] + + clean_text = _clean_text(text, cleaner_names) + for symbol in clean_text: + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + return sequence + + +def cleaned_text_to_sequence(cleaned_text): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] + return sequence + + +def sequence_to_text(sequence): + """Converts a sequence of IDs back to a string""" + result = "" + for symbol_id in sequence: + s = _id_to_symbol[symbol_id] + result += s + return result + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception("Unknown cleaner: %s" % name) + text = cleaner(text) + return text diff --git a/pflow/text/cleaners.py b/pflow/text/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..dc22a909fccf545c3454c6b7710827eafe9e9ae2 --- /dev/null +++ b/pflow/text/cleaners.py @@ -0,0 +1,19 @@ +from pflow.text.textnormalizer import norm +from ukrainian_word_stress import Stressifier +import regex +import re +from ipa_uk import ipa +stressify = Stressifier() + + +_whitespace_re = re.compile(r"\s+") +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def ukr_cleaners(text): + text = collapse_whitespace(text) + text = norm(text).lower() + + text = regex.sub(r'[^\p{L}\p{N}\?\!\,\.\-\: ]', '', text) + return ipa(stressify(text), False) diff --git a/pflow/text/numbers.py b/pflow/text/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..f99a8686dcb73532091122613e74bd643a8a327f --- /dev/null +++ b/pflow/text/numbers.py @@ -0,0 +1,71 @@ +""" from https://github.com/keithito/tacotron """ + +import re + +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return f"{dollars} {dollar_unit}, {cents} {cent_unit}" + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return f"{dollars} {dollar_unit}" + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return f"{cents} {cent_unit}" + else: + return "zero dollars" + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + else: + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/pflow/text/symbols.py b/pflow/text/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..0a74ff9d49bc56cbf60c397dd33a11ac31b5af75 --- /dev/null +++ b/pflow/text/symbols.py @@ -0,0 +1,17 @@ +""" from https://github.com/keithito/tacotron + +Defines the set of symbols used in text input to the model. +""" +_pad = "_" +_punctuation = '-´;:,.!?¡¿—…"«»“” ' +_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_letters_ipa = ( + "éýíó'̯'͡ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" +) + + +# Export all symbols: +symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + +# Special symbol ids +SPACE_ID = symbols.index(" ") diff --git a/pflow/text/textnormalizer.py b/pflow/text/textnormalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..00d9e82eee2f3da7d937e73b79bc0991a7ec2102 --- /dev/null +++ b/pflow/text/textnormalizer.py @@ -0,0 +1,198 @@ +import regex +from num2words import num2words +import unicodedata + +simple_replacements = { + '№' : 'номер', + '§': 'номер' +} + +masc_replacments_dict = { + '%':['відсоток', 'відсотки', 'відсотків'], + 'мм': ['міліметр', 'міліметри', 'міліметрів'], + 'см': ['сантиметр', 'сантиметри', 'сантиметрів'], + 'мм': ['міліметр', 'міліметри', 'міліметрів'], + # 'м': ['метр', 'метри', 'метрів'], + 'км': ['кілометр', 'кілометри', 'кілометрів'], + 'гц': ['герц', 'герци', 'герців'], + 'кгц': ['кілогерц', 'кілогерци', 'кілогерців'], + 'мгц': ['мегагерц', 'мегагерци', 'мегагерців'], + 'ггц': ['гігагерц', 'гігагерци', 'гігагерців'], + 'вт': ['ват', 'вати', 'ватів'], + 'квт': ['кіловат', 'кіловати', 'кіловатів'], + 'мвт': ['мегават', 'мегавати', 'мегаватів'], + 'гвт': ['гігават', 'гігавати', 'гігаватів'], + 'дж': ['джоуль', 'джоулі', 'джоулів'], + 'кдж': ['кілоджоуль', 'кілоджоулі', 'кілоджоулів'], + 'мдж': ['мегаджоуль', 'мегаджоулі', 'мегаджоулів'], + 'см2': ['сантиметр квадратний', 'сантиметри квадратні', 'сантиметрів квадратних'], + 'м2': ['метр квадратний', 'метри квадратні', 'метрів квадратних'], + 'м2': ['кілометр квадратний', 'кілометри квадратні', 'кілометрів квадратних'], + '$': ['долар', 'долари', 'доларів'], + '€': ['євро', 'євро', 'євро'], +} + +fem_replacments_dict = { + 'кал': ['калорія', 'калорії', 'калорій'], + 'ккал': ['кілокалорія', 'кілокалорії', 'кілокалорій'], + 'грн': ['гривня', 'гривні', 'гривень'], + 'грв': ['гривня', 'гривні', 'гривень'], + '₴': ['гривня', 'гривні', 'гривень'], +} + +neu_replacments_dict = { + '€': ['євро', 'євро', 'євро'], +} + +all_replacments_keys = list(masc_replacments_dict.keys()) + list(fem_replacments_dict.keys()) + list(neu_replacments_dict.keys()) + +#Ordinal types +#Називний +ordinal_nominative_masculine_cases = ('й','ий') +ordinal_nominative_feminine_cases = ('a','ша', 'я') +ordinal_nominative_neuter_cases = ('е',) + +#Родовий +ordinal_genitive_masculine_case = ('го','о',) +ordinal_genitive_feminine_case = ('ї', 'ої') + + +#Давальний +ordinal_dative_masculine_case = ('му',) +ordinal_dative_feminine_case = ('й','ій') + +#Знахідний +ordinal_accusative_masculine_case = ordinal_genitive_masculine_case +ordinal_accusative_feminine_case = ('у',) + +#Орудний +ordinal_instrumental_masculine_case = ('им', 'ім') +ordinal_instrumental_feminine_case = ('ю') + + +#Місцевий +# ordinal_locative_masculine_case = ordinal_dative_masculine_case +# ordinal_locative_feminine_case = ordinal_dative_feminine_case + +numcases_r = regex.compile(rf'((?:^|\s)(\d+)\s*(\-?)(([^\d,]*?)|(\-\.+))(?:\.|,|:|-)?)(\s+[^,.:\-]|$)', regex.IGNORECASE, regex.UNICODE) + +print(numcases_r) +cardinal_genitive_endings = ('а', 'e', 'є', 'й') +ordinal_genitive_cases = ('року',) + +def number_form(number): + if number[-1] == "1": + return 0 + elif number[-1] in ("2", "3", "4"): + return 1 + else: + return 2 + +def replace_cases(number, dash, case='', next_word=''): + print(f'{number}, {dash}, {case}, {next_word}') + gender = 'masculine' + m_case = 'nominative' + to = 'ordinal' + repl = '' + if not dash: + if case in all_replacments_keys: + if case in masc_replacments_dict.keys(): + repl = masc_replacments_dict.get(case)[number_form(number)] + gender = 'masculine' + elif case in fem_replacments_dict.keys(): + repl = fem_replacments_dict.get(case)[number_form(number)] + gender = 'feminine' + elif case in neu_replacments_dict.keys(): + repl = neu_replacments_dict.get(case)[number_form(number)] + gender = 'neuter' + to = 'cardinal' + else: + if len(case) < 3 and case and case[-1] in cardinal_genitive_endings: + m_case = 'genitive' + gender='masculine' + to = 'cardinal' + elif case in ordinal_genitive_cases: + to = 'ordinal' + m_case = 'genitive' + repl = case + else: + to = 'cardinal' + repl = case + + else: + if case in ordinal_nominative_masculine_cases: + m_case = 'nominative' + gender = 'masculine' + elif case in ordinal_nominative_feminine_cases: + m_case = 'nominative' + gender = 'feminine' + elif case in ordinal_nominative_neuter_cases: + m_case = 'nominative' + gender = 'neuter' + elif case in ordinal_genitive_masculine_case: + m_case = 'genitive' + gender = 'masculine' + elif case in ordinal_genitive_feminine_case: + m_case = 'genitive' + gender = 'feminine' + elif case in ordinal_dative_masculine_case: + m_case = 'dative' + gender = 'masculine' + elif case in ordinal_dative_feminine_case: + m_case = 'dative' + gender = 'feminine' + elif case in ordinal_accusative_feminine_case: + m_case = 'accusative' + gender = 'feminine' + elif case in ordinal_instrumental_masculine_case: + m_case = 'instrumental' + gender = 'masculine' + elif case in ordinal_instrumental_feminine_case: + m_case = 'instrumental' + gender = 'feminine' + else: + if case and case[-1] in cardinal_genitive_endings: + m_case = 'genitive' + gender='masculine' + to = 'cardinal' + repl = case + else: + print(f'UNKNOWN CASE {number}-{case}') + + return_str = num2words(number, to=to, lang='uk', case=m_case, gender=gender) + if repl: + return_str += ' ' + repl + if not next_word or (next_word and next_word.strip().isupper()): + return_str += '.' + return return_str + +def norm(text): + text = regex.sub(r'[\t\n]', ' ', text) + text = regex.sub(rf"[{''.join(simple_replacements.keys())}]", lambda x: f' {simple_replacements[x.group()]} ', text) + text = regex.sub(r"(\d)\s+(\d)", r"\1\2", text) + text = regex.sub(r'\s+', ' ', text) + text = unicodedata.normalize('NFC', text) + matches = numcases_r.finditer(text) + pos = 0 + new_text = '' + for m in matches: + repl = replace_cases(m.group(2), m.group(3), m.group(4), m.group(7)) + new_text += text[pos:m.start(0)]+ ' ' + repl + pos = m.end(1) + new_text += text[pos:] + return new_text.strip() + + + +#1-го квітня, на 1-му поверсі Яринка загубила 2грн але знайшла 5€. Але її 4-річна сестричка забрала 50% її знахідки. +#Також 2003 року щось там сталося і 40-річний чоловік помер. Його знайшли через 3 години. + +#01:51:37.250 -> 01:51:44.650: Серед міленіалів цей показник становить 39%, серед покоління X – 30%, +#39 +#30 +#MATCHED: серед міленіалів цей показник становить тридцять девять , серед покоління Х - тридцять , +#Skipped because contains inapropirate characters + +#05:28:52.350 -> 05:29:00.000: 2016 рік завершився з чистими збитками 1,2 мільярди доларів США. +#2016 +#MATCHED: дві тисячі шістнадцять рік завершився з чистими збитками 1,2 млрд доларів США. \ No newline at end of file diff --git a/pflow/utils/__init__.py b/pflow/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f74fe9fd88225854a2eb5c1d624c98f5f704836 --- /dev/null +++ b/pflow/utils/__init__.py @@ -0,0 +1,5 @@ +from pflow.utils.instantiators import instantiate_callbacks, instantiate_loggers +from pflow.utils.logging_utils import log_hyperparameters +from pflow.utils.pylogger import get_pylogger +from pflow.utils.rich_utils import enforce_tags, print_config_tree +from pflow.utils.utils import extras, get_metric_value, task_wrapper diff --git a/pflow/utils/audio.py b/pflow/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..0bcd74df47fb006f68deb5a5f4a4c2fb0aa84f57 --- /dev/null +++ b/pflow/utils/audio.py @@ -0,0 +1,82 @@ +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if f"{str(fmax)}_{str(y.device)}" not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/pflow/utils/generate_data_statistics.py b/pflow/utils/generate_data_statistics.py new file mode 100644 index 0000000000000000000000000000000000000000..fd09abc8f6b3d670d446ec0d808cee6217eae3d1 --- /dev/null +++ b/pflow/utils/generate_data_statistics.py @@ -0,0 +1,115 @@ +r""" +The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it +when needed. + +Parameters from hparam.py will be used +""" +import os + +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) + +import argparse +import json +import sys +from pathlib import Path + +import rootutils +import torch +from hydra import compose, initialize +from omegaconf import open_dict +from tqdm.auto import tqdm + +from pflow.data.text_mel_datamodule import TextMelDataModule +from pflow.utils.logging_utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): + """Generate data mean and standard deviation helpful in data normalisation + + Args: + data_loader (torch.utils.data.Dataloader): _description_ + out_channels (int): mel spectrogram channels + """ + total_mel_sum = 0 + total_mel_sq_sum = 0 + total_mel_len = 0 + + for batch in tqdm(data_loader, leave=False): + mels = batch["y"] + mel_lengths = batch["y_lengths"] + + total_mel_len += torch.sum(mel_lengths) + total_mel_sum += torch.sum(mels) + total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) + + data_mean = total_mel_sum / (total_mel_len * out_channels) + data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) + + return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--input-config", + type=str, + default="vctk.yaml", + help="The name of the yaml config file under configs/data", + ) + + parser.add_argument( + "-b", + "--batch-size", + type=int, + default="256", + help="Can have increased batch size for faster computation", + ) + + parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + required=False, + help="force overwrite the file", + ) + args = parser.parse_args() + output_file = Path(args.input_config).with_suffix(".json") + + if os.path.exists(output_file) and not args.force: + print("File already exists. Use -f to force overwrite") + sys.exit(1) + + with initialize(version_base="1.3", config_path="../../configs/data"): + cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) + + root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") + + with open_dict(cfg): + del cfg["hydra"] + del cfg["_target_"] + cfg["data_statistics"] = None + cfg["seed"] = 1234 + cfg["batch_size"] = args.batch_size + cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) + cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) + + text_mel_datamodule = TextMelDataModule(**cfg) + text_mel_datamodule.setup() + data_loader = text_mel_datamodule.train_dataloader() + log.info("Dataloader loaded! Now computing stats...") + params = compute_data_statistics(data_loader, cfg["n_feats"]) + print(params) + json.dump( + params, + open(output_file, "w"), + ) + + +if __name__ == "__main__": + main() diff --git a/pflow/utils/instantiators.py b/pflow/utils/instantiators.py new file mode 100644 index 0000000000000000000000000000000000000000..d585d48c2a9d9c8dffc55fd9e759f01dc345f991 --- /dev/null +++ b/pflow/utils/instantiators.py @@ -0,0 +1,56 @@ +from typing import List + +import hydra +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from pflow.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config. + + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config. + + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/pflow/utils/logging_utils.py b/pflow/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f5aba0ab7c82c43a5a25bfde8f429405d5b8d9e1 --- /dev/null +++ b/pflow/utils/logging_utils.py @@ -0,0 +1,53 @@ +from typing import Any, Dict + +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import OmegaConf + +from pflow.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - Number of model parameters + + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ + hparams = {} + + cfg = OmegaConf.to_container(object_dict["cfg"]) + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) + hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/pflow/utils/model.py b/pflow/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..869cc6092f5952930534c47544fae88308e96abf --- /dev/null +++ b/pflow/utils/model.py @@ -0,0 +1,90 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import numpy as np +import torch + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def fix_len_compatibility(length, num_downsamplings_in_unet=2): + factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) + length = (length / factor).ceil() * factor + if not torch.onnx.is_in_onnx_export(): + return length.int().item() + else: + return length + + +def convert_pad_shape(pad_shape): + inverted_shape = pad_shape[::-1] + pad_shape = [item for sublist in inverted_shape for item in sublist] + return pad_shape + + +def generate_path(duration, mask): + device = duration.device + + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path * mask + return path + + +def duration_loss(logw, logw_, lengths): + loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) + return loss + + +def normalize(data, mu, std): + if not isinstance(mu, (float, int)): + if isinstance(mu, list): + mu = torch.tensor(mu, dtype=data.dtype, device=data.device) + elif isinstance(mu, torch.Tensor): + mu = mu.to(data.device) + elif isinstance(mu, np.ndarray): + mu = torch.from_numpy(mu).to(data.device) + mu = mu.unsqueeze(-1) + + if not isinstance(std, (float, int)): + if isinstance(std, list): + std = torch.tensor(std, dtype=data.dtype, device=data.device) + elif isinstance(std, torch.Tensor): + std = std.to(data.device) + elif isinstance(std, np.ndarray): + std = torch.from_numpy(std).to(data.device) + std = std.unsqueeze(-1) + + return (data - mu) / std + + +def denormalize(data, mu, std): + if not isinstance(mu, float): + if isinstance(mu, list): + mu = torch.tensor(mu, dtype=data.dtype, device=data.device) + elif isinstance(mu, torch.Tensor): + mu = mu.to(data.device) + elif isinstance(mu, np.ndarray): + mu = torch.from_numpy(mu).to(data.device) + mu = mu.unsqueeze(-1) + + if not isinstance(std, float): + if isinstance(std, list): + std = torch.tensor(std, dtype=data.dtype, device=data.device) + elif isinstance(std, torch.Tensor): + std = std.to(data.device) + elif isinstance(std, np.ndarray): + std = torch.from_numpy(std).to(data.device) + std = std.unsqueeze(-1) + + return data * std + mu diff --git a/pflow/utils/monotonic_align/__init__.py b/pflow/utils/monotonic_align/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65155e16c5568e2236a328eb8d0ada3ed937d69e --- /dev/null +++ b/pflow/utils/monotonic_align/__init__.py @@ -0,0 +1,19 @@ +import numpy as np +import torch +from pflow.utils.monotonic_align.core import maximum_path_c + + +def maximum_path(neg_cent, mask): + """Cython optimized version. + neg_cent: [b, t_t, t_s] + mask: [b, t_t, t_s] + """ + device = neg_cent.device + dtype = neg_cent.dtype + neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) + path = np.zeros(neg_cent.shape, dtype=np.int32) + + t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) + t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) + maximum_path_c(path, neg_cent, t_t_max, t_s_max) + return torch.from_numpy(path).to(device=device, dtype=dtype) diff --git a/pflow/utils/monotonic_align/core.pyx b/pflow/utils/monotonic_align/core.pyx new file mode 100644 index 0000000000000000000000000000000000000000..bfaabd4d21c2299cdd978f0cc0caefa20ad186e5 --- /dev/null +++ b/pflow/utils/monotonic_align/core.pyx @@ -0,0 +1,42 @@ +cimport cython +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y-1, x] + if x == 0: + if y == 0: + v_prev = 0. + else: + v_prev = max_neg_val + else: + v_prev = value[y-1, x-1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil: + cdef int b = paths.shape[0] + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/pflow/utils/pylogger.py b/pflow/utils/pylogger.py new file mode 100644 index 0000000000000000000000000000000000000000..61600678029362e110f655edb91d5f3bc5b1cd1c --- /dev/null +++ b/pflow/utils/pylogger.py @@ -0,0 +1,21 @@ +import logging + +from lightning.pytorch.utilities import rank_zero_only + + +def get_pylogger(name: str = __name__) -> logging.Logger: + """Initializes a multi-GPU-friendly python command line logger. + + :param name: The name of the logger, defaults to ``__name__``. + + :return: A logger object. + """ + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger diff --git a/pflow/utils/rich_utils.py b/pflow/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..086e80bffeea4ac151032b0d26c722ed32feaeea --- /dev/null +++ b/pflow/utils/rich_utils.py @@ -0,0 +1,101 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from pflow.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints the contents of a DictConfig as a tree structure using the Rich library. + + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + _ = ( + queue.append(field) + if field in cfg + else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config. + + :param cfg: A DictConfig composed by Hydra. + :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + """ + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/pflow/utils/utils.py b/pflow/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2781793f6c5d396045484f7f77185728ad9b1225 --- /dev/null +++ b/pflow/utils/utils.py @@ -0,0 +1,218 @@ +import os +import sys +import warnings +from importlib.util import find_spec +from pathlib import Path +from typing import Any, Callable, Dict, Tuple + +import gdown +import matplotlib.pyplot as plt +import numpy as np +import torch +import wget +from omegaconf import DictConfig + +from pflow.utils import pylogger, rich_utils + +log = pylogger.get_pylogger(__name__) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + + :param cfg: A DictConfig object containing the config tree. + """ + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ... + return metric_dict, object_dict + ``` + + :param task_func: The task function to be wrapped. + + :return: The wrapped task function. + """ + + def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule. + + :param metric_dict: A dict containing metric values. + :param metric_name: The name of the metric to retrieve. + :return: The value of the metric. + """ + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def intersperse(lst, item): + # Adds blank symbol + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def save_figure_to_numpy(fig): + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +def plot_tensor(tensor): + plt.style.use("default") + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.tight_layout() + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +def save_plot(tensor, savepath): + plt.style.use("default") + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.tight_layout() + fig.canvas.draw() + plt.savefig(savepath) + plt.close() + + +def to_numpy(tensor): + if isinstance(tensor, np.ndarray): + return tensor + elif isinstance(tensor, torch.Tensor): + return tensor.detach().cpu().numpy() + elif isinstance(tensor, list): + return np.array(tensor) + else: + raise TypeError("Unsupported type for conversion to numpy array") + + +def get_user_data_dir(appname="pflow_tts"): + """ + Args: + appname (str): Name of application + + Returns: + Path: path to user data directory + """ + + PFLOW_HOME = os.environ.get("PFLOW_HOME") + if PFLOW_HOME is not None: + ans = Path(PFLOW_HOME).expanduser().resolve(strict=False) + elif sys.platform == "win32": + import winreg # pylint: disable=import-outside-toplevel + + key = winreg.OpenKey( + winreg.HKEY_CURRENT_USER, + r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", + ) + dir_, _ = winreg.QueryValueEx(key, "Local AppData") + ans = Path(dir_).resolve(strict=False) + elif sys.platform == "darwin": + ans = Path("~/Library/Application Support/").expanduser() + else: + ans = Path.home().joinpath(".local/share") + + final_path = ans.joinpath(appname) + final_path.mkdir(parents=True, exist_ok=True) + return final_path + + +def assert_model_downloaded(checkpoint_path, url, use_wget=False): + print(checkpoint_path) + if Path(checkpoint_path).exists(): + log.debug(f"[+] Model already present at {checkpoint_path}!") + return + log.info(f"[-] Model not found at {checkpoint_path}! Will download it") + checkpoint_path = str(checkpoint_path) + if not use_wget: + gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) + else: + wget.download(url=url, out=checkpoint_path) diff --git a/prompt.wav b/prompt.wav new file mode 100644 index 0000000000000000000000000000000000000000..06f8e9c071f80d3cdaf952d1f3e9f0dd5c07f7ed Binary files /dev/null and b/prompt.wav differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4b727e04b523b199f60d2e5d8e633251de290f33 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,23 @@ + +torch>=2.0.0 +hydra-core==1.3.2 +lightning>=2.0.0 +conformer==0.3.2 +diffusers==0.21.3 + +gdown +wget + +numpy +beartype + + +torchaudio +librosa +gradio + + +regex +ukrainian_word_stress +ipa_uk@git+https://github.com/lang-uk/ipa-uk.git +num2words@git+https://github.com/patriotyk/num2words.git \ No newline at end of file