diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..c85cbb5184755529a7f48e4c63003dcbce689a6e --- /dev/null +++ b/.gitattributes @@ -0,0 +1,36 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +.pth filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..685d6448d5a197ed8a434f72d6f2279552613c91 --- /dev/null +++ b/.gitignore @@ -0,0 +1,184 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/env.sh +venv/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Files created by experiments +output/ +snapshot/ +*.m4a +notebooks/scratch.ipynb +notebooks/inspect.ipynb +notebooks/effects.ipynb +notebooks/*.ipynb +notebooks/*.gif +notebooks/*.wav +notebooks/*.mp4 +*runs/ +boards/ +samples/ +*.ipynb + +results.json +metrics.csv +mprofile_* +mem.png + +results/ +mprofile* +*.png +# do not ignore the test wav file +!tests/audio/short_test_audio.wav +!tests/audio/output.wav +*/.DS_Store +.DS_Store +env.sh +_codebraid/ +**/*.html +**/*.exec.md +flagged/ +log.txt +ckpt/ +.syncthing* +tests/assets/ +archived/ + +scratch/ + +runs-archive +lyrebird-audiotools +lyrebird-audio-codec +samples-*/** + +gradio-outputs/ +samples*/ +models-all/ +models.zip +audiotools/ +descript-audio-codec/ +# *.pth +.git-old diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eca3cf9115d04a10dd9c4087fb14940ca3721094 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +repos: +- repo: https://github.com/asottile/reorder_python_imports + rev: v2.5.0 + hooks: + - id: reorder-python-imports +- repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + language_version: python3 +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..53e66ea3e16109c82c94b12b859fdb7ca1c2ee81 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Hugo Flores García and Prem Seetharaman + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a962bd62656c2a689ef666f4a07015c287afc59a --- /dev/null +++ b/README.md @@ -0,0 +1,95 @@ +--- +title: 'VampNet: Music Generation with Masked Transformers' +emoji: 🤖 +colorFrom: gray +colorTo: gray +sdk: gradio +sdk_version: 3.36.1 +app_file: app.py +pinned: false +duplicated_from: hugggof/vampnet +--- + +# VampNet + +This repository contains recipes for training generative music models on top of the Lyrebird Audio Codec. + +# Setting up + +Requires Python 3.9 or later. + + +install VampNet + +```bash +git clone https://github.com/hugofloresgarcia/vampnet.git +pip install -e ./vampnet +``` + +## A note on argbind +This repository relies on [argbind](https://github.com/pseeth/argbind) to manage CLIs and config files. +Config files are stored in the `conf/` folder. + +## Getting the Pretrained Models + +### Licensing for Pretrained Models: +The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). + +Download the pretrained models from [this link](https://zenodo.org/record/8136545). Then, extract the models to the `models/` folder. + + +# Usage + +## Launching the Gradio Interface +You can launch a gradio UI to play with vampnet. + +```bash +python app.py --args.load conf/interface.yml --Interface.device cuda +``` + +# Training / Fine-tuning + +## Training a model + +To train a model, run the following script: + +```bash +python scripts/exp/train.py --args.load conf/vampnet.yml --save_path /path/to/checkpoints +``` + +You can edit `conf/vampnet.yml` to change the dataset paths or any training hyperparameters. + +For coarse2fine models, you can use `conf/c2f.yml` as a starting configuration. + +See `python scripts/exp/train.py -h` for a list of options. + +## Fine-tuning +To fine-tune a model, use the script in `scripts/exp/fine_tune.py` to generate 3 configuration files: `c2f.yml`, `coarse.yml`, and `interface.yml`. +The first two are used to fine-tune the coarse and fine models, respectively. The last one is used to launch the gradio interface. + +```bash +python scripts/exp/fine_tune.py "/path/to/audio1.mp3 /path/to/audio2/ /path/to/audio3.wav" +``` + +This will create a folder under `conf//` with the 3 configuration files. + +The save_paths will be set to `runs//coarse` and `runs//c2f`. + +launch the coarse job: +```bash +python scripts/exp/train.py --args.load conf//coarse.yml +``` + +this will save the coarse model to `runs//coarse/ckpt/best/`. + +launch the c2f job: +```bash +python scripts/exp/train.py --args.load conf//c2f.yml +``` + +launch the interface: +```bash +python demo.py --args.load conf/generated//interface.yml +``` + + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..eb351be8b62b72abab5840a03c033e04cb04cc16 --- /dev/null +++ b/app.py @@ -0,0 +1,511 @@ +from pathlib import Path +from typing import Tuple +import yaml +import tempfile +import uuid +import shutil +from dataclasses import dataclass, asdict + +import numpy as np +import audiotools as at +import argbind +import torch + +import gradio as gr +from vampnet.interface import Interface +from vampnet import mask as pmask + +# Interface = argbind.bind(Interface) +# AudioLoader = argbind.bind(at.data.datasets.AudioLoader) + +interface = Interface( + coarse_ckpt="./models/vampnet/coarse.pth", + coarse2fine_ckpt="./models/vampnet/c2f.pth", + codec_ckpt="./models/vampnet/codec.pth", + device="cuda" if torch.cuda.is_available() else "cpu", +) + +# loader = AudioLoader() +print(f"interface device is {interface.device}") + +# dataset = at.data.datasets.AudioDataset( +# loader, +# sample_rate=interface.codec.sample_rate, +# duration=interface.coarse.chunk_size_s, +# n_examples=5000, +# without_replacement=True, +# ) + +OUT_DIR = Path("gradio-outputs") +OUT_DIR.mkdir(exist_ok=True, parents=True) + + +def load_audio(file): + print(file) + filepath = file.name + sig = at.AudioSignal.salient_excerpt( + filepath, + duration=interface.coarse.chunk_size_s + ) + sig = interface.preprocess(sig) + + out_dir = OUT_DIR / str(uuid.uuid4()) + out_dir.mkdir(parents=True, exist_ok=True) + sig.write(out_dir / "input.wav") + return sig.path_to_file + + +def load_example_audio(): + return "./assets/example.wav" + + +def _vamp(data, return_mask=False): + # remove any old files in the output directory (from previous runs) + shutil.rmtree(OUT_DIR) + OUT_DIR.mkdir() + + out_dir = OUT_DIR / str(uuid.uuid4()) + out_dir.mkdir() + sig = at.AudioSignal(data[input_audio]) + + z = interface.encode(sig) + + ncc = data[n_conditioning_codebooks] + + # build the mask + mask = pmask.linear_random(z, data[rand_mask_intensity]) + mask = pmask.mask_and( + mask, pmask.inpaint( + z, + interface.s2t(data[prefix_s]), + interface.s2t(data[suffix_s]) + ) + ) + mask = pmask.mask_and( + mask, pmask.periodic_mask( + z, + data[periodic_p], + data[periodic_w], + random_roll=True + ) + ) + if data[onset_mask_width] > 0: + mask = pmask.mask_or( + mask, pmask.onset_mask(sig, z, interface, width=data[onset_mask_width]) + ) + if data[beat_mask_width] > 0: + beat_mask = interface.make_beat_mask( + sig, + after_beat_s=(data[beat_mask_width]/1000), + mask_upbeats=not data[beat_mask_downbeats], + ) + mask = pmask.mask_and(mask, beat_mask) + + # these should be the last two mask ops + mask = pmask.dropout(mask, data[dropout]) + mask = pmask.codebook_unmask(mask, ncc) + + + print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[temp]}, use coarse2fine {data[use_coarse2fine]}") + # save the mask as a txt file + np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy()) + + zv, mask_z = interface.coarse_vamp( + z, + mask=mask, + sampling_steps=data[num_steps], + temperature=data[temp]*10, + return_mask=True, + typical_filtering=data[typical_filtering], + typical_mass=data[typical_mass], + typical_min_tokens=data[typical_min_tokens], + gen_fn=interface.coarse.generate, + ) + + if use_coarse2fine: + zv = interface.coarse_to_fine(zv, temperature=data[temp]) + + sig = interface.to_signal(zv).cpu() + print("done") + + + + sig.write(out_dir / "output.wav") + + if return_mask: + mask = interface.to_signal(mask_z).cpu() + mask.write(out_dir / "mask.wav") + return sig.path_to_file, mask.path_to_file + else: + return sig.path_to_file + +def vamp(data): + return _vamp(data, return_mask=True) + +def api_vamp(data): + return _vamp(data, return_mask=False) + +def save_vamp(data): + out_dir = OUT_DIR / "saved" / str(uuid.uuid4()) + out_dir.mkdir(parents=True, exist_ok=True) + + sig_in = at.AudioSignal(data[input_audio]) + sig_out = at.AudioSignal(data[output_audio]) + + sig_in.write(out_dir / "input.wav") + sig_out.write(out_dir / "output.wav") + + _data = { + "temp": data[temp], + "prefix_s": data[prefix_s], + "suffix_s": data[suffix_s], + "rand_mask_intensity": data[rand_mask_intensity], + "num_steps": data[num_steps], + "notes": data[notes_text], + "periodic_period": data[periodic_p], + "periodic_width": data[periodic_w], + "n_conditioning_codebooks": data[n_conditioning_codebooks], + "use_coarse2fine": data[use_coarse2fine], + "stretch_factor": data[stretch_factor], + } + + # save with yaml + with open(out_dir / "data.yaml", "w") as f: + yaml.dump(_data, f) + + import zipfile + zip_path = out_dir.with_suffix(".zip") + with zipfile.ZipFile(zip_path, "w") as zf: + for file in out_dir.iterdir(): + zf.write(file, file.name) + + return f"saved! your save code is {out_dir.stem}", zip_path + + +with gr.Blocks() as demo: + + with gr.Row(): + with gr.Column(): + gr.Markdown("# VampNet Audio Vamping") + gr.Markdown("""## Description: + This is a demo of the VampNet, a generative audio model that transforms the input audio based on the chosen settings. + You can control the extent and nature of variation with a set of manual controls and presets. + Use this interface to experiment with different mask settings and explore the audio outputs. + """) + + gr.Markdown(""" + ## Instructions: + 1. You can start by uploading some audio, or by loading the example audio. + 2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings. + 3. Click the "generate (vamp)!!!" button to apply the vamp operation. Listen to the output audio. + 4. Optionally, you can add some notes and save the result. + 5. You can also use the output as the new input and continue experimenting! + """) + with gr.Row(): + with gr.Column(): + + + manual_audio_upload = gr.File( + label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)", + file_types=["audio"] + ) + load_example_audio_button = gr.Button("or load example audio") + + input_audio = gr.Audio( + label="input audio", + interactive=False, + type="filepath", + ) + + audio_mask = gr.Audio( + label="audio mask (listen to this to hear the mask hints)", + interactive=False, + type="filepath", + ) + + # connect widgets + load_example_audio_button.click( + fn=load_example_audio, + inputs=[], + outputs=[ input_audio] + ) + + manual_audio_upload.change( + fn=load_audio, + inputs=[manual_audio_upload], + outputs=[ input_audio] + ) + + # mask settings + with gr.Column(): + + + presets = { + "unconditional": { + "periodic_p": 0, + "onset_mask_width": 0, + "beat_mask_width": 0, + "beat_mask_downbeats": False, + }, + "slight periodic variation": { + "periodic_p": 5, + "onset_mask_width": 5, + "beat_mask_width": 0, + "beat_mask_downbeats": False, + }, + "moderate periodic variation": { + "periodic_p": 13, + "onset_mask_width": 5, + "beat_mask_width": 0, + "beat_mask_downbeats": False, + }, + "strong periodic variation": { + "periodic_p": 17, + "onset_mask_width": 5, + "beat_mask_width": 0, + "beat_mask_downbeats": False, + }, + "very strong periodic variation": { + "periodic_p": 21, + "onset_mask_width": 5, + "beat_mask_width": 0, + "beat_mask_downbeats": False, + }, + "beat-driven variation": { + "periodic_p": 0, + "onset_mask_width": 0, + "beat_mask_width": 50, + "beat_mask_downbeats": False, + }, + "beat-driven variation (downbeats only)": { + "periodic_p": 0, + "onset_mask_width": 0, + "beat_mask_width": 50, + "beat_mask_downbeats": True, + }, + "beat-driven variation (downbeats only, strong)": { + "periodic_p": 0, + "onset_mask_width": 0, + "beat_mask_width": 20, + "beat_mask_downbeats": True, + }, + } + + preset = gr.Dropdown( + label="preset", + choices=list(presets.keys()), + value="strong periodic variation", + ) + load_preset_button = gr.Button("load_preset") + + with gr.Accordion("manual controls", open=True): + periodic_p = gr.Slider( + label="periodic prompt (0 - unconditional, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)", + minimum=0, + maximum=128, + step=1, + value=3, + ) + + + onset_mask_width = gr.Slider( + label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ", + minimum=0, + maximum=20, + step=1, + value=5, + ) + + beat_mask_width = gr.Slider( + label="beat mask width (in milliseconds)", + minimum=0, + maximum=200, + value=0, + ) + beat_mask_downbeats = gr.Checkbox( + label="beat mask downbeats only?", + value=False + ) + + + with gr.Accordion("extras ", open=False): + rand_mask_intensity = gr.Slider( + label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)", + minimum=0.0, + maximum=1.0, + value=1.0 + ) + + periodic_w = gr.Slider( + label="periodic prompt width (steps, 1 step ~= 10milliseconds)", + minimum=1, + maximum=20, + step=1, + value=1, + ) + n_conditioning_codebooks = gr.Number( + label="number of conditioning codebooks. probably 0", + value=0, + precision=0, + ) + + stretch_factor = gr.Slider( + label="time stretch factor", + minimum=0, + maximum=64, + step=1, + value=1, + ) + + preset_outputs = { + periodic_p, + onset_mask_width, + beat_mask_width, + beat_mask_downbeats, + } + + def load_preset(_preset): + return tuple(presets[_preset].values()) + + load_preset_button.click( + fn=load_preset, + inputs=[preset], + outputs=preset_outputs + ) + + + with gr.Accordion("prefix/suffix prompts", open=False): + prefix_s = gr.Slider( + label="prefix hint length (seconds)", + minimum=0.0, + maximum=10.0, + value=0.0 + ) + suffix_s = gr.Slider( + label="suffix hint length (seconds)", + minimum=0.0, + maximum=10.0, + value=0.0 + ) + + temp = gr.Slider( + label="temperature", + minimum=0.0, + maximum=10.0, + value=0.8 + ) + + + + with gr.Accordion("sampling settings", open=False): + typical_filtering = gr.Checkbox( + label="typical filtering ", + value=False + ) + typical_mass = gr.Slider( + label="typical mass (should probably stay between 0.1 and 0.5)", + minimum=0.01, + maximum=0.99, + value=0.15 + ) + typical_min_tokens = gr.Slider( + label="typical min tokens (should probably stay between 1 and 256)", + minimum=1, + maximum=256, + step=1, + value=64 + ) + + use_coarse2fine = gr.Checkbox( + label="use coarse2fine", + value=True + ) + + num_steps = gr.Slider( + label="number of steps (should normally be between 12 and 36)", + minimum=1, + maximum=128, + step=1, + value=36 + ) + + dropout = gr.Slider( + label="mask dropout", + minimum=0.0, + maximum=1.0, + step=0.01, + value=0.0 + ) + + + # mask settings + with gr.Column(): + vamp_button = gr.Button("generate (vamp)!!!") + output_audio = gr.Audio( + label="output audio", + interactive=False, + type="filepath" + ) + + notes_text = gr.Textbox( + label="type any notes about the generated audio here", + value="", + interactive=True + ) + save_button = gr.Button("save vamp") + download_file = gr.File( + label="vamp to download will appear here", + interactive=False + ) + use_as_input_button = gr.Button("use output as input") + + thank_you = gr.Markdown("") + + + _inputs = { + input_audio, + num_steps, + temp, + prefix_s, suffix_s, + rand_mask_intensity, + periodic_p, periodic_w, + n_conditioning_codebooks, + dropout, + use_coarse2fine, + stretch_factor, + onset_mask_width, + typical_filtering, + typical_mass, + typical_min_tokens, + beat_mask_width, + beat_mask_downbeats + } + + # connect widgets + vamp_button.click( + fn=vamp, + inputs=_inputs, + outputs=[output_audio, audio_mask], + ) + + api_vamp_button = gr.Button("api vamp", visible=False) + api_vamp_button.click( + fn=api_vamp, + inputs=_inputs, + outputs=[output_audio], + api_name="vamp" + ) + + use_as_input_button.click( + fn=lambda x: x, + inputs=[output_audio], + outputs=[input_audio] + ) + + save_button.click( + fn=save_vamp, + inputs=_inputs | {notes_text, output_audio}, + outputs=[thank_you, download_file] + ) + +demo.queue().launch() diff --git a/assets/example.wav b/assets/example.wav new file mode 100644 index 0000000000000000000000000000000000000000..a8999e975b85c991614c3ef5b6767f91e09d9756 Binary files /dev/null and b/assets/example.wav differ diff --git a/conf/c2f.yml b/conf/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..c41fd848e884ba6963c7ceb78fa2eb4b97e9239d --- /dev/null +++ b/conf/c2f.yml @@ -0,0 +1,14 @@ +$include: + - conf/vampnet.yml + +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 + +VampNet.embedding_dim: 1280 +VampNet.n_layers: 16 +VampNet.n_heads: 20 + +AudioDataset.duration: 3.0 + + +AudioDataset.loudness_cutoff: -40.0 diff --git a/conf/generated-v0/berta-goldman-speech/c2f.yml b/conf/generated-v0/berta-goldman-speech/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..0f5a4cd57e7a801121d7c77a62a0e8767b7fe61c --- /dev/null +++ b/conf/generated-v0/berta-goldman-speech/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +save_path: ./runs/berta-goldman-speech/c2f +train/AudioLoader.sources: +- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 +val/AudioLoader.sources: +- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 diff --git a/conf/generated-v0/berta-goldman-speech/coarse.yml b/conf/generated-v0/berta-goldman-speech/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..7c1207e9cfe83bac59f76fcf21068405cd6c9551 --- /dev/null +++ b/conf/generated-v0/berta-goldman-speech/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +save_path: ./runs/berta-goldman-speech/coarse +train/AudioLoader.sources: +- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 +val/AudioLoader.sources: +- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 diff --git a/conf/generated-v0/berta-goldman-speech/interface.yml b/conf/generated-v0/berta-goldman-speech/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..d1ba35ec732a0148f3bced5542e27e85575c4d4e --- /dev/null +++ b/conf/generated-v0/berta-goldman-speech/interface.yml @@ -0,0 +1,5 @@ +AudioLoader.sources: +- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 +Interface.coarse2fine_ckpt: ./runs/berta-goldman-speech/c2f/best/vampnet/weights.pth +Interface.coarse_ckpt: ./runs/berta-goldman-speech/coarse/best/vampnet/weights.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated-v0/gamelan-xeno-canto/c2f.yml b/conf/generated-v0/gamelan-xeno-canto/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..9e6fec4ddc7dd0a2e02d1be66cc7f6eafa669ed1 --- /dev/null +++ b/conf/generated-v0/gamelan-xeno-canto/c2f.yml @@ -0,0 +1,17 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +save_path: ./runs/gamelan-xeno-canto/c2f +train/AudioLoader.sources: +- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 +- /media/CHONK/hugo/loras/xeno-canto-2 +val/AudioLoader.sources: +- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 +- /media/CHONK/hugo/loras/xeno-canto-2 diff --git a/conf/generated-v0/gamelan-xeno-canto/coarse.yml b/conf/generated-v0/gamelan-xeno-canto/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..7e8d38e18d714cb08db7ed456939737404533c3e --- /dev/null +++ b/conf/generated-v0/gamelan-xeno-canto/coarse.yml @@ -0,0 +1,10 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +save_path: ./runs/gamelan-xeno-canto/coarse +train/AudioLoader.sources: +- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 +- /media/CHONK/hugo/loras/xeno-canto-2 +val/AudioLoader.sources: +- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 +- /media/CHONK/hugo/loras/xeno-canto-2 diff --git a/conf/generated-v0/gamelan-xeno-canto/interface.yml b/conf/generated-v0/gamelan-xeno-canto/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..e567800477816ac1cc41719744c1ba40562e35b1 --- /dev/null +++ b/conf/generated-v0/gamelan-xeno-canto/interface.yml @@ -0,0 +1,6 @@ +AudioLoader.sources: +- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 +- /media/CHONK/hugo/loras/xeno-canto-2 +Interface.coarse2fine_ckpt: ./runs/gamelan-xeno-canto/c2f/best/vampnet/weights.pth +Interface.coarse_ckpt: ./runs/gamelan-xeno-canto/coarse/best/vampnet/weights.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated-v0/nasralla/c2f.yml b/conf/generated-v0/nasralla/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..9d9db7bed268c18f3ca4047dcde34dd18a5a2301 --- /dev/null +++ b/conf/generated-v0/nasralla/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +save_path: ./runs/nasralla/c2f +train/AudioLoader.sources: +- /media/CHONK/hugo/nasralla +val/AudioLoader.sources: +- /media/CHONK/hugo/nasralla diff --git a/conf/generated-v0/nasralla/coarse.yml b/conf/generated-v0/nasralla/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..43a4d18c7f955e38200ded0d2a4fa0959ddb639e --- /dev/null +++ b/conf/generated-v0/nasralla/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +save_path: ./runs/nasralla/coarse +train/AudioLoader.sources: +- /media/CHONK/hugo/nasralla +val/AudioLoader.sources: +- /media/CHONK/hugo/nasralla diff --git a/conf/generated-v0/nasralla/interface.yml b/conf/generated-v0/nasralla/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..c93e872d1e4b66567812755882a996814794ad8f --- /dev/null +++ b/conf/generated-v0/nasralla/interface.yml @@ -0,0 +1,5 @@ +AudioLoader.sources: +- /media/CHONK/hugo/nasralla +Interface.coarse2fine_ckpt: ./runs/nasralla/c2f/best/vampnet/weights.pth +Interface.coarse_ckpt: ./runs/nasralla/coarse/best/vampnet/weights.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/breaks-steps/c2f.yml b/conf/generated/breaks-steps/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..49617a6d52de00a9bc7c82c6e820168076402fac --- /dev/null +++ b/conf/generated/breaks-steps/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/breaks-steps/c2f +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/breaks-steps +val/AudioLoader.sources: *id001 diff --git a/conf/generated/breaks-steps/coarse.yml b/conf/generated/breaks-steps/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..71d9b27fbc4aac7d407d3606e98c4eaca35e2d3f --- /dev/null +++ b/conf/generated/breaks-steps/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/breaks-steps/coarse +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/breaks-steps +val/AudioLoader.sources: *id001 diff --git a/conf/generated/breaks-steps/interface.yml b/conf/generated/breaks-steps/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..b4b5182c4a378884e1614d89bc39abdf78a4eaa2 --- /dev/null +++ b/conf/generated/breaks-steps/interface.yml @@ -0,0 +1,7 @@ +AudioLoader.sources: +- - /media/CHONK/hugo/breaks-steps +Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/breaks-steps/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/breaks-steps/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/bulgarian-tv-choir/c2f.yml b/conf/generated/bulgarian-tv-choir/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..7bc54bf54bc8cc5c599a11f30c036822fa4b84c5 --- /dev/null +++ b/conf/generated/bulgarian-tv-choir/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/bulgarian-tv-choir/c2f +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/ +val/AudioLoader.sources: *id001 diff --git a/conf/generated/bulgarian-tv-choir/coarse.yml b/conf/generated/bulgarian-tv-choir/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..06f27f140dbd8c6d6315aab0787435ff501f8958 --- /dev/null +++ b/conf/generated/bulgarian-tv-choir/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/bulgarian-tv-choir/coarse +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/ +val/AudioLoader.sources: *id001 diff --git a/conf/generated/bulgarian-tv-choir/interface.yml b/conf/generated/bulgarian-tv-choir/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..b56e8d721adf99da361dadf423a669bb576478e1 --- /dev/null +++ b/conf/generated/bulgarian-tv-choir/interface.yml @@ -0,0 +1,7 @@ +AudioLoader.sources: +- - /media/CHONK/hugo/loras/bulgarian-female-tv-choir/ +Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/bulgarian-tv-choir/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/bulgarian-tv-choir/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/dariacore/c2f.yml b/conf/generated/dariacore/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..e8e52fc05be63fe891d3adf0c2115efd5e06ecef --- /dev/null +++ b/conf/generated/dariacore/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/dariacore/c2f +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/dariacore +val/AudioLoader.sources: *id001 diff --git a/conf/generated/dariacore/coarse.yml b/conf/generated/dariacore/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..42044d7bbafbf890d6d6bc504beb49edf977c39b --- /dev/null +++ b/conf/generated/dariacore/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/dariacore/coarse +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/dariacore +val/AudioLoader.sources: *id001 diff --git a/conf/generated/dariacore/interface.yml b/conf/generated/dariacore/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..29342d2fe9d97f20d9521885869f1cca16d2aeba --- /dev/null +++ b/conf/generated/dariacore/interface.yml @@ -0,0 +1,7 @@ +AudioLoader.sources: +- - /media/CHONK/hugo/loras/dariacore +Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/dariacore/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/dariacore/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/musica-bolero-marimba/c2f.yml b/conf/generated/musica-bolero-marimba/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..cd06c72814deaf9fd41d3dabc8e6046e050ad968 --- /dev/null +++ b/conf/generated/musica-bolero-marimba/c2f.yml @@ -0,0 +1,18 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/musica-bolero-marimba/c2f +train/AudioLoader.sources: +- /media/CHONK/hugo/loras/boleros +- /media/CHONK/hugo/loras/marimba-honduras +val/AudioLoader.sources: +- /media/CHONK/hugo/loras/boleros +- /media/CHONK/hugo/loras/marimba-honduras diff --git a/conf/generated/musica-bolero-marimba/coarse.yml b/conf/generated/musica-bolero-marimba/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..a3e1c0ee8e8593528cb389fb84c56894727cfca5 --- /dev/null +++ b/conf/generated/musica-bolero-marimba/coarse.yml @@ -0,0 +1,11 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/musica-bolero-marimba/coarse +train/AudioLoader.sources: +- /media/CHONK/hugo/loras/boleros +- /media/CHONK/hugo/loras/marimba-honduras +val/AudioLoader.sources: +- /media/CHONK/hugo/loras/boleros +- /media/CHONK/hugo/loras/marimba-honduras diff --git a/conf/generated/musica-bolero-marimba/interface.yml b/conf/generated/musica-bolero-marimba/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..08b42e3120a3cedbb5aafb9a39ca879d8958127a --- /dev/null +++ b/conf/generated/musica-bolero-marimba/interface.yml @@ -0,0 +1,8 @@ +AudioLoader.sources: +- /media/CHONK/hugo/loras/boleros +- /media/CHONK/hugo/loras/marimba-honduras +Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/musica-bolero-marimba/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/musica-bolero-marimba/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/panchos/c2f.yml b/conf/generated/panchos/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..4efd6fb4caf409382929dcf61d40ed37e3773eac --- /dev/null +++ b/conf/generated/panchos/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/panchos/c2f +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/panchos/ +val/AudioLoader.sources: *id001 diff --git a/conf/generated/panchos/coarse.yml b/conf/generated/panchos/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..c4f21a3f4deb58cd6b98680e82d59ad32098542e --- /dev/null +++ b/conf/generated/panchos/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/panchos/coarse +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/panchos/ +val/AudioLoader.sources: *id001 diff --git a/conf/generated/panchos/interface.yml b/conf/generated/panchos/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..8bae11c225a0fa49c27efdfc808a63d53c21755a --- /dev/null +++ b/conf/generated/panchos/interface.yml @@ -0,0 +1,7 @@ +AudioLoader.sources: +- - /media/CHONK/hugo/loras/panchos/ +Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/panchos/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/panchos/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/titi-monkey/c2f.yml b/conf/generated/titi-monkey/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..456912ab1589eee1dfe6c5768e70ede4e455c828 --- /dev/null +++ b/conf/generated/titi-monkey/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/titi-monkey/c2f +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/titi-monkey.mp3 +val/AudioLoader.sources: *id001 diff --git a/conf/generated/titi-monkey/coarse.yml b/conf/generated/titi-monkey/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..c2af934aa5aff33c26ae95a2d7a46eb19f9b7194 --- /dev/null +++ b/conf/generated/titi-monkey/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/titi-monkey/coarse +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/titi-monkey.mp3 +val/AudioLoader.sources: *id001 diff --git a/conf/generated/titi-monkey/interface.yml b/conf/generated/titi-monkey/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..cbc4ffad24c7c3b34e930aff08404955348b49a2 --- /dev/null +++ b/conf/generated/titi-monkey/interface.yml @@ -0,0 +1,7 @@ +AudioLoader.sources: +- - /media/CHONK/hugo/loras/titi-monkey.mp3 +Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/titi-monkey/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/titi-monkey/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/xeno-canto/c2f.yml b/conf/generated/xeno-canto/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..251b0e361ee15d01f7715608480cb3d5e9fdb122 --- /dev/null +++ b/conf/generated/xeno-canto/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/xeno-canto/c2f +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/xeno-canto-2/ +val/AudioLoader.sources: *id001 diff --git a/conf/generated/xeno-canto/coarse.yml b/conf/generated/xeno-canto/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..ea151dbb64ff13982b0004685901da2b58c8e596 --- /dev/null +++ b/conf/generated/xeno-canto/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/xeno-canto/coarse +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/xeno-canto-2/ +val/AudioLoader.sources: *id001 diff --git a/conf/generated/xeno-canto/interface.yml b/conf/generated/xeno-canto/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..1a8b1420f142cef024471073e674cd9db59ffad0 --- /dev/null +++ b/conf/generated/xeno-canto/interface.yml @@ -0,0 +1,7 @@ +AudioLoader.sources: +- - /media/CHONK/hugo/loras/xeno-canto-2/ +Interface.coarse2fine_ckpt: ./mod els/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/xeno-canto/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/xeno-canto/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/interface.yml b/conf/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..b574e8991218a2b9ed0822f1771cbf9d43dd1140 --- /dev/null +++ b/conf/interface.yml @@ -0,0 +1,10 @@ +Interface.coarse_ckpt: ./models/vampnet/coarse.pth +Interface.coarse2fine_ckpt: ./models/vampnet/c2f.pth +Interface.codec_ckpt: ./models/vampnet/codec.pth +Interface.coarse_chunk_size_s: 10 +Interface.coarse2fine_chunk_size_s: 3 +Interface.wavebeat_ckpt: ./models/wavebeat.pth + +# AudioLoader.sources: +# - /media/CHONK/null + diff --git a/conf/lora/birds.yml b/conf/lora/birds.yml new file mode 100644 index 0000000000000000000000000000000000000000..de413ec0dec4f974e664923c9319861a1c957e87 --- /dev/null +++ b/conf/lora/birds.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/birds + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/birds diff --git a/conf/lora/birdss.yml b/conf/lora/birdss.yml new file mode 100644 index 0000000000000000000000000000000000000000..3526de67d24e296de2cc0a7d2e5ebbc18245a6c8 --- /dev/null +++ b/conf/lora/birdss.yml @@ -0,0 +1,12 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/birds + - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/ + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/birds + - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/ diff --git a/conf/lora/constructions.yml b/conf/lora/constructions.yml new file mode 100644 index 0000000000000000000000000000000000000000..f513b4898e06339fa0d0b4af24e98fdf5289094a --- /dev/null +++ b/conf/lora/constructions.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3 + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3 diff --git a/conf/lora/ella-baila-sola.yml b/conf/lora/ella-baila-sola.yml new file mode 100644 index 0000000000000000000000000000000000000000..24eeada8013ea0d56d7d6474db52a48c3fd43bc1 --- /dev/null +++ b/conf/lora/ella-baila-sola.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3 + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3 diff --git a/conf/lora/gas-station.yml b/conf/lora/gas-station.yml new file mode 100644 index 0000000000000000000000000000000000000000..4369f9203232fa3dcfd21667f3e55d0d0fda108e --- /dev/null +++ b/conf/lora/gas-station.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3 + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3 diff --git a/conf/lora/lora-is-this-charlie-parker.yml b/conf/lora/lora-is-this-charlie-parker.yml new file mode 100644 index 0000000000000000000000000000000000000000..9cfaa31a421266fafa60a1ee4bb2d45f1c47577c --- /dev/null +++ b/conf/lora/lora-is-this-charlie-parker.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3 + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3 diff --git a/conf/lora/lora.yml b/conf/lora/lora.yml new file mode 100644 index 0000000000000000000000000000000000000000..b901ea00a6008b92f25728d6d01a258c6aba5d1e --- /dev/null +++ b/conf/lora/lora.yml @@ -0,0 +1,22 @@ +$include: + - conf/vampnet.yml + +fine_tune: True + +train/AudioDataset.n_examples: 10000000 + +val/AudioDataset.n_examples: 10 + + +NoamScheduler.warmup: 500 + +batch_size: 7 +num_workers: 7 +epoch_length: 100 +save_audio_epochs: 10 + +AdamW.lr: 0.0001 + +# let's us organize sound classes into folders and choose from those sound classes uniformly +AudioDataset.without_replacement: False +max_epochs: 500 \ No newline at end of file diff --git a/conf/lora/underworld.yml b/conf/lora/underworld.yml new file mode 100644 index 0000000000000000000000000000000000000000..6fd1a6cf1e74220a2b51b1117afb373acda033a7 --- /dev/null +++ b/conf/lora/underworld.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/underworld.mp3 + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/underworld.mp3 diff --git a/conf/lora/xeno-canto/c2f.yml b/conf/lora/xeno-canto/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..94f9906189f0b74b6c492bdd53fa56d58a0fa04d --- /dev/null +++ b/conf/lora/xeno-canto/c2f.yml @@ -0,0 +1,21 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/xeno-canto-2 + +val/AudioLoader.sources: + - /media/CHONK/hugo/xeno-canto-2 + + +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 + +VampNet.embedding_dim: 1280 +VampNet.n_layers: 16 +VampNet.n_heads: 20 + +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 diff --git a/conf/lora/xeno-canto/coarse.yml b/conf/lora/xeno-canto/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..223c8f0f8481f55ac1c33816ed79fe45b50f1495 --- /dev/null +++ b/conf/lora/xeno-canto/coarse.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/xeno-canto-2 + +val/AudioLoader.sources: + - /media/CHONK/hugo/xeno-canto-2 diff --git a/conf/vampnet-musdb-drums.yml b/conf/vampnet-musdb-drums.yml new file mode 100644 index 0000000000000000000000000000000000000000..010843d81ec9ac3c832b8e88f30af2f99a56ba99 --- /dev/null +++ b/conf/vampnet-musdb-drums.yml @@ -0,0 +1,22 @@ +$include: + - conf/vampnet.yml + +VampNet.embedding_dim: 512 +VampNet.n_layers: 12 +VampNet.n_heads: 8 + +AudioDataset.duration: 12.0 + +train/AudioDataset.n_examples: 10000000 +train/AudioLoader.sources: + - /data/musdb18hq/train/**/*drums.wav + + +val/AudioDataset.n_examples: 500 +val/AudioLoader.sources: + - /data/musdb18hq/test/**/*drums.wav + + +test/AudioDataset.n_examples: 1000 +test/AudioLoader.sources: + - /data/musdb18hq/test/**/*drums.wav diff --git a/conf/vampnet.yml b/conf/vampnet.yml new file mode 100644 index 0000000000000000000000000000000000000000..d24df3fc1923eeb98f76f5747a52c3e83ef98795 --- /dev/null +++ b/conf/vampnet.yml @@ -0,0 +1,59 @@ + +codec_ckpt: ./models/spotdl/codec.pth +save_path: ckpt +max_epochs: 1000 +epoch_length: 1000 +save_audio_epochs: 2 +val_idx: [0,1,2,3,4,5,6,7,8,9] + +prefix_amt: 0.0 +suffix_amt: 0.0 +prefix_dropout: 0.1 +suffix_dropout: 0.1 + +batch_size: 8 +num_workers: 10 + +# Optimization +detect_anomaly: false +amp: false + +CrossEntropyLoss.label_smoothing: 0.1 + +AdamW.lr: 0.001 + +NoamScheduler.factor: 2.0 +NoamScheduler.warmup: 10000 + +PitchShift.shift_amount: [const, 0] +PitchShift.prob: 0.0 + +VampNet.vocab_size: 1024 +VampNet.n_codebooks: 4 +VampNet.n_conditioning_codebooks: 0 +VampNet.r_cond_dim: 0 +VampNet.noise_mode: mask +VampNet.embedding_dim: 1280 +VampNet.n_layers: 20 +VampNet.n_heads: 20 +VampNet.flash_attn: false +VampNet.dropout: 0.1 + +AudioLoader.relative_path: /data/ +AudioDataset.loudness_cutoff: -30.0 +AudioDataset.without_replacement: true +AudioLoader.shuffle: true + +AudioDataset.duration: 10.0 + +train/AudioDataset.n_examples: 10000000 +train/AudioLoader.sources: + - /data/spotdl/audio/train + +val/AudioDataset.n_examples: 2000 +val/AudioLoader.sources: + - /data/spotdl/audio/val + +test/AudioDataset.n_examples: 1000 +test/AudioLoader.sources: + - /data/spotdl/audio/test diff --git a/models/vampnet/c2f.pth b/models/vampnet/c2f.pth new file mode 100644 index 0000000000000000000000000000000000000000..fd42af6fef4bd178a1d3e910e72b958cf6d861bb --- /dev/null +++ b/models/vampnet/c2f.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b10ea2d45459d34edb773cbacd71f40f7baa1f4e75ac8bcd93b022ac69f8fa63 +size 1101898865 diff --git a/models/vampnet/coarse.pth b/models/vampnet/coarse.pth new file mode 100644 index 0000000000000000000000000000000000000000..b0137951d0801e0eeb1bdac47187c93d3e94c977 --- /dev/null +++ b/models/vampnet/coarse.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:78e4ad4f8398e8ec3651bc5e5c6ea2995e1080b6226be186723ccf4320c9756c +size 1332182321 diff --git a/models/vampnet/codec.pth b/models/vampnet/codec.pth new file mode 100644 index 0000000000000000000000000000000000000000..82027c48fce17039bd9fc80bc4e089ac97cfb31e --- /dev/null +++ b/models/vampnet/codec.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3db3fa43ab5d160439ddb81fc540b5573ad5ae962230de3fc5b47d218845b855 +size 600996465 diff --git a/models/wavebeat.pth b/models/wavebeat.pth new file mode 100644 index 0000000000000000000000000000000000000000..7994a69f199749641f77e4fd4c16c10ac8088d6e --- /dev/null +++ b/models/wavebeat.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ff1066a4470cb98b20edf1e489f6995b19e0435b9cfd5a190bf90a954d0cadb +size 33248861 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4feab2a44cbfd4d65f5e7fdfbe2c5a3363790c6d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +torch +argbind>=0.3.2 +numpy==1.22 +gradio +loralib +wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat +lac @ git+https://github.com/hugofloresgarcia/lac.git +audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git \ No newline at end of file diff --git a/scripts/exp/eval.py b/scripts/exp/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..47b4cf4ee1a2dcb72bd6fb9797f22d85c2c7dac9 --- /dev/null +++ b/scripts/exp/eval.py @@ -0,0 +1,110 @@ +from pathlib import Path +import os +from functools import partial + +from frechet_audio_distance import FrechetAudioDistance +import pandas +import argbind +import torch +from tqdm import tqdm + +import audiotools +from audiotools import AudioSignal + +@argbind.bind(without_prefix=True) +def eval( + exp_dir: str = None, + baseline_key: str = "baseline", + audio_ext: str = ".wav", +): + assert exp_dir is not None + exp_dir = Path(exp_dir) + assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist" + + # set up our metrics + # sisdr_loss = audiotools.metrics.distance.SISDRLoss() + # stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss() + mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss() + frechet = FrechetAudioDistance( + use_pca=False, + use_activation=False, + verbose=True, + audio_load_worker=4, + ) + frechet.model.to("cuda" if torch.cuda.is_available() else "cpu") + + # figure out what conditions we have + conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()] + + assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}" + conditions.remove(baseline_key) + + print(f"Found {len(conditions)} conditions in {exp_dir}") + print(f"conditions: {conditions}") + + baseline_dir = exp_dir / baseline_key + baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem)) + + metrics = [] + for condition in tqdm(conditions): + cond_dir = exp_dir / condition + cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem)) + + print(f"computing fad for {baseline_dir} and {cond_dir}") + frechet_score = frechet.score(baseline_dir, cond_dir) + + # make sure we have the same number of files + num_files = min(len(baseline_files), len(cond_files)) + baseline_files = baseline_files[:num_files] + cond_files = cond_files[:num_files] + assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}" + + def process(baseline_file, cond_file): + # make sure the files match (same name) + assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match" + + # load the files + baseline_sig = AudioSignal(str(baseline_file)) + cond_sig = AudioSignal(str(cond_file)) + + cond_sig.resample(baseline_sig.sample_rate) + cond_sig.truncate_samples(baseline_sig.length) + + # if our condition is inpainting, we need to trim the conditioning off + if "inpaint" in condition: + ctx_amt = float(condition.split("_")[-1]) + ctx_samples = int(ctx_amt * baseline_sig.sample_rate) + print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}") + cond_sig.trim(ctx_samples, ctx_samples) + baseline_sig.trim(ctx_samples, ctx_samples) + + return { + # "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(), + # "stft": stft_loss(baseline_sig, cond_sig).item(), + "mel": mel_loss(baseline_sig, cond_sig).item(), + "frechet": frechet_score, + # "visqol": vsq, + "condition": condition, + "file": baseline_file.stem, + } + + print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}") + metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files))) + + metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")] + + + for mk in metric_keys: + stat = pandas.DataFrame(metrics) + stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std']) + stat.to_csv(exp_dir / f"stats-{mk}.csv") + + df = pandas.DataFrame(metrics) + df.to_csv(exp_dir / "metrics-all.csv", index=False) + + +if __name__ == "__main__": + args = argbind.parse_args() + + with argbind.scope(args): + eval() \ No newline at end of file diff --git a/scripts/exp/experiment.py b/scripts/exp/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..d4426736949111ad0195c7373cce7d7d71a71a63 --- /dev/null +++ b/scripts/exp/experiment.py @@ -0,0 +1,254 @@ +from pathlib import Path +import random +from typing import List +import tempfile +import subprocess + +import argbind +from tqdm import tqdm +import torch + +from vampnet.interface import Interface +from vampnet import mask as pmask +import audiotools as at + +Interface: Interface = argbind.bind(Interface) + + + +def calculate_bitrate( + interface, num_codebooks, + downsample_factor + ): + bit_width = 10 + sr = interface.codec.sample_rate + hop = interface.codec.hop_size + rate = (sr / hop) * ((bit_width * num_codebooks) / downsample_factor) + return rate + +def baseline(sig, interface): + return interface.preprocess(sig) + +def reconstructed(sig, interface): + return interface.to_signal( + interface.encode(sig) + ) + +def coarse2fine(sig, interface): + z = interface.encode(sig) + z = z[:, :interface.c2f.n_conditioning_codebooks, :] + + z = interface.coarse_to_fine(z) + return interface.to_signal(z) + +class CoarseCond: + + def __init__(self, num_conditioning_codebooks, downsample_factor): + self.num_conditioning_codebooks = num_conditioning_codebooks + self.downsample_factor = downsample_factor + + def __call__(self, sig, interface): + z = interface.encode(sig) + mask = pmask.full_mask(z) + mask = pmask.codebook_unmask(mask, self.num_conditioning_codebooks) + mask = pmask.periodic_mask(mask, self.downsample_factor) + + zv = interface.coarse_vamp(z, mask) + zv = interface.coarse_to_fine(zv) + return interface.to_signal(zv) + +def opus(sig, interface, bitrate=128): + sig = interface.preprocess(sig) + + with tempfile.NamedTemporaryFile(suffix=".wav") as f: + sig.write(f.name) + + opus_name = Path(f.name).with_suffix(".opus") + # convert to opus + cmd = [ + "ffmpeg", "-y", "-i", f.name, + "-c:a", "libopus", + "-b:a", f"{bitrate}", + opus_name + ] + subprocess.run(cmd, check=True) + + # convert back to wav + output_name = Path(f"{f.name}-opus").with_suffix(".wav") + cmd = [ + "ffmpeg", "-y", "-i", opus_name, + output_name + ] + + subprocess.run(cmd, check=True) + + sig = at.AudioSignal( + output_name, + sample_rate=sig.sample_rate + ) + return sig + +def mask_ratio_1_step(ratio=1.0): + def wrapper(sig, interface): + z = interface.encode(sig) + mask = pmask.linear_random(z, ratio) + zv = interface.coarse_vamp( + z, + mask, + sampling_steps=1, + ) + + return interface.to_signal(zv) + return wrapper + +def num_sampling_steps(num_steps=1): + def wrapper(sig, interface: Interface): + z = interface.encode(sig) + mask = pmask.periodic_mask(z, 16) + zv = interface.coarse_vamp( + z, + mask, + sampling_steps=num_steps, + ) + + zv = interface.coarse_to_fine(zv) + return interface.to_signal(zv) + return wrapper + +def beat_mask(ctx_time): + def wrapper(sig, interface): + beat_mask = interface.make_beat_mask( + sig, + before_beat_s=ctx_time/2, + after_beat_s=ctx_time/2, + invert=True + ) + + z = interface.encode(sig) + + zv = interface.coarse_vamp( + z, beat_mask + ) + + zv = interface.coarse_to_fine(zv) + return interface.to_signal(zv) + return wrapper + +def inpaint(ctx_time): + def wrapper(sig, interface: Interface): + z = interface.encode(sig) + mask = pmask.inpaint(z, interface.s2t(ctx_time), interface.s2t(ctx_time)) + + zv = interface.coarse_vamp(z, mask) + zv = interface.coarse_to_fine(zv) + + return interface.to_signal(zv) + return wrapper + +def token_noise(noise_amt): + def wrapper(sig, interface: Interface): + z = interface.encode(sig) + mask = pmask.random(z, noise_amt) + z = torch.where( + mask, + torch.randint_like(z, 0, interface.coarse.vocab_size), + z + ) + return interface.to_signal(z) + return wrapper + +EXP_REGISTRY = {} + +EXP_REGISTRY["gen-compression"] = { + "baseline": baseline, + "reconstructed": reconstructed, + "coarse2fine": coarse2fine, + **{ + f"{n}_codebooks_downsampled_{x}x": CoarseCond(num_conditioning_codebooks=n, downsample_factor=x) + for (n, x) in ( + (1, 1), # 1 codebook, no downsampling + (4, 4), # 4 codebooks, downsampled 4x + (4, 16), # 4 codebooks, downsampled 16x + (4, 32), # 4 codebooks, downsampled 16x + ) + }, + **{ + f"token_noise_{x}": mask_ratio_1_step(ratio=x) + for x in [0.25, 0.5, 0.75] + }, + +} + + +EXP_REGISTRY["sampling-steps"] = { + # "codec": reconstructed, + **{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 36, 64, 72]}, +} + + +EXP_REGISTRY["musical-sampling"] = { + **{f"beat_mask_{t}": beat_mask(t) for t in [0.075]}, + **{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right) +} + +@argbind.bind(without_prefix=True) +def main( + sources=[ + "/media/CHONK/hugo/spotdl/val", + ], + output_dir: str = "./samples", + max_excerpts: int = 2000, + exp_type: str = "gen-compression", + seed: int = 0, + ext: str = [".mp3"], + ): + at.util.seed(seed) + interface = Interface() + + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + + from audiotools.data.datasets import AudioLoader, AudioDataset + + loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext) + dataset = AudioDataset(loader, + sample_rate=interface.codec.sample_rate, + duration=interface.coarse.chunk_size_s, + n_examples=max_excerpts, + without_replacement=True, + ) + + if exp_type in EXP_REGISTRY: + SAMPLE_CONDS = EXP_REGISTRY[exp_type] + else: + raise ValueError(f"Unknown exp_type {exp_type}") + + + indices = list(range(max_excerpts)) + random.shuffle(indices) + for i in tqdm(indices): + # if all our files are already there, skip + done = [] + for name in SAMPLE_CONDS: + o_dir = Path(output_dir) / name + done.append((o_dir / f"{i}.wav").exists()) + if all(done): + continue + + sig = dataset[i]["signal"] + results = { + name: cond(sig, interface).cpu() + for name, cond in SAMPLE_CONDS.items() + } + + for name, sig in results.items(): + o_dir = Path(output_dir) / name + o_dir.mkdir(exist_ok=True, parents=True) + + sig.write(o_dir / f"{i}.wav") + +if __name__ == "__main__": + args = argbind.parse_args() + + with argbind.scope(args): + main() diff --git a/scripts/exp/fine_tune.py b/scripts/exp/fine_tune.py new file mode 100644 index 0000000000000000000000000000000000000000..e2c6c3b768f585242705e5cdabeebe45ced557cf --- /dev/null +++ b/scripts/exp/fine_tune.py @@ -0,0 +1,82 @@ +import argbind +from pathlib import Path +import yaml +from typing import List + + + + +"""example output: (yaml) + +""" + +@argbind.bind(without_prefix=True, positional=True) +def fine_tune(audio_files_or_folders: List[str], name: str): + + conf_dir = Path("conf") + assert conf_dir.exists(), "conf directory not found. are you in the vampnet directory?" + + conf_dir = conf_dir / "generated" + conf_dir.mkdir(exist_ok=True) + + finetune_dir = conf_dir / name + finetune_dir.mkdir(exist_ok=True) + + finetune_c2f_conf = { + "$include": ["conf/lora/lora.yml"], + "fine_tune": True, + "train/AudioLoader.sources": audio_files_or_folders, + "val/AudioLoader.sources": audio_files_or_folders, + "VampNet.n_codebooks": 14, + "VampNet.n_conditioning_codebooks": 4, + "VampNet.embedding_dim": 1280, + "VampNet.n_layers": 16, + "VampNet.n_heads": 20, + "AudioDataset.duration": 3.0, + "AudioDataset.loudness_cutoff": -40.0, + "save_path": f"./runs/{name}/c2f", + "fine_tune_checkpoint": "./models/spotdl/c2f.pth" + } + + finetune_coarse_conf = { + "$include": ["conf/lora/lora.yml"], + "fine_tune": True, + "train/AudioLoader.sources": audio_files_or_folders, + "val/AudioLoader.sources": audio_files_or_folders, + "save_path": f"./runs/{name}/coarse", + "fine_tune_checkpoint": "./models/spotdl/coarse.pth" + } + + interface_conf = { + "Interface.coarse_ckpt": f"./models/spotdl/coarse.pth", + "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth", + + "Interface.coarse2fine_ckpt": f"./models/spotdl/c2f.pth", + "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth", + + "Interface.codec_ckpt": "./models/spotdl/codec.pth", + "AudioLoader.sources": [audio_files_or_folders], + } + + # save the confs + with open(finetune_dir / "c2f.yml", "w") as f: + yaml.dump(finetune_c2f_conf, f) + + with open(finetune_dir / "coarse.yml", "w") as f: + yaml.dump(finetune_coarse_conf, f) + + with open(finetune_dir / "interface.yml", "w") as f: + yaml.dump(interface_conf, f) + + + print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/.yml` ") + +if __name__ == "__main__": + args = argbind.parse_args() + + with argbind.scope(args): + fine_tune() + + + + \ No newline at end of file diff --git a/scripts/exp/train.py b/scripts/exp/train.py new file mode 100644 index 0000000000000000000000000000000000000000..79251a529c9512b7bf8c2613e6ae173df21c5c61 --- /dev/null +++ b/scripts/exp/train.py @@ -0,0 +1,621 @@ +import os +import subprocess +import time +import warnings +from pathlib import Path +from typing import Optional + +import argbind +import audiotools as at +import torch +import torch.nn as nn +from audiotools import AudioSignal +from audiotools.data import transforms +from einops import rearrange +from rich import pretty +from rich.traceback import install +from tensorboardX import SummaryWriter + +import vampnet +from vampnet.modules.transformer import VampNet +from vampnet.util import codebook_unflatten, codebook_flatten +from vampnet import mask as pmask +# from dac.model.dac import DAC +from lac.model.lac import LAC as DAC + + +# Enable cudnn autotuner to speed up training +# (can be altered by the funcs.seed function) +torch.backends.cudnn.benchmark = bool(int(os.getenv("CUDNN_BENCHMARK", 1))) +# Uncomment to trade memory for speed. + +# Install to make things look nice +warnings.filterwarnings("ignore", category=UserWarning) +pretty.install() +install() + +# optim +Accelerator = argbind.bind(at.ml.Accelerator, without_prefix=True) +CrossEntropyLoss = argbind.bind(nn.CrossEntropyLoss) +AdamW = argbind.bind(torch.optim.AdamW) +NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler) + +# transforms +filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [ + "BaseTransform", + "Compose", + "Choose", +] +tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn) + +# model +VampNet = argbind.bind(VampNet) + + +# data +AudioLoader = argbind.bind(at.datasets.AudioLoader) +AudioDataset = argbind.bind(at.datasets.AudioDataset, "train", "val") + +IGNORE_INDEX = -100 + + +@argbind.bind("train", "val", without_prefix=True) +def build_transform(): + transform = transforms.Compose( + tfm.VolumeNorm(("const", -24)), + # tfm.PitchShift(), + tfm.RescaleAudio(), + ) + return transform + + +@torch.no_grad() +def apply_transform(transform_fn, batch): + sig: AudioSignal = batch["signal"] + kwargs = batch["transform_args"] + + sig: AudioSignal = transform_fn(sig.clone(), **kwargs) + return sig + + +def build_datasets(args, sample_rate: int): + with argbind.scope(args, "train"): + train_data = AudioDataset( + AudioLoader(), sample_rate, transform=build_transform() + ) + with argbind.scope(args, "val"): + val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform()) + with argbind.scope(args, "test"): + test_data = AudioDataset( + AudioLoader(), sample_rate, transform=build_transform() + ) + return train_data, val_data, test_data + + +def rand_float(shape, low, high, rng): + return rng.draw(shape)[:, 0] * (high - low) + low + + +def flip_coin(shape, p, rng): + return rng.draw(shape)[:, 0] < p + + +@argbind.bind(without_prefix=True) +def load( + args, + accel: at.ml.Accelerator, + save_path: str, + resume: bool = False, + tag: str = "latest", + load_weights: bool = False, + fine_tune_checkpoint: Optional[str] = None, +): + codec = DAC.load(args["codec_ckpt"], map_location="cpu") + codec.eval() + + model, v_extra = None, {} + + if resume: + kwargs = { + "folder": f"{save_path}/{tag}", + "map_location": "cpu", + "package": not load_weights, + } + if (Path(kwargs["folder"]) / "vampnet").exists(): + model, v_extra = VampNet.load_from_folder(**kwargs) + else: + raise ValueError( + f"Could not find a VampNet checkpoint in {kwargs['folder']}" + ) + + + if args["fine_tune"]: + assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint" + model = VampNet.load(location=Path(fine_tune_checkpoint), map_location="cpu") + + + model = VampNet() if model is None else model + + model = accel.prepare_model(model) + + # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks + assert ( + accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size + ) + + optimizer = AdamW(model.parameters(), use_zero=accel.use_ddp) + scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim) + scheduler.step() + + trainer_state = {"state_dict": None, "start_idx": 0} + + if "optimizer.pth" in v_extra: + optimizer.load_state_dict(v_extra["optimizer.pth"]) + if "scheduler.pth" in v_extra: + scheduler.load_state_dict(v_extra["scheduler.pth"]) + if "trainer.pth" in v_extra: + trainer_state = v_extra["trainer.pth"] + + return { + "model": model, + "codec": codec, + "optimizer": optimizer, + "scheduler": scheduler, + "trainer_state": trainer_state, + } + + + +def num_params_hook(o, p): + return o + f" {p/1e6:<.3f}M params." + + +def add_num_params_repr_hook(model): + import numpy as np + from functools import partial + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + + setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p)) + + +def accuracy( + preds: torch.Tensor, + target: torch.Tensor, + top_k: int = 1, + ignore_index: Optional[int] = None, +) -> torch.Tensor: + # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class) + preds = rearrange(preds, "b p s -> (b s) p") + target = rearrange(target, "b s -> (b s)") + + # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index) + if ignore_index is not None: + # Create a mask for the ignored index + mask = target != ignore_index + # Apply the mask to the target and predictions + preds = preds[mask] + target = target[mask] + + # Get the top-k predicted classes and their indices + _, pred_indices = torch.topk(preds, k=top_k, dim=-1) + + # Determine if the true target is in the top-k predicted classes + correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1) + + # Calculate the accuracy + accuracy = torch.mean(correct.float()) + + return accuracy + + +@argbind.bind(without_prefix=True) +def train( + args, + accel: at.ml.Accelerator, + codec_ckpt: str = None, + seed: int = 0, + save_path: str = "ckpt", + max_epochs: int = int(100e3), + epoch_length: int = 1000, + save_audio_epochs: int = 2, + save_epochs: list = [10, 50, 100, 200, 300, 400,], + batch_size: int = 48, + grad_acc_steps: int = 1, + val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + num_workers: int = 10, + detect_anomaly: bool = False, + grad_clip_val: float = 5.0, + fine_tune: bool = False, + quiet: bool = False, +): + assert codec_ckpt is not None, "codec_ckpt is required" + + seed = seed + accel.local_rank + at.util.seed(seed) + writer = None + + if accel.local_rank == 0: + writer = SummaryWriter(log_dir=f"{save_path}/logs/") + argbind.dump_args(args, f"{save_path}/args.yml") + + # load the codec model + loaded = load(args, accel, save_path) + model = loaded["model"] + codec = loaded["codec"] + optimizer = loaded["optimizer"] + scheduler = loaded["scheduler"] + trainer_state = loaded["trainer_state"] + + sample_rate = codec.sample_rate + + # a better rng for sampling from our schedule + rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=seed) + + # log a model summary w/ num params + if accel.local_rank == 0: + add_num_params_repr_hook(accel.unwrap(model)) + with open(f"{save_path}/model.txt", "w") as f: + f.write(repr(accel.unwrap(model))) + + # load the datasets + train_data, val_data, _ = build_datasets(args, sample_rate) + train_dataloader = accel.prepare_dataloader( + train_data, + start_idx=trainer_state["start_idx"], + num_workers=num_workers, + batch_size=batch_size, + collate_fn=train_data.collate, + ) + val_dataloader = accel.prepare_dataloader( + val_data, + start_idx=0, + num_workers=num_workers, + batch_size=batch_size, + collate_fn=val_data.collate, + ) + + criterion = CrossEntropyLoss() + + if fine_tune: + import loralib as lora + lora.mark_only_lora_as_trainable(model) + + + class Trainer(at.ml.BaseTrainer): + _last_grad_norm = 0.0 + + def _metrics(self, vn, z_hat, r, target, flat_mask, output): + for r_range in [(0, 0.5), (0.5, 1.0)]: + unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX) + masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) + + assert target.shape[0] == r.shape[0] + # grab the indices of the r values that are in the range + r_idx = (r >= r_range[0]) & (r < r_range[1]) + + # grab the target and z_hat values that are in the range + r_unmasked_target = unmasked_target[r_idx] + r_masked_target = masked_target[r_idx] + r_z_hat = z_hat[r_idx] + + for topk in (1, 25): + s, e = r_range + tag = f"accuracy-{s}-{e}/top{topk}" + + output[f"{tag}/unmasked"] = accuracy( + preds=r_z_hat, + target=r_unmasked_target, + ignore_index=IGNORE_INDEX, + top_k=topk, + ) + output[f"{tag}/masked"] = accuracy( + preds=r_z_hat, + target=r_masked_target, + ignore_index=IGNORE_INDEX, + top_k=topk, + ) + + def train_loop(self, engine, batch): + model.train() + batch = at.util.prepare_batch(batch, accel.device) + signal = apply_transform(train_data.transform, batch) + + output = {} + vn = accel.unwrap(model) + with accel.autocast(): + with torch.inference_mode(): + codec.to(accel.device) + z = codec.encode(signal.samples, signal.sample_rate)["codes"] + z = z[:, : vn.n_codebooks, :] + + n_batch = z.shape[0] + r = rng.draw(n_batch)[:, 0].to(accel.device) + + mask = pmask.random(z, r) + mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) + z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) + + z_mask_latent = vn.embedding.from_codes(z_mask, codec) + + dtype = torch.bfloat16 if accel.amp else None + with accel.autocast(dtype=dtype): + z_hat = model(z_mask_latent, r) + + target = codebook_flatten( + z[:, vn.n_conditioning_codebooks :, :], + ) + + flat_mask = codebook_flatten( + mask[:, vn.n_conditioning_codebooks :, :], + ) + + # replace target with ignore index for masked tokens + t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) + output["loss"] = criterion(z_hat, t_masked) + + self._metrics( + vn=vn, + r=r, + z_hat=z_hat, + target=target, + flat_mask=flat_mask, + output=output, + ) + + + accel.backward(output["loss"] / grad_acc_steps) + + output["other/learning_rate"] = optimizer.param_groups[0]["lr"] + output["other/batch_size"] = z.shape[0] + + if ( + (engine.state.iteration % grad_acc_steps == 0) + or (engine.state.iteration % epoch_length == 0) + or (engine.state.iteration % epoch_length == 1) + ): # (or we reached the end of the epoch) + accel.scaler.unscale_(optimizer) + output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_( + model.parameters(), grad_clip_val + ) + self._last_grad_norm = output["other/grad_norm"] + + accel.step(optimizer) + optimizer.zero_grad() + + scheduler.step() + accel.update() + else: + output["other/grad_norm"] = self._last_grad_norm + + return {k: v for k, v in sorted(output.items())} + + @torch.no_grad() + def val_loop(self, engine, batch): + model.eval() + codec.eval() + batch = at.util.prepare_batch(batch, accel.device) + signal = apply_transform(val_data.transform, batch) + + vn = accel.unwrap(model) + z = codec.encode(signal.samples, signal.sample_rate)["codes"] + z = z[:, : vn.n_codebooks, :] + + n_batch = z.shape[0] + r = rng.draw(n_batch)[:, 0].to(accel.device) + + mask = pmask.random(z, r) + mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) + z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) + + z_mask_latent = vn.embedding.from_codes(z_mask, codec) + + z_hat = model(z_mask_latent, r) + + target = codebook_flatten( + z[:, vn.n_conditioning_codebooks :, :], + ) + + flat_mask = codebook_flatten( + mask[:, vn.n_conditioning_codebooks :, :] + ) + + output = {} + # replace target with ignore index for masked tokens + t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) + output["loss"] = criterion(z_hat, t_masked) + + self._metrics( + vn=vn, + r=r, + z_hat=z_hat, + target=target, + flat_mask=flat_mask, + output=output, + ) + + return output + + def checkpoint(self, engine): + if accel.local_rank != 0: + print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}") + return + + metadata = {"logs": dict(engine.state.logs["epoch"])} + + if self.state.epoch % save_audio_epochs == 0: + self.save_samples() + + tags = ["latest"] + loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train" + self.print(f"Saving to {str(Path('.').absolute())}") + + if self.state.epoch in save_epochs: + tags.append(f"epoch={self.state.epoch}") + + if self.is_best(engine, loss_key): + self.print(f"Best model so far") + tags.append("best") + + if fine_tune: + for tag in tags: + # save the lora model + (Path(save_path) / tag).mkdir(parents=True, exist_ok=True) + torch.save( + lora.lora_state_dict(accel.unwrap(model)), + f"{save_path}/{tag}/lora.pth" + ) + + for tag in tags: + model_extra = { + "optimizer.pth": optimizer.state_dict(), + "scheduler.pth": scheduler.state_dict(), + "trainer.pth": { + "start_idx": self.state.iteration * batch_size, + "state_dict": self.state_dict(), + }, + "metadata.pth": metadata, + } + + accel.unwrap(model).metadata = metadata + accel.unwrap(model).save_to_folder( + f"{save_path}/{tag}", model_extra, + ) + + def save_sampled(self, z): + num_samples = z.shape[0] + + for i in range(num_samples): + sampled = accel.unwrap(model).generate( + codec=codec, + time_steps=z.shape[-1], + start_tokens=z[i : i + 1], + ) + sampled.cpu().write_audio_to_tb( + f"sampled/{i}", + self.writer, + step=self.state.epoch, + plot_fn=None, + ) + + + def save_imputation(self, z: torch.Tensor): + n_prefix = int(z.shape[-1] * 0.25) + n_suffix = int(z.shape[-1] * 0.25) + + vn = accel.unwrap(model) + + mask = pmask.inpaint(z, n_prefix, n_suffix) + mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) + z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) + + imputed_noisy = vn.to_signal(z_mask, codec) + imputed_true = vn.to_signal(z, codec) + + imputed = [] + for i in range(len(z)): + imputed.append( + vn.generate( + codec=codec, + time_steps=z.shape[-1], + start_tokens=z[i][None, ...], + mask=mask[i][None, ...], + ) + ) + imputed = AudioSignal.batch(imputed) + + for i in range(len(val_idx)): + imputed_noisy[i].cpu().write_audio_to_tb( + f"imputed_noisy/{i}", + self.writer, + step=self.state.epoch, + plot_fn=None, + ) + imputed[i].cpu().write_audio_to_tb( + f"imputed/{i}", + self.writer, + step=self.state.epoch, + plot_fn=None, + ) + imputed_true[i].cpu().write_audio_to_tb( + f"imputed_true/{i}", + self.writer, + step=self.state.epoch, + plot_fn=None, + ) + + @torch.no_grad() + def save_samples(self): + model.eval() + codec.eval() + vn = accel.unwrap(model) + + batch = [val_data[i] for i in val_idx] + batch = at.util.prepare_batch(val_data.collate(batch), accel.device) + + signal = apply_transform(val_data.transform, batch) + + z = codec.encode(signal.samples, signal.sample_rate)["codes"] + z = z[:, : vn.n_codebooks, :] + + r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device) + + + mask = pmask.random(z, r) + mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) + z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) + + z_mask_latent = vn.embedding.from_codes(z_mask, codec) + + z_hat = model(z_mask_latent, r) + + z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1) + z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks) + z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1) + + generated = vn.to_signal(z_pred, codec) + reconstructed = vn.to_signal(z, codec) + masked = vn.to_signal(z_mask.squeeze(1), codec) + + for i in range(generated.batch_size): + audio_dict = { + "original": signal[i], + "masked": masked[i], + "generated": generated[i], + "reconstructed": reconstructed[i], + } + for k, v in audio_dict.items(): + v.cpu().write_audio_to_tb( + f"samples/_{i}.r={r[i]:0.2f}/{k}", + self.writer, + step=self.state.epoch, + plot_fn=None, + ) + + self.save_sampled(z) + self.save_imputation(z) + + trainer = Trainer(writer=writer, quiet=quiet) + + if trainer_state["state_dict"] is not None: + trainer.load_state_dict(trainer_state["state_dict"]) + if hasattr(train_dataloader.sampler, "set_epoch"): + train_dataloader.sampler.set_epoch(trainer.trainer.state.epoch) + + trainer.run( + train_dataloader, + val_dataloader, + num_epochs=max_epochs, + epoch_length=epoch_length, + detect_anomaly=detect_anomaly, + ) + + +if __name__ == "__main__": + args = argbind.parse_args() + args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0 + with argbind.scope(args): + with Accelerator() as accel: + train(args, accel) diff --git a/scripts/utils/README.md b/scripts/utils/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9fa13389f47b0432d9e20933bf77dfd9dfcb2aa7 --- /dev/null +++ b/scripts/utils/README.md @@ -0,0 +1,28 @@ +# Scripts + +## process_zip.py + +Some requirements that may not be installed in the docker image: +* argbind +* wav2wav (pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git or `pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git@`) + +### zip folder structure + +The zip folder should have the following internal structure: + +``` +base_folder/ + test_case_1/ + before.wav + test_case_2/ + before.wav + ... + test_case_n/ + before.wav +``` + +Note: There can be issues with the output zip if the input zip folder structure is too deep or too shallow. IF you want/need to use a zip file with a different folder structure, adjust this: +https://github.com/descriptinc/lyrebird-wav2wav/blob/136c923ce19df03876a515ca0ed83854710cfa30/scripts/utils/process_zip.py#L28 + +### Execution +`python process_zip.py -tag ` diff --git a/scripts/utils/augment.py b/scripts/utils/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..e971bf8404fb328d556a56f1e997a211e355dbaa --- /dev/null +++ b/scripts/utils/augment.py @@ -0,0 +1,53 @@ +from pathlib import Path + +import audiotools as at +from audiotools import AudioSignal + +import argbind +import tqdm + + +from pedalboard import ( + Compressor, Gain, Chorus, LadderFilter, Phaser, Convolution, Reverb, Pedalboard +) +from pedalboard.io import AudioFile + +# Read in a whole file, resampling to our desired sample rate: +samplerate = 44100.0 +with AudioFile('guitar-input.wav').resampled_to(samplerate) as f: + audio = f.read(f.frames) + +# Make a pretty interesting sounding guitar pedalboard: +board = Pedalboard([ + Compressor(threshold_db=-50, ratio=25), + Gain(gain_db=30), + Chorus(), + LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900), + Phaser(), + Convolution("./guitar_amp.wav", 1.0), + Reverb(room_size=0.25), +]) + + +@argbind.bind(without_prefix=True) +def augment( + audio_folder: Path, + dest_folder: Path, + n_augmentations: int = 10, +): + """ + Augment a folder of audio files by applying audiotools and pedalboard transforms. + + The dest foler will contain a folder for each of the clean dataset's files. + Under each of these folders, there will be a clean file and many augmented files. + """ + + audio_files = at.util.find_audio(audio_folder) + + for audio_file in tqdm.tqdm(audio_files): + subtree = dest_folder / audio_file.relative_to(audio_folder).parent + subdir = subtree / audio_file.stem + subdir.mkdir(parents=True, exist_ok=True) + + # apply pedalboard transforms + for i in range(n_augmentations): diff --git a/scripts/utils/maestro-reorg.py b/scripts/utils/maestro-reorg.py new file mode 100644 index 0000000000000000000000000000000000000000..96b65cae514165ad6c286146f94fe84cd305380e --- /dev/null +++ b/scripts/utils/maestro-reorg.py @@ -0,0 +1,39 @@ +from pathlib import Path +import json +import os + +maestro_path = Path("/media/CHONK/hugo/maestro-v3.0.0") +output_path = Path("/media/CHONK/hugo/maestro-v3.0.0-split") + +# split +with open(maestro_path / "maestro-v3.0.0.json") as f: + maestro = json.load(f) + +breakpoint() +train = [] +validation = [] +test = [] +for key, split in maestro["split"].items(): + audio_filename = maestro['audio_filename'][key] + if split == "train": + train.append(audio_filename) + elif split == "test": + test.append(audio_filename) + elif split == "validation": + validation.append(audio_filename) + else: + raise ValueError(f"Unknown split {split}") + +# symlink all files +for audio_filename in train: + p = output_path / "train" / audio_filename + p.parent.mkdir(parents=True, exist_ok=True) + os.symlink(maestro_path / audio_filename, p) +for audio_filename in validation: + p = output_path / "validation" / audio_filename + p.parent.mkdir(parents=True, exist_ok=True) + os.symlink(maestro_path / audio_filename, p) +for audio_filename in test: + p = output_path / "test" / audio_filename + p.parent.mkdir(parents=True, exist_ok=True) + os.symlink(maestro_path / audio_filename, p) \ No newline at end of file diff --git a/scripts/utils/plots.py b/scripts/utils/plots.py new file mode 100644 index 0000000000000000000000000000000000000000..955f891505a5d29c0d72d967905e3d62db04f5fa --- /dev/null +++ b/scripts/utils/plots.py @@ -0,0 +1,43 @@ +import matplotlib.pyplot as plt +import seaborn as sns +from pandas.api.types import CategoricalDtype + +def plot_metrics(metrics, condition_to_latex, title, color_palette): + # Add a new column to your dataframe with the latex representation + metrics['condition_latex'] = metrics['condition'].map(condition_to_latex) + + # Order condition_latex as per the condition_to_latex dictionary + cat_type = CategoricalDtype(categories=condition_to_latex.values(), ordered=True) + metrics['condition_latex'] = metrics['condition_latex'].astype(cat_type) + + # Compute mean and std for each condition for each metric + grouped = metrics.groupby('condition_latex')[['mel', 'frechet']].agg(['mean', 'std']) + + fig, axs = plt.subplots(2, 1, figsize=(7, 5.25)) + + # Set the main title for the figure + fig.suptitle(title, fontsize=16) + + # Get color for each bar in the plot + bar_colors = [color_palette[condition] for condition in grouped.index] + + # Plot mel + sns.boxplot(x='condition_latex', y='mel', data=metrics, ax=axs[0], palette=color_palette, showfliers=False) + axs[0].set_ylabel('Mel Spectrogram Loss \u2190') + axs[0].set_xlabel('') # Remove x-axis label + axs[0].set_xticklabels(grouped.index, rotation=0, ha='center') + + # Plot frechet + axs[1].bar(grouped.index, grouped['frechet']['mean'], yerr=grouped['frechet']['std'], color=bar_colors) + axs[1].set_ylabel('FAD \u2190') + axs[1].set_xlabel('') # Remove x-axis label + axs[1].set_xticklabels(grouped.index, rotation=0, ha='center') + + # Adjust the space between plots + plt.subplots_adjust(hspace=0.1) + + # Remove any unnecessary space around the plot + plt.tight_layout(rect=[0, 0, 1, 0.96]) + + # Reduce the space between suptitle and the plot + plt.subplots_adjust(top=0.92) \ No newline at end of file diff --git a/scripts/utils/split.py b/scripts/utils/split.py new file mode 100644 index 0000000000000000000000000000000000000000..8ddb9b27b8854b6bf84e8404b56834564996e637 --- /dev/null +++ b/scripts/utils/split.py @@ -0,0 +1,64 @@ +from pathlib import Path +import random +import shutil +import os +import json + +import argbind +from tqdm import tqdm +from tqdm.contrib.concurrent import thread_map + +from audiotools.core import util + + +@argbind.bind(without_prefix=True) +def train_test_split( + audio_folder: str = ".", + test_size: float = 0.2, + seed: int = 42, + pattern: str = "**/*.mp3", +): + print(f"finding audio") + + audio_folder = Path(audio_folder) + audio_files = list(tqdm(audio_folder.glob(pattern))) + print(f"found {len(audio_files)} audio files") + + # split according to test_size + n_test = int(len(audio_files) * test_size) + n_train = len(audio_files) - n_test + + # shuffle + random.seed(seed) + random.shuffle(audio_files) + + train_files = audio_files[:n_train] + test_files = audio_files[n_train:] + + + print(f"Train files: {len(train_files)}") + print(f"Test files: {len(test_files)}") + continue_ = input("Continue [yn]? ") or "n" + + if continue_ != "y": + return + + for split, files in ( + ("train", train_files), ("test", test_files) + ): + for file in tqdm(files): + out_file = audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name + out_file.parent.mkdir(exist_ok=True, parents=True) + os.symlink(file, out_file) + + # save split as json + with open(Path(audio_folder) / f"{split}.json", "w") as f: + json.dump([str(f) for f in files], f) + + + +if __name__ == "__main__": + args = argbind.parse_args() + + with argbind.scope(args): + train_test_split() \ No newline at end of file diff --git a/scripts/utils/stage.py b/scripts/utils/stage.py new file mode 100644 index 0000000000000000000000000000000000000000..253e1d070ccf3754be01578d22b65136858fa697 --- /dev/null +++ b/scripts/utils/stage.py @@ -0,0 +1,30 @@ +import os +import subprocess +from pathlib import Path + +import argbind +import rich +from audiotools.ml import Experiment + + +@argbind.bind(without_prefix=True) +def run( + run_dir: str = os.getenv("PATH_TO_RUNS", "runs"), + name: str = None, + recent: bool = False, +): + if recent: + paths = sorted(Path(run_dir).iterdir(), key=os.path.getmtime) + paths = [p.name for p in paths if p.is_dir()] + if paths: + name = paths[-1] + + with Experiment(run_dir, name) as exp: + exp.snapshot() + rich.print(f"Created a snapshot of {exp.parent_directory} at {exp.exp_dir}") + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + run() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..2964e0f810f32dab3abc433912a2de128c081761 --- /dev/null +++ b/setup.py @@ -0,0 +1,39 @@ +from setuptools import find_packages +from setuptools import setup + +with open("README.md") as f: + long_description = f.read() + +setup( + name="vampnet", + version="0.0.1", + classifiers=[ + "Intended Audience :: Developers", + "Natural Language :: English", + "Programming Language :: Python :: 3.7", + "Topic :: Artistic Software", + "Topic :: Multimedia", + "Topic :: Multimedia :: Sound/Audio", + "Topic :: Multimedia :: Sound/Audio :: Editors", + "Topic :: Software Development :: Libraries", + ], + description="Generative Music Modeling.", + long_description=long_description, + long_description_content_type="text/markdown", + author="Hugo Flores García, Prem Seetharaman", + author_email="hfgacrcia@descript.com", + url="https://github.com/hugofloresgarcia/vampnet", + license="MIT", + packages=find_packages(), + install_requires=[ + "torch", + "argbind>=0.3.2", + "numpy==1.22", + "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat", + "lac @ git+https://github.com/hugofloresgarcia/lac.git", + "audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git", + "gradio", + "tensorboardX", + "loralib", + ], +) diff --git a/vampnet/__init__.py b/vampnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a9dd073cacc51dd2996e856cf6c0eb87dd93ce0 --- /dev/null +++ b/vampnet/__init__.py @@ -0,0 +1,6 @@ + +from . import modules +from . import scheduler +from .interface import Interface + +__version__ = "0.0.1" diff --git a/vampnet/beats.py b/vampnet/beats.py new file mode 100644 index 0000000000000000000000000000000000000000..317496ef83d7b764fbbc51068c13170ce0c17e13 --- /dev/null +++ b/vampnet/beats.py @@ -0,0 +1,249 @@ +import json +import logging +import warnings +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from typing import List +from typing import Tuple +from typing import Union + +import librosa +import numpy as np +from audiotools import AudioSignal + + +logging.basicConfig(level=logging.INFO) + +################### +# beat sync utils # +################### + +AGGREGATOR_REGISTRY = { + "mean": np.mean, + "median": np.median, + "max": np.max, + "min": np.min, +} + + +def list_aggregators() -> list: + return list(AGGREGATOR_REGISTRY.keys()) + + +@dataclass +class TimeSegment: + start: float + end: float + + @property + def duration(self): + return self.end - self.start + + def __str__(self) -> str: + return f"{self.start} - {self.end}" + + def find_overlapping_segment( + self, segments: List["TimeSegment"] + ) -> Union["TimeSegment", None]: + """Find the first segment that overlaps with this segment, or None if no segment overlaps""" + for s in segments: + if s.start <= self.start and s.end >= self.end: + return s + return None + + +def mkdir(path: Union[Path, str]) -> Path: + p = Path(path) + p.mkdir(parents=True, exist_ok=True) + return p + + + +################### +# beat data # +################### +@dataclass +class BeatSegment(TimeSegment): + downbeat: bool = False # if there's a downbeat on the start_time + + +class Beats: + def __init__(self, beat_times, downbeat_times): + if isinstance(beat_times, np.ndarray): + beat_times = beat_times.tolist() + if isinstance(downbeat_times, np.ndarray): + downbeat_times = downbeat_times.tolist() + self._beat_times = beat_times + self._downbeat_times = downbeat_times + self._use_downbeats = False + + def use_downbeats(self, use_downbeats: bool = True): + """use downbeats instead of beats when calling beat_times""" + self._use_downbeats = use_downbeats + + def beat_segments(self, signal: AudioSignal) -> List[BeatSegment]: + """ + segments a song into time segments corresponding to beats. + the first segment starts at 0 and ends at the first beat time. + the last segment starts at the last beat time and ends at the end of the song. + """ + beat_times = self._beat_times.copy() + downbeat_times = self._downbeat_times + beat_times.insert(0, 0) + beat_times.append(signal.signal_duration) + + downbeat_ids = np.intersect1d(beat_times, downbeat_times, return_indices=True)[ + 1 + ] + is_downbeat = [ + True if i in downbeat_ids else False for i in range(len(beat_times)) + ] + segments = [ + BeatSegment(start_time, end_time, downbeat) + for start_time, end_time, downbeat in zip( + beat_times[:-1], beat_times[1:], is_downbeat + ) + ] + return segments + + def get_beats(self) -> np.ndarray: + """returns an array of beat times, in seconds + if downbeats is True, returns an array of downbeat times, in seconds + """ + return np.array( + self._downbeat_times if self._use_downbeats else self._beat_times + ) + + @property + def beat_times(self) -> np.ndarray: + """return beat times""" + return np.array(self._beat_times) + + @property + def downbeat_times(self) -> np.ndarray: + """return downbeat times""" + return np.array(self._downbeat_times) + + def beat_times_to_feature_frames( + self, signal: AudioSignal, features: np.ndarray + ) -> np.ndarray: + """convert beat times to frames, given an array of time-varying features""" + beat_times = self.get_beats() + beat_frames = ( + beat_times * signal.sample_rate / signal.signal_length * features.shape[-1] + ).astype(np.int64) + return beat_frames + + def sync_features( + self, feature_frames: np.ndarray, features: np.ndarray, aggregate="median" + ) -> np.ndarray: + """sync features to beats""" + if aggregate not in AGGREGATOR_REGISTRY: + raise ValueError(f"unknown aggregation method {aggregate}") + + return librosa.util.sync( + features, feature_frames, aggregate=AGGREGATOR_REGISTRY[aggregate] + ) + + def to_json(self) -> dict: + """return beats and downbeats as json""" + return { + "beats": self._beat_times, + "downbeats": self._downbeat_times, + "use_downbeats": self._use_downbeats, + } + + @classmethod + def from_dict(cls, data: dict): + """load beats and downbeats from json""" + inst = cls(data["beats"], data["downbeats"]) + inst.use_downbeats(data["use_downbeats"]) + return inst + + def save(self, output_dir: Path): + """save beats and downbeats to json""" + mkdir(output_dir) + with open(output_dir / "beats.json", "w") as f: + json.dump(self.to_json(), f) + + @classmethod + def load(cls, input_dir: Path): + """load beats and downbeats from json""" + beats_file = Path(input_dir) / "beats.json" + with open(beats_file, "r") as f: + data = json.load(f) + return cls.from_dict(data) + + +################### +# beat tracking # +################### + + +class BeatTracker: + def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]: + """extract beats from an audio signal""" + raise NotImplementedError + + def __call__(self, signal: AudioSignal) -> Beats: + """extract beats from an audio signal + NOTE: if the first beat (and/or downbeat) is detected within the first 100ms of the audio, + it is discarded. This is to avoid empty bins with no beat synced features in the first beat. + Args: + signal (AudioSignal): signal to beat track + Returns: + Tuple[np.ndarray, np.ndarray]: beats and downbeats + """ + beats, downbeats = self.extract_beats(signal) + return Beats(beats, downbeats) + + +class WaveBeat(BeatTracker): + def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"): + from wavebeat.dstcn import dsTCNModel + + model = dsTCNModel.load_from_checkpoint(ckpt_path) + model.eval() + + self.device = device + self.model = model + + def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]: + """returns beat and downbeat times, in seconds""" + # extract beats + beats, downbeats = self.model.predict_beats_from_array( + audio=signal.audio_data.squeeze(0), + sr=signal.sample_rate, + use_gpu=self.device != "cpu", + ) + + return beats, downbeats + + +class MadmomBeats(BeatTracker): + def __init__(self): + raise NotImplementedError + + def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]: + """returns beat and downbeat times, in seconds""" + pass + + +BEAT_TRACKER_REGISTRY = { + "wavebeat": WaveBeat, + "madmom": MadmomBeats, +} + + +def list_beat_trackers() -> list: + return list(BEAT_TRACKER_REGISTRY.keys()) + + +def load_beat_tracker(beat_tracker: str, **kwargs) -> BeatTracker: + if beat_tracker not in BEAT_TRACKER_REGISTRY: + raise ValueError( + f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}" + ) + + return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs) \ No newline at end of file diff --git a/vampnet/interface.py b/vampnet/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..0a6e39182c9d91c1b76bcb18476f9c018a247543 --- /dev/null +++ b/vampnet/interface.py @@ -0,0 +1,404 @@ +import os +from pathlib import Path +import math + +import torch +import numpy as np +from audiotools import AudioSignal +import tqdm + +from .modules.transformer import VampNet +from .beats import WaveBeat +from .mask import * + +# from dac.model.dac import DAC +from lac.model.lac import LAC as DAC + + +def signal_concat( + audio_signals: list, +): + audio_data = torch.cat([x.audio_data for x in audio_signals], dim=-1) + + return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate) + +def _load_model( + ckpt: str, + lora_ckpt: str = None, + device: str = "cpu", + chunk_size_s: int = 10, +): + # we need to set strict to False if the model has lora weights to add later + model = VampNet.load(location=Path(ckpt), map_location="cpu", strict=False) + + # load lora weights if needed + if lora_ckpt is not None: + if not Path(lora_ckpt).exists(): + should_cont = input( + f"lora checkpoint {lora_ckpt} does not exist. continue? (y/n) " + ) + if should_cont != "y": + raise Exception("aborting") + else: + model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False) + + model.to(device) + model.eval() + model.chunk_size_s = chunk_size_s + return model + + + +class Interface(torch.nn.Module): + def __init__( + self, + coarse_ckpt: str = None, + coarse_lora_ckpt: str = None, + coarse2fine_ckpt: str = None, + coarse2fine_lora_ckpt: str = None, + codec_ckpt: str = None, + wavebeat_ckpt: str = None, + device: str = "cpu", + coarse_chunk_size_s: int = 10, + coarse2fine_chunk_size_s: int = 3, + ): + super().__init__() + assert codec_ckpt is not None, "must provide a codec checkpoint" + self.codec = DAC.load(Path(codec_ckpt)) + self.codec.eval() + self.codec.to(device) + + assert coarse_ckpt is not None, "must provide a coarse checkpoint" + self.coarse = _load_model( + ckpt=coarse_ckpt, + lora_ckpt=coarse_lora_ckpt, + device=device, + chunk_size_s=coarse_chunk_size_s, + ) + + # check if we have a coarse2fine ckpt + if coarse2fine_ckpt is not None: + self.c2f = _load_model( + ckpt=coarse2fine_ckpt, + lora_ckpt=coarse2fine_lora_ckpt, + device=device, + chunk_size_s=coarse2fine_chunk_size_s, + ) + else: + self.c2f = None + + if wavebeat_ckpt is not None: + print(f"loading wavebeat from {wavebeat_ckpt}") + self.beat_tracker = WaveBeat(wavebeat_ckpt) + self.beat_tracker.model.to(device) + else: + self.beat_tracker = None + + self.device = device + + def lora_load( + self, + coarse_ckpt: str = None, + c2f_ckpt: str = None, + full_ckpts: bool = False, + ): + if full_ckpts: + if coarse_ckpt is not None: + self.coarse = _load_model( + ckpt=coarse_ckpt, + device=self.device, + chunk_size_s=self.coarse.chunk_size_s, + ) + if c2f_ckpt is not None: + self.c2f = _load_model( + ckpt=c2f_ckpt, + device=self.device, + chunk_size_s=self.c2f.chunk_size_s, + ) + else: + if coarse_ckpt is not None: + self.coarse.to("cpu") + state_dict = torch.load(coarse_ckpt, map_location="cpu") + + self.coarse.load_state_dict(state_dict, strict=False) + self.coarse.to(self.device) + if c2f_ckpt is not None: + self.c2f.to("cpu") + state_dict = torch.load(c2f_ckpt, map_location="cpu") + + self.c2f.load_state_dict(state_dict, strict=False) + self.c2f.to(self.device) + + + def s2t(self, seconds: float): + """seconds to tokens""" + if isinstance(seconds, np.ndarray): + return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length) + else: + return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length) + + def s2t2s(self, seconds: float): + """seconds to tokens to seconds""" + return self.t2s(self.s2t(seconds)) + + def t2s(self, tokens: int): + """tokens to seconds""" + return tokens * self.codec.hop_length / self.codec.sample_rate + + def to(self, device): + self.device = device + self.coarse.to(device) + self.codec.to(device) + + if self.c2f is not None: + self.c2f.to(device) + + if self.beat_tracker is not None: + self.beat_tracker.model.to(device) + return self + + def to_signal(self, z: torch.Tensor): + return self.coarse.to_signal(z, self.codec) + + def preprocess(self, signal: AudioSignal): + signal = ( + signal.clone() + .resample(self.codec.sample_rate) + .to_mono() + .normalize(-24) + .ensure_max_of_audio(1.0) + ) + return signal + + @torch.inference_mode() + def encode(self, signal: AudioSignal): + signal = self.preprocess(signal).to(self.device) + z = self.codec.encode(signal.samples, signal.sample_rate)["codes"] + return z + + def snap_to_beats( + self, + signal: AudioSignal + ): + assert hasattr(self, "beat_tracker"), "No beat tracker loaded" + beats, downbeats = self.beat_tracker.extract_beats(signal) + + # trim the signa around the first beat time + samples_begin = int(beats[0] * signal.sample_rate ) + samples_end = int(beats[-1] * signal.sample_rate) + print(beats[0]) + signal = signal.clone().trim(samples_begin, signal.length - samples_end) + + return signal + + def make_beat_mask(self, + signal: AudioSignal, + before_beat_s: float = 0.1, + after_beat_s: float = 0.1, + mask_downbeats: bool = True, + mask_upbeats: bool = True, + downbeat_downsample_factor: int = None, + beat_downsample_factor: int = None, + dropout: float = 0.0, + invert: bool = True, + ): + """make a beat synced mask. that is, make a mask that + places 1s at and around the beat, and 0s everywhere else. + """ + assert self.beat_tracker is not None, "No beat tracker loaded" + + # get the beat times + beats, downbeats = self.beat_tracker.extract_beats(signal) + + # get the beat indices in z + beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats) + + # remove downbeats from beats + beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))] + beats_z = beats_z.tolist() + downbeats_z = downbeats_z.tolist() + + # make the mask + seq_len = self.s2t(signal.duration) + mask = torch.zeros(seq_len, device=self.device) + + mask_b4 = self.s2t(before_beat_s) + mask_after = self.s2t(after_beat_s) + + if beat_downsample_factor is not None: + if beat_downsample_factor < 1: + raise ValueError("mask_beat_downsample_factor must be >= 1 or None") + else: + beat_downsample_factor = 1 + + if downbeat_downsample_factor is not None: + if downbeat_downsample_factor < 1: + raise ValueError("mask_beat_downsample_factor must be >= 1 or None") + else: + downbeat_downsample_factor = 1 + + beats_z = beats_z[::beat_downsample_factor] + downbeats_z = downbeats_z[::downbeat_downsample_factor] + print(f"beats_z: {len(beats_z)}") + print(f"downbeats_z: {len(downbeats_z)}") + + if mask_upbeats: + for beat_idx in beats_z: + _slice = int(beat_idx - mask_b4), int(beat_idx + mask_after) + num_steps = mask[_slice[0]:_slice[1]].shape[0] + _m = torch.ones(num_steps, device=self.device) + _m_mask = torch.bernoulli(_m * (1 - dropout)) + _m = _m * _m_mask.long() + + mask[_slice[0]:_slice[1]] = _m + + if mask_downbeats: + for downbeat_idx in downbeats_z: + _slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after) + num_steps = mask[_slice[0]:_slice[1]].shape[0] + _m = torch.ones(num_steps, device=self.device) + _m_mask = torch.bernoulli(_m * (1 - dropout)) + _m = _m * _m_mask.long() + + mask[_slice[0]:_slice[1]] = _m + + mask = mask.clamp(0, 1) + if invert: + mask = 1 - mask + + mask = mask[None, None, :].bool().long() + if self.c2f is not None: + mask = mask.repeat(1, self.c2f.n_codebooks, 1) + else: + mask = mask.repeat(1, self.coarse.n_codebooks, 1) + return mask + + def coarse_to_fine( + self, + coarse_z: torch.Tensor, + **kwargs + ): + assert self.c2f is not None, "No coarse2fine model loaded" + length = coarse_z.shape[-1] + chunk_len = self.s2t(self.c2f.chunk_size_s) + n_chunks = math.ceil(coarse_z.shape[-1] / chunk_len) + + # zero pad to chunk_len + if length % chunk_len != 0: + pad_len = chunk_len - (length % chunk_len) + coarse_z = torch.nn.functional.pad(coarse_z, (0, pad_len)) + + n_codebooks_to_append = self.c2f.n_codebooks - coarse_z.shape[1] + if n_codebooks_to_append > 0: + coarse_z = torch.cat([ + coarse_z, + torch.zeros(coarse_z.shape[0], n_codebooks_to_append, coarse_z.shape[-1]).long().to(self.device) + ], dim=1) + + fine_z = [] + for i in range(n_chunks): + chunk = coarse_z[:, :, i * chunk_len : (i + 1) * chunk_len] + chunk = self.c2f.generate( + codec=self.codec, + time_steps=chunk_len, + start_tokens=chunk, + return_signal=False, + **kwargs + ) + fine_z.append(chunk) + + fine_z = torch.cat(fine_z, dim=-1) + return fine_z[:, :, :length].clone() + + def coarse_vamp( + self, + z, + mask, + return_mask=False, + gen_fn=None, + **kwargs + ): + # coarse z + cz = z[:, : self.coarse.n_codebooks, :].clone() + assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}" + + mask = mask[:, : self.coarse.n_codebooks, :] + + cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token) + cz_masked = cz_masked[:, : self.coarse.n_codebooks, :] + + gen_fn = gen_fn or self.coarse.generate + c_vamp = gen_fn( + codec=self.codec, + time_steps=cz.shape[-1], + start_tokens=cz, + mask=mask, + return_signal=False, + **kwargs + ) + + if return_mask: + return c_vamp, cz_masked + + return c_vamp + + +if __name__ == "__main__": + import audiotools as at + import logging + logger = logging.getLogger() + logger.setLevel(logging.INFO) + torch.set_printoptions(threshold=10000) + at.util.seed(42) + + interface = Interface( + coarse_ckpt="./models/spotdl/coarse.pth", + coarse2fine_ckpt="./models/spotdl/c2f.pth", + codec_ckpt="./models/spotdl/codec.pth", + device="cuda", + wavebeat_ckpt="./models/wavebeat.pth" + ) + + + sig = at.AudioSignal.zeros(duration=10, sample_rate=44100) + + z = interface.encode(sig) + + # mask = linear_random(z, 1.0) + # mask = mask_and( + # mask, periodic_mask( + # z, + # 32, + # 1, + # random_roll=True + # ) + # ) + + mask = interface.make_beat_mask( + sig, 0.0, 0.075 + ) + # mask = dropout(mask, 0.0) + # mask = codebook_unmask(mask, 0) + + breakpoint() + zv, mask_z = interface.coarse_vamp( + z, + mask=mask, + sampling_steps=36, + temperature=8.0, + return_mask=True, + gen_fn=interface.coarse.generate + ) + + use_coarse2fine = True + if use_coarse2fine: + zv = interface.coarse_to_fine(zv, temperature=0.8) + + mask = interface.to_signal(mask_z).cpu() + + sig = interface.to_signal(zv).cpu() + print("done") + + sig.write("output3.wav") + mask.write("mask.wav") + \ No newline at end of file diff --git a/vampnet/mask.py b/vampnet/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..1302fd101925fca4403c9532561a0a8dc7b55265 --- /dev/null +++ b/vampnet/mask.py @@ -0,0 +1,219 @@ +from typing import Optional + +import torch +from audiotools import AudioSignal + +from .util import scalar_to_batch_tensor + +def _gamma(r): + return (r * torch.pi / 2).cos().clamp(1e-10, 1.0) + +def _invgamma(y): + if not torch.is_tensor(y): + y = torch.tensor(y)[None] + return 2 * y.acos() / torch.pi + +def full_mask(x: torch.Tensor): + assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" + return torch.ones_like(x).long() + +def empty_mask(x: torch.Tensor): + assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" + return torch.zeros_like(x).long() + +def apply_mask( + x: torch.Tensor, + mask: torch.Tensor, + mask_token: int + ): + assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}" + assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}" + assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}" + assert ~torch.any(mask > 1), "mask must be binary" + assert ~torch.any(mask < 0), "mask must be binary" + + fill_x = torch.full_like(x, mask_token) + x = x * (1 - mask) + fill_x * mask + + return x, mask + +def random( + x: torch.Tensor, + r: torch.Tensor +): + assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" + if not isinstance(r, torch.Tensor): + r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device) + + r = _gamma(r)[:, None, None] + probs = torch.ones_like(x) * r + + mask = torch.bernoulli(probs) + mask = mask.round().long() + + return mask + +def linear_random( + x: torch.Tensor, + r: torch.Tensor, +): + assert x.ndim == 3, "x must be (batch, n_codebooks, seq)" + if not isinstance(r, torch.Tensor): + r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float() + + probs = torch.ones_like(x).to(x.device).float() + # expand to batch and codebook dims + probs = probs.expand(x.shape[0], x.shape[1], -1) + probs = probs * r + + mask = torch.bernoulli(probs) + mask = mask.round().long() + + return mask + +def inpaint(x: torch.Tensor, + n_prefix, + n_suffix, +): + assert n_prefix is not None + assert n_suffix is not None + + mask = full_mask(x) + + # if we have a prefix or suffix, set their mask prob to 0 + if n_prefix > 0: + if not isinstance(n_prefix, torch.Tensor): + n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device) + for i, n in enumerate(n_prefix): + if n > 0: + mask[i, :, :n] = 0.0 + if n_suffix > 0: + if not isinstance(n_suffix, torch.Tensor): + n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device) + for i, n in enumerate(n_suffix): + if n > 0: + mask[i, :, -n:] = 0.0 + + + return mask + +def periodic_mask(x: torch.Tensor, + period: int, width: int = 1, + random_roll=False, + ): + mask = full_mask(x) + if period == 0: + return mask + + if not isinstance(period, torch.Tensor): + period = scalar_to_batch_tensor(period, x.shape[0]) + for i, factor in enumerate(period): + if factor == 0: + continue + for j in range(mask.shape[-1]): + if j % factor == 0: + # figure out how wide the mask should be + j_start = max(0, j - width // 2 ) + j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1 + # flip a coin for each position in the mask + j_mask = torch.bernoulli(torch.ones(j_end - j_start)) + assert torch.all(j_mask == 1) + j_fill = torch.ones_like(j_mask) * (1 - j_mask) + assert torch.all(j_fill == 0) + # fill + mask[i, :, j_start:j_end] = j_fill + if random_roll: + # add a random offset to the mask + offset = torch.randint(0, period[0], (1,)) + mask = torch.roll(mask, offset.item(), dims=-1) + + return mask + +def codebook_unmask( + mask: torch.Tensor, + n_conditioning_codebooks: int +): + if n_conditioning_codebooks == None: + return mask + # if we have any conditioning codebooks, set their mask to 0 + mask = mask.clone() + mask[:, :n_conditioning_codebooks, :] = 0 + return mask + +def mask_and( + mask1: torch.Tensor, + mask2: torch.Tensor +): + assert mask1.shape == mask2.shape, "masks must be same shape" + return torch.min(mask1, mask2) + +def dropout( + mask: torch.Tensor, + p: float, +): + assert 0 <= p <= 1, "p must be between 0 and 1" + assert mask.max() <= 1, "mask must be binary" + assert mask.min() >= 0, "mask must be binary" + mask = (~mask.bool()).float() + mask = torch.bernoulli(mask * (1 - p)) + mask = ~mask.round().bool() + return mask.long() + +def mask_or( + mask1: torch.Tensor, + mask2: torch.Tensor +): + assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}" + assert mask1.max() <= 1, "mask1 must be binary" + assert mask2.max() <= 1, "mask2 must be binary" + assert mask1.min() >= 0, "mask1 must be binary" + assert mask2.min() >= 0, "mask2 must be binary" + return (mask1 + mask2).clamp(0, 1) + +def time_stretch_mask( + x: torch.Tensor, + stretch_factor: int, +): + assert stretch_factor >= 1, "stretch factor must be >= 1" + c_seq_len = x.shape[-1] + x = x.repeat_interleave(stretch_factor, dim=-1) + + # trim cz to the original length + x = x[:, :, :c_seq_len] + + mask = periodic_mask(x, stretch_factor, width=1) + return mask + +def onset_mask( + sig: AudioSignal, + z: torch.Tensor, + interface, + width: int = 1 +): + import librosa + + onset_indices = librosa.onset.onset_detect( + y=sig.clone().to_mono().samples.cpu().numpy()[0, 0], + sr=sig.sample_rate, + hop_length=interface.codec.hop_length, + backtrack=True, + ) + + # create a mask, set onset + mask = torch.ones_like(z) + n_timesteps = z.shape[-1] + + for onset_index in onset_indices: + onset_index = min(onset_index, n_timesteps - 1) + onset_index = max(onset_index, 0) + mask[:, :, onset_index - width:onset_index + width] = 0.0 + + print(mask) + + return mask + + + +if __name__ == "__main__": + torch.set_printoptions(threshold=10000) + diff --git a/vampnet/modules/__init__.py b/vampnet/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3481f32e0287faa9e79ba219f17d18529a4b57ac --- /dev/null +++ b/vampnet/modules/__init__.py @@ -0,0 +1,4 @@ +import audiotools + +audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"] +audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"] diff --git a/vampnet/modules/activations.py b/vampnet/modules/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..c013c6302d1569e9c1937915a0ee638632071e51 --- /dev/null +++ b/vampnet/modules/activations.py @@ -0,0 +1,55 @@ +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class NewGELU(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo + (identical to OpenAI GPT). Also see the Gaussian Error Linear Units + paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, x): + return ( + 0.5 + * x + * ( + 1.0 + + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) + ) + ) + ) + +class GatedGELU(nn.Module): + def __init__(self): + super().__init__() + self.gelu = NewGELU() + + def forward(self, x, dim: int = -1): + p1, p2 = x.chunk(2, dim=dim) + return p1 * self.gelu(p2) + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(channels)) + + def forward(self, x): + return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2) + +def get_activation(name: str = "relu"): + if name == "relu": + return nn.ReLU + elif name == "gelu": + return NewGELU + elif name == "geglu": + return GatedGELU + elif name == "snake": + return Snake1d + else: + raise ValueError(f"Unrecognized activation {name}") \ No newline at end of file diff --git a/vampnet/modules/layers.py b/vampnet/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..0c7df97be3b1726ac210da762c83a25dbd3434c7 --- /dev/null +++ b/vampnet/modules/layers.py @@ -0,0 +1,164 @@ +import time +from typing import Optional +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) + + +def num_params(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def recurse_children(module, fn): + for child in module.children(): + if isinstance(child, nn.ModuleList): + for c in child: + yield recurse_children(c, fn) + if isinstance(child, nn.ModuleDict): + for c in child.values(): + yield recurse_children(c, fn) + + yield recurse_children(child, fn) + yield fn(child) + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +class SequentialWithFiLM(nn.Module): + """ + handy wrapper for nn.Sequential that allows FiLM layers to be + inserted in between other layers. + """ + + def __init__(self, *layers): + super().__init__() + self.layers = nn.ModuleList(layers) + + @staticmethod + def has_film(module): + mod_has_film = any( + [res for res in recurse_children(module, lambda c: isinstance(c, FiLM))] + ) + return mod_has_film + + def forward(self, x, cond): + for layer in self.layers: + if self.has_film(layer): + x = layer(x, cond) + else: + x = layer(x) + return x + + +class FiLM(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + + if input_dim > 0: + self.beta = nn.Linear(input_dim, output_dim) + self.gamma = nn.Linear(input_dim, output_dim) + + def forward(self, x, r): + if self.input_dim == 0: + return x + else: + beta, gamma = self.beta(r), self.gamma(r) + beta, gamma = ( + beta.view(x.size(0), self.output_dim, 1), + gamma.view(x.size(0), self.output_dim, 1), + ) + x = x * (gamma + 1) + beta + return x + + +class CodebookEmbedding(nn.Module): + def __init__( + self, + vocab_size: int, + latent_dim: int, + n_codebooks: int, + emb_dim: int, + special_tokens: Optional[Tuple[str]] = None, + ): + super().__init__() + self.n_codebooks = n_codebooks + self.emb_dim = emb_dim + self.latent_dim = latent_dim + self.vocab_size = vocab_size + + if special_tokens is not None: + for tkn in special_tokens: + self.special = nn.ParameterDict( + { + tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim)) + for tkn in special_tokens + } + ) + self.special_idxs = { + tkn: i + vocab_size for i, tkn in enumerate(special_tokens) + } + + self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1) + + def from_codes(self, codes: torch.Tensor, codec): + """ + get a sequence of continuous embeddings from a sequence of discrete codes. + unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens + necessary for the language model, like . + """ + n_codebooks = codes.shape[1] + latent = [] + for i in range(n_codebooks): + c = codes[:, i, :] + + lookup_table = codec.quantizer.quantizers[i].codebook.weight + if hasattr(self, "special"): + special_lookup = torch.cat( + [self.special[tkn][i : i + 1] for tkn in self.special], dim=0 + ) + lookup_table = torch.cat([lookup_table, special_lookup], dim=0) + + l = F.embedding(c, lookup_table).transpose(1, 2) + latent.append(l) + + latent = torch.cat(latent, dim=1) + return latent + + def forward(self, latents: torch.Tensor): + """ + project a sequence of latents to a sequence of embeddings + """ + x = self.out_proj(latents) + return x + diff --git a/vampnet/modules/transformer.py b/vampnet/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..debb483ff3e0d3ced5a81581a9bfe89e6a185863 --- /dev/null +++ b/vampnet/modules/transformer.py @@ -0,0 +1,866 @@ +import math +import logging +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +import loralib as lora +import audiotools as at + +from .activations import get_activation +from .layers import CodebookEmbedding +from .layers import FiLM +from .layers import SequentialWithFiLM +from .layers import WNConv1d +from ..util import scalar_to_batch_tensor, codebook_flatten, codebook_unflatten +from ..mask import _gamma + +LORA_R = 8 + +# def log(t, eps=1e-20): +# return torch.log(t + eps) + + +def gumbel_noise_like(t): + noise = torch.zeros_like(t).uniform_(1e-20, 1) + return -torch.log(-torch.log(noise)) + + +def gumbel_sample(t, temperature=1.0, dim=-1): + return ((t / max(temperature, 1e-10)) + gumbel_noise_like(t)).argmax(dim=dim) + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.var_eps = eps + + def forward(self, x): + """Returns root mean square normalized version of input `x` + # T5 uses a layer_norm which only scales and doesn't shift, which is also known + # as Root Mean Square Layer Normalization https://arxiv.org/abs/1910.07467 + # thus varience is calculated w/o mean and there is no bias + Parameters + ---------- + x : Tensor[B x T x D] + Returns + ------- + Tensor[B x T x D] + """ + var = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(var + self.var_eps) + + return self.weight * x + + +class FeedForward(nn.Module): + def __init__( + self, d_model: int = 512, dropout: float = 0.1, activation: str = "geglu" + ): + super().__init__() + factor = 2 if activation == "geglu" else 1 + self.w_1 = lora.Linear(d_model, d_model * 4, bias=False, r=LORA_R) + self.w_2 = lora.Linear(d_model * 4 // factor, d_model, bias=False, r=LORA_R) + self.drop = nn.Dropout(dropout) + self.act = get_activation(activation)() + + def forward(self, x): + """Computes position-wise feed-forward layer + Parameters + ---------- + x : Tensor[B x T x D] + Returns + ------- + Tensor[B x T x D] + """ + x = self.w_1(x) + x = self.act(x) + x = self.drop(x) + x = self.w_2(x) + return x + + +class MultiHeadRelativeAttention(nn.Module): + def __init__( + self, + n_head: int = 8, + d_model: int = 512, + dropout: float = 0.1, + bidirectional: bool = True, + has_relative_attention_bias: bool = True, + attention_num_buckets: int = 32, + attention_max_distance: int = 128, + ): + super().__init__() + d_head = d_model // n_head + self.n_head = n_head + self.d_head = d_head + self.bidirectional = bidirectional + self.has_relative_attention_bias = has_relative_attention_bias + self.attention_num_buckets = attention_num_buckets + self.attention_max_distance = attention_max_distance + + # Create linear query, key, value projections + self.w_qs = lora.Linear(d_model, d_model, bias=False, r=LORA_R) + self.w_ks = nn.Linear(d_model, d_model, bias=False) + self.w_vs = lora.Linear(d_model, d_model, bias=False, r=LORA_R) + + # Create linear final output projection + self.fc = lora.Linear(d_model, d_model, bias=False, r=LORA_R) + + # Dropout for attention output weights + self.dropout = nn.Dropout(dropout) + + # Create relative positional embeddings (if turned on) + if has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(attention_num_buckets, n_head) + + def _relative_position_bucket(self, relative_position): + """Converts unbounded relative position into bounded set of buckets + with half "exact" buckets (1 position = 1 bucket) and half "log-spaced" + buckets + Parameters + ---------- + relative_position : Tensor[T_q x T_kv] + Relative positions between queries and key_value items + Returns + ------- + Tensor[T_q x T_kv] + Input relative positions converted into buckets + """ + relative_buckets = 0 + num_buckets = self.attention_num_buckets + max_distance = self.attention_max_distance + + # Convert relative position for (-inf, inf) to [0, inf] + # Negative relative positions correspond to past + # Positive relative positions correspond to future + if self.bidirectional: + # use half buckets for each side (past / future) + num_buckets //= 2 + + # Shift the position positions by `num_buckets` to wrap around + # negative positions + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + # If not bidirectional, ignore positive positions and wrap + # negative positions to positive + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + + # Allocate half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in + # positions up to `max_distance` + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + + # Clip the max relative position to `num_buckets - 1` + relative_postion_if_large = torch.min( + relative_postion_if_large, + torch.full_like(relative_postion_if_large, num_buckets - 1), + ) + + # Choose relative buckets based on small or large positions + relative_buckets += torch.where( + is_small, relative_position, relative_postion_if_large + ) + + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Computes a position bias scalar for each index in query_length x key_length + Parameters + ---------- + query_length : int + key_length : int + Returns + ------- + Tensor[heads x 1 x T_q x T_kv] + Position bias to be applied on attention logits + """ + + query_position = torch.arange(query_length, dtype=torch.long)[:, None] + key_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = key_position - query_position + + # Convert relative position to buckets + relative_position_bucket = self._relative_position_bucket(relative_position) + relative_position_bucket = relative_position_bucket.to( + self.relative_attention_bias.weight.device + ) + + # Index attention bias values + values = self.relative_attention_bias(relative_position_bucket) + values = rearrange(values, "q k h -> h 1 q k") + + return values + + def forward(self, q, k, v, mask=None, position_bias=None): + """Computes attention over (keys, values) for every timestep in query + Parameters + ---------- + q : Tensor[B x T_q x d_model] + Query vectors + k : Tensor[B x T_kv x d_model] + Key vectors to compute attention over + v : Tensor[B x T_kv x d_model] + Value vectors corresponding to the keys + mask : Tensor[B x T_q x T_kv], optional + position_bias: Tensor[head x 1 x T_q x T_kv] + Returns + ------- + Tensor[B x T_q x d_model] + Outputs after attending (key, value) using queries + """ + # Compute query, key, value projections + q = rearrange(self.w_qs(q), "b l (head k) -> head b l k", head=self.n_head) + k = rearrange(self.w_ks(k), "b t (head k) -> head b t k", head=self.n_head) + v = rearrange(self.w_vs(v), "b t (head k) -> head b t k", head=self.n_head) + + # Compute attention matrix + attn = torch.einsum("hblk,hbtk->hblt", [q, k]) / np.sqrt(q.shape[-1]) + + # Add relative position bias to attention scores + if position_bias is None: + if self.has_relative_attention_bias: + position_bias = self.compute_bias(q.size(-2), k.size(-2)) + else: + position_bias = torch.zeros_like(attn) + attn += position_bias + + # Apply mask to attention scores to prevent looking up invalid locations + if mask is not None: + attn = attn.masked_fill(mask[None] == 0, -1e9) + + # Normalize attention scores and add dropout + attn = torch.softmax(attn, dim=3) + attn = self.dropout(attn) + + # Compute attended outputs (product of attention matrix and values) + output = torch.einsum("hblt,hbtv->hblv", [attn, v]) + output = rearrange(output, "head b l v -> b l (head v)") + output = self.fc(output) + + return output, position_bias + + +class TransformerLayer(nn.Module): + def __init__( + self, + d_model: int = 512, + d_cond: int = 64, + n_heads: int = 8, + bidirectional: bool = True, + is_decoder: bool = False, + has_relative_attention_bias: bool = False, + flash_attn: bool = False, + dropout: float = 0.1, + ): + super().__init__() + # Store args + self.is_decoder = is_decoder + + # Create self-attention layer + self.norm_1 = RMSNorm(d_model) + self.film_1 = FiLM(d_cond, d_model) + self.flash_attn = flash_attn + + if flash_attn: + from flash_attn.flash_attention import FlashMHA + self.self_attn = FlashMHA( + embed_dim=d_model, + num_heads=n_heads, + attention_dropout=dropout, + causal=False, + ) + else: + self.self_attn = MultiHeadRelativeAttention( + n_heads, d_model, dropout, bidirectional, has_relative_attention_bias + ) + + # (Optional) Create cross-attention layer + if is_decoder: + self.norm_2 = RMSNorm(d_model) + self.film_2 = FiLM(d_cond, d_model) + self.cross_attn = MultiHeadRelativeAttention( + n_heads, + d_model, + dropout, + bidirectional=True, + has_relative_attention_bias=False, + ) + + # Create last feed-forward layer + self.norm_3 = RMSNorm(d_model) + self.film_3 = FiLM(d_cond, d_model) + self.feed_forward = FeedForward(d_model=d_model, dropout=dropout) + + # Create dropout + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x, + x_mask, + cond, + src=None, + src_mask=None, + position_bias=None, + encoder_decoder_position_bias=None, + ): + """Computes one transformer layer consisting of self attention, (op) cross attention + and feedforward layer + Parameters + ---------- + x : Tensor[B x T_q x D] + x_mask : Tensor[B x T_q] + src : Tensor[B x T_kv x D], optional + src_mask : Tensor[B x T_kv x D], optional + position_bias : Tensor[heads x B x T_q x T_q], optional + Relative position bias for self attention layer + encoder_decoder_position_bias : Tensor[heads x B x T_q x T_kv], optional + Relative position bias for cross attention layer + Returns + ------- + Tensor[B x T_q x D] + """ + y = self.norm_1(x) + y = self.film_1(y.permute(0, 2, 1), cond).permute(0, 2, 1) + if self.flash_attn: + with torch.autocast(y.device.type, dtype=torch.bfloat16): + y = self.self_attn(y)[0] + else: + y, position_bias = self.self_attn(y, y, y, x_mask, position_bias) + x = x + self.dropout(y) + + if self.is_decoder: + y = self.norm_2(x) + y = self.film_2(y.permute(0, 2, 1), cond).permute(0, 2, 1) + y, encoder_decoder_position_bias = self.cross_attn( + y, src, src, src_mask, encoder_decoder_position_bias + ) + x = x + self.dropout(y) + + y = self.norm_3(x) + y = self.film_3( + y.permute( + 0, + 2, + 1, + ), + cond, + ).permute(0, 2, 1) + y = self.feed_forward(y) + x = x + self.dropout(y) + + return x, position_bias, encoder_decoder_position_bias + + +class TransformerStack(nn.Module): + def __init__( + self, + d_model: int = 512, + d_cond: int = 64, + n_heads: int = 8, + n_layers: int = 8, + last_layer: bool = True, + bidirectional: bool = True, + flash_attn: bool = False, + is_decoder: bool = False, + dropout: float = 0.1, + ): + super().__init__() + # Store args + self.bidirectional = bidirectional + self.is_decoder = is_decoder + + # Create transformer layers + # In T5, relative attention bias is shared by all layers in the stack + self.layers = nn.ModuleList( + [ + TransformerLayer( + d_model, + d_cond, + n_heads, + bidirectional, + is_decoder, + has_relative_attention_bias=True if (i == 0) else False, + flash_attn=flash_attn, + dropout=dropout, + ) + for i in range(n_layers) + ] + ) + + # Perform last normalization + self.norm = RMSNorm(d_model) if last_layer else None + + def subsequent_mask(self, size): + return torch.ones(1, size, size).tril().bool() + + def forward(self, x, x_mask, cond=None, src=None, src_mask=None): + """Computes a full transformer stack + Parameters + ---------- + x : Tensor[B x T_q x D] + x_mask : Tensor[B x T_q] + src : Tensor[B x T_kv x D], optional + src_mask : Tensor[B x T_kv], optional + Returns + ------- + Tensor[B x T_q x D] + """ + + # Convert `src_mask` to (B x T_q x T_kv) shape for cross attention masking + if self.is_decoder: + src_mask = x_mask.unsqueeze(-1) * src_mask.unsqueeze(-2) + + # Convert `x_mask` to (B x T_q x T_q) shape for self attention masking + x_mask = x_mask.unsqueeze(-2) + if not self.bidirectional: + x_mask = x_mask * self.subsequent_mask(x.size(1)).to(x_mask.device) + + # Initialize position biases + position_bias = None + encoder_decoder_position_bias = None + + # Compute transformer layers + for layer in self.layers: + x, position_bias, encoder_decoder_position_bias = layer( + x=x, + x_mask=x_mask, + cond=cond, + src=src, + src_mask=src_mask, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + ) + + return self.norm(x) if self.norm is not None else x + + +class VampNet(at.ml.BaseModel): + def __init__( + self, + n_heads: int = 20, + n_layers: int = 16, + r_cond_dim: int = 64, + n_codebooks: int = 9, + n_conditioning_codebooks: int = 0, + latent_dim: int = 8, + embedding_dim: int = 1280, + vocab_size: int = 1024, + flash_attn: bool = True, + noise_mode: str = "mask", + dropout: float = 0.1 + ): + super().__init__() + self.n_heads = n_heads + self.n_layers = n_layers + self.r_cond_dim = r_cond_dim + self.n_codebooks = n_codebooks + self.n_conditioning_codebooks = n_conditioning_codebooks + self.embedding_dim = embedding_dim + self.vocab_size = vocab_size + self.latent_dim = latent_dim + self.flash_attn = flash_attn + self.noise_mode = noise_mode + + assert self.noise_mode == "mask", "deprecated" + + self.embedding = CodebookEmbedding( + latent_dim=latent_dim, + n_codebooks=n_codebooks, + vocab_size=vocab_size, + emb_dim=embedding_dim, + special_tokens=["MASK"], + ) + self.mask_token = self.embedding.special_idxs["MASK"] + + self.transformer = TransformerStack( + d_model=embedding_dim, + d_cond=r_cond_dim, + n_heads=n_heads, + n_layers=n_layers, + last_layer=True, + bidirectional=True, + flash_attn=flash_attn, + is_decoder=False, + dropout=dropout, + ) + + # Add final conv layer + self.n_predict_codebooks = n_codebooks - n_conditioning_codebooks + self.classifier = SequentialWithFiLM( + WNConv1d( + embedding_dim, + vocab_size * self.n_predict_codebooks, + kernel_size=1, + padding="same", + # groups=self.n_predict_codebooks, + ), + ) + + def forward(self, x, cond): + x = self.embedding(x) + x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1) + + cond = self.r_embed(cond) + + x = rearrange(x, "b d n -> b n d") + out = self.transformer(x=x, x_mask=x_mask, cond=cond) + out = rearrange(out, "b n d -> b d n") + + out = self.classifier(out, cond) + + out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks) + + return out + + def r_embed(self, r, max_positions=10000): + if self.r_cond_dim > 0: + dtype = r.dtype + + r = _gamma(r) * max_positions + half_dim = self.r_cond_dim // 2 + + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + + if self.r_cond_dim % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + + return emb.to(dtype) + else: + return r + + @torch.no_grad() + def to_signal(self, z, codec): + """ + convert a sequence of latents to a signal. + """ + assert z.ndim == 3 + + signal = at.AudioSignal( + codec.decode( + codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0] + )["audio"], + codec.sample_rate, + ) + + # find where the mask token is and replace it with silence in the audio + for tstep in range(z.shape[-1]): + if torch.any(z[:, :, tstep] == self.mask_token): + sample_idx_0 = tstep * codec.hop_length + sample_idx_1 = sample_idx_0 + codec.hop_length + signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0 + + return signal + + + @torch.no_grad() + def generate( + self, + codec, + time_steps: int = 300, + sampling_steps: int = 24, + start_tokens: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + temperature: Union[float, Tuple[float, float]] = 2.5, + typical_filtering=False, + typical_mass=0.2, + typical_min_tokens=1, + return_signal=True, + ): + logging.debug(f"beginning generation with {sampling_steps} steps") + + ##################### + # resolve temperature # + ##################### + assert isinstance(temperature, float) + logging.debug(f"temperature: {temperature}") + + + ##################### + # resolve initial z # + ##################### + z = start_tokens + + if z is None: + z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to( + self.device + ) + + logging.debug(f"created z with shape {z.shape}") + + + ################# + # resolve mask # + ################# + + if mask is None: + mask = torch.ones_like(z).to(self.device).int() + mask[:, : self.n_conditioning_codebooks, :] = 0.0 + if mask.ndim == 2: + mask = mask[:, None, :].repeat(1, z.shape[1], 1) + # init_mask = mask.clone() + + logging.debug(f"created mask with shape {mask.shape}") + + + ########### + # set up # + ########## + # apply the mask to z + z_masked = z.masked_fill(mask.bool(), self.mask_token) + # logging.debug(f"z_masked: {z_masked}") + + # how many mask tokens to begin with? + num_mask_tokens_at_start = (z_masked == self.mask_token).sum() + logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}") + + # how many codebooks are we inferring vs conditioning on? + n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks + logging.debug(f"n infer codebooks: {n_infer_codebooks}") + + ################# + # begin sampling # + ################# + + for i in range(sampling_steps): + logging.debug(f"step {i} of {sampling_steps}") + + # our current temperature + logging.debug(f"temperature: {temperature}") + + # our current schedule step + r = scalar_to_batch_tensor( + (i + 1) / sampling_steps, + z.shape[0] + ).to(z.device) + logging.debug(f"r: {r}") + + # get latents + latents = self.embedding.from_codes(z_masked, codec) + logging.debug(f"computed latents with shape: {latents.shape}") + + + # infer from latents + # NOTE: this collapses the codebook dimension into the sequence dimension + logits = self.forward(latents, r) # b, prob, seq + logits = logits.permute(0, 2, 1) # b, seq, prob + if typical_filtering: + typical_filter(logits, + typical_mass=typical_mass, + typical_min_tokens=typical_min_tokens + ) + + + logging.debug(f"permuted logits with shape: {logits.shape}") + + + # logits2probs + probs = torch.softmax(logits, dim=-1) + logging.debug(f"computed probs with shape: {probs.shape}") + + + # sample from logits with multinomial sampling + b = probs.shape[0] + probs = rearrange(probs, "b seq prob -> (b seq) prob") + + sampled_z = torch.multinomial(probs, 1).squeeze(-1) + + sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b) + probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b) + logging.debug(f"sampled z with shape: {sampled_z.shape}") + + # get the confidences: which tokens did we sample? + selected_probs = ( + torch.take_along_dim( + probs, sampled_z.long().unsqueeze(-1), + dim=-1 + ).squeeze(-1) + ) + + # flatten z_masked and mask, so we can deal with the sampling logic + # we'll unflatten them at the end of the loop for the next forward pass + # remove conditioning codebooks, we'll add them back at the end + z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :]) + + mask = (z_masked == self.mask_token).int() + + # update the mask, remove conditioning codebooks from the mask + logging.debug(f"updated mask with shape: {mask.shape}") + # add z back into sampled z where the mask was false + sampled_z = torch.where( + mask.bool(), sampled_z, z_masked + ) + logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}") + + # ignore any tokens that weren't masked + selected_probs = torch.where( + mask.bool(), selected_probs, torch.inf + ) + + # get the num tokens to mask, according to the schedule + num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long() + logging.debug(f"num to mask: {num_to_mask}") + + if i != (sampling_steps - 1): + num_to_mask = torch.maximum( + torch.tensor(1), + torch.minimum( + mask.sum(dim=-1, keepdim=True) - 1, + num_to_mask + ) + ) + + + # get our new mask + mask = mask_by_random_topk( + num_to_mask, selected_probs, temperature * (1-r) + ) + + # update the mask + z_masked = torch.where( + mask.bool(), self.mask_token, sampled_z + ) + logging.debug(f"updated z_masked with shape: {z_masked.shape}") + + z_masked = codebook_unflatten(z_masked, n_infer_codebooks) + mask = codebook_unflatten(mask, n_infer_codebooks) + logging.debug(f"unflattened z_masked with shape: {z_masked.shape}") + + # add conditioning codebooks back to z_masked + z_masked = torch.cat( + (z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1 + ) + logging.debug(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}") + + + # add conditioning codebooks back to sampled_z + sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks) + sampled_z = torch.cat( + (z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1 + ) + + logging.debug(f"finished sampling") + + if return_signal: + return self.to_signal(sampled_z, codec) + else: + return sampled_z + + +def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0): + """ + Args: + num_to_mask (int): number of tokens to mask + probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq) + temperature (float, optional): temperature. Defaults to 1.0. + """ + logging.debug(f"masking by random topk") + logging.debug(f"num to mask: {num_to_mask}") + logging.debug(f"probs shape: {probs.shape}") + logging.debug(f"temperature: {temperature}") + logging.debug("") + + confidence = torch.log(probs) + temperature * gumbel_noise_like(probs) + logging.debug(f"confidence shape: {confidence.shape}") + + sorted_confidence, sorted_idx = confidence.sort(dim=-1) + logging.debug(f"sorted confidence shape: {sorted_confidence.shape}") + logging.debug(f"sorted idx shape: {sorted_idx.shape}") + + # get the cut off threshold, given the mask length + cut_off = torch.take_along_dim( + sorted_confidence, num_to_mask, axis=-1 + ) + logging.debug(f"cut off shape: {cut_off.shape}") + + # mask out the tokens + mask = confidence < cut_off + logging.debug(f"mask shape: {mask.shape}") + + return mask + +def typical_filter( + logits, + typical_mass: float = 0.95, + typical_min_tokens: int = 1,): + nb, nt, _ = logits.shape + x_flat = rearrange(logits, "b t l -> (b t ) l") + x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1) + x_flat_norm_p = torch.exp(x_flat_norm) + entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True) + + c_flat_shifted = torch.abs((-x_flat_norm) - entropy) + c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False) + x_flat_cumsum = ( + x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1) + ) + + last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1) + sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather( + 1, last_ind.view(-1, 1) + ) + if typical_min_tokens > 1: + sorted_indices_to_remove[..., :typical_min_tokens] = 0 + indices_to_remove = sorted_indices_to_remove.scatter( + 1, x_flat_indices, sorted_indices_to_remove + ) + x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf")) + logits = rearrange(x_flat, "(b t) l -> b t l", t=nt) + return logits + + +if __name__ == "__main__": + # import argbind + from .layers import num_params + + VampNet = argbind.bind(VampNet) + + @argbind.bind(without_prefix=True) + def try_model(device: str = "cuda", batch_size: int = 2, seq_len_s: float = 10.0): + seq_len = int(32000 / 512 * seq_len_s) + + model = VampNet().to(device) + + z = torch.randint( + 0, model.vocab_size, size=(batch_size, model.n_codebooks, seq_len) + ).to(device) + + r = torch.zeros(batch_size).to(device) + + z_mask_latent = torch.rand( + batch_size, model.latent_dim * model.n_codebooks, seq_len + ).to(device) + z_hat = model(z_mask_latent, r) + + pred = z_hat.argmax(dim=1) + pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks) + + print(f"model has {num_params(model)/1e6:<.3f}M parameters") + print(f"prediction has shape {pred.shape}") + breakpoint() + + args = argbind.parse_args() + with argbind.scope(args): + try_model() + + diff --git a/vampnet/scheduler.py b/vampnet/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..a57108c2af7c974d882b45e092907195ded71c9a --- /dev/null +++ b/vampnet/scheduler.py @@ -0,0 +1,47 @@ +import copy +from typing import List + +import torch + +class NoamScheduler: + """OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf + Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + d_model: int = 512, + factor: float = 1.0, + warmup: int = 4000, + ): + # Store hparams + self.warmup = warmup + self.factor = factor + self.d_model = d_model + + # Initialize variables `lr` and `steps` + self.lr = None + self.steps = 0 + + # Store the optimizer + self.optimizer = optimizer + + def state_dict(self): + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + + def step(self): + self.steps += 1 + self.lr = self.factor * ( + self.d_model ** (-0.5) + * min(self.steps ** (-0.5), self.steps * self.warmup ** (-1.5)) + ) + + for p in self.optimizer.param_groups: + p["lr"] = self.lr + diff --git a/vampnet/util.py b/vampnet/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbf8fb41f1d2e1c0ad95e959acb5ae8655768f7 --- /dev/null +++ b/vampnet/util.py @@ -0,0 +1,46 @@ +import tqdm + +import torch +from einops import rearrange + +def scalar_to_batch_tensor(x, batch_size): + return torch.tensor(x).repeat(batch_size) + + +def parallelize( + fn, + *iterables, + parallel: str = "thread_map", + **kwargs + ): + if parallel == "thread_map": + from tqdm.contrib.concurrent import thread_map + return thread_map( + fn, + *iterables, + **kwargs + ) + elif parallel == "process_map": + from tqdm.contrib.concurrent import process_map + return process_map( + fn, + *iterables, + **kwargs + ) + elif parallel == "single": + return [fn(x) for x in tqdm.tqdm(*iterables)] + else: + raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}") + +def codebook_flatten(tokens: torch.Tensor): + """ + flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time) + """ + return rearrange(tokens, "b c t -> b (t c)") + +def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None): + """ + unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time) + """ + tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c) + return tokens