diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..426a637f01f0823b91a7604be8e5bd8c68f253d0 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,6 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text +*.json filter=lfs diff=lfs merge=lfs -text *.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +*.txt filter=lfs diff=lfs merge=lfs -text +*.pt* filter=lfs diff=lfs merge=lfs -text +*.ckpt* filter=lfs diff=lfs merge=lfs -text +*.pl filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..fd5b5a1b41088920f428c9b2673bbd75aad4784d --- /dev/null +++ b/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +Temp_Audios/ + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +*.ini + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/Languages/ben/G_100000.pth b/Languages/ben/G_100000.pth new file mode 100644 index 0000000000000000000000000000000000000000..c4c69f1ca7edc83ffe7aee0b10e034a4606923df --- /dev/null +++ b/Languages/ben/G_100000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8c098eab2e5e378fc52bec57683839bbc641b2241033dab17174f6e37db29a4 +size 145512166 diff --git a/Languages/ben/config.json b/Languages/ben/config.json new file mode 100644 index 0000000000000000000000000000000000000000..4c206eccc73982de2c97df33fa7a2dc81982b91a --- /dev/null +++ b/Languages/ben/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49 +size 1887 diff --git a/Languages/ben/vocab.txt b/Languages/ben/vocab.txt new file mode 100644 index 0000000000000000000000000000000000000000..c05473ed9e09353e0f55ce9d584c621caeb9b776 --- /dev/null +++ b/Languages/ben/vocab.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7085f1a1f6040b4da0ac55bb3ff91b77229d1ed14f7d86df2b23676a1a2cb81b +size 268 diff --git a/Languages/ell/G_100000.pth b/Languages/ell/G_100000.pth new file mode 100644 index 0000000000000000000000000000000000000000..73d239b5150ab7ae018c24b23ff5ee546082ba62 --- /dev/null +++ b/Languages/ell/G_100000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:75bfa237f0fe859b34c4340bc7dccd944678cf9984bce5b5a82e2c90ca268db8 +size 145504497 diff --git a/Languages/ell/config.json b/Languages/ell/config.json new file mode 100644 index 0000000000000000000000000000000000000000..4c206eccc73982de2c97df33fa7a2dc81982b91a --- /dev/null +++ b/Languages/ell/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49 +size 1887 diff --git a/Languages/ell/vocab.txt b/Languages/ell/vocab.txt new file mode 100644 index 0000000000000000000000000000000000000000..e2181e9e0727478421bd325ed5fcb9c52a6f5f52 --- /dev/null +++ b/Languages/ell/vocab.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c53d89f446eba9d061b510e31900d235ef0e021e44a790978dae5a4350a4013 +size 164 diff --git a/Languages/fra/G_100000.pth b/Languages/fra/G_100000.pth new file mode 100644 index 0000000000000000000000000000000000000000..557469051d4247405493b333a6efdf9096643813 --- /dev/null +++ b/Languages/fra/G_100000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63725b5a9201548b2247af02bd69a059335bddf52c1b858dbe38a43a40478bd7 +size 145489135 diff --git a/Languages/fra/config.json b/Languages/fra/config.json new file mode 100644 index 0000000000000000000000000000000000000000..4c206eccc73982de2c97df33fa7a2dc81982b91a --- /dev/null +++ b/Languages/fra/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49 +size 1887 diff --git a/Languages/fra/vocab.txt b/Languages/fra/vocab.txt new file mode 100644 index 0000000000000000000000000000000000000000..868c0e022d1946f20046501af0df3197df8e0201 --- /dev/null +++ b/Languages/fra/vocab.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b57f0f246b488fe914508d82a8607e1aea357beb0f801069b39bfeb3a4c0d47 +size 104 diff --git a/Languages/guj/G_100000.pth b/Languages/guj/G_100000.pth new file mode 100644 index 0000000000000000000000000000000000000000..4388ca95e7c36800d610069f3cf7c0f05e65a055 --- /dev/null +++ b/Languages/guj/G_100000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:427ac3c74f61be494b389cae7d771311d0bcf576f4e2f1b22f257539e26e323a +size 145501427 diff --git a/Languages/guj/config.json b/Languages/guj/config.json new file mode 100644 index 0000000000000000000000000000000000000000..4c206eccc73982de2c97df33fa7a2dc81982b91a --- /dev/null +++ b/Languages/guj/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49 +size 1887 diff --git a/Languages/guj/vocab.txt b/Languages/guj/vocab.txt new file mode 100644 index 0000000000000000000000000000000000000000..236fda9cac231294a0fd1c8fccebe697329b2928 --- /dev/null +++ b/Languages/guj/vocab.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:611d4c5d7ba4bce727c1154277aea43df7a534e22e523877d1885a36727d63c3 +size 232 diff --git a/Languages/hin/G_100000.pth b/Languages/hin/G_100000.pth new file mode 100644 index 0000000000000000000000000000000000000000..154702dfcec0fc671ddaa9c9962814cdecf86717 --- /dev/null +++ b/Languages/hin/G_100000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f1d5e47edd7368ff40ff5673ddfc606ea713e785420d26c2da396b555458d3b +size 145510619 diff --git a/Languages/hin/config.json b/Languages/hin/config.json new file mode 100644 index 0000000000000000000000000000000000000000..4c206eccc73982de2c97df33fa7a2dc81982b91a --- /dev/null +++ b/Languages/hin/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49 +size 1887 diff --git a/Languages/hin/vocab.txt b/Languages/hin/vocab.txt new file mode 100644 index 0000000000000000000000000000000000000000..0f6d3fc331fb6a3d572920c573f35678a36238e2 --- /dev/null +++ b/Languages/hin/vocab.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eea03474615c78a1c42d1299c345b6865421c07485544ac3361bff472e5005ac +size 266 diff --git a/Languages/nld/G_100000.pth b/Languages/nld/G_100000.pth new file mode 100644 index 0000000000000000000000000000000000000000..9c0f8442f4b9be63ec1de8fcd423e965d723871e --- /dev/null +++ b/Languages/nld/G_100000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b09e9917b07f06dd911045c8fc8738594b4c4d65c55223c46335093a4904816 +size 145486855 diff --git a/Languages/nld/config.json b/Languages/nld/config.json new file mode 100644 index 0000000000000000000000000000000000000000..4c206eccc73982de2c97df33fa7a2dc81982b91a --- /dev/null +++ b/Languages/nld/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49 +size 1887 diff --git a/Languages/nld/vocab.txt b/Languages/nld/vocab.txt new file mode 100644 index 0000000000000000000000000000000000000000..d6e0086df0e85c97320bd76f2d9cf145e6e3e674 --- /dev/null +++ b/Languages/nld/vocab.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d7e65d89b8be768ac4a1e53643aebfe42830b82910bfed87f13904e2c5292a4 +size 94 diff --git a/Languages/pol/G_100000.pth b/Languages/pol/G_100000.pth new file mode 100644 index 0000000000000000000000000000000000000000..53996893b4a8e22ebed02c03d7e690e8014d6630 --- /dev/null +++ b/Languages/pol/G_100000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d6f4a9de92a6eb15bca8cb01826d8a9938ab6fb2c04a1c13a06d1d170c88ba6 +size 145490647 diff --git a/Languages/pol/config.json b/Languages/pol/config.json new file mode 100644 index 0000000000000000000000000000000000000000..4c206eccc73982de2c97df33fa7a2dc81982b91a --- /dev/null +++ b/Languages/pol/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec453d50be1976ddb3b501c7dc09128cdf360b83e5b0cf7dec68dc5caa460f49 +size 1887 diff --git a/Languages/pol/vocab.txt b/Languages/pol/vocab.txt new file mode 100755 index 0000000000000000000000000000000000000000..09345c0e0a92eb5eaa62d78e354b3a843952cd69 --- /dev/null +++ b/Languages/pol/vocab.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5514b17eb0cd17950849e3f2f2f22a6c7c2d18f2729f5b2fbfc2f2e5f035dc4a +size 103 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..0f96d1cd7a8b4eae3e75be0549ae15b8c98c466a --- /dev/null +++ b/app.py @@ -0,0 +1,317 @@ + +# load the libraries for the application +# ------------------------------------------- +import os +import re +import nltk +import torch +import librosa +import tempfile +import subprocess + +import gradio as gr + +from scipy.io import wavfile +from nnet import utils, commons +from transformers import pipeline +from scipy.io.wavfile import write +from faster_whisper import WhisperModel +from nnet.models import SynthesizerTrn as vitsTRN +from nnet.models_vc import SynthesizerTrn as freeTRN +from nnet.mel_processing import mel_spectrogram_torch +from configurations.get_constants import constantConfig + +from speaker_encoder.voice_encoder import SpeakerEncoder + +from df.enhance import enhance, init_df, load_audio, save_audio +from configurations.get_hyperparameters import hyperparameterConfig +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline + +nltk.download('punkt') +from nltk.tokenize import sent_tokenize + +# making the FreeVC function +# --------------------------------- +class FreeVCModel: + def __init__(self, config, ptfile, speaker_model, wavLM_model, device='cpu'): + self.hps = utils.get_hparams_from_file(config) + + self.net_g = freeTRN( + self.hps.data.filter_length // 2 + 1, + self.hps.train.segment_size // self.hps.data.hop_length, + **self.hps.model + ).to(hyperparameters.device) + _ = self.net_g.eval() + _ = utils.load_checkpoint(ptfile, self.net_g, None, True) + + self.cmodel = utils.get_cmodel(device, wavLM_model) + + if self.hps.model.use_spk: + self.smodel = SpeakerEncoder(speaker_model) + + def convert(self, src, tgt): + fs_src, src_audio = src + fs_tgt, tgt_audio = tgt + + src = f"{constants.temp_audio_folder}/src.wav" + tgt = f"{constants.temp_audio_folder}/tgt.wav" + out = f"{constants.temp_audio_folder}/cnvr.wav" + with torch.no_grad(): + wavfile.write(tgt, fs_tgt, tgt_audio) + wav_tgt, _ = librosa.load(tgt, sr=self.hps.data.sampling_rate) + wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20) + if self.hps.model.use_spk: + g_tgt = self.smodel.embed_utterance(wav_tgt) + g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(hyperparameters.device.type) + else: + wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(hyperparameters.device.type) + mel_tgt = mel_spectrogram_torch( + wav_tgt, + self.hps.data.filter_length, + self.hps.data.n_mel_channels, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + self.hps.data.mel_fmin, + self.hps.data.mel_fmax, + ) + wavfile.write(src, fs_src, src_audio) + wav_src, _ = librosa.load(src, sr=self.hps.data.sampling_rate) + wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(hyperparameters.device.type) + c = utils.get_content(self.cmodel, wav_src) + + if self.hps.model.use_spk: + audio = self.net_g.infer(c, g=g_tgt) + else: + audio = self.net_g.infer(c, mel=mel_tgt) + audio = audio[0][0].data.cpu().float().numpy() + write(out, 24000, audio) + + return out + +# load the system configurations +constants = constantConfig() +hyperparameters = hyperparameterConfig() + +# load the models +model, df_state, _ = init_df(hyperparameters.voice_enhacing_model, config_allow_defaults=True) # voice enhancing model +stt_model = WhisperModel(hyperparameters.stt_model, device=hyperparameters.device.type, compute_type="float32") #speech to text model + +trans_model = AutoModelForSeq2SeqLM.from_pretrained(constants.model_name_dict[hyperparameters.nllb_model], torch_dtype=torch.bfloat16).to(hyperparameters.device) +trans_tokenizer = AutoTokenizer.from_pretrained(constants.model_name_dict[hyperparameters.nllb_model]) + +modelConvertSpeech = FreeVCModel(config=hyperparameters.text2speech_config, ptfile=hyperparameters.text2speech_model, + speaker_model=hyperparameters.text2speech_encoder, wavLM_model=hyperparameters.wavlm_model, + device=hyperparameters.device.type) + +# download the language model if doesn't existing +# ---------------------------------------------------- +def download(lang, lang_directory): + + if not os.path.exists(f"{lang_directory}/{lang}"): + cmd = ";".join([ + f"wget {constants.language_download_web}/{lang}.tar.gz -O {lang_directory}/{lang}.tar.gz", + f"tar zxvf {lang_directory}/{lang}.tar.gz -C {lang_directory}" + ]) + subprocess.check_output(cmd, shell=True) + try: + os.remove(f"{lang_directory}/{lang}.tar.gz") + except: + pass + return f"{lang_directory}/{lang}" + +def preprocess_char(text, lang=None): + """ + Special treatement of characters in certain languages + """ + if lang == 'ron': + text = text.replace("ț", "ţ") + return text + +def preprocess_text(txt, text_mapper, hps, uroman_dir=None, lang=None): + txt = preprocess_char(txt, lang=lang) + is_uroman = hps.data.training_files.split('.')[-1] == 'uroman' + if is_uroman: + txt = text_mapper.uromanize(txt, f'{uroman_dir}/bin/uroman.pl') + + txt = txt.lower() + txt = text_mapper.filter_oov(txt) + return txt + +def detect_language(text,LID): + predictions = LID.predict(text) + detected_lang_code = predictions[0][0].replace("__label__", "") + return detected_lang_code + +# text to speech +class TextMapper(object): + def __init__(self, vocab_file): + self.symbols = [x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()] + self.SPACE_ID = self.symbols.index(" ") + self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} + self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)} + + def text_to_sequence(self, text, cleaner_names): + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + Returns: + List of integers corresponding to the symbols in the text + ''' + sequence = [] + clean_text = text.strip() + for symbol in clean_text: + symbol_id = self._symbol_to_id[symbol] + sequence += [symbol_id] + return sequence + + def uromanize(self, text, uroman_pl): + with tempfile.NamedTemporaryFile() as tf, \ + tempfile.NamedTemporaryFile() as tf2: + with open(tf.name, "w") as f: + f.write("\n".join([text])) + cmd = f"perl " + uroman_pl + cmd += f" -l xxx " + cmd += f" < {tf.name} > {tf2.name}" + os.system(cmd) + outtexts = [] + with open(tf2.name) as f: + for line in f: + line = re.sub(r"\s+", " ", line).strip() + outtexts.append(line) + outtext = outtexts[0] + return outtext + + def get_text(self, text, hps): + text_norm = self.text_to_sequence(text, hps.data.text_cleaners) + if hps.data.add_blank: + text_norm = commons.intersperse(text_norm, 0) + text_norm = torch.LongTensor(text_norm) + return text_norm + + def filter_oov(self, text): + val_chars = self._symbol_to_id + txt_filt = "".join(list(filter(lambda x: x in val_chars, text))) + return txt_filt + +def speech_to_text(audio_file): + try: + fs, audio = audio_file + wavfile.write(constants.input_speech_file, fs, audio) + audio0, _ = load_audio(constants.input_speech_file, sr=df_state.sr()) + + # Enhance the SNR of the audio + enhanced = enhance(model, df_state, audio0) + save_audio(constants.enhanced_speech_file, enhanced, df_state.sr()) + + segments, info = stt_model.transcribe(constants.enhanced_speech_file) + + speech_text = '' + for segment in segments: + speech_text = f'{speech_text}{segment.text}' + try: + source_lang_nllb = [k for k, v in constants.flores_codes_to_tts_codes.items() if v[:2] == info.language][0] + except: + source_lang_nllb = 'language cant be determined, select manually' + + # text translation + return speech_text, gr.Dropdown.update(value=source_lang_nllb) + except: + return '', gr.Dropdown.update(value='English') + +# Text tp speech +def text_to_speech(text, target_lang): + txt = text + + # LANG = get_target_tts_lang(target_lang) + LANG = constants.flores_codes_to_tts_codes[target_lang] + ckpt_dir = download(LANG, lang_directory=constants.language_directory) + + vocab_file = f"{ckpt_dir}/{constants.language_vocab_text}" + config_file = f"{ckpt_dir}/{constants.language_vocab_configuration}" + hps = utils.get_hparams_from_file(config_file) + text_mapper = TextMapper(vocab_file) + net_g = vitsTRN( + len(text_mapper.symbols), + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + **hps.model) + net_g.to(hyperparameters.device) + _ = net_g.eval() + + g_pth = f"{ckpt_dir}/{constants.language_vocab_model}" + + _ = utils.load_checkpoint(g_pth, net_g, None) + + txt = preprocess_text(txt, text_mapper, hps, lang=LANG, uroman_dir=constants.uroman_directory) + stn_tst = text_mapper.get_text(txt, hps) + with torch.no_grad(): + x_tst = stn_tst.unsqueeze(0).to(hyperparameters.device) + x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(hyperparameters.device) + hyp = net_g.infer( + x_tst, x_tst_lengths, noise_scale=.667, + noise_scale_w=0.8, length_scale=1.0 + )[0][0,0].cpu().float().numpy() + + return hps.data.sampling_rate, hyp + +def translation(audio, text, source_lang_nllb, target_code_nllb, output_type, sentence_mode): + target_code = constants.flores_codes[target_code_nllb] + translator = pipeline('translation', model=trans_model, tokenizer=trans_tokenizer, src_lang=source_lang_nllb, tgt_lang=target_code, device=hyperparameters.device) + + # output = translator(text, max_length=400)[0]['translation_text'] + if sentence_mode == "Sentence-wise": + sentences = sent_tokenize(text) + translated_sentences = [] + for sentence in sentences: + translated_sentence = translator(sentence, max_length=400)[0]['translation_text'] + translated_sentences.append(translated_sentence) + output = ' '.join(translated_sentences) + else: + output = translator(text, max_length=1024)[0]['translation_text'] + + # get the text to speech + fs_out, audio_out = text_to_speech(output, target_code_nllb) + + if output_type == 'own voice': + out_file = modelConvertSpeech.convert((fs_out, audio_out), audio) + return output, out_file + + wavfile.write(constants.text2speech_wavfile, fs_out, audio_out) + return output, constants.text2speech_wavfile + +with gr.Blocks(title = "Octopus Translation App") as octopus_translator: + with gr.Row(): + audio_file = gr.Audio(source="microphone") + + with gr.Row(): + input_text = gr.Textbox(label="Input text") + source_language = gr.Dropdown(list(constants.flores_codes.keys()), value='English', label='Source (Autoselected)', interactive=True) + + with gr.Row(): + output_text = gr.Textbox(label='Translated text') + target_language = gr.Dropdown(list(constants.flores_codes.keys()), value='German', label='Target', interactive=True) + + + with gr.Row(): + output_speech = gr.Audio(label='Translated speech') + translate_button = gr.Button('Translate') + + + with gr.Row(): + enhance_audio = gr.Radio(['yes', 'no'], value='yes', label='Enhance input voice', interactive=True) + input_type = gr.Radio(['Whole text', 'Sentence-wise'],value='Sentence-wise', label="Translation Mode", interactive=True) + output_audio_type = gr.Radio(['standard speaker', 'voice transfer'], value='voice transfer', label='Enhance output voice', interactive=True) + + audio_file.change(speech_to_text, + inputs=[audio_file], + outputs=[input_text, source_language]) + + translate_button.click(translation, + inputs=[audio_file, input_text, + source_language, target_language, + output_audio_type, input_type], + outputs=[output_text, output_speech]) + +octopus_translator.launch(share=False) diff --git a/aux_files/uroman.pl b/aux_files/uroman.pl new file mode 100755 index 0000000000000000000000000000000000000000..bb978366c6ae65604bdc3051cf5fc257149a353e --- /dev/null +++ b/aux_files/uroman.pl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4ceece2c05343e8bc3b1a7cdc8cecd530af94a7928013c0e4224fd5c729fb29a +size 5347 diff --git a/configurations/__init__.py b/configurations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configurations/get_constants.py b/configurations/get_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..295a48c52046268483534ac3eca3245da59500f5 --- /dev/null +++ b/configurations/get_constants.py @@ -0,0 +1,176 @@ +import os + +class constantConfig: + def __init__(self): + self.flores_codes={'Acehnese (Arabic script)': 'ace_Arab', + 'Acehnese (Latin script)': 'ace_Latn', + 'Mesopotamian Arabic': 'acm_Arab', + 'Ta’izzi-Adeni Arabic': 'acq_Arab', + 'Tunisian Arabic': 'aeb_Arab', + 'Afrikaans': 'afr_Latn', + 'South Levantine Arabic': 'ajp_Arab', + 'Akan': 'aka_Latn', + 'Amharic': 'amh_Ethi', + 'North Levantine Arabic': 'apc_Arab', + 'Modern Standard Arabic': 'arb_Arab', + 'Modern Standard Arabic (Romanized)': 'arb_Latn', + 'Najdi Arabic': 'ars_Arab', + 'Moroccan Arabic': 'ary_Arab', + 'Egyptian Arabic': 'arz_Arab', + 'Assamese': 'asm_Beng', + 'Asturian': 'ast_Latn', + 'Awadhi': 'awa_Deva', + 'Central Aymara': 'ayr_Latn', + 'South Azerbaijani': 'azb_Arab', + 'North Azerbaijani': 'azj_Latn', + 'Bashkir': 'bak_Cyrl', + 'Bambara': 'bam_Latn', + 'Balinese': 'ban_Latn', + 'Belarusian': 'bel_Cyrl', + 'Bemba': 'bem_Latn', + 'Bengali': 'ben_Beng', + 'Bhojpuri': 'bho_Deva', + 'Banjar (Arabic script)': 'bjn_Arab', + 'Banjar (Latin script)': 'bjn_Latn', + 'Standard Tibetan': 'bod_Tibt', + 'Bosnian': 'bos_Latn', + 'Buginese': 'bug_Latn', + 'Bulgarian': 'bul_Cyrl', + 'Catalan': 'cat_Latn', + 'Cebuano': 'ceb_Latn', + 'Czech': 'ces_Latn', + 'Chokwe': 'cjk_Latn', + 'Central Kurdish': 'ckb_Arab', + 'Crimean Tatar': 'crh_Latn', + 'Welsh': 'cym_Latn', + 'Danish': 'dan_Latn', + 'German': 'deu_Latn', + 'Southwestern Dinka': 'dik_Latn', + 'Dyula': 'dyu_Latn', + 'Dzongkha': 'dzo_Tibt', + 'Greek': 'ell_Grek', + 'English': 'eng_Latn', + 'Esperanto': 'epo_Latn', + 'Estonian': 'est_Latn', + 'Basque': 'eus_Latn', + 'Ewe': 'ewe_Latn', + 'Faroese': 'fao_Latn', + 'Fijian': 'fij_Latn', + 'Finnish': 'fin_Latn', + 'Fon': 'fon_Latn', + 'French': 'fra_Latn', + 'Friulian': 'fur_Latn', + 'Nigerian Fulfulde': 'fuv_Latn', + 'Scottish Gaelic': 'gla_Latn', + 'Irish': 'gle_Latn', + 'Galician': 'glg_Latn', + 'Guarani': 'grn_Latn', + 'Gujarati': 'guj_Gujr', + 'Haitian Creole': 'hat_Latn', + 'Hausa': 'hau_Latn', + 'Hebrew': 'heb_Hebr', + 'Hindi': 'hin_Deva', + 'Chhattisgarhi': 'hne_Deva', + 'Croatian': 'hrv_Latn', + 'Hungarian': 'hun_Latn', + 'Armenian': 'hye_Armn', + 'Igbo': 'ibo_Latn', + 'Ilocano': 'ilo_Latn', + 'Indonesian': 'ind_Latn', + 'Icelandic': 'isl_Latn', + 'Italian': 'ita_Latn', + 'Javanese': 'jav_Latn', + 'Japanese': 'jpn_Jpan', + 'Kabyle': 'kab_Latn', + 'Jingpho': 'kac_Latn', + 'Kamba': 'kam_Latn', + 'Kannada': 'kan_Knda', + 'Kashmiri (Arabic script)': 'kas_Arab', + 'Kashmiri (Devanagari script)': 'kas_Deva', + 'Georgian': 'kat_Geor', + 'Central Kanuri (Arabic script)': 'knc_Arab', + 'Central Kanuri (Latin script)': 'knc_Latn', + 'Kazakh': 'kaz_Cyrl', + 'Kabiyè': 'kbp_Latn', + 'Kabuverdianu': 'kea_Latn', + 'Khmer': 'khm_Khmr', + 'Kikuyu': 'kik_Latn', + 'Kinyarwanda': 'kin_Latn', 'Kyrgyz': 'kir_Cyrl', 'Kimbundu': 'kmb_Latn', + 'Northern Kurdish': 'kmr_Latn', 'Kikongo': 'kon_Latn', + 'Korean': 'kor_Hang', 'Lao': 'lao_Laoo', 'Ligurian': 'lij_Latn', + 'Limburgish': 'lim_Latn', 'Lingala': 'lin_Latn', 'Lithuanian': 'lit_Latn', 'Lombard': 'lmo_Latn', + 'Latgalian': 'ltg_Latn', 'Luxembourgish': 'ltz_Latn', 'Luba-Kasai': 'lua_Latn', 'Ganda': 'lug_Latn', + 'Luo': 'luo_Latn', 'Mizo': 'lus_Latn', 'Standard Latvian': 'lvs_Latn', 'Magahi': 'mag_Deva', + 'Maithili': 'mai_Deva', 'Malayalam': 'mal_Mlym', 'Marathi': 'mar_Deva', + 'Minangkabau (Arabic script)': 'min_Arab', 'Minangkabau (Latin script)': 'min_Latn', + 'Macedonian': 'mkd_Cyrl', 'Plateau Malagasy': 'plt_Latn', 'Maltese': 'mlt_Latn', + 'Meitei (Bengali script)': 'mni_Beng', 'Halh Mongolian': 'khk_Cyrl', 'Mossi': 'mos_Latn', + 'Maori': 'mri_Latn', 'Burmese': 'mya_Mymr', 'Dutch': 'nld_Latn', 'Norwegian Nynorsk': 'nno_Latn', + 'Norwegian Bokmål': 'nob_Latn', 'Nepali': 'npi_Deva', 'Northern Sotho': 'nso_Latn', + 'Nuer': 'nus_Latn', + 'Nyanja': 'nya_Latn', 'Occitan': 'oci_Latn', 'West Central Oromo': 'gaz_Latn', 'Odia': 'ory_Orya', + 'Pangasinan': 'pag_Latn', 'Eastern Panjabi': 'pan_Guru', 'Papiamento': 'pap_Latn', + 'Western Persian': 'pes_Arab', + 'Polish': 'pol_Latn', 'Portuguese': 'por_Latn', 'Dari': 'prs_Arab', 'Southern Pashto': 'pbt_Arab', + 'Ayacucho Quechua': 'quy_Latn', 'Romanian': 'ron_Latn', 'Rundi': 'run_Latn', 'Russian': 'rus_Cyrl', + 'Sango': 'sag_Latn', 'Sanskrit': 'san_Deva', 'Santali': 'sat_Olck', 'Sicilian': 'scn_Latn', + 'Shan': 'shn_Mymr', + 'Sinhala': 'sin_Sinh', 'Slovak': 'slk_Latn', 'Slovenian': 'slv_Latn', 'Samoan': 'smo_Latn', + 'Shona': 'sna_Latn', + 'Sindhi': 'snd_Arab', 'Somali': 'som_Latn', 'Southern Sotho': 'sot_Latn', 'Spanish': 'spa_Latn', + 'Tosk Albanian': 'als_Latn', 'Sardinian': 'srd_Latn', 'Serbian': 'srp_Cyrl', 'Swati': 'ssw_Latn', + 'Sundanese': 'sun_Latn', 'Swedish': 'swe_Latn', 'Swahili': 'swh_Latn', 'Silesian': 'szl_Latn', + 'Tamil': 'tam_Taml', 'Tatar': 'tat_Cyrl', 'Telugu': 'tel_Telu', 'Tajik': 'tgk_Cyrl', + 'Tagalog': 'tgl_Latn', + 'Thai': 'tha_Thai', 'Tigrinya': 'tir_Ethi', 'Tamasheq (Latin script)': 'taq_Latn', + 'Tamasheq (Tifinagh script)': 'taq_Tfng', + 'Tok Pisin': 'tpi_Latn', 'Tswana': 'tsn_Latn', 'Tsonga': 'tso_Latn', 'Turkmen': 'tuk_Latn', 'Tumbuka': 'tum_Latn', + 'Turkish': 'tur_Latn', 'Twi': 'twi_Latn', 'Central Atlas Tamazight': 'tzm_Tfng', + 'Uyghur': 'uig_Arab', + 'Ukrainian': 'ukr_Cyrl', 'Umbundu': 'umb_Latn', 'Urdu': 'urd_Arab', 'Northern Uzbek': 'uzn_Latn', + 'Venetian': 'vec_Latn', + 'Vietnamese': 'vie_Latn', 'Waray': 'war_Latn', 'Wolof': 'wol_Latn', 'Xhosa': 'xho_Latn', + 'Eastern Yiddish': 'ydd_Hebr', + 'Yoruba': 'yor_Latn', 'Yue Chinese': 'yue_Hant', 'Chinese (Simplified)': 'zho_Hans', + 'Chinese (Traditional)': 'zho_Hant', + 'Standard Malay': 'zsm_Latn', 'Zulu': 'zul_Latn'} + + self.model_name_dict = {'0.6B': 'facebook/nllb-200-distilled-600M', + '1.3B': 'facebook/nllb-200-distilled-1.3B', + '3.3B': 'facebook/nllb-200-3.3B', + } + + self.whisper_codes_to_flores_codes = {"de" : self.flores_codes['German'], + "en" : self.flores_codes['English'], + "pl" : self.flores_codes['Polish'], + "hi" : self.flores_codes['Hindi'] + } + + self.flores_codes_to_tts_codes = {'Acehnese': 'ace', 'Mesopotamian Arabic': 'acm', 'Ta’izzi-Adeni Arabic': 'acq', 'Tunisian Arabic': 'aeb', 'Afrikaans': 'afr', 'South Levantine Arabic': 'ajp', 'Akan': 'aka', 'Amharic': 'amh', 'North Levantine Arabic': 'apc', 'Modern Standard Arabic': 'arb', 'Najdi Arabic': 'ars', 'Moroccan Arabic': 'ary', 'Egyptian Arabic': 'arz', 'Assamese': 'asm', 'Asturian': 'ast', 'Awadhi': 'awa', 'Central Aymara': 'ayr', 'South Azerbaijani': 'azb', 'North Azerbaijani': 'azj', 'Bashkir': 'bak', 'Bambara': 'bam', 'Balinese': 'ban', 'Belarusian': 'bel', 'Bemba': 'bem', 'Bengali': 'ben', 'Bhojpuri': 'bho', 'Banjar': 'bjn', 'Standard Tibetan': 'bod', 'Bosnian': 'bos', 'Buginese': 'bug', 'Bulgarian': 'bul', 'Catalan': 'cat', 'Cebuano': 'ceb', 'Czech': 'ces', 'Chokwe': 'cjk', 'Central Kurdish': 'ckb', 'Crimean Tatar': 'crh', 'Welsh': 'cym', 'Danish': 'dan', 'German': 'deu', 'Southwestern Dinka': 'dik', 'Dyula': 'dyu', 'Dzongkha': 'dzo', 'Greek': 'ell', 'English': 'eng', 'Esperanto': 'epo', 'Estonian': 'est', 'Basque': 'eus', 'Ewe': 'ewe', 'Faroese': 'fao', 'Fijian': 'fij', 'Finnish': 'fin', 'Fon': 'fon', 'French': 'fra', 'Friulian': 'fur', 'Nigerian Fulfulde': 'fuv', 'Scottish Gaelic': 'gla', 'Irish': 'gle', 'Galician': 'glg', 'Guarani': 'grn', 'Gujarati': 'guj', 'Haitian Creole': 'hat', 'Hausa': 'hau', 'Hebrew': 'heb', 'Hindi': 'hin', 'Chhattisgarhi': 'hne', 'Croatian': 'hrv', 'Hungarian': 'hun', 'Armenian': 'hye', 'Igbo': 'ibo', 'Ilocano': 'ilo', 'Indonesian': 'ind', 'Icelandic': 'isl', 'Italian': 'ita', 'Javanese': 'jav', 'Japanese': 'jpn', 'Kabyle': 'kab', 'Jingpho': 'kac', 'Kamba': 'kam', 'Kannada': 'kan', 'Kashmiri': 'kas', 'Georgian': 'kat', 'Central Kanuri': 'knc', 'Kazakh': 'kaz', 'Kabiyè': 'kbp', 'Kabuverdianu': 'kea', 'Khmer': 'khm', 'Kikuyu': 'kik', 'Kinyarwanda': 'kin', 'Kyrgyz': 'kir', 'Kimbundu': 'kmb', 'Northern Kurdish': 'kmr', 'Kikongo': 'kon', 'Korean': 'kor', 'Lao': 'lao', 'Ligurian': 'lij', 'Limburgish': 'lim', 'Lingala': 'lin', 'Lithuanian': 'lit', 'Lombard': 'lmo', 'Latgalian': 'ltg', 'Luxembourgish': 'ltz', 'Luba-Kasai': 'lua', 'Ganda': 'lug', 'Luo': 'luo', 'Mizo': 'lus', 'Standard Latvian': 'lvs', 'Magahi': 'mag', 'Maithili': 'mai', 'Malayalam': 'mal', 'Marathi': 'mar', 'Minangkabau': 'min', 'Macedonian': 'mkd', 'Plateau Malagasy': 'plt', 'Maltese': 'mlt', 'Meitei': 'mni', 'Halh Mongolian': 'khk', 'Mossi': 'mos', 'Maori': 'mri', 'Burmese': 'mya', 'Dutch': 'nld', 'Norwegian Nynorsk': 'nno', 'Norwegian Bokmål': 'nob', 'Nepali': 'npi', 'Northern Sotho': 'nso', 'Nuer': 'nus', 'Nyanja': 'nya', 'Occitan': 'oci', 'West Central Oromo': 'gaz', 'Odia': 'ory', 'Pangasinan': 'pag', 'Eastern Panjabi': 'pan', 'Papiamento': 'pap', 'Western Persian': 'pes', 'Polish': 'pol', 'Portuguese': 'por', 'Dari': 'prs', 'Southern Pashto': 'pbt', 'Ayacucho Quechua': 'quy', 'Romanian': 'ron', 'Rundi': 'run', 'Russian': 'rus', 'Sango': 'sag', 'Sanskrit': 'san', 'Santali': 'sat', 'Sicilian': 'scn', 'Shan': 'shn', 'Sinhala': 'sin', 'Slovak': 'slk', 'Slovenian': 'slv', 'Samoan': 'smo', 'Shona': 'sna', 'Sindhi': 'snd', 'Somali': 'som', 'Southern Sotho': 'sot', 'Spanish': 'spa', 'Tosk Albanian': 'als', 'Sardinian': 'srd', 'Serbian': 'srp', 'Swati': 'ssw', 'Sundanese': 'sun', 'Swedish': 'swe', 'Swahili': 'swh', 'Silesian': 'szl', 'Tamil': 'tam', 'Tatar': 'tat', 'Telugu': 'tel', 'Tajik': 'tgk', 'Tagalog': 'tgl', 'Thai': 'tha', 'Tigrinya': 'tir', 'Tamasheq': 'taq', 'Tok Pisin': 'tpi', 'Tswana': 'tsn', 'Tsonga': 'tso', 'Turkmen': 'tuk', 'Tumbuka': 'tum', 'Turkish': 'tur', 'Twi': 'twi', 'Central Atlas Tamazight': 'tzm', 'Uyghur': 'uig', 'Ukrainian': 'ukr', 'Umbundu': 'umb', 'Urdu': 'urd', 'Northern Uzbek': 'uzn', 'Venetian': 'vec', 'Vietnamese': 'vie', 'Waray': 'war', 'Wolof': 'wol', 'Xhosa': 'xho', 'Eastern Yiddish': 'ydd', 'Yoruba': 'yor', 'Yue Chinese': 'yue', 'Chinese': 'zho', 'Standard Malay': 'zsm', 'Zulu': 'zul'} + + self.language_directory = 'Languages' + self.uroman_directory = 'aux_files' + + self.language_download_web = 'https://dl.fbaipublicfiles.com/mms/tts' + self.language_vocab_text = "vocab.txt" + self.language_vocab_configuration = "config.json" + self.language_vocab_model = "G_100000.pth" + + # creating the audio files temporary + # --------------------------------------- + self.temp_audio_folder = 'Temp_Audios' + + self.text2speech_wavfile = f'{self.temp_audio_folder}/text2speech.wav' + self.enhanced_speech_file = f"{self.temp_audio_folder}/enhanced.mp3" + self.input_speech_file = f'{self.temp_audio_folder}/output.wav' + + + try: + os.makedirs(self.language_directory) + except: + pass + + try: + os.makedirs(self.temp_audio_folder) + except: + pass diff --git a/configurations/get_hyperparameters.py b/configurations/get_hyperparameters.py new file mode 100644 index 0000000000000000000000000000000000000000..5ceb76ac752cb1165ccd1c9f1aaa86991fc1aa16 --- /dev/null +++ b/configurations/get_hyperparameters.py @@ -0,0 +1,19 @@ +import torch + +class hyperparameterConfig: + def __init__(self): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + self.stt_model = "large-v2" + self.nllb_model = '1.3B' + + # text to speech model + self.text2speech_model = 'model_weights/voiceover/freevc-24.pth' + self.text2speech_config = 'model_weights/voiceover/freevc-24.json' + self.text2speech_encoder = 'model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt' + + # voice enhancing model + self.voice_enhacing_model = 'model_weights/voice_enhance' + + # loading the wavlm model + self.wavlm_model = 'model_weights/wavlm_models/WavLM-Large.pt' \ No newline at end of file diff --git a/df/__init__.py b/df/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abcccd3f005af7532a9337f8cb6a9a7e410a6ff4 --- /dev/null +++ b/df/__init__.py @@ -0,0 +1,3 @@ +from .config import config + +__all__ = ["config"] diff --git a/df/checkpoint.py b/df/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..24b805234d988b34254b77b401ec842444b9bbe0 --- /dev/null +++ b/df/checkpoint.py @@ -0,0 +1,213 @@ +import glob +import os +import re +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from loguru import logger +from torch import nn + +from df.config import Csv, config +from df.model import init_model +from df.utils import check_finite_module +from libdf import DF + + +def get_epoch(cp) -> int: + return int(os.path.basename(cp).split(".")[0].split("_")[-1]) + + +def load_model( + cp_dir: Optional[str], + df_state: DF, + jit: bool = False, + mask_only: bool = False, + train_df_only: bool = False, + extension: str = "ckpt", + epoch: Union[str, int, None] = "latest", +) -> Tuple[nn.Module, int]: + if mask_only and train_df_only: + raise ValueError("Only one of `mask_only` `train_df_only` can be enabled") + model = init_model(df_state, run_df=mask_only is False, train_mask=train_df_only is False) + if jit: + model = torch.jit.script(model) + blacklist: List[str] = config("CP_BLACKLIST", [], Csv(), save=False, section="train") # type: ignore + if cp_dir is not None: + epoch = read_cp( + model, "model", cp_dir, blacklist=blacklist, extension=extension, epoch=epoch + ) + epoch = 0 if epoch is None else epoch + else: + epoch = 0 + return model, epoch + + +def read_cp( + obj: Union[torch.optim.Optimizer, nn.Module], + name: str, + dirname: str, + epoch: Union[str, int, None] = "latest", + extension="ckpt", + blacklist=[], + log: bool = True, +): + checkpoints = [] + if isinstance(epoch, str): + assert epoch in ("best", "latest") + if epoch == "best": + checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}.best")) + if len(checkpoints) == 0: + logger.warning("Could not find `best` checkpoint. Checking for default...") + if len(checkpoints) == 0: + checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}")) + checkpoints += glob.glob(os.path.join(dirname, f"{name}*.{extension}.best")) + if len(checkpoints) == 0: + return None + if isinstance(epoch, int): + latest = next((x for x in checkpoints if get_epoch(x) == epoch), None) + if latest is None: + logger.error(f"Could not find checkpoint of epoch {epoch}") + exit(1) + else: + latest = max(checkpoints, key=get_epoch) + epoch = get_epoch(latest) + if log: + logger.info("Found checkpoint {} with epoch {}".format(latest, epoch)) + latest = torch.load(latest, map_location="cpu") + latest = {k.replace("clc", "df"): v for k, v in latest.items()} + if blacklist: + reg = re.compile("".join(f"({b})|" for b in blacklist)[:-1]) + len_before = len(latest) + latest = {k: v for k, v in latest.items() if reg.search(k) is None} + if len(latest) < len_before: + logger.info("Filtered checkpoint modules: {}".format(blacklist)) + if isinstance(obj, nn.Module): + while True: + try: + missing, unexpected = obj.load_state_dict(latest, strict=False) + except RuntimeError as e: + e_str = str(e) + logger.warning(e_str) + if "size mismatch" in e_str: + latest = {k: v for k, v in latest.items() if k not in e_str} + continue + raise e + break + for key in missing: + logger.warning(f"Missing key: '{key}'") + for key in unexpected: + if key.endswith(".h0"): + continue + logger.warning(f"Unexpected key: {key}") + return epoch + obj.load_state_dict(latest) + + +def write_cp( + obj: Union[torch.optim.Optimizer, nn.Module], + name: str, + dirname: str, + epoch: int, + extension="ckpt", + metric: Optional[float] = None, + cmp="min", +): + check_finite_module(obj) + n_keep = config("n_checkpoint_history", default=3, cast=int, section="train") + n_keep_best = config("n_best_checkpoint_history", default=5, cast=int, section="train") + if metric is not None: + assert cmp in ("min", "max") + metric = float(metric) # Make sure it is not an integer + # Each line contains a previous best with entries: (epoch, metric) + with open(os.path.join(dirname, ".best"), "a+") as prev_best_f: + prev_best_f.seek(0) # "a+" creates a file in read/write mode without truncating + lines = prev_best_f.readlines() + if len(lines) == 0: + prev_best = float("inf" if cmp == "min" else "-inf") + else: + prev_best = float(lines[-1].strip().split(" ")[1]) + cmp = "__lt__" if cmp == "min" else "__gt__" + if getattr(metric, cmp)(prev_best): + logger.info(f"Saving new best checkpoint at epoch {epoch} with metric: {metric}") + prev_best_f.seek(0, os.SEEK_END) + np.savetxt(prev_best_f, np.array([[float(epoch), metric]])) + cp_name = os.path.join(dirname, f"{name}_{epoch}.{extension}.best") + torch.save(obj.state_dict(), cp_name) + cleanup(name, dirname, extension + ".best", nkeep=n_keep_best) + cp_name = os.path.join(dirname, f"{name}_{epoch}.{extension}") + logger.info(f"Writing checkpoint {cp_name} with epoch {epoch}") + torch.save(obj.state_dict(), cp_name) + cleanup(name, dirname, extension, nkeep=n_keep) + + +def cleanup(name: str, dirname: str, extension: str, nkeep=5): + if nkeep < 0: + return + checkpoints = glob.glob(os.path.join(dirname, f"{name}*.{extension}")) + if len(checkpoints) == 0: + return + checkpoints = sorted(checkpoints, key=get_epoch, reverse=True) + for cp in checkpoints[nkeep:]: + logger.debug("Removing old checkpoint: {}".format(cp)) + os.remove(cp) + + +def check_patience( + dirname: str, max_patience: int, new_metric: float, cmp: str = "min", raise_: bool = True +): + cmp = "__lt__" if cmp == "min" else "__gt__" + new_metric = float(new_metric) # Make sure it is not an integer + prev_patience, prev_metric = read_patience(dirname) + if prev_patience is None or getattr(new_metric, cmp)(prev_metric): + # We have a better new_metric, reset patience + write_patience(dirname, 0, new_metric) + else: + # We don't have a better metric, decrement patience + new_patience = prev_patience + 1 + write_patience(dirname, new_patience, prev_metric) + if new_patience >= max_patience: + if raise_: + raise ValueError( + f"No improvements on validation metric ({new_metric}) for {max_patience} epochs. " + "Stopping." + ) + else: + return False + return True + + +def read_patience(dirname: str) -> Tuple[Optional[int], float]: + fn = os.path.join(dirname, ".patience") + if not os.path.isfile(fn): + return None, 0.0 + patience, metric = np.loadtxt(fn) + return int(patience), float(metric) + + +def write_patience(dirname: str, new_patience: int, metric: float): + return np.savetxt(os.path.join(dirname, ".patience"), [new_patience, metric]) + + +def test_check_patience(): + import tempfile + + with tempfile.TemporaryDirectory() as d: + check_patience(d, 3, 1.0) + check_patience(d, 3, 1.0) + check_patience(d, 3, 1.0) + assert check_patience(d, 3, 1.0, raise_=False) is False + + with tempfile.TemporaryDirectory() as d: + check_patience(d, 3, 1.0) + check_patience(d, 3, 0.9) + check_patience(d, 3, 1.0) + check_patience(d, 3, 1.0) + assert check_patience(d, 3, 1.0, raise_=False) is False + + with tempfile.TemporaryDirectory() as d: + check_patience(d, 3, 1.0, cmp="max") + check_patience(d, 3, 1.9, cmp="max") + check_patience(d, 3, 1.0, cmp="max") + check_patience(d, 3, 1.0, cmp="max") + assert check_patience(d, 3, 1.0, cmp="max", raise_=False) is False diff --git a/df/config.py b/df/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c785140d559532ac3afa69e139c4ac063a0517ef --- /dev/null +++ b/df/config.py @@ -0,0 +1,266 @@ +import os +import string +from configparser import ConfigParser +from shlex import shlex +from typing import Any, List, Optional, Tuple, Type, TypeVar, Union + +from loguru import logger + +T = TypeVar("T") + + +class DfParams: + def __init__(self): + # Sampling rate used for training + self.sr: int = config("SR", cast=int, default=48_000, section="DF") + # FFT size in samples + self.fft_size: int = config("FFT_SIZE", cast=int, default=960, section="DF") + # STFT Hop size in samples + self.hop_size: int = config("HOP_SIZE", cast=int, default=480, section="DF") + # Number of ERB bands + self.nb_erb: int = config("NB_ERB", cast=int, default=32, section="DF") + # Number of deep filtering bins; DF is applied from 0th to nb_df-th frequency bins + self.nb_df: int = config("NB_DF", cast=int, default=96, section="DF") + # Normalization decay factor; used for complex and erb features + self.norm_tau: float = config("NORM_TAU", 1, float, section="DF") + # Local SNR minimum value, ground truth will be truncated + self.lsnr_max: int = config("LSNR_MAX", 35, int, section="DF") + # Local SNR maximum value, ground truth will be truncated + self.lsnr_min: int = config("LSNR_MIN", -15, int, section="DF") + # Minimum number of frequency bins per ERB band + self.min_nb_freqs = config("MIN_NB_ERB_FREQS", 2, int, section="DF") + # Deep Filtering order + self.df_order: int = config("DF_ORDER", cast=int, default=5, section="DF") + # Deep Filtering look-ahead + self.df_lookahead: int = config("DF_LOOKAHEAD", cast=int, default=0, section="DF") + # Pad mode. By default, padding will be handled on the input side: + # - `input`, which pads the input features passed to the model + # - `output`, which pads the output spectrogram corresponding to `df_lookahead` + self.pad_mode: str = config("PAD_MODE", default="input_specf", section="DF") + + +class Config: + """Adopted from python-decouple""" + + DEFAULT_SECTION = "settings" + + def __init__(self): + self.parser: ConfigParser = None # type: ignore + self.path = "" + self.modified = False + self.allow_defaults = True + + def load( + self, path: Optional[str], config_must_exist=False, allow_defaults=True, allow_reload=False + ): + self.allow_defaults = allow_defaults + if self.parser is not None and not allow_reload: + raise ValueError("Config already loaded") + self.parser = ConfigParser() + self.path = path + if path is not None and os.path.isfile(path): + with open(path) as f: + self.parser.read_file(f) + else: + if config_must_exist: + raise ValueError(f"No config file found at '{path}'.") + if not self.parser.has_section(self.DEFAULT_SECTION): + self.parser.add_section(self.DEFAULT_SECTION) + self._fix_clc() + self._fix_df() + + def use_defaults(self): + self.load(path=None, config_must_exist=False) + + def save(self, path: str): + if not self.modified: + logger.debug("Config not modified. No need to overwrite on disk.") + return + if self.parser is None: + self.parser = ConfigParser() + for section in self.parser.sections(): + if len(self.parser[section]) == 0: + self.parser.remove_section(section) + with open(path, mode="w") as f: + self.parser.write(f) + + def tostr(self, value, cast): + if isinstance(cast, Csv) and isinstance(value, (tuple, list)): + return "".join(str(v) + cast.delimiter for v in value)[:-1] + return str(value) + + def set(self, option: str, value: T, cast: Type[T], section: Optional[str] = None) -> T: + section = self.DEFAULT_SECTION if section is None else section + section = section.lower() + if not self.parser.has_section(section): + self.parser.add_section(section) + if self.parser.has_option(section, option): + if value == self.cast(self.parser.get(section, option), cast): + return value + self.modified = True + self.parser.set(section, option, self.tostr(value, cast)) + return value + + def __call__( + self, + option: str, + default: Any = None, + cast: Type[T] = str, + save: bool = True, + section: Optional[str] = None, + ) -> T: + # Get value either from an ENV or from the .ini file + section = self.DEFAULT_SECTION if section is None else section + value = None + if self.parser is None: + raise ValueError("No configuration loaded") + if not self.parser.has_section(section.lower()): + self.parser.add_section(section.lower()) + if option in os.environ: + value = os.environ[option] + if save: + self.parser.set(section, option, self.tostr(value, cast)) + elif self.parser.has_option(section, option): + value = self.read_from_section(section, option, default, cast=cast, save=save) + elif self.parser.has_option(section.lower(), option): + value = self.read_from_section(section.lower(), option, default, cast=cast, save=save) + elif self.parser.has_option(self.DEFAULT_SECTION, option): + logger.warning( + f"Couldn't find option {option} in section {section}. " + "Falling back to default settings section." + ) + value = self.read_from_section(self.DEFAULT_SECTION, option, cast=cast, save=save) + elif default is None: + raise ValueError("Value {} not found.".format(option)) + elif not self.allow_defaults and save: + raise ValueError(f"Value '{option}' not found in config (defaults not allowed).") + else: + value = default + if save: + self.set(option, value, cast, section) + return self.cast(value, cast) + + def cast(self, value, cast): + # Do the casting to get the correct type + if cast is bool: + value = str(value).lower() + if value in {"true", "yes", "y", "on", "1"}: + return True # type: ignore + elif value in {"false", "no", "n", "off", "0"}: + return False # type: ignore + raise ValueError("Parse error") + return cast(value) + + def get(self, option: str, cast: Type[T] = str, section: Optional[str] = None) -> T: + section = self.DEFAULT_SECTION if section is None else section + if not self.parser.has_section(section): + raise KeyError(section) + if not self.parser.has_option(section, option): + raise KeyError(option) + return self.cast(self.parser.get(section, option), cast) + + def read_from_section( + self, section: str, option: str, default: Any = None, cast: Type = str, save: bool = True + ) -> str: + value = self.parser.get(section, option) + if not save: + # Set to default or remove to not read it at trainig start again + if default is None: + self.parser.remove_option(section, option) + elif not self.allow_defaults: + raise ValueError(f"Value '{option}' not found in config (defaults not allowed).") + else: + self.parser.set(section, option, self.tostr(default, cast)) + elif section.lower() != section: + self.parser.set(section.lower(), option, self.tostr(value, cast)) + self.parser.remove_option(section, option) + self.modified = True + return value + + def overwrite(self, section: str, option: str, value: Any): + if not self.parser.has_section(section): + return ValueError(f"Section not found: '{section}'") + if not self.parser.has_option(section, option): + return ValueError(f"Option not found '{option}' in section '{section}'") + self.modified = True + cast = type(value) + return self.parser.set(section, option, self.tostr(value, cast)) + + def _fix_df(self): + """Renaming of some groups/options for compatibility with old models.""" + if self.parser.has_section("deepfilternet") and self.parser.has_section("df"): + sec_deepfilternet = self.parser["deepfilternet"] + sec_df = self.parser["df"] + if "df_order" in sec_deepfilternet: + sec_df["df_order"] = sec_deepfilternet["df_order"] + del sec_deepfilternet["df_order"] + if "df_lookahead" in sec_deepfilternet: + sec_df["df_lookahead"] = sec_deepfilternet["df_lookahead"] + del sec_deepfilternet["df_lookahead"] + + def _fix_clc(self): + """Renaming of some groups/options for compatibility with old models.""" + if ( + not self.parser.has_section("deepfilternet") + and self.parser.has_section("train") + and self.parser.get("train", "model") == "convgru5" + ): + self.overwrite("train", "model", "deepfilternet") + self.parser.add_section("deepfilternet") + self.parser["deepfilternet"] = self.parser["convgru"] + del self.parser["convgru"] + if not self.parser.has_section("df") and self.parser.has_section("clc"): + self.parser["df"] = self.parser["clc"] + del self.parser["clc"] + for section in self.parser.sections(): + for k, v in self.parser[section].items(): + if "clc" in k.lower(): + self.parser.set(section, k.lower().replace("clc", "df"), v) + del self.parser[section][k] + + def __repr__(self): + msg = "" + for section in self.parser.sections(): + msg += f"{section}:\n" + for k, v in self.parser[section].items(): + msg += f" {k}: {v}\n" + return msg + + +config = Config() + + +class Csv(object): + """ + Produces a csv parser that return a list of transformed elements. From python-decouple. + """ + + def __init__( + self, cast: Type[T] = str, delimiter=",", strip=string.whitespace, post_process=list + ): + """ + Parameters: + cast -- callable that transforms the item just before it's added to the list. + delimiter -- string of delimiters chars passed to shlex. + strip -- string of non-relevant characters to be passed to str.strip after the split. + post_process -- callable to post process all casted values. Default is `list`. + """ + self.cast: Type[T] = cast + self.delimiter = delimiter + self.strip = strip + self.post_process = post_process + + def __call__(self, value: Union[str, Tuple[T], List[T]]) -> List[T]: + """The actual transformation""" + if isinstance(value, (tuple, list)): + # if default value is a list + value = "".join(str(v) + self.delimiter for v in value)[:-1] + + def transform(s): + return self.cast(s.strip(self.strip)) + + splitter = shlex(value, posix=True) + splitter.whitespace = self.delimiter + splitter.whitespace_split = True + + return self.post_process(transform(s) for s in splitter) diff --git a/df/deepfilternet2.py b/df/deepfilternet2.py new file mode 100644 index 0000000000000000000000000000000000000000..0e624923e61b143201ce604c25c4f465a753b690 --- /dev/null +++ b/df/deepfilternet2.py @@ -0,0 +1,453 @@ +from functools import partial +from typing import Final, List, Optional, Tuple, Union + +import torch +from loguru import logger +from torch import Tensor, nn + +from df.config import Csv, DfParams, config +from df.modules import ( + Conv2dNormAct, + ConvTranspose2dNormAct, + DfOp, + GroupedGRU, + GroupedLinear, + GroupedLinearEinsum, + Mask, + SqueezedGRU, + erb_fb, + get_device, +) +from df.multiframe import MF_METHODS, MultiFrameModule +from libdf import DF + + +class ModelParams(DfParams): + section = "deepfilternet" + + def __init__(self): + super().__init__() + self.conv_lookahead: int = config( + "CONV_LOOKAHEAD", cast=int, default=0, section=self.section + ) + self.conv_ch: int = config("CONV_CH", cast=int, default=16, section=self.section) + self.conv_depthwise: bool = config( + "CONV_DEPTHWISE", cast=bool, default=True, section=self.section + ) + self.convt_depthwise: bool = config( + "CONVT_DEPTHWISE", cast=bool, default=True, section=self.section + ) + self.conv_kernel: List[int] = config( + "CONV_KERNEL", cast=Csv(int), default=(1, 3), section=self.section # type: ignore + ) + self.conv_kernel_inp: List[int] = config( + "CONV_KERNEL_INP", cast=Csv(int), default=(3, 3), section=self.section # type: ignore + ) + self.emb_hidden_dim: int = config( + "EMB_HIDDEN_DIM", cast=int, default=256, section=self.section + ) + self.emb_num_layers: int = config( + "EMB_NUM_LAYERS", cast=int, default=2, section=self.section + ) + self.df_hidden_dim: int = config( + "DF_HIDDEN_DIM", cast=int, default=256, section=self.section + ) + self.df_gru_skip: str = config("DF_GRU_SKIP", default="none", section=self.section) + self.df_output_layer: str = config( + "DF_OUTPUT_LAYER", default="linear", section=self.section + ) + self.df_pathway_kernel_size_t: int = config( + "DF_PATHWAY_KERNEL_SIZE_T", cast=int, default=1, section=self.section + ) + self.enc_concat: bool = config("ENC_CONCAT", cast=bool, default=False, section=self.section) + self.df_num_layers: int = config("DF_NUM_LAYERS", cast=int, default=3, section=self.section) + self.df_n_iter: int = config("DF_N_ITER", cast=int, default=2, section=self.section) + self.gru_type: str = config("GRU_TYPE", default="grouped", section=self.section) + self.gru_groups: int = config("GRU_GROUPS", cast=int, default=1, section=self.section) + self.lin_groups: int = config("LINEAR_GROUPS", cast=int, default=1, section=self.section) + self.group_shuffle: bool = config( + "GROUP_SHUFFLE", cast=bool, default=True, section=self.section + ) + self.dfop_method: str = config("DFOP_METHOD", cast=str, default="df", section=self.section) + self.mask_pf: bool = config("MASK_PF", cast=bool, default=False, section=self.section) + + +def init_model(df_state: Optional[DF] = None, run_df: bool = True, train_mask: bool = True): + p = ModelParams() + if df_state is None: + df_state = DF(sr=p.sr, fft_size=p.fft_size, hop_size=p.hop_size, nb_bands=p.nb_erb) + erb = erb_fb(df_state.erb_widths(), p.sr, inverse=False) + erb_inverse = erb_fb(df_state.erb_widths(), p.sr, inverse=True) + model = DfNet(erb, erb_inverse, run_df, train_mask) + return model.to(device=get_device()) + + +class Add(nn.Module): + def forward(self, a, b): + return a + b + + +class Concat(nn.Module): + def forward(self, a, b): + return torch.cat((a, b), dim=-1) + + +class Encoder(nn.Module): + def __init__(self): + super().__init__() + p = ModelParams() + assert p.nb_erb % 4 == 0, "erb_bins should be divisible by 4" + + self.erb_conv0 = Conv2dNormAct( + 1, p.conv_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True + ) + conv_layer = partial( + Conv2dNormAct, + in_ch=p.conv_ch, + out_ch=p.conv_ch, + kernel_size=p.conv_kernel, + bias=False, + separable=True, + ) + self.erb_conv1 = conv_layer(fstride=2) + self.erb_conv2 = conv_layer(fstride=2) + self.erb_conv3 = conv_layer(fstride=1) + self.df_conv0 = Conv2dNormAct( + 2, p.conv_ch, kernel_size=p.conv_kernel_inp, bias=False, separable=True + ) + self.df_conv1 = conv_layer(fstride=2) + self.erb_bins = p.nb_erb + self.emb_in_dim = p.conv_ch * p.nb_erb // 4 + self.emb_out_dim = p.emb_hidden_dim + if p.gru_type == "grouped": + self.df_fc_emb = GroupedLinear( + p.conv_ch * p.nb_df // 2, self.emb_in_dim, groups=p.lin_groups + ) + else: + df_fc_emb = GroupedLinearEinsum( + p.conv_ch * p.nb_df // 2, self.emb_in_dim, groups=p.lin_groups + ) + self.df_fc_emb = nn.Sequential(df_fc_emb, nn.ReLU(inplace=True)) + if p.enc_concat: + self.emb_in_dim *= 2 + self.combine = Concat() + else: + self.combine = Add() + self.emb_out_dim = p.emb_hidden_dim + self.emb_n_layers = p.emb_num_layers + assert p.gru_type in ("grouped", "squeeze"), f"But got {p.gru_type}" + if p.gru_type == "grouped": + self.emb_gru = GroupedGRU( + self.emb_in_dim, + self.emb_out_dim, + num_layers=1, + batch_first=True, + groups=p.gru_groups, + shuffle=p.group_shuffle, + add_outputs=True, + ) + else: + self.emb_gru = SqueezedGRU( + self.emb_in_dim, + self.emb_out_dim, + num_layers=1, + batch_first=True, + linear_groups=p.lin_groups, + linear_act_layer=partial(nn.ReLU, inplace=True), + ) + self.lsnr_fc = nn.Sequential(nn.Linear(self.emb_out_dim, 1), nn.Sigmoid()) + self.lsnr_scale = p.lsnr_max - p.lsnr_min + self.lsnr_offset = p.lsnr_min + + def forward( + self, feat_erb: Tensor, feat_spec: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + # Encodes erb; erb should be in dB scale + normalized; Fe are number of erb bands. + # erb: [B, 1, T, Fe] + # spec: [B, 2, T, Fc] + # b, _, t, _ = feat_erb.shape + e0 = self.erb_conv0(feat_erb) # [B, C, T, F] + e1 = self.erb_conv1(e0) # [B, C*2, T, F/2] + e2 = self.erb_conv2(e1) # [B, C*4, T, F/4] + e3 = self.erb_conv3(e2) # [B, C*4, T, F/4] + c0 = self.df_conv0(feat_spec) # [B, C, T, Fc] + c1 = self.df_conv1(c0) # [B, C*2, T, Fc] + cemb = c1.permute(0, 2, 3, 1).flatten(2) # [B, T, -1] + cemb = self.df_fc_emb(cemb) # [T, B, C * F/4] + emb = e3.permute(0, 2, 3, 1).flatten(2) # [B, T, C * F/4] + emb = self.combine(emb, cemb) + emb, _ = self.emb_gru(emb) # [B, T, -1] + lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset + return e0, e1, e2, e3, emb, c0, lsnr + + +class ErbDecoder(nn.Module): + def __init__(self): + super().__init__() + p = ModelParams() + assert p.nb_erb % 8 == 0, "erb_bins should be divisible by 8" + + self.emb_out_dim = p.emb_hidden_dim + + if p.gru_type == "grouped": + self.emb_gru = GroupedGRU( + p.conv_ch * p.nb_erb // 4, # For compat + self.emb_out_dim, + num_layers=p.emb_num_layers - 1, + batch_first=True, + groups=p.gru_groups, + shuffle=p.group_shuffle, + add_outputs=True, + ) + # SqueezedGRU uses GroupedLinearEinsum, so let's use it here as well + fc_emb = GroupedLinear( + p.emb_hidden_dim, + p.conv_ch * p.nb_erb // 4, + groups=p.lin_groups, + shuffle=p.group_shuffle, + ) + self.fc_emb = nn.Sequential(fc_emb, nn.ReLU(inplace=True)) + else: + self.emb_gru = SqueezedGRU( + self.emb_out_dim, + self.emb_out_dim, + output_size=p.conv_ch * p.nb_erb // 4, + num_layers=p.emb_num_layers - 1, + batch_first=True, + gru_skip_op=nn.Identity, + linear_groups=p.lin_groups, + linear_act_layer=partial(nn.ReLU, inplace=True), + ) + self.fc_emb = nn.Identity() + tconv_layer = partial( + ConvTranspose2dNormAct, + kernel_size=p.conv_kernel, + bias=False, + separable=True, + ) + conv_layer = partial( + Conv2dNormAct, + bias=False, + separable=True, + ) + # convt: TransposedConvolution, convp: Pathway (encoder to decoder) convolutions + self.conv3p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1) + self.convt3 = conv_layer(p.conv_ch, p.conv_ch, kernel_size=p.conv_kernel) + self.conv2p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1) + self.convt2 = tconv_layer(p.conv_ch, p.conv_ch, fstride=2) + self.conv1p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1) + self.convt1 = tconv_layer(p.conv_ch, p.conv_ch, fstride=2) + self.conv0p = conv_layer(p.conv_ch, p.conv_ch, kernel_size=1) + self.conv0_out = conv_layer( + p.conv_ch, 1, kernel_size=p.conv_kernel, activation_layer=nn.Sigmoid + ) + + def forward(self, emb, e3, e2, e1, e0) -> Tensor: + # Estimates erb mask + b, _, t, f8 = e3.shape + emb, _ = self.emb_gru(emb) + emb = self.fc_emb(emb) + emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) # [B, C*8, T, F/8] + e3 = self.convt3(self.conv3p(e3) + emb) # [B, C*4, T, F/4] + e2 = self.convt2(self.conv2p(e2) + e3) # [B, C*2, T, F/2] + e1 = self.convt1(self.conv1p(e1) + e2) # [B, C, T, F] + m = self.conv0_out(self.conv0p(e0) + e1) # [B, 1, T, F] + return m + + +class DfOutputReshapeMF(nn.Module): + """Coefficients output reshape for multiframe/MultiFrameModule + + Requires input of shape B, C, T, F, 2. + """ + + def __init__(self, df_order: int, df_bins: int): + super().__init__() + self.df_order = df_order + self.df_bins = df_bins + + def forward(self, coefs: Tensor) -> Tensor: + # [B, T, F, O*2] -> [B, O, T, F, 2] + coefs = coefs.view(*coefs.shape[:-1], -1, 2) + coefs = coefs.permute(0, 3, 1, 2, 4) + return coefs + + +class DfDecoder(nn.Module): + def __init__(self, out_channels: int = -1): + super().__init__() + p = ModelParams() + layer_width = p.conv_ch + self.emb_dim = p.emb_hidden_dim + + self.df_n_hidden = p.df_hidden_dim + self.df_n_layers = p.df_num_layers + self.df_order = p.df_order + self.df_bins = p.nb_df + self.gru_groups = p.gru_groups + self.df_out_ch = out_channels if out_channels > 0 else p.df_order * 2 + + conv_layer = partial(Conv2dNormAct, separable=True, bias=False) + kt = p.df_pathway_kernel_size_t + self.df_convp = conv_layer(layer_width, self.df_out_ch, fstride=1, kernel_size=(kt, 1)) + if p.gru_type == "grouped": + self.df_gru = GroupedGRU( + p.emb_hidden_dim, + p.df_hidden_dim, + num_layers=self.df_n_layers, + batch_first=True, + groups=p.gru_groups, + shuffle=p.group_shuffle, + add_outputs=True, + ) + else: + self.df_gru = SqueezedGRU( + p.emb_hidden_dim, + p.df_hidden_dim, + num_layers=self.df_n_layers, + batch_first=True, + gru_skip_op=nn.Identity, + linear_act_layer=partial(nn.ReLU, inplace=True), + ) + p.df_gru_skip = p.df_gru_skip.lower() + assert p.df_gru_skip in ("none", "identity", "groupedlinear") + self.df_skip: Optional[nn.Module] + if p.df_gru_skip == "none": + self.df_skip = None + elif p.df_gru_skip == "identity": + assert p.emb_hidden_dim == p.df_hidden_dim, "Dimensions do not match" + self.df_skip = nn.Identity() + elif p.df_gru_skip == "groupedlinear": + self.df_skip = GroupedLinearEinsum( + p.emb_hidden_dim, p.df_hidden_dim, groups=p.lin_groups + ) + else: + raise NotImplementedError() + assert p.df_output_layer in ("linear", "groupedlinear") + self.df_out: nn.Module + out_dim = self.df_bins * self.df_out_ch + if p.df_output_layer == "linear": + df_out = nn.Linear(self.df_n_hidden, out_dim) + elif p.df_output_layer == "groupedlinear": + df_out = GroupedLinearEinsum(self.df_n_hidden, out_dim, groups=p.lin_groups) + else: + raise NotImplementedError + self.df_out = nn.Sequential(df_out, nn.Tanh()) + self.df_fc_a = nn.Sequential(nn.Linear(self.df_n_hidden, 1), nn.Sigmoid()) + self.out_transform = DfOutputReshapeMF(self.df_order, self.df_bins) + + def forward(self, emb: Tensor, c0: Tensor) -> Tuple[Tensor, Tensor]: + b, t, _ = emb.shape + c, _ = self.df_gru(emb) # [B, T, H], H: df_n_hidden + if self.df_skip is not None: + c += self.df_skip(emb) + c0 = self.df_convp(c0).permute(0, 2, 3, 1) # [B, T, F, O*2], channels_last + alpha = self.df_fc_a(c) # [B, T, 1] + c = self.df_out(c) # [B, T, F*O*2], O: df_order + c = c.view(b, t, self.df_bins, self.df_out_ch) + c0 # [B, T, F, O*2] + c = self.out_transform(c) + return c, alpha + + +class DfNet(nn.Module): + run_df: Final[bool] + pad_specf: Final[bool] + + def __init__( + self, + erb_fb: Tensor, + erb_inv_fb: Tensor, + run_df: bool = True, + train_mask: bool = True, + ): + super().__init__() + p = ModelParams() + layer_width = p.conv_ch + assert p.nb_erb % 8 == 0, "erb_bins should be divisible by 8" + self.df_lookahead = p.df_lookahead if p.pad_mode == "model" else 0 + self.nb_df = p.nb_df + self.freq_bins: int = p.fft_size // 2 + 1 + self.emb_dim: int = layer_width * p.nb_erb + self.erb_bins: int = p.nb_erb + if p.conv_lookahead > 0 and p.pad_mode.startswith("input"): + self.pad_feat = nn.ConstantPad2d((0, 0, -p.conv_lookahead, p.conv_lookahead), 0.0) + else: + self.pad_feat = nn.Identity() + self.pad_specf = p.pad_mode.endswith("specf") + if p.df_lookahead > 0 and self.pad_specf: + self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, -p.df_lookahead, p.df_lookahead), 0.0) + else: + self.pad_spec = nn.Identity() + if (p.conv_lookahead > 0 or p.df_lookahead > 0) and p.pad_mode.startswith("output"): + assert p.conv_lookahead == p.df_lookahead + pad = (0, 0, 0, 0, -p.conv_lookahead, p.conv_lookahead) + self.pad_out = nn.ConstantPad3d(pad, 0.0) + else: + self.pad_out = nn.Identity() + self.register_buffer("erb_fb", erb_fb) + self.enc = Encoder() + self.erb_dec = ErbDecoder() + self.mask = Mask(erb_inv_fb, post_filter=p.mask_pf) + + self.df_order = p.df_order + self.df_bins = p.nb_df + self.df_op: Union[DfOp, MultiFrameModule] + if p.dfop_method == "real_unfold": + raise ValueError("RealUnfold DF OP is now unsupported.") + assert p.df_output_layer != "linear", "Must be used with `groupedlinear`" + self.df_op = MF_METHODS[p.dfop_method]( + num_freqs=p.nb_df, frame_size=p.df_order, lookahead=self.df_lookahead + ) + n_ch_out = self.df_op.num_channels() + self.df_dec = DfDecoder(out_channels=n_ch_out) + + self.run_df = run_df + if not run_df: + logger.warning("Runing without DF") + self.train_mask = train_mask + assert p.df_n_iter == 1 + + def forward( + self, + spec: Tensor, + feat_erb: Tensor, + feat_spec: Tensor, # Not used, take spec modified by mask instead + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Forward method of DeepFilterNet2. + + Args: + spec (Tensor): Spectrum of shape [B, 1, T, F, 2] + feat_erb (Tensor): ERB features of shape [B, 1, T, E] + feat_spec (Tensor): Complex spectrogram features of shape [B, 1, T, F'] + + Returns: + spec (Tensor): Enhanced spectrum of shape [B, 1, T, F, 2] + m (Tensor): ERB mask estimate of shape [B, 1, T, E] + lsnr (Tensor): Local SNR estimate of shape [B, T, 1] + """ + feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2) + + feat_erb = self.pad_feat(feat_erb) + feat_spec = self.pad_feat(feat_spec) + e0, e1, e2, e3, emb, c0, lsnr = self.enc(feat_erb, feat_spec) + m = self.erb_dec(emb, e3, e2, e1, e0) + + m = self.pad_out(m.unsqueeze(-1)).squeeze(-1) + spec = self.mask(spec, m) + + if self.run_df: + df_coefs, df_alpha = self.df_dec(emb, c0) + df_coefs = self.pad_out(df_coefs) + + if self.pad_specf: + # Only pad the lower part of the spectrum. + spec_f = self.pad_spec(spec) + spec_f = self.df_op(spec_f, df_coefs) + spec[..., : self.nb_df, :] = spec_f[..., : self.nb_df, :] + else: + spec = self.pad_spec(spec) + spec = self.df_op(spec, df_coefs) + else: + df_alpha = torch.zeros(()) + + return spec, m, lsnr, df_alpha diff --git a/df/enhance.py b/df/enhance.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3a226b0e3a95b905a07da5328c467b0cc5644a --- /dev/null +++ b/df/enhance.py @@ -0,0 +1,333 @@ +import argparse +import os +import time +import warnings +from typing import Optional, Tuple, Union + +import torch +import torchaudio as ta +from loguru import logger +from numpy import ndarray +from torch import Tensor, nn +from torch.nn import functional as F +from torchaudio.backend.common import AudioMetaData + +import df +from df import config +from df.checkpoint import load_model as load_model_cp +from df.logger import init_logger, warn_once +from df.model import ModelParams +from df.modules import get_device +from df.utils import as_complex, as_real, get_norm_alpha, resample +from libdf import DF, erb, erb_norm, unit_norm + + +def main(args): + model, df_state, suffix = init_df( + args.model_base_dir, + post_filter=args.pf, + log_level=args.log_level, + config_allow_defaults=True, + epoch=args.epoch, + ) + if args.output_dir is None: + args.output_dir = "." + elif not os.path.isdir(args.output_dir): + os.mkdir(args.output_dir) + df_sr = ModelParams().sr + n_samples = len(args.noisy_audio_files) + for i, file in enumerate(args.noisy_audio_files): + progress = (i + 1) / n_samples * 100 + audio, meta = load_audio(file, df_sr) + t0 = time.time() + audio = enhance( + model, df_state, audio, pad=args.compensate_delay, atten_lim_db=args.atten_lim + ) + t1 = time.time() + t_audio = audio.shape[-1] / df_sr + t = t1 - t0 + rtf = t / t_audio + fn = os.path.basename(file) + p_str = f"{progress:2.0f}% | " if n_samples > 1 else "" + logger.info(f"{p_str}Enhanced noisy audio file '{fn}' in {t:.1f}s (RT factor: {rtf:.3f})") + audio = resample(audio, df_sr, meta.sample_rate) + save_audio( + file, audio, sr=meta.sample_rate, output_dir=args.output_dir, suffix=suffix, log=False + ) + + +def init_df( + model_base_dir: Optional[str] = None, + post_filter: bool = False, + log_level: str = "INFO", + log_file: Optional[str] = "enhance.log", + config_allow_defaults: bool = False, + epoch: Union[str, int, None] = "best", + default_model: str = "DeepFilterNet2", +) -> Tuple[nn.Module, DF, str]: + """Initializes and loads config, model and deep filtering state. + + Args: + model_base_dir (str): Path to the model directory containing checkpoint and config. If None, + load the pretrained DeepFilterNet2 model. + post_filter (bool): Enable post filter for some minor, extra noise reduction. + log_level (str): Control amount of logging. Defaults to `INFO`. + log_file (str): Optional log file name. None disables it. Defaults to `enhance.log`. + config_allow_defaults (bool): Whether to allow initializing new config values with defaults. + epoch (str): Checkpoint epoch to load. Options are `best`, `latest`, ``, and `none`. + `none` disables checkpoint loading. Defaults to `best`. + + Returns: + model (nn.Modules): Intialized model, moved to GPU if available. + df_state (DF): Deep filtering state for stft/istft/erb + suffix (str): Suffix based on the model name. This can be used for saving the enhanced + audio. + """ + try: + from icecream import ic, install + + ic.configureOutput(includeContext=True) + install() + except ImportError: + pass + use_default_model = False + if model_base_dir == "DeepFilterNet": + default_model = "DeepFilterNet" + use_default_model = True + elif model_base_dir == "DeepFilterNet2": + use_default_model = True + if model_base_dir is None or use_default_model: + use_default_model = True + model_base_dir = os.path.relpath( + os.path.join( + os.path.dirname(df.__file__), os.pardir, "pretrained_models", default_model + ) + ) + if not os.path.isdir(model_base_dir): + raise NotADirectoryError("Base directory not found at {}".format(model_base_dir)) + log_file = os.path.join(model_base_dir, log_file) if log_file is not None else None + init_logger(file=log_file, level=log_level, model=model_base_dir) + if use_default_model: + logger.info(f"Using {default_model} model at {model_base_dir}") + config.load( + os.path.join(model_base_dir, "config.ini"), + config_must_exist=True, + allow_defaults=config_allow_defaults, + allow_reload=True, + ) + if post_filter: + config.set("mask_pf", True, bool, ModelParams().section) + logger.info("Running with post-filter") + p = ModelParams() + df_state = DF( + sr=p.sr, + fft_size=p.fft_size, + hop_size=p.hop_size, + nb_bands=p.nb_erb, + min_nb_erb_freqs=p.min_nb_freqs, + ) + checkpoint_dir = os.path.join(model_base_dir, "checkpoints") + load_cp = epoch is not None and not (isinstance(epoch, str) and epoch.lower() == "none") + if not load_cp: + checkpoint_dir = None + try: + mask_only = config.get("mask_only", cast=bool, section="train") + except KeyError: + mask_only = False + model, epoch = load_model_cp(checkpoint_dir, df_state, epoch=epoch, mask_only=mask_only) + if (epoch is None or epoch == 0) and load_cp: + logger.error("Could not find a checkpoint") + exit(1) + logger.debug(f"Loaded checkpoint from epoch {epoch}") + model = model.to(get_device()) + # Set suffix to model name + suffix = os.path.basename(os.path.abspath(model_base_dir)) + if post_filter: + suffix += "_pf" + logger.info("Model loaded") + return model, df_state, suffix + + +def df_features(audio: Tensor, df: DF, nb_df: int, device=None) -> Tuple[Tensor, Tensor, Tensor]: + spec = df.analysis(audio.numpy()) # [C, Tf] -> [C, Tf, F] + a = get_norm_alpha(False) + erb_fb = df.erb_widths() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + erb_feat = torch.as_tensor(erb_norm(erb(spec, erb_fb), a)).unsqueeze(1) + spec_feat = as_real(torch.as_tensor(unit_norm(spec[..., :nb_df], a)).unsqueeze(1)) + spec = as_real(torch.as_tensor(spec).unsqueeze(1)) + if device is not None: + spec = spec.to(device) + erb_feat = erb_feat.to(device) + spec_feat = spec_feat.to(device) + return spec, erb_feat, spec_feat + + +def load_audio( + file: str, sr: Optional[int], verbose=True, **kwargs +) -> Tuple[Tensor, AudioMetaData]: + """Loads an audio file using torchaudio. + + Args: + file (str): Path to an audio file. + sr (int): Optionally resample audio to specified target sampling rate. + **kwargs: Passed to torchaudio.load(). Depends on the backend. The resample method + may be set via `method` which is passed to `resample()`. + + Returns: + audio (Tensor): Audio tensor of shape [C, T], if channels_first=True (default). + info (AudioMetaData): Meta data of the original audio file. Contains the original sr. + """ + ikwargs = {} + if "format" in kwargs: + ikwargs["format"] = kwargs["format"] + rkwargs = {} + if "method" in kwargs: + rkwargs["method"] = kwargs.pop("method") + info: AudioMetaData = ta.info(file, **ikwargs) + audio, orig_sr = ta.load(file, **kwargs) + if sr is not None and orig_sr != sr: + if verbose: + warn_once( + f"Audio sampling rate does not match model sampling rate ({orig_sr}, {sr}). " + "Resampling..." + ) + audio = resample(audio, orig_sr, sr, **rkwargs) + return audio, info + + +def save_audio( + file: str, + audio: Union[Tensor, ndarray], + sr: int, + output_dir: Optional[str] = None, + suffix: Optional[str] = None, + log: bool = False, + dtype=torch.int16, +): + outpath = file + if suffix is not None: + file, ext = os.path.splitext(file) + outpath = file + f"_{suffix}" + ext + if output_dir is not None: + outpath = os.path.join(output_dir, os.path.basename(outpath)) + if log: + logger.info(f"Saving audio file '{outpath}'") + audio = torch.as_tensor(audio) + if audio.ndim == 1: + audio.unsqueeze_(0) + if dtype == torch.int16 and audio.dtype != torch.int16: + audio = (audio * (1 << 15)).to(torch.int16) + if dtype == torch.float32 and audio.dtype != torch.float32: + audio = audio.to(torch.float32) / (1 << 15) + ta.save(outpath, audio, sr) + + +@torch.no_grad() +def enhance( + model: nn.Module, df_state: DF, audio: Tensor, pad=False, atten_lim_db: Optional[float] = None +): + model.eval() + bs = audio.shape[0] + if hasattr(model, "reset_h0"): + model.reset_h0(batch_size=bs, device=get_device()) + orig_len = audio.shape[-1] + n_fft, hop = 0, 0 + if pad: + n_fft, hop = df_state.fft_size(), df_state.hop_size() + # Pad audio to compensate for the delay due to the real-time STFT implementation + audio = F.pad(audio, (0, n_fft)) + nb_df = getattr(model, "nb_df", getattr(model, "df_bins", ModelParams().nb_df)) + spec, erb_feat, spec_feat = df_features(audio, df_state, nb_df, device=get_device()) + enhanced = model(spec, erb_feat, spec_feat)[0].cpu() + enhanced = as_complex(enhanced.squeeze(1)) + if atten_lim_db is not None and abs(atten_lim_db) > 0: + lim = 10 ** (-abs(atten_lim_db) / 20) + enhanced = as_complex(spec.squeeze(1)) * lim + enhanced * (1 - lim) + audio = torch.as_tensor(df_state.synthesis(enhanced.numpy())) + if pad: + # The frame size is equal to p.hop_size. Given a new frame, the STFT loop requires e.g. + # ceil((n_fft-hop)/hop). I.e. for 50% overlap, then hop=n_fft//2 + # requires 1 additional frame lookahead; 75% requires 3 additional frames lookahead. + # Thus, the STFT/ISTFT loop introduces an algorithmic delay of n_fft - hop. + assert n_fft % hop == 0 # This is only tested for 50% and 75% overlap + d = n_fft - hop + audio = audio[:, d : orig_len + d] + return audio + + +def parse_epoch_type(value: str) -> Union[int, str]: + try: + return int(value) + except ValueError: + assert value in ("best", "latest") + return value + + +def setup_df_argument_parser(default_log_level: str = "INFO") -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-base-dir", + "-m", + type=str, + default=None, + help="Model directory containing checkpoints and config. " + "To load a pretrained model, you may just provide the model name, e.g. `DeepFilterNet`. " + "By default, the pretrained DeepFilterNet2 model is loaded.", + ) + parser.add_argument( + "--pf", + help="Post-filter that slightly over-attenuates very noisy sections.", + action="store_true", + ) + parser.add_argument( + "--output-dir", + "-o", + type=str, + default=None, + help="Directory in which the enhanced audio files will be stored.", + ) + parser.add_argument( + "--log-level", + type=str, + default=default_log_level, + help="Logger verbosity. Can be one of (debug, info, error, none)", + ) + parser.add_argument("--debug", "-d", action="store_const", const="DEBUG", dest="log_level") + parser.add_argument( + "--epoch", + "-e", + default="best", + type=parse_epoch_type, + help="Epoch for checkpoint loading. Can be one of ['best', 'latest', ].", + ) + return parser + + +def run(): + parser = setup_df_argument_parser() + parser.add_argument( + "--compensate-delay", + "-D", + action="store_true", + help="Add some paddig to compensate the delay introduced by the real-time STFT/ISTFT implementation.", + ) + parser.add_argument( + "--atten-lim", + "-a", + type=int, + default=None, + help="Attenuation limit in dB by mixing the enhanced signal with the noisy signal.", + ) + parser.add_argument( + "noisy_audio_files", + type=str, + nargs="+", + help="List of noise files to mix with the clean speech file.", + ) + main(parser.parse_args()) + + +if __name__ == "__main__": + run() diff --git a/df/logger.py b/df/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..db917cd9317383c8d929219bb8b7ca7b90e74ef4 --- /dev/null +++ b/df/logger.py @@ -0,0 +1,212 @@ +import os +import sys +import warnings +from collections import defaultdict +from copy import deepcopy +from typing import Dict, Optional, Tuple + +import numpy as np +import torch +from loguru import logger +from torch.types import Number + +from df.modules import GroupedLinearEinsum +from df.utils import get_branch_name, get_commit_hash, get_device, get_host + +_logger_initialized = False +WARN_ONCE_NO = logger.level("WARNING").no + 1 +DEPRECATED_NO = logger.level("WARNING").no + 2 + + +def init_logger(file: Optional[str] = None, level: str = "INFO", model: Optional[str] = None): + global _logger_initialized, _duplicate_filter + if _logger_initialized: + logger.debug("Logger already initialized.") + else: + logger.remove() + level = level.upper() + if level.lower() != "none": + log_format = Formatter(debug=logger.level(level).no <= logger.level("DEBUG").no).format + logger.add( + sys.stdout, + level=level, + format=log_format, + filter=lambda r: r["level"].no not in {WARN_ONCE_NO, DEPRECATED_NO}, + ) + if file is not None: + logger.add( + file, + level=level, + format=log_format, + filter=lambda r: r["level"].no != WARN_ONCE_NO, + ) + + logger.info(f"Running on torch {torch.__version__}") + logger.info(f"Running on host {get_host()}") + commit = get_commit_hash() + if commit is not None: + logger.info(f"Git commit: {commit}, branch: {get_branch_name()}") + if (jobid := os.getenv("SLURM_JOB_ID")) is not None: + logger.info(f"Slurm jobid: {jobid}") + logger.level("WARNONCE", no=WARN_ONCE_NO, color="") + logger.add( + sys.stderr, + level=max(logger.level(level).no, WARN_ONCE_NO), + format=log_format, + filter=lambda r: r["level"].no == WARN_ONCE_NO and _duplicate_filter(r), + ) + logger.level("DEPRECATED", no=DEPRECATED_NO, color="") + logger.add( + sys.stderr, + level=max(logger.level(level).no, DEPRECATED_NO), + format=log_format, + filter=lambda r: r["level"].no == DEPRECATED_NO and _duplicate_filter(r), + ) + if model is not None: + logger.info("Loading model settings of {}", os.path.basename(model.rstrip("/"))) + _logger_initialized = True + + +def warn_once(message, *args, **kwargs): + logger.log("WARNONCE", message, *args, **kwargs) + + +def log_deprecated(message, *args, **kwargs): + logger.log("DEPRECATED", message, *args, **kwargs) + + +class Formatter: + def __init__(self, debug=False): + if debug: + self.fmt = ( + "{time:YYYY-MM-DD HH:mm:ss}" + " | {level: <8}" + " | {name}:{function}:{line}" + " | {message}" + ) + else: + self.fmt = ( + "{time:YYYY-MM-DD HH:mm:ss}" + " | {level: <8}" + " | DF" + " | {message}" + ) + self.fmt += "\n{exception}" + + def format(self, record): + if record["level"].no == WARN_ONCE_NO: + return self.fmt.replace("{level: <8}", "WARNING ") + return self.fmt + + +def _metrics_key(k_: Tuple[str, float]): + k0 = k_[0] + ks = k0.split("_") + if len(ks) > 2: + try: + return int(ks[-1]) + except ValueError: + return 1000 + elif k0 == "loss": + return -999 + elif "loss" in k0.lower(): + return -998 + elif k0 == "lr": + return 998 + elif k0 == "wd": + return 999 + else: + return -101 + + +def log_metrics(prefix: str, metrics: Dict[str, Number], level="INFO"): + msg = "" + stages = defaultdict(str) + loss_msg = "" + for n, v in sorted(metrics.items(), key=_metrics_key): + if abs(v) > 1e-3: + m = f" | {n}: {v:.5f}" + else: + m = f" | {n}: {v:.3E}" + if "stage" in n: + s = n.split("stage_")[1].split("_snr")[0] + stages[s] += m.replace(f"stage_{s}_", "") + elif ("valid" in prefix or "test" in prefix) and "loss" in n.lower(): + loss_msg += m + else: + msg += m + for s, msg_s in stages.items(): + logger.log(level, f"{prefix} | stage {s}" + msg_s) + if len(stages) == 0: + logger.log(level, prefix + msg) + if len(loss_msg) > 0: + logger.log(level, prefix + loss_msg) + + +class DuplicateFilter: + """ + Filters away duplicate log messages. + Modified version of: https://stackoverflow.com/a/60462619 + """ + + def __init__(self): + self.msgs = set() + + def __call__(self, record) -> bool: + k = f"{record['level']}{record['message']}" + if k in self.msgs: + return False + else: + self.msgs.add(k) + return True + + +_duplicate_filter = DuplicateFilter() + + +def log_model_summary(model: torch.nn.Module, verbose=False): + try: + import ptflops + except ImportError: + logger.debug("Failed to import ptflops. Cannot print model summary.") + return + + from df.model import ModelParams + + # Generate input of 1 second audio + # Necessary inputs are: + # spec: [B, 1, T, F, 2], F: freq bin + # feat_erb: [B, 1, T, E], E: ERB bands + # feat_spec: [B, 2, T, C*2], C: Complex features + p = ModelParams() + b = 1 + t = p.sr // p.hop_size + device = get_device() + spec = torch.randn([b, 1, t, p.fft_size // 2 + 1, 2]).to(device) + feat_erb = torch.randn([b, 1, t, p.nb_erb]).to(device) + feat_spec = torch.randn([b, 1, t, p.nb_df, 2]).to(device) + + warnings.filterwarnings("ignore", "RNN module weights", category=UserWarning, module="torch") + macs, params = ptflops.get_model_complexity_info( + deepcopy(model), + (t,), + input_constructor=lambda _: {"spec": spec, "feat_erb": feat_erb, "feat_spec": feat_spec}, + as_strings=False, + print_per_layer_stat=verbose, + verbose=verbose, + custom_modules_hooks={ + GroupedLinearEinsum: grouped_linear_flops_counter_hook, + }, + ) + logger.info(f"Model complexity: {params/1e6:.3f}M #Params, {macs/1e6:.1f}M MACS") + + +def grouped_linear_flops_counter_hook(module: GroupedLinearEinsum, input, output): + # input: ([B, T, I],) + # output: [B, T, H] + input = input[0] # [B, T, I] + output_last_dim = module.weight.shape[-1] + input = input.unflatten(-1, (module.groups, module.ws)) # [B, T, G, I/G] + # GroupedLinear calculates "...gi,...gih->...gh" + weight_flops = np.prod(input.shape) * output_last_dim + module.__flops__ += int(weight_flops) # type: ignore diff --git a/df/model.py b/df/model.py new file mode 100644 index 0000000000000000000000000000000000000000..00328a21acd715bc89fc749d64a6c3343e45634f --- /dev/null +++ b/df/model.py @@ -0,0 +1,24 @@ +from importlib import import_module + +import torch +from loguru import logger + +from df.config import DfParams, config + + +class ModelParams(DfParams): + def __init__(self): + self.__model = config("MODEL", default="deepfilternet", section="train") + self.__params = getattr(import_module("df." + self.__model), "ModelParams")() + + def __getattr__(self, attr: str): + return getattr(self.__params, attr) + + +def init_model(*args, **kwargs): + """Initialize the model specified in the config.""" + model = config("MODEL", default="deepfilternet", section="train") + logger.info(f"Initializing model `{model}`") + model = getattr(import_module("df." + model), "init_model")(*args, **kwargs) + model.to(memory_format=torch.channels_last) + return model diff --git a/df/modules.py b/df/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..3cfa2e4ba7f69a5fdf104127ae7f107a65f2961b --- /dev/null +++ b/df/modules.py @@ -0,0 +1,956 @@ +import math +from collections import OrderedDict +from typing import Callable, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import Tensor, nn +from torch.nn import functional as F +from torch.nn import init +from torch.nn.parameter import Parameter +from typing_extensions import Final + +from df.model import ModelParams +from df.utils import as_complex, as_real, get_device, get_norm_alpha +from libdf import unit_norm_init + + +class Conv2dNormAct(nn.Sequential): + def __init__( + self, + in_ch: int, + out_ch: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + fpad: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + ): + """Causal Conv2d by delaying the signal for any lookahead. + + Expected input format: [B, C, T, F] + """ + lookahead = 0 # This needs to be handled on the input feature side + # Padding on time axis + kernel_size = ( + (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) + ) + if fpad: + fpad_ = kernel_size[1] // 2 + dilation - 1 + else: + fpad_ = 0 + pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) + layers = [] + if any(x > 0 for x in pad): + layers.append(nn.ConstantPad2d(pad, 0.0)) + groups = math.gcd(in_ch, out_ch) if separable else 1 + if groups == 1: + separable = False + if max(kernel_size) == 1: + separable = False + layers.append( + nn.Conv2d( + in_ch, + out_ch, + kernel_size=kernel_size, + padding=(0, fpad_), + stride=(1, fstride), # Stride over time is always 1 + dilation=(1, dilation), # Same for dilation + groups=groups, + bias=bias, + ) + ) + if separable: + layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False)) + if norm_layer is not None: + layers.append(norm_layer(out_ch)) + if activation_layer is not None: + layers.append(activation_layer()) + super().__init__(*layers) + + +class ConvTranspose2dNormAct(nn.Sequential): + def __init__( + self, + in_ch: int, + out_ch: int, + kernel_size: Union[int, Tuple[int, int]], + fstride: int = 1, + dilation: int = 1, + fpad: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + ): + """Causal ConvTranspose2d. + + Expected input format: [B, C, T, F] + """ + # Padding on time axis, with lookahead = 0 + lookahead = 0 # This needs to be handled on the input feature side + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + if fpad: + fpad_ = kernel_size[1] // 2 + else: + fpad_ = 0 + pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) + layers = [] + if any(x > 0 for x in pad): + layers.append(nn.ConstantPad2d(pad, 0.0)) + groups = math.gcd(in_ch, out_ch) if separable else 1 + if groups == 1: + separable = False + layers.append( + nn.ConvTranspose2d( + in_ch, + out_ch, + kernel_size=kernel_size, + padding=(kernel_size[0] - 1, fpad_ + dilation - 1), + output_padding=(0, fpad_), + stride=(1, fstride), # Stride over time is always 1 + dilation=(1, dilation), + groups=groups, + bias=bias, + ) + ) + if separable: + layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False)) + if norm_layer is not None: + layers.append(norm_layer(out_ch)) + if activation_layer is not None: + layers.append(activation_layer()) + super().__init__(*layers) + + +def convkxf( + in_ch: int, + out_ch: Optional[int] = None, + k: int = 1, + f: int = 3, + fstride: int = 2, + lookahead: int = 0, + batch_norm: bool = False, + act: nn.Module = nn.ReLU(inplace=True), + mode="normal", # must be "normal", "transposed" or "upsample" + depthwise: bool = True, + complex_in: bool = False, +): + bias = batch_norm is False + assert f % 2 == 1 + stride = 1 if f == 1 else (1, fstride) + if out_ch is None: + out_ch = in_ch * 2 if mode == "normal" else in_ch // 2 + fpad = (f - 1) // 2 + convpad = (0, fpad) + modules = [] + # Manually pad for time axis kernel to not introduce delay + pad = (0, 0, k - 1 - lookahead, lookahead) + if any(p > 0 for p in pad): + modules.append(("pad", nn.ConstantPad2d(pad, 0.0))) + if depthwise: + groups = min(in_ch, out_ch) + else: + groups = 1 + if in_ch % groups != 0 or out_ch % groups != 0: + groups = 1 + if complex_in and groups % 2 == 0: + groups //= 2 + convkwargs = { + "in_channels": in_ch, + "out_channels": out_ch, + "kernel_size": (k, f), + "stride": stride, + "groups": groups, + "bias": bias, + } + if mode == "normal": + modules.append(("sconv", nn.Conv2d(padding=convpad, **convkwargs))) + elif mode == "transposed": + # Since pytorch's transposed conv padding does not correspond to the actual padding but + # rather the padding that was used in the encoder conv, we need to set time axis padding + # according to k. E.g., this disables the padding for k=2: + # dilation - (k - 1) - padding + # = 1 - (2 - 1) - 1 = 0; => padding = fpad (=1 for k=2) + padding = (k - 1, fpad) + modules.append( + ("sconvt", nn.ConvTranspose2d(padding=padding, output_padding=convpad, **convkwargs)) + ) + elif mode == "upsample": + modules.append(("upsample", FreqUpsample(fstride))) + convkwargs["stride"] = 1 + modules.append(("sconv", nn.Conv2d(padding=convpad, **convkwargs))) + else: + raise NotImplementedError() + if groups > 1: + modules.append(("1x1conv", nn.Conv2d(out_ch, out_ch, 1, bias=False))) + if batch_norm: + modules.append(("norm", nn.BatchNorm2d(out_ch))) + modules.append(("act", act)) + return nn.Sequential(OrderedDict(modules)) + + +class FreqUpsample(nn.Module): + def __init__(self, factor: int, mode="nearest"): + super().__init__() + self.f = float(factor) + self.mode = mode + + def forward(self, x: Tensor) -> Tensor: + return F.interpolate(x, scale_factor=[1.0, self.f], mode=self.mode) + + +def erb_fb(widths: np.ndarray, sr: int, normalized: bool = True, inverse: bool = False) -> Tensor: + n_freqs = int(np.sum(widths)) + all_freqs = torch.linspace(0, sr // 2, n_freqs + 1)[:-1] + + b_pts = np.cumsum([0] + widths.tolist()).astype(int)[:-1] + + fb = torch.zeros((all_freqs.shape[0], b_pts.shape[0])) + for i, (b, w) in enumerate(zip(b_pts.tolist(), widths.tolist())): + fb[b : b + w, i] = 1 + # Normalize to constant energy per resulting band + if inverse: + fb = fb.t() + if not normalized: + fb /= fb.sum(dim=1, keepdim=True) + else: + if normalized: + fb /= fb.sum(dim=0) + return fb.to(device=get_device()) + + +class Mask(nn.Module): + def __init__(self, erb_inv_fb: Tensor, post_filter: bool = False, eps: float = 1e-12): + super().__init__() + self.erb_inv_fb: Tensor + self.register_buffer("erb_inv_fb", erb_inv_fb) + self.clamp_tensor = torch.__version__ > "1.9.0" or torch.__version__ == "1.9.0" + self.post_filter = post_filter + self.eps = eps + + def pf(self, mask: Tensor, beta: float = 0.02) -> Tensor: + """Post-Filter proposed by Valin et al. [1]. + + Args: + mask (Tensor): Real valued mask, typically of shape [B, C, T, F]. + beta: Global gain factor. + Refs: + [1]: Valin et al.: A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. + """ + mask_sin = mask * torch.sin(np.pi * mask / 2) + mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) + return mask_pf + + def forward(self, spec: Tensor, mask: Tensor, atten_lim: Optional[Tensor] = None) -> Tensor: + # spec (real) [B, 1, T, F, 2], F: freq_bins + # mask (real): [B, 1, T, Fe], Fe: erb_bins + # atten_lim: [B] + if not self.training and self.post_filter: + mask = self.pf(mask) + if atten_lim is not None: + # dB to amplitude + atten_lim = 10 ** (-atten_lim / 20) + # Greater equal (__ge__) not implemented for TorchVersion. + if self.clamp_tensor: + # Supported by torch >= 1.9 + mask = mask.clamp(min=atten_lim.view(-1, 1, 1, 1)) + else: + m_out = [] + for i in range(atten_lim.shape[0]): + m_out.append(mask[i].clamp_min(atten_lim[i].item())) + mask = torch.stack(m_out, dim=0) + mask = mask.matmul(self.erb_inv_fb) # [B, 1, T, F] + return spec * mask.unsqueeze(4) + + +class ExponentialUnitNorm(nn.Module): + """Unit norm for a complex spectrogram. + + This should match the rust code: + ```rust + for (x, s) in xs.iter_mut().zip(state.iter_mut()) { + *s = x.norm() * (1. - alpha) + *s * alpha; + *x /= s.sqrt(); + } + ``` + """ + + alpha: Final[float] + eps: Final[float] + + def __init__(self, alpha: float, num_freq_bins: int, eps: float = 1e-14): + super().__init__() + self.alpha = alpha + self.eps = eps + self.init_state: Tensor + s = torch.from_numpy(unit_norm_init(num_freq_bins)).view(1, 1, num_freq_bins, 1) + self.register_buffer("init_state", s) + + def forward(self, x: Tensor) -> Tensor: + # x: [B, C, T, F, 2] + b, c, t, f, _ = x.shape + x_abs = x.square().sum(dim=-1, keepdim=True).clamp_min(self.eps).sqrt() + state = self.init_state.clone().expand(b, c, f, 1) + out_states: List[Tensor] = [] + for t in range(t): + state = x_abs[:, :, t] * (1 - self.alpha) + state * self.alpha + out_states.append(state) + return x / torch.stack(out_states, 2).sqrt() + + +class DfOp(nn.Module): + df_order: Final[int] + df_bins: Final[int] + df_lookahead: Final[int] + freq_bins: Final[int] + + def __init__( + self, + df_bins: int, + df_order: int = 5, + df_lookahead: int = 0, + method: str = "complex_strided", + freq_bins: int = 0, + ): + super().__init__() + self.df_order = df_order + self.df_bins = df_bins + self.df_lookahead = df_lookahead + self.freq_bins = freq_bins + self.set_forward(method) + + def set_forward(self, method: str): + # All forward methods should be mathematically similar. + # DeepFilterNet results are obtained with 'real_unfold'. + forward_methods = { + "real_loop": self.forward_real_loop, + "real_strided": self.forward_real_strided, + "real_unfold": self.forward_real_unfold, + "complex_strided": self.forward_complex_strided, + "real_one_step": self.forward_real_no_pad_one_step, + "real_hidden_state_loop": self.forward_real_hidden_state_loop, + } + if method not in forward_methods.keys(): + raise NotImplementedError(f"`method` must be one of {forward_methods.keys()}") + if method == "real_hidden_state_loop": + assert self.freq_bins >= self.df_bins + self.spec_buf: Tensor + # Currently only designed for batch size of 1 + self.register_buffer( + "spec_buf", torch.zeros(1, 1, self.df_order, self.freq_bins, 2), persistent=False + ) + self.forward = forward_methods[method] + + def forward_real_loop( + self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None + ) -> Tensor: + # Version 0: Manual loop over df_order, maybe best for onnx export? + b, _, t, _, _ = spec.shape + f = self.df_bins + padded = spec_pad( + spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3 + ) + + spec_f = torch.zeros((b, t, f, 2), device=spec.device) + for i in range(self.df_order): + spec_f[..., 0] += padded[:, i : i + t, ..., 0] * coefs[:, :, i, :, 0] + spec_f[..., 0] -= padded[:, i : i + t, ..., 1] * coefs[:, :, i, :, 1] + spec_f[..., 1] += padded[:, i : i + t, ..., 1] * coefs[:, :, i, :, 0] + spec_f[..., 1] += padded[:, i : i + t, ..., 0] * coefs[:, :, i, :, 1] + return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha) + + def forward_real_strided( + self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None + ) -> Tensor: + # Version1: Use as_strided instead of unfold + # spec (real) [B, 1, T, F, 2], O: df_order + # coefs (real) [B, T, O, F, 2] + # alpha (real) [B, T, 1] + padded = as_strided( + spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3 + ) + # Complex numbers are not supported by onnx + re = padded[..., 0] * coefs[..., 0] + re -= padded[..., 1] * coefs[..., 1] + im = padded[..., 1] * coefs[..., 0] + im += padded[..., 0] * coefs[..., 1] + spec_f = torch.stack((re, im), -1).sum(2) + return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha) + + def forward_real_unfold( + self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None + ) -> Tensor: + # Version2: Unfold + # spec (real) [B, 1, T, F, 2], O: df_order + # coefs (real) [B, T, O, F, 2] + # alpha (real) [B, T, 1] + padded = spec_pad( + spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3 + ) + padded = padded.unfold(dimension=1, size=self.df_order, step=1) # [B, T, F, 2, O] + padded = padded.permute(0, 1, 4, 2, 3) + spec_f = torch.empty_like(padded) + spec_f[..., 0] = padded[..., 0] * coefs[..., 0] # re1 + spec_f[..., 0] -= padded[..., 1] * coefs[..., 1] # re2 + spec_f[..., 1] = padded[..., 1] * coefs[..., 0] # im1 + spec_f[..., 1] += padded[..., 0] * coefs[..., 1] # im2 + spec_f = spec_f.sum(dim=2) + return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha) + + def forward_complex_strided( + self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None + ) -> Tensor: + # Version3: Complex strided; definatly nicest, no permute, no indexing, but complex gradient + # spec (real) [B, 1, T, F, 2], O: df_order + # coefs (real) [B, T, O, F, 2] + # alpha (real) [B, T, 1] + padded = as_strided( + spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3 + ) + spec_f = torch.sum(torch.view_as_complex(padded) * torch.view_as_complex(coefs), dim=2) + spec_f = torch.view_as_real(spec_f) + return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha) + + def forward_real_no_pad_one_step( + self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None + ) -> Tensor: + # Version4: Only viable for onnx handling. `spec` needs external (ring-)buffer handling. + # Thus, time steps `t` must be equal to `df_order`. + + # spec (real) [B, 1, O, F', 2] + # coefs (real) [B, 1, O, F, 2] + assert ( + spec.shape[2] == self.df_order + ), "This forward method needs spectrogram buffer with `df_order` time steps as input" + assert coefs.shape[1] == 1, "This forward method is only valid for 1 time step" + sre, sim = spec[..., : self.df_bins, :].split(1, -1) + cre, cim = coefs.split(1, -1) + outr = torch.sum(sre * cre - sim * cim, dim=2).squeeze(-1) + outi = torch.sum(sre * cim + sim * cre, dim=2).squeeze(-1) + spec_f = torch.stack((outr, outi), dim=-1) + return assign_df( + spec[:, :, self.df_order - self.df_lookahead - 1], + spec_f.unsqueeze(1), + self.df_bins, + alpha, + ) + + def forward_real_hidden_state_loop(self, spec: Tensor, coefs: Tensor, alpha: Tensor) -> Tensor: + # Version5: Designed for onnx export. `spec` buffer handling is done via a torch buffer. + + # spec (real) [B, 1, T, F', 2] + # coefs (real) [B, T, O, F, 2] + b, _, t, _, _ = spec.shape + spec_out = torch.empty((b, 1, t, self.freq_bins, 2), device=spec.device) + for t in range(spec.shape[2]): + self.spec_buf = self.spec_buf.roll(-1, dims=2) + self.spec_buf[:, :, -1] = spec[:, :, t] + sre, sim = self.spec_buf[..., : self.df_bins, :].split(1, -1) + cre, cim = coefs[:, t : t + 1].split(1, -1) + outr = torch.sum(sre * cre - sim * cim, dim=2).squeeze(-1) + outi = torch.sum(sre * cim + sim * cre, dim=2).squeeze(-1) + spec_f = torch.stack((outr, outi), dim=-1) + spec_out[:, :, t] = assign_df( + self.spec_buf[:, :, self.df_order - self.df_lookahead - 1].unsqueeze(2), + spec_f.unsqueeze(1), + self.df_bins, + alpha[:, t], + ).squeeze(2) + return spec_out + + +def assign_df(spec: Tensor, spec_f: Tensor, df_bins: int, alpha: Optional[Tensor]): + spec_out = spec.clone() + if alpha is not None: + b = spec.shape[0] + alpha = alpha.view(b, 1, -1, 1, 1) + spec_out[..., :df_bins, :] = spec_f * alpha + spec[..., :df_bins, :] * (1 - alpha) + else: + spec_out[..., :df_bins, :] = spec_f + return spec_out + + +def spec_pad(x: Tensor, window_size: int, lookahead: int, dim: int = 0) -> Tensor: + pad = [0] * x.dim() * 2 + if dim >= 0: + pad[(x.dim() - dim - 1) * 2] = window_size - lookahead - 1 + pad[(x.dim() - dim - 1) * 2 + 1] = lookahead + else: + pad[(-dim - 1) * 2] = window_size - lookahead - 1 + pad[(-dim - 1) * 2 + 1] = lookahead + return F.pad(x, pad) + + +def as_strided(x: Tensor, window_size: int, lookahead: int, step: int = 1, dim: int = 0) -> Tensor: + shape = list(x.shape) + shape.insert(dim + 1, window_size) + x = spec_pad(x, window_size, lookahead, dim=dim) + # torch.fx workaround + step = 1 + stride = [x.stride(0), x.stride(1), x.stride(2), x.stride(3)] + stride.insert(dim, stride[dim] * step) + return torch.as_strided(x, shape, stride) + + +class GroupedGRULayer(nn.Module): + input_size: Final[int] + hidden_size: Final[int] + out_size: Final[int] + bidirectional: Final[bool] + num_directions: Final[int] + groups: Final[int] + batch_first: Final[bool] + + def __init__( + self, + input_size: int, + hidden_size: int, + groups: int, + batch_first: bool = True, + bias: bool = True, + dropout: float = 0, + bidirectional: bool = False, + ): + super().__init__() + assert input_size % groups == 0 + assert hidden_size % groups == 0 + kwargs = { + "bias": bias, + "batch_first": batch_first, + "dropout": dropout, + "bidirectional": bidirectional, + } + self.input_size = input_size // groups + self.hidden_size = hidden_size // groups + self.out_size = hidden_size + self.bidirectional = bidirectional + self.num_directions = 2 if bidirectional else 1 + self.groups = groups + self.batch_first = batch_first + assert (self.hidden_size % groups) == 0, "Hidden size must be divisible by groups" + self.layers = nn.ModuleList( + (nn.GRU(self.input_size, self.hidden_size, **kwargs) for _ in range(groups)) + ) + + def flatten_parameters(self): + for layer in self.layers: + layer.flatten_parameters() + + def get_h0(self, batch_size: int = 1, device: torch.device = torch.device("cpu")): + return torch.zeros( + self.groups * self.num_directions, + batch_size, + self.hidden_size, + device=device, + ) + + def forward(self, input: Tensor, h0: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: + # input shape: [B, T, I] if batch_first else [T, B, I], B: batch_size, I: input_size + # state shape: [G*D, B, H], where G: groups, D: num_directions, H: hidden_size + + if h0 is None: + dim0, dim1 = input.shape[:2] + bs = dim0 if self.batch_first else dim1 + h0 = self.get_h0(bs, device=input.device) + outputs: List[Tensor] = [] + outstates: List[Tensor] = [] + for i, layer in enumerate(self.layers): + o, s = layer( + input[..., i * self.input_size : (i + 1) * self.input_size], + h0[i * self.num_directions : (i + 1) * self.num_directions].detach(), + ) + outputs.append(o) + outstates.append(s) + output = torch.cat(outputs, dim=-1) + h = torch.cat(outstates, dim=0) + return output, h + + +class GroupedGRU(nn.Module): + groups: Final[int] + num_layers: Final[int] + batch_first: Final[bool] + hidden_size: Final[int] + bidirectional: Final[bool] + num_directions: Final[int] + shuffle: Final[bool] + add_outputs: Final[bool] + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + groups: int = 4, + bias: bool = True, + batch_first: bool = True, + dropout: float = 0, + bidirectional: bool = False, + shuffle: bool = True, + add_outputs: bool = False, + ): + super().__init__() + kwargs = { + "groups": groups, + "bias": bias, + "batch_first": batch_first, + "dropout": dropout, + "bidirectional": bidirectional, + } + assert input_size % groups == 0 + assert hidden_size % groups == 0 + assert num_layers > 0 + self.input_size = input_size + self.groups = groups + self.num_layers = num_layers + self.batch_first = batch_first + self.hidden_size = hidden_size // groups + self.bidirectional = bidirectional + self.num_directions = 2 if bidirectional else 1 + if groups == 1: + shuffle = False # Fully connected, no need to shuffle + self.shuffle = shuffle + self.add_outputs = add_outputs + self.grus: List[GroupedGRULayer] = nn.ModuleList() # type: ignore + self.grus.append(GroupedGRULayer(input_size, hidden_size, **kwargs)) + for _ in range(1, num_layers): + self.grus.append(GroupedGRULayer(hidden_size, hidden_size, **kwargs)) + self.flatten_parameters() + + def flatten_parameters(self): + for gru in self.grus: + gru.flatten_parameters() + + def get_h0(self, batch_size: int, device: torch.device = torch.device("cpu")) -> Tensor: + return torch.zeros( + (self.num_layers * self.groups * self.num_directions, batch_size, self.hidden_size), + device=device, + ) + + def forward(self, input: Tensor, state: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: + dim0, dim1, _ = input.shape + b = dim0 if self.batch_first else dim1 + if state is None: + state = self.get_h0(b, input.device) + output = torch.zeros( + dim0, dim1, self.hidden_size * self.num_directions * self.groups, device=input.device + ) + outstates = [] + h = self.groups * self.num_directions + for i, gru in enumerate(self.grus): + input, s = gru(input, state[i * h : (i + 1) * h]) + outstates.append(s) + if self.shuffle and i < self.num_layers - 1: + input = ( + input.view(dim0, dim1, -1, self.groups).transpose(2, 3).reshape(dim0, dim1, -1) + ) + if self.add_outputs: + output += input + else: + output = input + outstate = torch.cat(outstates, dim=0) + return output, outstate + + +class SqueezedGRU(nn.Module): + input_size: Final[int] + hidden_size: Final[int] + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: Optional[int] = None, + num_layers: int = 1, + linear_groups: int = 8, + batch_first: bool = True, + gru_skip_op: Optional[Callable[..., torch.nn.Module]] = None, + linear_act_layer: Callable[..., torch.nn.Module] = nn.Identity, + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.linear_in = nn.Sequential( + GroupedLinearEinsum(input_size, hidden_size, linear_groups), linear_act_layer() + ) + self.gru = nn.GRU(hidden_size, hidden_size, num_layers=num_layers, batch_first=batch_first) + self.gru_skip = gru_skip_op() if gru_skip_op is not None else None + if output_size is not None: + self.linear_out = nn.Sequential( + GroupedLinearEinsum(hidden_size, output_size, linear_groups), linear_act_layer() + ) + else: + self.linear_out = nn.Identity() + + def forward(self, input: Tensor, h=None) -> Tuple[Tensor, Tensor]: + input = self.linear_in(input) + x, h = self.gru(input, h) + if self.gru_skip is not None: + x = x + self.gru_skip(input) + x = self.linear_out(x) + return x, h + + +class GroupedLinearEinsum(nn.Module): + input_size: Final[int] + hidden_size: Final[int] + groups: Final[int] + + def __init__(self, input_size: int, hidden_size: int, groups: int = 1): + super().__init__() + # self.weight: Tensor + self.input_size = input_size + self.hidden_size = hidden_size + self.groups = groups + assert input_size % groups == 0 + self.ws = input_size // groups + self.register_parameter( + "weight", + Parameter( + torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True + ), + ) + self.reset_parameters() + + def reset_parameters(self): + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore + + def forward(self, x: Tensor) -> Tensor: + # x: [..., I] + x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] + x = torch.einsum("...gi,...gih->...gh", x, self.weight) # [..., G, H/G] + x = x.flatten(2, 3) # [B, T, H] + return x + + +class GroupedLinear(nn.Module): + input_size: Final[int] + hidden_size: Final[int] + groups: Final[int] + shuffle: Final[bool] + + def __init__(self, input_size: int, hidden_size: int, groups: int = 1, shuffle: bool = True): + super().__init__() + assert input_size % groups == 0 + assert hidden_size % groups == 0 + self.groups = groups + self.input_size = input_size // groups + self.hidden_size = hidden_size // groups + if groups == 1: + shuffle = False + self.shuffle = shuffle + self.layers = nn.ModuleList( + nn.Linear(self.input_size, self.hidden_size) for _ in range(groups) + ) + + def forward(self, x: Tensor) -> Tensor: + outputs: List[Tensor] = [] + for i, layer in enumerate(self.layers): + outputs.append(layer(x[..., i * self.input_size : (i + 1) * self.input_size])) + output = torch.cat(outputs, dim=-1) + if self.shuffle: + orig_shape = output.shape + output = ( + output.view(-1, self.hidden_size, self.groups).transpose(-1, -2).reshape(orig_shape) + ) + return output + + +class LocalSnrTarget(nn.Module): + def __init__( + self, ws: int = 20, db: bool = True, ws_ns: Optional[int] = None, target_snr_range=None + ): + super().__init__() + self.ws = self.calc_ws(ws) + self.ws_ns = self.ws * 2 if ws_ns is None else self.calc_ws(ws_ns) + self.db = db + self.range = target_snr_range + + def calc_ws(self, ws_ms: int) -> int: + # Calculates windows size in stft domain given a window size in ms + p = ModelParams() + ws = ws_ms - p.fft_size / p.sr * 1000 # length ms of an fft_window + ws = 1 + ws / (p.hop_size / p.sr * 1000) # consider hop_size + return max(int(round(ws)), 1) + + def forward(self, clean: Tensor, noise: Tensor, max_bin: Optional[int] = None) -> Tensor: + # clean: [B, 1, T, F] + # out: [B, T'] + if max_bin is not None: + clean = as_complex(clean[..., :max_bin]) + noise = as_complex(noise[..., :max_bin]) + return ( + local_snr(clean, noise, window_size=self.ws, db=self.db, window_size_ns=self.ws_ns)[0] + .clamp(self.range[0], self.range[1]) + .squeeze(1) + ) + + +def _local_energy(x: Tensor, ws: int, device: torch.device) -> Tensor: + if (ws % 2) == 0: + ws += 1 + ws_half = ws // 2 + x = F.pad(x.pow(2).sum(-1).sum(-1), (ws_half, ws_half, 0, 0)) + w = torch.hann_window(ws, device=device, dtype=x.dtype) + x = x.unfold(-1, size=ws, step=1) * w + return torch.sum(x, dim=-1).div(ws) + + +def local_snr( + clean: Tensor, + noise: Tensor, + window_size: int, + db: bool = False, + window_size_ns: Optional[int] = None, + eps: float = 1e-12, +) -> Tuple[Tensor, Tensor, Tensor]: + # clean shape: [B, C, T, F] + clean = as_real(clean) + noise = as_real(noise) + assert clean.dim() == 5 + + E_speech = _local_energy(clean, window_size, clean.device) + window_size_ns = window_size if window_size_ns is None else window_size_ns + E_noise = _local_energy(noise, window_size_ns, clean.device) + + snr = E_speech / E_noise.clamp_min(eps) + if db: + snr = snr.clamp_min(eps).log10().mul(10) + return snr, E_speech, E_noise + + +def test_grouped_gru(): + from icecream import ic + + g = 2 # groups + h = 4 # hidden_size + i = 2 # input_size + b = 1 # batch_size + t = 5 # time_steps + m = GroupedGRULayer(i, h, g, batch_first=True) + ic(m) + input = torch.randn((b, t, i)) + h0 = m.get_h0(b) + assert list(h0.shape) == [g, b, h // g] + out, hout = m(input, h0) + + # Should be exportable as raw nn.Module + torch.onnx.export( + m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 + ) + # Should be exportable as traced + m = torch.jit.trace(m, (input, h0)) + torch.onnx.export( + m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 + ) + # and as scripted module + m = torch.jit.script(m) + torch.onnx.export( + m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 + ) + + # now grouped gru + num = 2 + m = GroupedGRU(i, h, num, g, batch_first=True, shuffle=True) + ic(m) + h0 = m.get_h0(b) + assert list(h0.shape) == [num * g, b, h // g] + out, hout = m(input, h0) + + # Should be exportable as traced + m = torch.jit.trace(m, (input, h0)) + torch.onnx.export( + m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 + ) + # and scripted module + m = torch.jit.script(m) + torch.onnx.export( + m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 + ) + + +def test_erb(): + import libdf + from df.config import config + + config.use_defaults() + p = ModelParams() + n_freq = p.fft_size // 2 + 1 + df_state = libdf.DF(sr=p.sr, fft_size=p.fft_size, hop_size=p.hop_size, nb_bands=p.nb_erb) + erb = erb_fb(df_state.erb_widths(), p.sr) + erb_inverse = erb_fb(df_state.erb_widths(), p.sr, inverse=True) + input = torch.randn((1, 1, 1, n_freq), dtype=torch.complex64) + input_abs = input.abs().square() + erb_widths = df_state.erb_widths() + df_erb = torch.from_numpy(libdf.erb(input.numpy(), erb_widths, False)) + py_erb = torch.matmul(input_abs, erb) + assert torch.allclose(df_erb, py_erb) + df_out = torch.from_numpy(libdf.erb_inv(df_erb.numpy(), erb_widths)) + py_out = torch.matmul(py_erb, erb_inverse) + assert torch.allclose(df_out, py_out) + + +def test_unit_norm(): + from df.config import config + from libdf import unit_norm + + config.use_defaults() + p = ModelParams() + b = 2 + F = p.nb_df + t = 100 + spec = torch.randn(b, 1, t, F, 2) + alpha = get_norm_alpha(log=False) + # Expects complex input of shape [C, T, F] + norm_lib = torch.as_tensor(unit_norm(torch.view_as_complex(spec).squeeze(1).numpy(), alpha)) + m = ExponentialUnitNorm(alpha, F) + norm_torch = torch.view_as_complex(m(spec).squeeze(1)) + assert torch.allclose(norm_lib.real, norm_torch.real) + assert torch.allclose(norm_lib.imag, norm_torch.imag) + assert torch.allclose(norm_lib.abs(), norm_torch.abs()) + + +def test_dfop(): + from df.config import config + + config.use_defaults() + p = ModelParams() + f = p.nb_df + F = f * 2 + o = p.df_order + d = p.df_lookahead + t = 100 + spec = torch.randn(1, 1, t, F, 2) + coefs = torch.randn(1, t, o, f, 2) + alpha = torch.randn(1, t, 1) + dfop = DfOp(df_bins=p.nb_df) + dfop.set_forward("real_loop") + out1 = dfop(spec, coefs, alpha) + dfop.set_forward("real_strided") + out2 = dfop(spec, coefs, alpha) + dfop.set_forward("real_unfold") + out3 = dfop(spec, coefs, alpha) + dfop.set_forward("complex_strided") + out4 = dfop(spec, coefs, alpha) + torch.testing.assert_allclose(out1, out2) + torch.testing.assert_allclose(out1, out3) + torch.testing.assert_allclose(out1, out4) + # This forward method requires external padding/lookahead as well as spectrogram buffer + # handling, i.e. via a ring buffer. Could be used in real time usage. + dfop.set_forward("real_one_step") + spec_padded = spec_pad(spec, o, d, dim=-3) + out5 = torch.zeros_like(out1) + for i in range(t): + out5[:, :, i] = dfop( + spec_padded[:, :, i : i + o], coefs[:, i].unsqueeze(1), alpha[:, i].unsqueeze(1) + ) + torch.testing.assert_allclose(out1, out5) + # Forward method that does the padding/lookahead handling using an internal hidden state. + dfop.freq_bins = F + dfop.set_forward("real_hidden_state_loop") + out6 = dfop(spec, coefs, alpha) + torch.testing.assert_allclose(out1, out6) diff --git a/df/multiframe.py b/df/multiframe.py new file mode 100644 index 0000000000000000000000000000000000000000..a9e94d028030dfe49c434ab30e698618d2baa952 --- /dev/null +++ b/df/multiframe.py @@ -0,0 +1,329 @@ +from abc import ABC, abstractmethod +from typing import Dict, Final + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class MultiFrameModule(nn.Module, ABC): + """Multi-frame speech enhancement modules. + + Signal model and notation: + Noisy: `x = s + n` + Enhanced: `y = f(x)` + Objective: `min ||s - y||` + + PSD: Power spectral density, notated eg. as `Rxx` for noisy PSD. + IFC: Inter-frame correlation vector: PSD*u, u: selection vector. Notated as `rxx` + """ + + num_freqs: Final[int] + frame_size: Final[int] + need_unfold: Final[bool] + + def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0): + """Multi-Frame filtering module. + + Args: + num_freqs (int): Number of frequency bins used for filtering. + frame_size (int): Frame size in FD domain. + lookahead (int): Lookahead, may be used to select the output time step. Note: This + module does not add additional padding according to lookahead! + """ + super().__init__() + self.num_freqs = num_freqs + self.frame_size = frame_size + self.pad = nn.ConstantPad2d((0, 0, frame_size - 1, 0), 0.0) + self.need_unfold = frame_size > 1 + self.lookahead = lookahead + + def spec_unfold(self, spec: Tensor): + """Pads and unfolds the spectrogram according to frame_size. + + Args: + spec (complex Tensor): Spectrogram of shape [B, C, T, F] + Returns: + spec (Tensor): Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size. + """ + if self.need_unfold: + return self.pad(spec).unfold(2, self.frame_size, 1) + return spec.unsqueeze(-1) + + def forward(self, spec: Tensor, coefs: Tensor): + """Pads and unfolds the spectrogram and forwards to impl. + + Args: + spec (Tensor): Spectrogram of shape [B, C, T, F, 2] + coefs (Tensor): Spectrogram of shape [B, C, T, F, 2] + """ + spec_u = self.spec_unfold(torch.view_as_complex(spec)) + coefs = torch.view_as_complex(coefs) + spec_f = spec_u.narrow(-2, 0, self.num_freqs) + spec_f = self.forward_impl(spec_f, coefs) + if self.training: + spec = spec.clone() + spec[..., : self.num_freqs, :] = torch.view_as_real(spec_f) + return spec + + @abstractmethod + def forward_impl(self, spec: Tensor, coefs: Tensor) -> Tensor: + """Forward impl taking complex spectrogram and coefficients. + + Args: + spec (complex Tensor): Spectrogram of shape [B, C1, T, F, N] + coefs (complex Tensor): Coefficients [B, C2, T, F] + + Returns: + spec (complex Tensor): Enhanced spectrogram of shape [B, C1, T, F] + """ + ... + + @abstractmethod + def num_channels(self) -> int: + """Return the number of required channels. + + If multiple inputs are required, then all these should be combined in one Tensor containing + the summed channels. + """ + ... + + +def psd(x: Tensor, n: int) -> Tensor: + """Compute the PSD correlation matrix Rxx for a spectrogram. + + That is, `X*conj(X)`, where `*` is the outer product. + + Args: + x (complex Tensor): Spectrogram of shape [B, C, T, F]. Will be unfolded with `n` steps over + the time axis. + + Returns: + Rxx (complex Tensor): Correlation matrix of shape [B, C, T, F, N, N] + """ + x = F.pad(x, (0, 0, n - 1, 0)).unfold(-2, n, 1) + return torch.einsum("...n,...m->...mn", x, x.conj()) + + +def df(spec: Tensor, coefs: Tensor) -> Tensor: + """Deep filter implemenation using `torch.einsum`. Requires unfolded spectrogram. + + Args: + spec (complex Tensor): Spectrogram of shape [B, C, T, F, N] + coefs (complex Tensor): Spectrogram of shape [B, C, N, T, F] + + Returns: + spec (complex Tensor): Spectrogram of shape [B, C, T, F] + """ + return torch.einsum("...tfn,...ntf->...tf", spec, coefs) + + +class CRM(MultiFrameModule): + """Complex ratio mask.""" + + def __init__(self, num_freqs: int, frame_size: int = 1, lookahead: int = 0): + assert frame_size == 1 and lookahead == 0, (frame_size, lookahead) + super().__init__(num_freqs, 1) + + def forward_impl(self, spec: Tensor, coefs: Tensor): + return spec.squeeze(-1).mul(coefs) + + def num_channels(self): + return 2 + + +class DF(MultiFrameModule): + conj: Final[bool] + """Deep Filtering.""" + + def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, conj: bool = False): + super().__init__(num_freqs, frame_size, lookahead) + self.conj = conj + + def forward_impl(self, spec: Tensor, coefs: Tensor): + coefs = coefs.view(coefs.shape[0], -1, self.frame_size, *coefs.shape[2:]) + if self.conj: + coefs = coefs.conj() + return df(spec, coefs) + + def num_channels(self): + return self.frame_size * 2 + + +class MfWf(MultiFrameModule): + """Multi-frame Wiener filter base module.""" + + def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0): + """Multi-frame Wiener Filter. + + Several implementation methods are available resulting in different number of required input + coefficient channels. + + Methods: + psd_ifc: Predict PSD `Rxx` and IFC `rss`. + df: Use deep filtering to predict speech and noisy spectrograms. These will be used for + PSD calculation for Wiener filtering. Alias: `df_sx` + c: Directly predict Wiener filter coefficients. Computation same as deep filtering. + + """ + super().__init__(num_freqs, frame_size, lookahead=0) + self.idx = -lookahead + + def num_channels(self): + return self.num_channels + + @staticmethod + def solve(Rxx, rss, diag_eps: float = 1e-8, eps: float = 1e-7) -> Tensor: + return torch.einsum( + "...nm,...m->...n", torch.inverse(_tik_reg(Rxx, diag_eps, eps)), rss + ) # [T, F, N] + + @abstractmethod + def mfwf(self, spec: Tensor, coefs: Tensor) -> Tensor: + """Multi-frame Wiener filter impl taking complex spectrogram and coefficients. + + Coefficients may be split into multiple parts w.g. for multiple DF coefs or PSDs. + + Args: + spec (complex Tensor): Spectrogram of shape [B, C1, T, F, N] + coefs (complex Tensor): Coefficients [B, C2, T, F] + + Returns: + c (complex Tensor): MfWf coefs of shape [B, C1, T, F, N] + """ + ... + + def forward_impl(self, spec: Tensor, coefs: Tensor) -> Tensor: + coefs = self.mfwf(spec, coefs) + return self.apply_coefs(spec, coefs) + + @staticmethod + def apply_coefs(spec: Tensor, coefs: Tensor) -> Tensor: + # spec: [B, C, T, F, N] + # coefs: [B, C, T, F, N] + return torch.einsum("...n,...n->...", spec, coefs) + + +class MfWfDf(MfWf): + eps_diag: Final[float] + + def __init__( + self, + num_freqs: int, + frame_size: int, + lookahead: int = 0, + eps_diag: float = 1e-7, + eps: float = 1e-7, + ): + super().__init__(num_freqs, frame_size, lookahead) + self.eps_diag = eps_diag + self.eps = eps + + def num_channels(self): + # frame_size/df_order * 2 (x/s) * 2 (re/im) + return self.frame_size * 4 + + def mfwf(self, spec: Tensor, coefs: Tensor) -> Tensor: + coefs.chunk + df_s, df_x = torch.chunk(coefs, 2, 1) # [B, C, T, F, N] + df_s = df_s.unflatten(1, (-1, self.frame_size)) + df_x = df_x.unflatten(1, (-1, self.frame_size)) + spec_s = df(spec, df_s) # [B, C, T, F] + spec_x = df(spec, df_x) + Rss = psd(spec_s, self.frame_size) # [B, C, T, F, N. N] + Rxx = psd(spec_x, self.frame_size) + rss = Rss[..., -1] # TODO: use -1 or self.idx? + c = self.solve(Rxx, rss, self.eps_diag, self.eps) # [B, C, T, F, N] + return c + + +class MfWfPsd(MfWf): + """Multi-frame Wiener filter by predicting noisy PSD `Rxx` and speech IFC `rss`.""" + + def num_channels(self): + # (Rxx + rss) * 2 (re/im) + return (self.frame_size**2 + self.frame_size) * 2 + + def mfwf(self, spec: Tensor, coefs: Tensor) -> Tensor: # type: ignore + Rxx, rss = torch.split(coefs.movedim(1, -1), [self.frame_size**2, self.frame_size], -1) + c = self.solve(Rxx.unflatten(-1, (self.frame_size, self.frame_size)), rss) + return c + + +class MfWfC(MfWf): + """Multi-frame Wiener filter by directly predicting the MfWf coefficients.""" + + def num_channels(self): + # mfwf coefs * 2 (re/im) + return self.frame_size * 2 + + def mfwf(self, spec: Tensor, coefs: Tensor) -> Tensor: # type: ignore + coefs = coefs.unflatten(1, (-1, self.frame_size)).permute( + 0, 1, 3, 4, 2 + ) # [B, C*N, T, F] -> [B, C, T, F, N] + return coefs + + +class MvdrSouden(MultiFrameModule): + def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0): + super().__init__(num_freqs, frame_size, lookahead) + + +class MvdrEvd(MultiFrameModule): + def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0): + super().__init__(num_freqs, frame_size, lookahead) + + +class MvdrRtfPower(MultiFrameModule): + def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0): + super().__init__(num_freqs, frame_size, lookahead) + + +MF_METHODS: Dict[str, MultiFrameModule] = { + "crm": CRM, + "df": DF, + "mfwf_df": MfWfDf, + "mfwf_df_sx": MfWfDf, + "mfwf_psd": MfWfPsd, + "mfwf_psd_ifc": MfWfPsd, + "mfwf_c": MfWfC, +} + + +# From torchaudio +def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor: + r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions. + Args: + input (torch.Tensor): Tensor of dimension `(..., channel, channel)` + dim1 (int, optional): the first dimension of the diagonal matrix + (Default: -1) + dim2 (int, optional): the second dimension of the diagonal matrix + (Default: -2) + Returns: + Tensor: trace of the input Tensor + """ + assert input.ndim >= 2, "The dimension of the tensor must be at least 2." + assert ( + input.shape[dim1] == input.shape[dim2] + ), "The size of ``dim1`` and ``dim2`` must be the same." + input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2) + return input.sum(dim=-1) + + +def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor: + """Perform Tikhonov regularization (only modifying real part). + Args: + mat (torch.Tensor): input matrix (..., channel, channel) + reg (float, optional): regularization factor (Default: 1e-8) + eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``) + Returns: + Tensor: regularized matrix (..., channel, channel) + """ + # Add eps + C = mat.size(-1) + eye = torch.eye(C, dtype=mat.dtype, device=mat.device) + epsilon = _compute_mat_trace(mat).real[..., None, None] * reg + # in case that correlation_matrix is all-zero + epsilon = epsilon + eps + mat = mat + epsilon * eye[..., :, :] + return mat diff --git a/df/utils.py b/df/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a750457ea4a83e3b5abf9cc115952cf784478d5 --- /dev/null +++ b/df/utils.py @@ -0,0 +1,230 @@ +import collections +import math +import os +import random +import subprocess +from socket import gethostname +from typing import Any, Dict, Set, Tuple, Union + +import numpy as np +import torch +from loguru import logger +from torch import Tensor +#from torch._six import string_classes +from torch.autograd import Function +from torch.types import Number + +from df.config import config +from df.model import ModelParams + +try: + from torchaudio.functional import resample as ta_resample +except ImportError: + from torchaudio.compliance.kaldi import resample_waveform as ta_resample # type: ignore + + +def get_resample_params(method: str) -> Dict[str, Any]: + params = { + "sinc_fast": {"resampling_method": "sinc_interpolation", "lowpass_filter_width": 16}, + "sinc_best": {"resampling_method": "sinc_interpolation", "lowpass_filter_width": 64}, + "kaiser_fast": { + "resampling_method": "kaiser_window", + "lowpass_filter_width": 16, + "rolloff": 0.85, + "beta": 8.555504641634386, + }, + "kaiser_best": { + "resampling_method": "kaiser_window", + "lowpass_filter_width": 16, + "rolloff": 0.9475937167399596, + "beta": 14.769656459379492, + }, + } + assert method in params.keys(), f"method must be one of {list(params.keys())}" + return params[method] + + +def resample(audio: Tensor, orig_sr: int, new_sr: int, method="sinc_fast"): + params = get_resample_params(method) + return ta_resample(audio, orig_sr, new_sr, **params) + + +def get_device(): + s = config("DEVICE", default="", section="train") + if s == "": + if torch.cuda.is_available(): + DEVICE = torch.device("cuda:0") + else: + DEVICE = torch.device("cpu") + else: + DEVICE = torch.device(s) + return DEVICE + + +def as_complex(x: Tensor): + if torch.is_complex(x): + return x + if x.shape[-1] != 2: + raise ValueError(f"Last dimension need to be of length 2 (re + im), but got {x.shape}") + if x.stride(-1) != 1: + x = x.contiguous() + return torch.view_as_complex(x) + + +def as_real(x: Tensor): + if torch.is_complex(x): + return torch.view_as_real(x) + return x + + +class angle_re_im(Function): + """Similar to torch.angle but robustify the gradient for zero magnitude.""" + + @staticmethod + def forward(ctx, re: Tensor, im: Tensor): + ctx.save_for_backward(re, im) + return torch.atan2(im, re) + + @staticmethod + def backward(ctx, grad: Tensor) -> Tuple[Tensor, Tensor]: + re, im = ctx.saved_tensors + grad_inv = grad / (re.square() + im.square()).clamp_min_(1e-10) + return -im * grad_inv, re * grad_inv + + +class angle(Function): + """Similar to torch.angle but robustify the gradient for zero magnitude.""" + + @staticmethod + def forward(ctx, x: Tensor): + ctx.save_for_backward(x) + return torch.atan2(x.imag, x.real) + + @staticmethod + def backward(ctx, grad: Tensor): + (x,) = ctx.saved_tensors + grad_inv = grad / (x.real.square() + x.imag.square()).clamp_min_(1e-10) + return torch.view_as_complex(torch.stack((-x.imag * grad_inv, x.real * grad_inv), dim=-1)) + + +def check_finite_module(obj, name="Module", _raise=True) -> Set[str]: + out: Set[str] = set() + if isinstance(obj, torch.nn.Module): + for name, child in obj.named_children(): + out = out | check_finite_module(child, name) + for name, param in obj.named_parameters(): + out = out | check_finite_module(param, name) + for name, buf in obj.named_buffers(): + out = out | check_finite_module(buf, name) + if _raise and len(out) > 0: + raise ValueError(f"{name} not finite during checkpoint writing including: {out}") + return out + + +def make_np(x: Union[Tensor, np.ndarray, Number]) -> np.ndarray: + """Transforms Tensor to numpy. + Args: + x: An instance of torch tensor or caffe blob name + + Returns: + numpy.array: Numpy array + """ + if isinstance(x, np.ndarray): + return x + if np.isscalar(x): + return np.array([x]) + if isinstance(x, Tensor): + return x.detach().cpu().numpy() + raise NotImplementedError( + "Got {}, but numpy array, scalar, or torch tensor are expected.".format(type(x)) + ) + + +def get_norm_alpha(log: bool = True) -> float: + p = ModelParams() + a_ = _calculate_norm_alpha(sr=p.sr, hop_size=p.hop_size, tau=p.norm_tau) + precision = 3 + a = 1.0 + while a >= 1.0: + a = round(a_, precision) + precision += 1 + if log: + logger.info(f"Running with normalization window alpha = '{a}'") + return a + + +def _calculate_norm_alpha(sr: int, hop_size: int, tau: float): + """Exponential decay factor alpha for a given tau (decay window size [s]).""" + dt = hop_size / sr + return math.exp(-dt / tau) + + +def check_manual_seed(seed: int = None): + """If manual seed is not specified, choose a random one and communicate it to the user.""" + seed = seed or random.randint(1, 10000) + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + return seed + + +def get_git_root(): + git_local_dir = os.path.dirname(os.path.abspath(__file__)) + args = ["git", "-C", git_local_dir, "rev-parse", "--show-toplevel"] + return subprocess.check_output(args).strip().decode() + + +def get_commit_hash(): + """Returns the current git commit.""" + try: + git_dir = get_git_root() + args = ["git", "-C", git_dir, "rev-parse", "--short", "--verify", "HEAD"] + commit = subprocess.check_output(args).strip().decode() + except subprocess.CalledProcessError: + # probably not in git repo + commit = None + return commit + + +def get_host() -> str: + return gethostname() + + +def get_branch_name(): + try: + git_dir = os.path.dirname(os.path.abspath(__file__)) + args = ["git", "-C", git_dir, "rev-parse", "--abbrev-ref", "HEAD"] + branch = subprocess.check_output(args).strip().decode() + except subprocess.CalledProcessError: + # probably not in git repo + branch = None + return branch + + +# from pytorch/ignite: +def apply_to_tensor(input_, func): + """Apply a function on a tensor or mapping, or sequence of tensors.""" + if isinstance(input_, torch.nn.Module): + return [apply_to_tensor(c, func) for c in input_.children()] + elif isinstance(input_, torch.nn.Parameter): + return func(input_.data) + elif isinstance(input_, Tensor): + return func(input_) + elif isinstance(input_, str): + return input_ + elif isinstance(input_, collections.Mapping): + return {k: apply_to_tensor(sample, func) for k, sample in input_.items()} + elif isinstance(input_, collections.Iterable): + return [apply_to_tensor(sample, func) for sample in input_] + elif input_ is None: + return input_ + else: + return input_ + + +def detach_hidden(hidden: Any) -> Any: + """Cut backpropagation graph. + Auxillary function to cut the backpropagation graph by detaching the hidden + vector. + """ + return apply_to_tensor(hidden, Tensor.detach) diff --git a/libdf/__init__.py b/libdf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b005e66db60c0539603a202e65ffd21939d5f019 --- /dev/null +++ b/libdf/__init__.py @@ -0,0 +1,3 @@ +from .libdf import * + +__doc__ = libdf.__doc__ diff --git a/libdf/__init__.pyi b/libdf/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..b4acbcf124f2ae605f2fe7b5058d836e7d9d5f37 --- /dev/null +++ b/libdf/__init__.pyi @@ -0,0 +1,57 @@ +from typing import List, Optional, Union + +from numpy import ndarray + +class DF: + def __init__( + self, + sr: int, + fft_size: int, + hop_size: int, + nb_bands: int, + min_nb_erb_freqs: Optional[int] = 1, + ): + """DeepFilter state used for analysis and synthesis. + + Args: + sr (int): Sampling rate. + fft_size (int): Window length used for the Fast Fourier transform. + hop_size (int): Hop size between two analysis windows. Also called frame size. + nb_bands (int): Number of ERB bands. + min_nb_erb_freqs (int): Minimum number of frequency bands per ERB band. Defaults to 1. + """ + ... + def analysis(self, input: ndarray) -> ndarray: + """Analysis of a time-domain signal. + + Args: + input (ndarray): 2D real-valued array of shape [C, T]. + Output: + output (ndarray): 3D complex-valued array of shape [C, T', F], where F is the `fft_size`, + and T' the original time T divided by `hop_size`. + """ + ... + def synthesis(self, input: ndarray) -> ndarray: + """Synthesis of a frequency-domain signal. + + Args: + input (ndarray): 3D complex-valued array of shape [C, T, F]. + Output: + output (ndarray): 2D real-valued array of shape [C, T]. + """ + ... + def erb_widths(self) -> ndarray: ... + def fft_window(self) -> ndarray: ... + def sr(self) -> int: ... + def fft_size(self) -> int: ... + def hop_size(self) -> int: ... + def nb_erb(self) -> int: ... + def reset(self) -> None: ... + +def erb( + input: ndarray, erb_fb: Union[ndarray, List[int]], db: Optional[bool] = None +) -> ndarray: ... +def erb_inv(input: ndarray, erb_fb: Union[ndarray, List[int]]) -> ndarray: ... +def erb_norm(erb: ndarray, alpha: float, state: Optional[ndarray] = None) -> ndarray: ... +def unit_norm(spec: ndarray, alpha: float, state: Optional[ndarray] = None) -> ndarray: ... +def unit_norm_init(num_freq_bins: int) -> ndarray: ... diff --git a/libdf/py.typed b/libdf/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_weights/voice_enhance/checkpoints/model_96.ckpt.best b/model_weights/voice_enhance/checkpoints/model_96.ckpt.best new file mode 100644 index 0000000000000000000000000000000000000000..261cbe85caed3680390cfa95a3ef184a0d0295d4 --- /dev/null +++ b/model_weights/voice_enhance/checkpoints/model_96.ckpt.best @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb5eccb429e675bb4ec5ec9e280f048bfff9787b40bd3eb835fd11509eb14a3e +size 9397209 diff --git a/model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt b/model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt new file mode 100644 index 0000000000000000000000000000000000000000..662d22b686114b4b6124330a688007d9495d22c8 --- /dev/null +++ b/model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bc7ff82ef75becd495aab2ede3a8220da393a717f178ae9534df355a6173bbca +size 17090379 diff --git a/model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt.txt b/model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt.txt new file mode 100644 index 0000000000000000000000000000000000000000..b8bea30038f105a76f1064e31702b5e3f56f7ae3 --- /dev/null +++ b/model_weights/voice_enhance/ckpt/pretrained_bak_5805000.pt.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7dfd48d0da24db35ee4a653d0d36a4104cb26873050a5c3584675eee21937621 +size 69 diff --git a/model_weights/voiceover/freevc-24.json b/model_weights/voiceover/freevc-24.json new file mode 100644 index 0000000000000000000000000000000000000000..b474cd546ffe8ea0d0b7a0ea4e8a0853e0e4ba3a --- /dev/null +++ b/model_weights/voiceover/freevc-24.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:872360b61e6bbe09bec29810e7ad0d16318e379f6195a7ff3b06e50efb08ad31 +size 1264 diff --git a/model_weights/voiceover/freevc-24.pth b/model_weights/voiceover/freevc-24.pth new file mode 100644 index 0000000000000000000000000000000000000000..d256c31ffd327d6002980112b891f0db5f7ae849 --- /dev/null +++ b/model_weights/voiceover/freevc-24.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b39a86fefbc9ec6e30be8d26ee2a6aa5ffe6d235f6ab15773d01cdf348e5b20 +size 472644351 diff --git a/model_weights/wavlm_models/WavLM-Large.pt b/model_weights/wavlm_models/WavLM-Large.pt new file mode 100644 index 0000000000000000000000000000000000000000..b704cf463982904321df737eb8a3fe092a0aa019 --- /dev/null +++ b/model_weights/wavlm_models/WavLM-Large.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6fb4b3c3e6aa567f0a997b30855859cb81528ee8078802af439f7b2da0bf100f +size 1261965425 diff --git a/model_weights/wavlm_models/WavLM-Large.pt.txt b/model_weights/wavlm_models/WavLM-Large.pt.txt new file mode 100644 index 0000000000000000000000000000000000000000..8f5a784907f2da3f11265b755f6987a48869ae8b --- /dev/null +++ b/model_weights/wavlm_models/WavLM-Large.pt.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9836bca8ab0e9d0b4797aa78f41b367800d26cfd25ade7b1edcb35bc3c171e4 +size 52 diff --git a/nnet/__init__.py b/nnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/nnet/attentions.py b/nnet/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..418a9f1408b253e255b95efdae078af7f5e4a2d7 --- /dev/null +++ b/nnet/attentions.py @@ -0,0 +1,300 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from nnet import commons +from nnet.modules import LayerNorm + + +class Encoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init)) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert t_s == t_t, "Local attention is only available for self-attention." + block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]])) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) + x_flat = x.view([batch, heads, length**2 + length*(length -1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/nnet/commons.py b/nnet/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..0e602b4493dc4a1aed9f902c17ea2c450dc0381a --- /dev/null +++ b/nnet/commons.py @@ -0,0 +1,167 @@ +import math +import torch +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def rand_spec_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d( + length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = ( + math.log(float(max_timescale) / float(min_timescale)) / + (num_timescales - 1)) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2,3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1. / norm_type) + return total_norm diff --git a/nnet/language_funcs.py b/nnet/language_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..5549f7d44f3666bb26bf9f7b0855477ac4bb8efe --- /dev/null +++ b/nnet/language_funcs.py @@ -0,0 +1,59 @@ +import time +import nltk + +from transformers import pipeline + +nltk.download('punkt') +from nltk.tokenize import sent_tokenize + +def detect_language(text,LID): + predictions = LID.predict(text) + detected_lang_code = predictions[0][0].replace("__label__", "") + return detected_lang_code + +def translation(model_name, + sentence_mode, selection_mode, + source, target, + text, + flores_codes, + model_dict, device): + start_time = time.time() + + # Determine the source language + if selection_mode == "Auto-detect": + detected_lang_code = detect_language(text) + flores_source_code = detected_lang_code + source_code = flores_source_code + else: + if source == "Auto-detect": # Make sure we don't use "Auto-detect" as a key + return {'error': "Source language cannot be 'Auto-detect' when selection mode is manual."} + source_code = flores_codes.get(source) + if not source_code: + return {'error': f"Source language {source} not found in flores_codes."} + + + target_code = flores_codes[target] + model = model_dict[model_name + '_model'] + tokenizer = model_dict[model_name + '_tokenizer'] + + translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source_code, tgt_lang=target_code, device=device) + + if sentence_mode == "Sentence-wise": + sentences = sent_tokenize(text) + translated_sentences = [] + for sentence in sentences: + translated_sentence = translator(sentence, max_length=400)[0]['translation_text'] + translated_sentences.append(translated_sentence) + output = ' '.join(translated_sentences) + else: + output = translator(text, max_length=400)[0]['translation_text'] + + end_time = time.time() + + result = { + 'inference_time': end_time - start_time, + 'source_language': source_code, + 'target_language': target_code, + 'result': output + } + return result \ No newline at end of file diff --git a/nnet/mel_processing.py b/nnet/mel_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..92c9141694c197e2ea40e34f4fbffc9c08714916 --- /dev/null +++ b/nnet/mel_processing.py @@ -0,0 +1,103 @@ +import torch +import torch.utils.data +import torch.nn.functional as F + +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + '_' + str(y.device) + wnsize_dtype_device = str(win_size) + '_' + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + center=center, pad_mode='reflect', normalized=False, onesided=True) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + global mel_basis + dtype_device = str(spec.dtype) + '_' + str(spec.device) + fmax_dtype_device = str(fmax) + '_' + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + return spec + + +def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + '_' + str(y.device) + fmax_dtype_device = str(fmax) + '_' + dtype_device + wnsize_dtype_device = str(win_size) + '_' + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + center=center, pad_mode='reflect', normalized=False, onesided=True) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/nnet/models.py b/nnet/models.py new file mode 100644 index 0000000000000000000000000000000000000000..a4ca9870b04a6f26416c91fcae974d37dba03019 --- /dev/null +++ b/nnet/models.py @@ -0,0 +1,544 @@ +import torch +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + +import math +import torch +from torch import nn +from torch.nn import functional as F + +from nnet import modules, attentions, monotonic_align + +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from nnet.commons import init_weights, get_padding +from nnet import commons + +def load_models(device, model_name_dict): + model_dict = {} + for call_name, real_name in model_name_dict.items(): + print('\tLoading model: %s' % call_name) + model = AutoModelForSeq2SeqLM.from_pretrained(real_name, torch_dtype=torch.bfloat16).to(device) + tokenizer = AutoTokenizer.from_pretrained(real_name) + model_dict[call_name+'_model'] = model + model_dict[call_name+'_tokenizer'] = tokenizer + + return model_dict + +class StochasticDurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) + logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class TextEncoder(nn.Module): + def __init__(self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(weight_norm( + ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), + k, u, padding=(k-u)//2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel//(2**(i+1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x) + else: + xs += self.resblocks[i*self.num_kernels+j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2,3,5,7,11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + n_vocab, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + **kwargs): + + super().__init__() + self.n_vocab = n_vocab + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + + self.use_sdp = use_sdp + + self.enc_p = TextEncoder(n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) + self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + if use_sdp: + self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) + else: + self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) + + if n_speakers > 1: + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + def forward(self, x, x_lengths, y, y_lengths, sid=None): + + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] + neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s] + neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] + neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 + + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() + + w = attn.sum(2) + if self.use_sdp: + l_length = self.dp(x, x_mask, w, g=g) + l_length = l_length / torch.sum(x_mask) + else: + logw_ = torch.log(w + 1e-6) * x_mask + logw = self.dp(x, x_mask, g=g) + l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging + + # expand prior + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) + + z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) + o = self.dec(z_slice, g=g) + return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + if self.use_sdp: + logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) + else: + logw = self.dp(x, x_mask, g=g) + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.dec((z * y_mask)[:,:,:max_len], g=g) + return o, attn, y_mask, (z, z_p, m_p, logs_p) + + def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): + assert self.n_speakers > 0, "n_speakers have to be larger than 0." + g_src = self.emb_g(sid_src).unsqueeze(-1) + g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) + diff --git a/nnet/models_vc.py b/nnet/models_vc.py new file mode 100644 index 0000000000000000000000000000000000000000..312741e90cd1fb2df816f595110e5b2503f378d9 --- /dev/null +++ b/nnet/models_vc.py @@ -0,0 +1,350 @@ +import torch + +from torch import nn +from torch.nn import functional as F + +from nnet import commons +from nnet import modules + +from torch.nn import Conv1d, ConvTranspose1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from nnet.commons import init_weights, get_padding + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class Encoder(nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(weight_norm( + ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), + k, u, padding=(k-u)//2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel//(2**(i+1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x) + else: + xs += self.resblocks[i*self.num_kernels+j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2,3,5,7,11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class SpeakerEncoder(torch.nn.Module): + def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256): + super(SpeakerEncoder, self).__init__() + self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) + self.linear = nn.Linear(model_hidden_size, model_embedding_size) + self.relu = nn.ReLU() + + def forward(self, mels): + self.lstm.flatten_parameters() + _, (hidden, _) = self.lstm(mels) + embeds_raw = self.relu(self.linear(hidden[-1])) + return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) + + def compute_partial_slices(self, total_frames, partial_frames, partial_hop): + mel_slices = [] + for i in range(0, total_frames-partial_frames, partial_hop): + mel_range = torch.arange(i, i+partial_frames) + mel_slices.append(mel_range) + + return mel_slices + + def embed_utterance(self, mel, partial_frames=128, partial_hop=64): + mel_len = mel.size(1) + last_mel = mel[:,-partial_frames:] + + if mel_len > partial_frames: + mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop) + mels = list(mel[:,s] for s in mel_slices) + mels.append(last_mel) + mels = torch.stack(tuple(mels), 0).squeeze(1) + + with torch.no_grad(): + partial_embeds = self(mels) + embed = torch.mean(partial_embeds, axis=0).unsqueeze(0) + #embed = embed / torch.linalg.norm(embed, 2) + else: + with torch.no_grad(): + embed = self(last_mel) + + return embed + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + ssl_dim, + use_spk, + **kwargs): + + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.gin_channels = gin_channels + self.ssl_dim = ssl_dim + self.use_spk = use_spk + + self.enc_p = Encoder(ssl_dim, inter_channels, hidden_channels, 5, 1, 16) + self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) + self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + if not self.use_spk: + self.enc_spk = SpeakerEncoder(model_hidden_size=gin_channels, model_embedding_size=gin_channels) + + def forward(self, c, spec, g=None, mel=None, c_lengths=None, spec_lengths=None): + if c_lengths == None: + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + if spec_lengths == None: + spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device) + + if not self.use_spk: + g = self.enc_spk(mel.transpose(1,2)) + g = g.unsqueeze(-1) + + _, m_p, logs_p, _ = self.enc_p(c, c_lengths) + z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g) + z_p = self.flow(z, spec_mask, g=g) + + z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size) + o = self.dec(z_slice, g=g) + + return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer(self, c, g=None, mel=None, c_lengths=None): + if c_lengths == None: + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + if not self.use_spk: + g = self.enc_spk.embed_utterance(mel.transpose(1,2)) + g = g.unsqueeze(-1) + + z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths) + z = self.flow(z_p, c_mask, g=g, reverse=True) + o = self.dec(z * c_mask, g=g) + + return o diff --git a/nnet/modules.py b/nnet/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..c49508920138fce27b632eb9f81d56e0c59b7baf --- /dev/null +++ b/nnet/modules.py @@ -0,0 +1,387 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from torch.nn import Conv1d +from torch.nn.utils import weight_norm, remove_weight_norm + +from nnet import commons +from nnet.commons import init_weights, get_padding +from nnet.transforms import piecewise_rational_quadratic_transform + + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers-1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dialted and Depth-Separable Convolution + """ + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size ** i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, + groups=channels, dilation=dilation, padding=padding + )) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): + super(WN, self).__init__() + assert(kernel_size % 2 == 1) + self.hidden_channels =hidden_channels + self.kernel_size = kernel_size, + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + + for i in range(n_layers): + dilation = dilation_rate ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, + dilation=dilation, padding=padding) + in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply( + x_in, + g_l, + n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:,:self.hidden_channels,:] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:,self.hidden_channels:,:] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels,1)) + self.logs = nn.Parameter(torch.zeros(channels,1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1,2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels]*2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels]*2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1,2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ConvFlow(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) + self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels]*2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins:] + + x1, logabsdet = piecewise_rational_quadratic_transform(x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails='linear', + tail_bound=self.tail_bound + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1,2]) + if not reverse: + return x, logdet + else: + return x diff --git a/nnet/monotonic_align/__init__.py b/nnet/monotonic_align/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..46c9183c2de2eb23bf51d6ae4d129ec13782dc3d --- /dev/null +++ b/nnet/monotonic_align/__init__.py @@ -0,0 +1,19 @@ +# import numpy as np +# import torch +# from .core import maximum_path_c + + +# def maximum_path(neg_cent, mask): +# """ Cython optimized version. +# neg_cent: [b, t_t, t_s] +# mask: [b, t_t, t_s] +# """ +# device = neg_cent.device +# dtype = neg_cent.dtype +# neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) +# path = np.zeros(neg_cent.shape, dtype=np.int32) + +# t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) +# t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) +# maximum_path_c(path, neg_cent, t_t_max, t_s_max) +# return torch.from_numpy(path).to(device=device, dtype=dtype) diff --git a/nnet/monotonic_align/core.pyx b/nnet/monotonic_align/core.pyx new file mode 100644 index 0000000000000000000000000000000000000000..bfaabd4d21c2299cdd978f0cc0caefa20ad186e5 --- /dev/null +++ b/nnet/monotonic_align/core.pyx @@ -0,0 +1,42 @@ +cimport cython +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y-1, x] + if x == 0: + if y == 0: + v_prev = 0. + else: + v_prev = max_neg_val + else: + v_prev = value[y-1, x-1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil: + cdef int b = paths.shape[0] + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/nnet/monotonic_align/setup.py b/nnet/monotonic_align/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..30c224807a70faa9df9c9eb75f8e80c8c867b16b --- /dev/null +++ b/nnet/monotonic_align/setup.py @@ -0,0 +1,9 @@ +from distutils.core import setup +from Cython.Build import cythonize +import numpy + +setup( + name = 'monotonic_align', + ext_modules = cythonize("core.pyx"), + include_dirs=[numpy.get_include()] +) diff --git a/nnet/transforms.py b/nnet/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..4793d67ca5a5630e0ffe0f9fb29445c949e64dae --- /dev/null +++ b/nnet/transforms.py @@ -0,0 +1,193 @@ +import torch +from torch.nn import functional as F + +import numpy as np + + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = { + 'tails': tails, + 'tail_bound': tail_bound + } + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum( + inputs[..., None] >= bin_locations, + dim=-1 + ) - 1 + + +def unconstrained_rational_quadratic_spline(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails='linear', + tail_bound=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == 'linear': + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError('{} tails are not implemented.'.format(tails)) + + outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative + ) + + return outputs, logabsdet + +def rational_quadratic_spline(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0., right=1., bottom=0., top=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError('Input to a transform is not within its domain') + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError('Minimal bin width too large for the number of bins') + if min_bin_height * num_bins > 1.0: + raise ValueError('Minimal bin height too large for the number of bins') + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (((inputs - input_cumheights) * (input_derivatives + + input_derivatives_plus_one + - 2 * input_delta) + + input_heights * (input_delta - input_derivatives))) + b = (input_heights * input_derivatives + - (inputs - input_cumheights) * (input_derivatives + + input_derivatives_plus_one + - 2 * input_delta)) + c = - input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta) + derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2)) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2)) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/nnet/utils.py b/nnet/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..88a2ab27f9ac77881766c6220542c74b71796151 --- /dev/null +++ b/nnet/utils.py @@ -0,0 +1,292 @@ +import os +import sys +import glob +import json +import torch +import logging +import argparse +import subprocess +import torchvision +import numpy as np +from scipy.io.wavfile import read +from wavlm.WavLM import WavLMConfig +from wavlm.WavLM import WavLM + +MATPLOTLIB_FLAG = False + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) +logger = logging + + +def get_cmodel(device, ckpt_model): + checkpoint = torch.load(ckpt_model) + cfg = WavLMConfig(checkpoint['cfg']) + cmodel = WavLM(cfg).to(device) + cmodel.load_state_dict(checkpoint['model']) + cmodel.eval() + return cmodel + + +def get_content(cmodel, y): + with torch.no_grad(): + c = cmodel.extract_features(y.squeeze(1))[0] + c = c.transpose(1, 2) + return c + +def transform(mel, height): + tgt = torchvision.transforms.functional.resize(mel, (height, mel.size(-1))) + if height >= mel.size(-2): + return tgt[:, :mel.size(-2), :] + else: + silence = tgt[:,-1:,:].repeat(1,mel.size(-2)-height,1) + silence += torch.randn_like(silence) / 10 + return torch.cat((tgt, silence), 1) + + +def stretch(mel, width): # 0.5-2 + return torchvision.transforms.functional.resize(mel, (mel.size(-2), width)) + + +def load_checkpoint(checkpoint_path, model, optimizer=None, strict=False): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') + iteration = checkpoint_dict['iteration'] + learning_rate = checkpoint_dict['learning_rate'] + if optimizer is not None: + optimizer.load_state_dict(checkpoint_dict['optimizer']) + saved_state_dict = checkpoint_dict['model'] + if hasattr(model, 'module'): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + if strict: + assert state_dict.keys() == saved_state_dict.keys(), "Mismatched model config and checkpoint." + new_state_dict= {} + for k, v in state_dict.items(): + try: + new_state_dict[k] = saved_state_dict[k] + except: + logger.info("%s is not in the checkpoint" % k) + new_state_dict[k] = v + if hasattr(model, 'module'): + model.module.load_state_dict(new_state_dict) + else: + model.load_state_dict(new_state_dict) + logger.info("Loaded checkpoint '{}' (iteration {})" .format( + checkpoint_path, iteration)) + return model, optimizer, learning_rate, iteration + + +def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): + logger.info("Saving model and optimizer state at iteration {} to {}".format( + iteration, checkpoint_path)) + if hasattr(model, 'module'): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + torch.save({'model': state_dict, + 'iteration': iteration, + 'optimizer': optimizer.state_dict(), + 'learning_rate': learning_rate}, checkpoint_path) + + +def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): + for k, v in scalars.items(): + writer.add_scalar(k, v, global_step) + for k, v in histograms.items(): + writer.add_histogram(k, v, global_step) + for k, v in images.items(): + writer.add_image(k, v, global_step, dataformats='HWC') + for k, v in audios.items(): + writer.add_audio(k, v, global_step, audio_sampling_rate) + + +def latest_checkpoint_path(dir_path, regex="G_*.pth"): + f_list = glob.glob(os.path.join(dir_path, regex)) + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + x = f_list[-1] + print(x) + return x + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10,2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def plot_alignment_to_numpy(alignment, info=None): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', + interpolation='none') + fig.colorbar(im, ax=ax) + xlabel = 'Decoder timestep' + if info is not None: + xlabel += '\n\n' + info + plt.xlabel(xlabel) + plt.ylabel('Encoder timestep') + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + return torch.FloatTensor(data.astype(np.float32)), sampling_rate + + +def load_filepaths_and_text(filename, split="|"): + with open(filename, encoding='utf-8') as f: + filepaths_and_text = [line.strip().split(split) for line in f] + return filepaths_and_text + + +def get_hparams(init=True): + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', type=str, default="./configs/base.json", + help='JSON file for configuration') + parser.add_argument('-m', '--model', type=str, required=True, + help='Model name') + + args = parser.parse_args() + model_dir = os.path.join("./logs", args.model) + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + config_path = args.config + config_save_path = os.path.join(model_dir, "config.json") + if init: + with open(config_path, "r") as f: + data = f.read() + with open(config_save_path, "w") as f: + f.write(data) + else: + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_dir(model_dir): + config_save_path = os.path.join(model_dir, "config.json") + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams =HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_file(config_path): + with open(config_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams =HParams(**config) + return hparams + + +def check_git_hash(model_dir): + source_dir = os.path.dirname(os.path.realpath(__file__)) + if not os.path.exists(os.path.join(source_dir, ".git")): + logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( + source_dir + )) + return + + cur_hash = subprocess.getoutput("git rev-parse HEAD") + + path = os.path.join(model_dir, "githash") + if os.path.exists(path): + saved_hash = open(path).read() + if saved_hash != cur_hash: + logger.warn("git hash values are different. {}(saved) != {}(current)".format( + saved_hash[:8], cur_hash[:8])) + else: + open(path, "w").write(cur_hash) + + +def get_logger(model_dir, filename="train.log"): + global logger + logger = logging.getLogger(os.path.basename(model_dir)) + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + h = logging.FileHandler(os.path.join(model_dir, filename)) + h.setLevel(logging.DEBUG) + h.setFormatter(formatter) + logger.addHandler(h) + return logger + + +class HParams(): + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..353999910bb70b5219cdd48c81b4640840f2d5b9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba281fe65a5b1db7b15fc3ce7e84b198e8a998f136552a34ca8d52eb8b3fadf7 +size 1915 diff --git a/speaker_encoder/__init__.py b/speaker_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/speaker_encoder/audio.py b/speaker_encoder/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..2fcb77ad1d3a85f523e24f84691886736a5686cb --- /dev/null +++ b/speaker_encoder/audio.py @@ -0,0 +1,107 @@ +from scipy.ndimage.morphology import binary_dilation +from speaker_encoder.params_data import * +from pathlib import Path +from typing import Optional, Union +import numpy as np +import webrtcvad +import librosa +import struct + +int16_max = (2 ** 15) - 1 + + +def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], + source_sr: Optional[int] = None): + """ + Applies the preprocessing operations used in training the Speaker Encoder to a waveform + either on disk or in memory. The waveform will be resampled to match the data hyperparameters. + + :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not + just .wav), either the waveform as a numpy array of floats. + :param source_sr: if passing an audio waveform, the sampling rate of the waveform before + preprocessing. After preprocessing, the waveform's sampling rate will match the data + hyperparameters. If passing a filepath, the sampling rate will be automatically detected and + this argument will be ignored. + """ + # Load the wav from disk if needed + if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): + wav, source_sr = librosa.load(fpath_or_wav, sr=None) + else: + wav = fpath_or_wav + + # Resample the wav if needed + if source_sr is not None and source_sr != sampling_rate: + wav = librosa.resample(wav, source_sr, sampling_rate) + + # Apply the preprocessing: normalize volume and shorten long silences + wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) + wav = trim_long_silences(wav) + + return wav + + +def wav_to_mel_spectrogram(wav): + """ + Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. + Note: this not a log-mel spectrogram. + """ + frames = librosa.feature.melspectrogram( + y=wav, + sr=sampling_rate, + n_fft=int(sampling_rate * mel_window_length / 1000), + hop_length=int(sampling_rate * mel_window_step / 1000), + n_mels=mel_n_channels + ) + return frames.astype(np.float32).T + + +def trim_long_silences(wav): + """ + Ensures that segments without voice in the waveform remain no longer than a + threshold determined by the VAD parameters in params.py. + + :param wav: the raw waveform as a numpy array of floats + :return: the same waveform with silences trimmed away (length <= original wav length) + """ + # Compute the voice detection window size + samples_per_window = (vad_window_length * sampling_rate) // 1000 + + # Trim the end of the audio to have a multiple of the window size + wav = wav[:len(wav) - (len(wav) % samples_per_window)] + + # Convert the float waveform to 16-bit mono PCM + pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) + + # Perform voice activation detection + voice_flags = [] + vad = webrtcvad.Vad(mode=3) + for window_start in range(0, len(wav), samples_per_window): + window_end = window_start + samples_per_window + voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], + sample_rate=sampling_rate)) + voice_flags = np.array(voice_flags) + + # Smooth the voice detection with a moving average + def moving_average(array, width): + array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) + ret = np.cumsum(array_padded, dtype=float) + ret[width:] = ret[width:] - ret[:-width] + return ret[width - 1:] / width + + audio_mask = moving_average(voice_flags, vad_moving_average_width) + audio_mask = np.round(audio_mask).astype(np.bool) + + # Dilate the voiced regions + audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) + audio_mask = np.repeat(audio_mask, samples_per_window) + + return wav[audio_mask == True] + + +def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): + if increase_only and decrease_only: + raise ValueError("Both increase only and decrease only are set") + dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2)) + if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): + return wav + return wav * (10 ** (dBFS_change / 20)) diff --git a/speaker_encoder/config.py b/speaker_encoder/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1c21312f3de971bfa008254c6035cebc09f05e4c --- /dev/null +++ b/speaker_encoder/config.py @@ -0,0 +1,45 @@ +librispeech_datasets = { + "train": { + "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"], + "other": ["LibriSpeech/train-other-500"] + }, + "test": { + "clean": ["LibriSpeech/test-clean"], + "other": ["LibriSpeech/test-other"] + }, + "dev": { + "clean": ["LibriSpeech/dev-clean"], + "other": ["LibriSpeech/dev-other"] + }, +} +libritts_datasets = { + "train": { + "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"], + "other": ["LibriTTS/train-other-500"] + }, + "test": { + "clean": ["LibriTTS/test-clean"], + "other": ["LibriTTS/test-other"] + }, + "dev": { + "clean": ["LibriTTS/dev-clean"], + "other": ["LibriTTS/dev-other"] + }, +} +voxceleb_datasets = { + "voxceleb1" : { + "train": ["VoxCeleb1/wav"], + "test": ["VoxCeleb1/test_wav"] + }, + "voxceleb2" : { + "train": ["VoxCeleb2/dev/aac"], + "test": ["VoxCeleb2/test_wav"] + } +} + +other_datasets = [ + "LJSpeech-1.1", + "VCTK-Corpus/wav48", +] + +anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"] diff --git a/speaker_encoder/hparams.py b/speaker_encoder/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..9a8c16471903b0c92253b1d70fcd6a61d10e085f --- /dev/null +++ b/speaker_encoder/hparams.py @@ -0,0 +1,31 @@ +## Mel-filterbank +mel_window_length = 25 # In milliseconds +mel_window_step = 10 # In milliseconds +mel_n_channels = 40 + + +## Audio +sampling_rate = 16000 +# Number of spectrogram frames in a partial utterance +partials_n_frames = 160 # 1600 ms + + +## Voice Activation Detection +# Window size of the VAD. Must be either 10, 20 or 30 milliseconds. +# This sets the granularity of the VAD. Should not need to be changed. +vad_window_length = 30 # In milliseconds +# Number of frames to average together when performing the moving average smoothing. +# The larger this value, the larger the VAD variations must be to not get smoothed out. +vad_moving_average_width = 8 +# Maximum number of consecutive silent frames a segment can have. +vad_max_silence_length = 6 + + +## Audio volume normalization +audio_norm_target_dBFS = -30 + + +## Model parameters +model_hidden_size = 256 +model_embedding_size = 256 +model_num_layers = 3 \ No newline at end of file diff --git a/speaker_encoder/params_data.py b/speaker_encoder/params_data.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb1716ed45617f2b127a7fb8885afe6cc74fb71 --- /dev/null +++ b/speaker_encoder/params_data.py @@ -0,0 +1,29 @@ + +## Mel-filterbank +mel_window_length = 25 # In milliseconds +mel_window_step = 10 # In milliseconds +mel_n_channels = 40 + + +## Audio +sampling_rate = 16000 +# Number of spectrogram frames in a partial utterance +partials_n_frames = 160 # 1600 ms +# Number of spectrogram frames at inference +inference_n_frames = 80 # 800 ms + + +## Voice Activation Detection +# Window size of the VAD. Must be either 10, 20 or 30 milliseconds. +# This sets the granularity of the VAD. Should not need to be changed. +vad_window_length = 30 # In milliseconds +# Number of frames to average together when performing the moving average smoothing. +# The larger this value, the larger the VAD variations must be to not get smoothed out. +vad_moving_average_width = 8 +# Maximum number of consecutive silent frames a segment can have. +vad_max_silence_length = 6 + + +## Audio volume normalization +audio_norm_target_dBFS = -30 + diff --git a/speaker_encoder/voice_encoder.py b/speaker_encoder/voice_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d4dfda42dd9eb95e13b57dc83e1a6cd69e777e --- /dev/null +++ b/speaker_encoder/voice_encoder.py @@ -0,0 +1,167 @@ + +import torch +import numpy as np + +from torch import nn +from typing import Union, List +from speaker_encoder import audio +from speaker_encoder.hparams import * +from time import perf_counter as timer + +class SpeakerEncoder(nn.Module): + def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose=True): + """ + :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). + If None, defaults to cuda if it is available on your machine, otherwise the model will + run on cpu. Outputs are always returned on the cpu, as numpy arrays. + """ + super().__init__() + + # Define the network + self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) + self.linear = nn.Linear(model_hidden_size, model_embedding_size) + self.relu = nn.ReLU() + + # Get the target device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, str): + device = torch.device(device) + self.device = device + + start = timer() + checkpoint = torch.load(weights_fpath, map_location="cpu") + + self.load_state_dict(checkpoint["model_state"], strict=False) + self.to(device) + + if verbose: + print("Loaded the voice encoder model on %s in %.2f seconds." % + (device.type, timer() - start)) + + def forward(self, mels: torch.FloatTensor): + """ + Computes the embeddings of a batch of utterance spectrograms. + :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape + (batch_size, n_frames, n_channels) + :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size). + Embeddings are positive and L2-normed, thus they lay in the range [0, 1]. + """ + # Pass the input through the LSTM layers and retrieve the final hidden state of the last + # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings. + _, (hidden, _) = self.lstm(mels) + embeds_raw = self.relu(self.linear(hidden[-1])) + return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) + + @staticmethod + def compute_partial_slices(n_samples: int, rate, min_coverage): + """ + Computes where to split an utterance waveform and its corresponding mel spectrogram to + obtain partial utterances of each. Both the waveform and the + mel spectrogram slices are returned, so as to make each partial utterance waveform + correspond to its spectrogram. + + The returned ranges may be indexing further than the length of the waveform. It is + recommended that you pad the waveform with zeros up to wav_slices[-1].stop. + + :param n_samples: the number of samples in the waveform + :param rate: how many partial utterances should occur per second. Partial utterances must + cover the span of the entire utterance, thus the rate should not be lower than the inverse + of the duration of a partial utterance. By default, partial utterances are 1.6s long and + the minimum rate is thus 0.625. + :param min_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered by zero-padding the audio. Otherwise, + it will be discarded. If there aren't enough frames for one partial utterance, + this parameter is ignored so that the function always returns at least one slice. + :return: the waveform slices and mel spectrogram slices as lists of array slices. Index + respectively the waveform and the mel spectrogram with these slices to obtain the partial + utterances. + """ + assert 0 < min_coverage <= 1 + + # Compute how many frames separate two partial utterances + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) + n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) + frame_step = int(np.round((sampling_rate / rate) / samples_per_frame)) + assert 0 < frame_step, "The rate is too high" + assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \ + (sampling_rate / (samples_per_frame * partials_n_frames)) + + # Compute the slices + wav_slices, mel_slices = [], [] + steps = max(1, n_frames - partials_n_frames + frame_step + 1) + for i in range(0, steps, frame_step): + mel_range = np.array([i, i + partials_n_frames]) + wav_range = mel_range * samples_per_frame + mel_slices.append(slice(*mel_range)) + wav_slices.append(slice(*wav_range)) + + # Evaluate whether extra padding is warranted or not + last_wav_range = wav_slices[-1] + coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) + if coverage < min_coverage and len(mel_slices) > 1: + mel_slices = mel_slices[:-1] + wav_slices = wav_slices[:-1] + + return wav_slices, mel_slices + + def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75): + """ + Computes an embedding for a single utterance. The utterance is divided in partial + utterances and an embedding is computed for each. The complete utterance embedding is the + L2-normed average embedding of the partial utterances. + + TODO: independent batched version of this function + + :param wav: a preprocessed utterance waveform as a numpy array of float32 + :param return_partials: if True, the partial embeddings will also be returned along with + the wav slices corresponding to each partial utterance. + :param rate: how many partial utterances should occur per second. Partial utterances must + cover the span of the entire utterance, thus the rate should not be lower than the inverse + of the duration of a partial utterance. By default, partial utterances are 1.6s long and + the minimum rate is thus 0.625. + :param min_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered by zero-padding the audio. Otherwise, + it will be discarded. If there aren't enough frames for one partial utterance, + this parameter is ignored so that the function always returns at least one slice. + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If + is True, the partial utterances as a numpy array of float32 of shape + (n_partials, model_embedding_size) and the wav partials as a list of slices will also be + returned. + """ + # Compute where to split the utterance into partials and pad the waveform with zeros if + # the partial utterances cover a larger range. + wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage) + max_wave_length = wav_slices[-1].stop + if max_wave_length >= len(wav): + wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") + + # Split the utterance into partials and forward them through the model + mel = audio.wav_to_mel_spectrogram(wav) + mels = np.array([mel[s] for s in mel_slices]) + with torch.no_grad(): + mels = torch.from_numpy(mels).to(self.device) + partial_embeds = self(mels).cpu().numpy() + + # Compute the utterance embedding from the partial embeddings + raw_embed = np.mean(partial_embeds, axis=0) + embed = raw_embed / np.linalg.norm(raw_embed, 2) + + if return_partials: + return embed, partial_embeds, wav_slices + return embed + + def embed_speaker(self, wavs: List[np.ndarray], **kwargs): + """ + Compute the embedding of a collection of wavs (presumably from the same speaker) by + averaging their embedding and L2-normalizing it. + + :param wavs: list of wavs a numpy arrays of float32. + :param kwargs: extra arguments to embed_utterance() + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). + """ + raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) \ + for wav in wavs], axis=0) + return raw_embed / np.linalg.norm(raw_embed, 2) \ No newline at end of file diff --git a/wavlm/WavLM.py b/wavlm/WavLM.py new file mode 100644 index 0000000000000000000000000000000000000000..09af97bfc48c5ecc3558c03666df9f965d2bef0d --- /dev/null +++ b/wavlm/WavLM.py @@ -0,0 +1,742 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import logging +from typing import List, Optional, Tuple + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm +from wavlm.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GradMultiply, + MultiheadAttention, + SamePad, + init_bert_params, + get_activation_fn, + TransposeLast, + GLU_Linear, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + #padding_mask = padding_mask.all(-1) + padding_mask = padding_mask.any(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask + ) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default" + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride, padding=1) + ) + self.conv_layers.append( + torch.nn.LayerNorm([dim, idim]) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append( + torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + ) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x += x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, + self_attn_mask=streaming_mask, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias diff --git a/wavlm/__init__.py b/wavlm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wavlm/modules.py b/wavlm/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..cd360aa8df4a8826199757c91fa5224215f20dd0 --- /dev/null +++ b/wavlm/modules.py @@ -0,0 +1,827 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +from torch.nn import Parameter +import torch.nn.functional as F + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function + """ + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights \ No newline at end of file