diff --git a/.gitignore b/.gitignore index 7bdafb91a788dcd6b87ae12d2de2d9f6f9d8969c..685d6448d5a197ed8a434f72d6f2279552613c91 100644 --- a/.gitignore +++ b/.gitignore @@ -175,14 +175,10 @@ lyrebird-audio-codec samples-*/** gradio-outputs/ -models/ samples*/ models-all/ models.zip +audiotools/ +descript-audio-codec/ +# *.pth .git-old -conf/generated/* -runs*/ - - -gtzan.zip -.gtzan_emb_cache diff --git a/README.md b/README.md index 9f63c43e04e5c6c4bf9d1ec12276636ee77a075d..687fb086b0db5ec747e42728d8d25be07f51e7cb 100644 --- a/README.md +++ b/README.md @@ -7,27 +7,16 @@ sdk: gradio sdk_version: 3.36.1 app_file: app.py pinned: false -python_version: 3.9 +duplicated_from: hugggof/vampnet --- # VampNet -This repository contains recipes for training generative music models on top of the Descript Audio Codec. - -## try `unloop` -you can try vampnet in a co-creative looper called unloop. see this link: https://github.com/hugofloresgarcia/unloop +This repository contains recipes for training generative music models on top of the Lyrebird Audio Codec. # Setting up -**Requires Python 3.9**. - -you'll need a Python 3.9 environment to run VampNet. This is due to a [known issue with madmom](https://github.com/hugofloresgarcia/vampnet/issues/15). - -(for example, using conda) -```bash -conda create -n vampnet python=3.9 -conda activate vampnet -``` +Requires Python 3.9 or later. install VampNet @@ -46,7 +35,7 @@ Config files are stored in the `conf/` folder. ### Licensing for Pretrained Models: The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). -Download the pretrained models from [this link](https://zenodo.org/record/8136629). Then, extract the models to the `models/` folder. +Download the pretrained models from [this link](https://zenodo.org/record/8136545). Then, extract the models to the `models/` folder. # Usage @@ -100,7 +89,7 @@ python scripts/exp/train.py --args.load conf//c2f.yml launch the interface: ```bash -python app.py --args.load conf/generated//interface.yml +python demo.py --args.load conf/generated//interface.yml ``` diff --git a/app.py b/app.py index 184a0b94bfbdb69ece1474ba81ea10565414dbe8..194a40703b3c3bc74eaf59bddaf95e4c59ee1cde 100644 --- a/app.py +++ b/app.py @@ -1,12 +1,3 @@ -# huggingface space exclusive -import os - -# print("installing pyharp") -# os.system('pip install "pyharp@git+https://github.com/audacitorch/pyharp.git"') -# print("installing madmom") -os.system('pip install cython') -os.system('pip install madmom') - from pathlib import Path from typing import Tuple import yaml @@ -24,38 +15,27 @@ import gradio as gr from vampnet.interface import Interface from vampnet import mask as pmask -from pyharp import ModelCard, build_endpoint - - - -# loader = AudioLoader() +# Interface = argbind.bind(Interface) # AudioLoader = argbind.bind(at.data.datasets.AudioLoader) -conf = argbind.parse_args() - - -from torch_pitch_shift import pitch_shift, get_fast_shifts -def shift_pitch(signal, interval: int): - signal.samples = pitch_shift( - signal.samples, - shift=interval, - sample_rate=signal.sample_rate - ) - return signal - -def load_interface(): - interface = Interface( - coarse_ckpt="./models/vampnet/coarse.pth", - coarse2fine_ckpt="./models/vampnet/c2f.pth", - codec_ckpt="./models/vampnet/codec.pth", - wavebeat_ckpt="./models/wavebeat.pth", - device="cuda" if torch.cuda.is_available() else "cpu", - ) - return interface +interface = Interface( + coarse_ckpt="./models/vampnet/coarse.pth", + coarse2fine_ckpt="./models/vampnet/c2f.pth", + codec_ckpt="./models/vampnet/codec.pth", + wavebeat_ckpt="./models/wavebeat.pth", + device="cuda" if torch.cuda.is_available() else "cpu", +) +# loader = AudioLoader() +print(f"interface device is {interface.device}") -interface = load_interface() - +# dataset = at.data.datasets.AudioDataset( +# loader, +# sample_rate=interface.codec.sample_rate, +# duration=interface.coarse.chunk_size_s, +# n_examples=5000, +# without_replacement=True, +# ) OUT_DIR = Path("gradio-outputs") OUT_DIR.mkdir(exist_ok=True, parents=True) @@ -70,7 +50,7 @@ def load_audio(file): ) sig = interface.preprocess(sig) - out_dir = OUT_DIR / "tmp" / str(uuid.uuid4()) + out_dir = OUT_DIR / str(uuid.uuid4()) out_dir.mkdir(parents=True, exist_ok=True) sig.write(out_dir / "input.wav") return sig.path_to_file @@ -88,10 +68,6 @@ def _vamp(data, return_mask=False): out_dir = OUT_DIR / str(uuid.uuid4()) out_dir.mkdir() sig = at.AudioSignal(data[input_audio]) - sig = interface.preprocess(sig) - - if data[pitch_shift_amt] != 0: - sig = shift_pitch(sig, data[pitch_shift_amt]) z = interface.encode(sig) @@ -131,58 +107,24 @@ def _vamp(data, return_mask=False): mask = pmask.codebook_unmask(mask, ncc) - print(f"dropout {data[dropout]}") - print(f"masktemp {data[masktemp]}") - print(f"sampletemp {data[sampletemp]}") - print(f"top_p {data[top_p]}") - print(f"prefix_s {data[prefix_s]}") - print(f"suffix_s {data[suffix_s]}") - print(f"rand_mask_intensity {data[rand_mask_intensity]}") - print(f"num_steps {data[num_steps]}") - print(f"periodic_p {data[periodic_p]}") - print(f"periodic_w {data[periodic_w]}") - print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}") - print(f"use_coarse2fine {data[use_coarse2fine]}") - print(f"onset_mask_width {data[onset_mask_width]}") - print(f"beat_mask_width {data[beat_mask_width]}") - print(f"beat_mask_downbeats {data[beat_mask_downbeats]}") - print(f"stretch_factor {data[stretch_factor]}") - print(f"seed {data[seed]}") - print(f"pitch_shift_amt {data[pitch_shift_amt]}") - print(f"sample_cutoff {data[sample_cutoff]}") - - - _top_p = data[top_p] if data[top_p] > 0 else None + print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[temp]}, use coarse2fine {data[use_coarse2fine]}") # save the mask as a txt file np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy()) - _seed = data[seed] if data[seed] > 0 else None zv, mask_z = interface.coarse_vamp( z, mask=mask, sampling_steps=data[num_steps], - mask_temperature=data[masktemp]*10, - sampling_temperature=data[sampletemp], + temperature=float(data[temp]*10), return_mask=True, typical_filtering=data[typical_filtering], typical_mass=data[typical_mass], typical_min_tokens=data[typical_min_tokens], - top_p=_top_p, gen_fn=interface.coarse.generate, - seed=_seed, - sample_cutoff=data[sample_cutoff], ) if use_coarse2fine: - zv = interface.coarse_to_fine( - zv, - mask_temperature=data[masktemp]*10, - sampling_temperature=data[sampletemp], - mask=mask, - sampling_steps=data[num_steps] // 2, - sample_cutoff=data[sample_cutoff], - seed=_seed, - ) + zv = interface.coarse_to_fine(zv, temperature=data[temp]) sig = interface.to_signal(zv).cpu() print("done") @@ -215,9 +157,7 @@ def save_vamp(data): sig_out.write(out_dir / "output.wav") _data = { - "masktemp": data[masktemp], - "sampletemp": data[sampletemp], - "top_p": data[top_p], + "temp": data[temp], "prefix_s": data[prefix_s], "suffix_s": data[suffix_s], "rand_mask_intensity": data[rand_mask_intensity], @@ -228,8 +168,6 @@ def save_vamp(data): "n_conditioning_codebooks": data[n_conditioning_codebooks], "use_coarse2fine": data[use_coarse2fine], "stretch_factor": data[stretch_factor], - "seed": data[seed], - "samplecutoff": data[sample_cutoff], } # save with yaml @@ -245,54 +183,13 @@ def save_vamp(data): return f"saved! your save code is {out_dir.stem}", zip_path -def harp_vamp(_input_audio, _beat_mask_width, _sampletemp): - - out_dir = OUT_DIR / str(uuid.uuid4()) - out_dir.mkdir() - sig = at.AudioSignal(_input_audio) - sig = interface.preprocess(sig) - - z = interface.encode(sig) - - # build the mask - mask = pmask.linear_random(z, 1.0) - if _beat_mask_width > 0: - beat_mask = interface.make_beat_mask( - sig, - after_beat_s=(_beat_mask_width/1000), - ) - mask = pmask.mask_and(mask, beat_mask) - - # save the mask as a txt file - zv, mask_z = interface.coarse_vamp( - z, - mask=mask, - sampling_temperature=_sampletemp, - return_mask=True, - gen_fn=interface.coarse.generate, - ) - - - zv = interface.coarse_to_fine( - zv, - sampling_temperature=_sampletemp, - mask=mask, - ) - - sig = interface.to_signal(zv).cpu() - print("done") - - sig.write(out_dir / "output.wav") - - return sig.path_to_file - with gr.Blocks() as demo: with gr.Row(): with gr.Column(): - gr.Markdown("# VampNet Audio Vamping") + gr.Markdown("# VampNet") gr.Markdown("""## Description: - This is a demo of the VampNet, a generative audio model that transforms the input audio based on the chosen settings. + This is a demo of VampNet, a masked generative music model capable of doing music variations. You can control the extent and nature of variation with a set of manual controls and presets. Use this interface to experiment with different mask settings and explore the audio outputs. """) @@ -300,8 +197,8 @@ with gr.Blocks() as demo: gr.Markdown(""" ## Instructions: 1. You can start by uploading some audio, or by loading the example audio. - 2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings. - 3. Click the "generate (vamp)!!!" button to apply the vamp operation. Listen to the output audio. + 2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings. Click the load preset button. + 3. Click the "generate (vamp)!!!" button to generate audio. Listen to the output audio, and the masked audio to hear the mask hints. 4. Optionally, you can add some notes and save the result. 5. You can also use the output as the new input and continue experimenting! """) @@ -352,12 +249,15 @@ with gr.Blocks() as demo: "beat_mask_downbeats": False, }, "slight periodic variation": { +<<<<<<< HEAD "periodic_p": 5, "onset_mask_width": 5, "beat_mask_width": 0, "beat_mask_downbeats": False, }, "moderate periodic variation": { +======= +>>>>>>> main "periodic_p": 13, "onset_mask_width": 5, "beat_mask_width": 0, @@ -378,15 +278,9 @@ with gr.Blocks() as demo: "beat-driven variation": { "periodic_p": 0, "onset_mask_width": 0, - "beat_mask_width": 50, + "beat_mask_width": 20, "beat_mask_downbeats": False, }, - "beat-driven variation (downbeats only)": { - "periodic_p": 0, - "onset_mask_width": 0, - "beat_mask_width": 50, - "beat_mask_downbeats": True, - }, "beat-driven variation (downbeats only, strong)": { "periodic_p": 0, "onset_mask_width": 0, @@ -408,20 +302,20 @@ with gr.Blocks() as demo: minimum=0, maximum=128, step=1, - value=3, + value=13, ) onset_mask_width = gr.Slider( label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ", minimum=0, - maximum=100, + maximum=20, step=1, value=5, ) beat_mask_width = gr.Slider( - label="beat prompt (ms)", + label="beat mask width (in milliseconds)", minimum=0, maximum=200, value=0, @@ -433,14 +327,6 @@ with gr.Blocks() as demo: with gr.Accordion("extras ", open=False): - pitch_shift_amt = gr.Slider( - label="pitch shift amount (semitones)", - minimum=-12, - maximum=12, - step=1, - value=0, - ) - rand_mask_intensity = gr.Slider( label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)", minimum=0.0, @@ -500,34 +386,21 @@ with gr.Blocks() as demo: value=0.0 ) - masktemp = gr.Slider( - label="mask temperature", + temp = gr.Slider( + label="temperature", minimum=0.0, - maximum=100.0, - value=1.5 - ) - sampletemp = gr.Slider( - label="sample temperature", - minimum=0.1, maximum=10.0, - value=1.0, - step=0.001 + value=1.8 ) - + with gr.Accordion("sampling settings", open=False): - top_p = gr.Slider( - label="top p (0.0 = off)", - minimum=0.0, - maximum=1.0, - value=0.9 - ) typical_filtering = gr.Checkbox( label="typical filtering ", value=False ) - typical_mass = gr.Slider( + typical_mass = gr.Slider( label="typical mass (should probably stay between 0.1 and 0.5)", minimum=0.01, maximum=0.99, @@ -540,18 +413,10 @@ with gr.Blocks() as demo: step=1, value=64 ) - sample_cutoff = gr.Slider( - label="sample cutoff", - minimum=0.0, - maximum=1.0, - value=0.5, - step=0.01 - ) use_coarse2fine = gr.Checkbox( label="use coarse2fine", - value=True, - visible=False + value=True ) num_steps = gr.Slider( @@ -571,24 +436,8 @@ with gr.Blocks() as demo: ) - seed = gr.Number( - label="seed (0 for random)", - value=0, - precision=0, - ) - - - # mask settings with gr.Column(): - - # lora_choice = gr.Dropdown( - # label="lora choice", - # choices=list(loras.keys()), - # value=LORA_NONE, - # visible=False - # ) - vamp_button = gr.Button("generate (vamp)!!!") output_audio = gr.Audio( label="output audio", @@ -614,9 +463,7 @@ with gr.Blocks() as demo: _inputs = { input_audio, num_steps, - masktemp, - sampletemp, - top_p, + temp, prefix_s, suffix_s, rand_mask_intensity, periodic_p, periodic_w, @@ -629,11 +476,7 @@ with gr.Blocks() as demo: typical_mass, typical_min_tokens, beat_mask_width, - beat_mask_downbeats, - seed, - # lora_choice, - pitch_shift_amt, - sample_cutoff + beat_mask_downbeats } # connect widgets @@ -663,24 +506,4 @@ with gr.Blocks() as demo: outputs=[thank_you, download_file] ) - # harp stuff - harp_inputs = [ - input_audio, - beat_mask_width, - sampletemp, - ] - - build_endpoint( - inputs=harp_inputs, - output=output_audio, - process_fn=harp_vamp, - card=ModelCard( - name="vampnet", - description="Generate variations on music input, based on small prompts around the beat. NOTE: vampnet's has a maximum context length of 10 seconds. Please split all audio clips into 10 second chunks, or processing will result in an error. ", - author="Hugo Flores GarcĂ­a", - tags=["music", "generative"] - ), - visible=False - ) - -demo.launch() +demo.queue().launch() diff --git a/conf/generated-v0/berta-goldman-speech/c2f.yml b/conf/generated-v0/berta-goldman-speech/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..0f5a4cd57e7a801121d7c77a62a0e8767b7fe61c --- /dev/null +++ b/conf/generated-v0/berta-goldman-speech/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +save_path: ./runs/berta-goldman-speech/c2f +train/AudioLoader.sources: +- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 +val/AudioLoader.sources: +- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 diff --git a/conf/generated-v0/berta-goldman-speech/coarse.yml b/conf/generated-v0/berta-goldman-speech/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..7c1207e9cfe83bac59f76fcf21068405cd6c9551 --- /dev/null +++ b/conf/generated-v0/berta-goldman-speech/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +save_path: ./runs/berta-goldman-speech/coarse +train/AudioLoader.sources: +- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 +val/AudioLoader.sources: +- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 diff --git a/conf/generated-v0/berta-goldman-speech/interface.yml b/conf/generated-v0/berta-goldman-speech/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..d1ba35ec732a0148f3bced5542e27e85575c4d4e --- /dev/null +++ b/conf/generated-v0/berta-goldman-speech/interface.yml @@ -0,0 +1,5 @@ +AudioLoader.sources: +- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 +Interface.coarse2fine_ckpt: ./runs/berta-goldman-speech/c2f/best/vampnet/weights.pth +Interface.coarse_ckpt: ./runs/berta-goldman-speech/coarse/best/vampnet/weights.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated-v0/gamelan-xeno-canto/c2f.yml b/conf/generated-v0/gamelan-xeno-canto/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..9e6fec4ddc7dd0a2e02d1be66cc7f6eafa669ed1 --- /dev/null +++ b/conf/generated-v0/gamelan-xeno-canto/c2f.yml @@ -0,0 +1,17 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +save_path: ./runs/gamelan-xeno-canto/c2f +train/AudioLoader.sources: +- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 +- /media/CHONK/hugo/loras/xeno-canto-2 +val/AudioLoader.sources: +- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 +- /media/CHONK/hugo/loras/xeno-canto-2 diff --git a/conf/generated-v0/gamelan-xeno-canto/coarse.yml b/conf/generated-v0/gamelan-xeno-canto/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..7e8d38e18d714cb08db7ed456939737404533c3e --- /dev/null +++ b/conf/generated-v0/gamelan-xeno-canto/coarse.yml @@ -0,0 +1,10 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +save_path: ./runs/gamelan-xeno-canto/coarse +train/AudioLoader.sources: +- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 +- /media/CHONK/hugo/loras/xeno-canto-2 +val/AudioLoader.sources: +- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 +- /media/CHONK/hugo/loras/xeno-canto-2 diff --git a/conf/generated-v0/gamelan-xeno-canto/interface.yml b/conf/generated-v0/gamelan-xeno-canto/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..e567800477816ac1cc41719744c1ba40562e35b1 --- /dev/null +++ b/conf/generated-v0/gamelan-xeno-canto/interface.yml @@ -0,0 +1,6 @@ +AudioLoader.sources: +- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 +- /media/CHONK/hugo/loras/xeno-canto-2 +Interface.coarse2fine_ckpt: ./runs/gamelan-xeno-canto/c2f/best/vampnet/weights.pth +Interface.coarse_ckpt: ./runs/gamelan-xeno-canto/coarse/best/vampnet/weights.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated-v0/nasralla/c2f.yml b/conf/generated-v0/nasralla/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..9d9db7bed268c18f3ca4047dcde34dd18a5a2301 --- /dev/null +++ b/conf/generated-v0/nasralla/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +save_path: ./runs/nasralla/c2f +train/AudioLoader.sources: +- /media/CHONK/hugo/nasralla +val/AudioLoader.sources: +- /media/CHONK/hugo/nasralla diff --git a/conf/generated-v0/nasralla/coarse.yml b/conf/generated-v0/nasralla/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..43a4d18c7f955e38200ded0d2a4fa0959ddb639e --- /dev/null +++ b/conf/generated-v0/nasralla/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +save_path: ./runs/nasralla/coarse +train/AudioLoader.sources: +- /media/CHONK/hugo/nasralla +val/AudioLoader.sources: +- /media/CHONK/hugo/nasralla diff --git a/conf/generated-v0/nasralla/interface.yml b/conf/generated-v0/nasralla/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..c93e872d1e4b66567812755882a996814794ad8f --- /dev/null +++ b/conf/generated-v0/nasralla/interface.yml @@ -0,0 +1,5 @@ +AudioLoader.sources: +- /media/CHONK/hugo/nasralla +Interface.coarse2fine_ckpt: ./runs/nasralla/c2f/best/vampnet/weights.pth +Interface.coarse_ckpt: ./runs/nasralla/coarse/best/vampnet/weights.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/breaks-steps/c2f.yml b/conf/generated/breaks-steps/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..49617a6d52de00a9bc7c82c6e820168076402fac --- /dev/null +++ b/conf/generated/breaks-steps/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/breaks-steps/c2f +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/breaks-steps +val/AudioLoader.sources: *id001 diff --git a/conf/generated/breaks-steps/coarse.yml b/conf/generated/breaks-steps/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..71d9b27fbc4aac7d407d3606e98c4eaca35e2d3f --- /dev/null +++ b/conf/generated/breaks-steps/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/breaks-steps/coarse +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/breaks-steps +val/AudioLoader.sources: *id001 diff --git a/conf/generated/breaks-steps/interface.yml b/conf/generated/breaks-steps/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..b4b5182c4a378884e1614d89bc39abdf78a4eaa2 --- /dev/null +++ b/conf/generated/breaks-steps/interface.yml @@ -0,0 +1,7 @@ +AudioLoader.sources: +- - /media/CHONK/hugo/breaks-steps +Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/breaks-steps/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/breaks-steps/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/bulgarian-tv-choir/c2f.yml b/conf/generated/bulgarian-tv-choir/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..7bc54bf54bc8cc5c599a11f30c036822fa4b84c5 --- /dev/null +++ b/conf/generated/bulgarian-tv-choir/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/bulgarian-tv-choir/c2f +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/ +val/AudioLoader.sources: *id001 diff --git a/conf/generated/bulgarian-tv-choir/coarse.yml b/conf/generated/bulgarian-tv-choir/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..06f27f140dbd8c6d6315aab0787435ff501f8958 --- /dev/null +++ b/conf/generated/bulgarian-tv-choir/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/bulgarian-tv-choir/coarse +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/ +val/AudioLoader.sources: *id001 diff --git a/conf/generated/bulgarian-tv-choir/interface.yml b/conf/generated/bulgarian-tv-choir/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..b56e8d721adf99da361dadf423a669bb576478e1 --- /dev/null +++ b/conf/generated/bulgarian-tv-choir/interface.yml @@ -0,0 +1,7 @@ +AudioLoader.sources: +- - /media/CHONK/hugo/loras/bulgarian-female-tv-choir/ +Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/bulgarian-tv-choir/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/bulgarian-tv-choir/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/dariacore/c2f.yml b/conf/generated/dariacore/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..e8e52fc05be63fe891d3adf0c2115efd5e06ecef --- /dev/null +++ b/conf/generated/dariacore/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/dariacore/c2f +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/dariacore +val/AudioLoader.sources: *id001 diff --git a/conf/generated/dariacore/coarse.yml b/conf/generated/dariacore/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..42044d7bbafbf890d6d6bc504beb49edf977c39b --- /dev/null +++ b/conf/generated/dariacore/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/dariacore/coarse +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/dariacore +val/AudioLoader.sources: *id001 diff --git a/conf/generated/dariacore/interface.yml b/conf/generated/dariacore/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..29342d2fe9d97f20d9521885869f1cca16d2aeba --- /dev/null +++ b/conf/generated/dariacore/interface.yml @@ -0,0 +1,7 @@ +AudioLoader.sources: +- - /media/CHONK/hugo/loras/dariacore +Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/dariacore/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/dariacore/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/musica-bolero-marimba/c2f.yml b/conf/generated/musica-bolero-marimba/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..cd06c72814deaf9fd41d3dabc8e6046e050ad968 --- /dev/null +++ b/conf/generated/musica-bolero-marimba/c2f.yml @@ -0,0 +1,18 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/musica-bolero-marimba/c2f +train/AudioLoader.sources: +- /media/CHONK/hugo/loras/boleros +- /media/CHONK/hugo/loras/marimba-honduras +val/AudioLoader.sources: +- /media/CHONK/hugo/loras/boleros +- /media/CHONK/hugo/loras/marimba-honduras diff --git a/conf/generated/musica-bolero-marimba/coarse.yml b/conf/generated/musica-bolero-marimba/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..a3e1c0ee8e8593528cb389fb84c56894727cfca5 --- /dev/null +++ b/conf/generated/musica-bolero-marimba/coarse.yml @@ -0,0 +1,11 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/musica-bolero-marimba/coarse +train/AudioLoader.sources: +- /media/CHONK/hugo/loras/boleros +- /media/CHONK/hugo/loras/marimba-honduras +val/AudioLoader.sources: +- /media/CHONK/hugo/loras/boleros +- /media/CHONK/hugo/loras/marimba-honduras diff --git a/conf/generated/musica-bolero-marimba/interface.yml b/conf/generated/musica-bolero-marimba/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..08b42e3120a3cedbb5aafb9a39ca879d8958127a --- /dev/null +++ b/conf/generated/musica-bolero-marimba/interface.yml @@ -0,0 +1,8 @@ +AudioLoader.sources: +- /media/CHONK/hugo/loras/boleros +- /media/CHONK/hugo/loras/marimba-honduras +Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/musica-bolero-marimba/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/musica-bolero-marimba/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/panchos/c2f.yml b/conf/generated/panchos/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..4efd6fb4caf409382929dcf61d40ed37e3773eac --- /dev/null +++ b/conf/generated/panchos/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/panchos/c2f +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/panchos/ +val/AudioLoader.sources: *id001 diff --git a/conf/generated/panchos/coarse.yml b/conf/generated/panchos/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..c4f21a3f4deb58cd6b98680e82d59ad32098542e --- /dev/null +++ b/conf/generated/panchos/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/panchos/coarse +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/panchos/ +val/AudioLoader.sources: *id001 diff --git a/conf/generated/panchos/interface.yml b/conf/generated/panchos/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..8bae11c225a0fa49c27efdfc808a63d53c21755a --- /dev/null +++ b/conf/generated/panchos/interface.yml @@ -0,0 +1,7 @@ +AudioLoader.sources: +- - /media/CHONK/hugo/loras/panchos/ +Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/panchos/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/panchos/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/titi-monkey/c2f.yml b/conf/generated/titi-monkey/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..456912ab1589eee1dfe6c5768e70ede4e455c828 --- /dev/null +++ b/conf/generated/titi-monkey/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/titi-monkey/c2f +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/titi-monkey.mp3 +val/AudioLoader.sources: *id001 diff --git a/conf/generated/titi-monkey/coarse.yml b/conf/generated/titi-monkey/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..c2af934aa5aff33c26ae95a2d7a46eb19f9b7194 --- /dev/null +++ b/conf/generated/titi-monkey/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/titi-monkey/coarse +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/titi-monkey.mp3 +val/AudioLoader.sources: *id001 diff --git a/conf/generated/titi-monkey/interface.yml b/conf/generated/titi-monkey/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..cbc4ffad24c7c3b34e930aff08404955348b49a2 --- /dev/null +++ b/conf/generated/titi-monkey/interface.yml @@ -0,0 +1,7 @@ +AudioLoader.sources: +- - /media/CHONK/hugo/loras/titi-monkey.mp3 +Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/titi-monkey/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/titi-monkey/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/xeno-canto/c2f.yml b/conf/generated/xeno-canto/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..251b0e361ee15d01f7715608480cb3d5e9fdb122 --- /dev/null +++ b/conf/generated/xeno-canto/c2f.yml @@ -0,0 +1,15 @@ +$include: +- conf/lora/lora.yml +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 +VampNet.embedding_dim: 1280 +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 +VampNet.n_heads: 20 +VampNet.n_layers: 16 +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/c2f.pth +save_path: ./runs/xeno-canto/c2f +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/xeno-canto-2/ +val/AudioLoader.sources: *id001 diff --git a/conf/generated/xeno-canto/coarse.yml b/conf/generated/xeno-canto/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..ea151dbb64ff13982b0004685901da2b58c8e596 --- /dev/null +++ b/conf/generated/xeno-canto/coarse.yml @@ -0,0 +1,8 @@ +$include: +- conf/lora/lora.yml +fine_tune: true +fine_tune_checkpoint: ./models/spotdl/coarse.pth +save_path: ./runs/xeno-canto/coarse +train/AudioLoader.sources: &id001 +- /media/CHONK/hugo/loras/xeno-canto-2/ +val/AudioLoader.sources: *id001 diff --git a/conf/generated/xeno-canto/interface.yml b/conf/generated/xeno-canto/interface.yml new file mode 100644 index 0000000000000000000000000000000000000000..1a8b1420f142cef024471073e674cd9db59ffad0 --- /dev/null +++ b/conf/generated/xeno-canto/interface.yml @@ -0,0 +1,7 @@ +AudioLoader.sources: +- - /media/CHONK/hugo/loras/xeno-canto-2/ +Interface.coarse2fine_ckpt: ./mod els/spotdl/c2f.pth +Interface.coarse2fine_lora_ckpt: ./runs/xeno-canto/c2f/latest/lora.pth +Interface.coarse_ckpt: ./models/spotdl/coarse.pth +Interface.coarse_lora_ckpt: ./runs/xeno-canto/coarse/latest/lora.pth +Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/lora/birds.yml b/conf/lora/birds.yml new file mode 100644 index 0000000000000000000000000000000000000000..de413ec0dec4f974e664923c9319861a1c957e87 --- /dev/null +++ b/conf/lora/birds.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/birds + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/birds diff --git a/conf/lora/birdss.yml b/conf/lora/birdss.yml new file mode 100644 index 0000000000000000000000000000000000000000..3526de67d24e296de2cc0a7d2e5ebbc18245a6c8 --- /dev/null +++ b/conf/lora/birdss.yml @@ -0,0 +1,12 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/birds + - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/ + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/birds + - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/ diff --git a/conf/lora/constructions.yml b/conf/lora/constructions.yml new file mode 100644 index 0000000000000000000000000000000000000000..f513b4898e06339fa0d0b4af24e98fdf5289094a --- /dev/null +++ b/conf/lora/constructions.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3 + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3 diff --git a/conf/lora/ella-baila-sola.yml b/conf/lora/ella-baila-sola.yml new file mode 100644 index 0000000000000000000000000000000000000000..24eeada8013ea0d56d7d6474db52a48c3fd43bc1 --- /dev/null +++ b/conf/lora/ella-baila-sola.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3 + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3 diff --git a/conf/lora/gas-station.yml b/conf/lora/gas-station.yml new file mode 100644 index 0000000000000000000000000000000000000000..4369f9203232fa3dcfd21667f3e55d0d0fda108e --- /dev/null +++ b/conf/lora/gas-station.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3 + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3 diff --git a/conf/lora/lora-is-this-charlie-parker.yml b/conf/lora/lora-is-this-charlie-parker.yml new file mode 100644 index 0000000000000000000000000000000000000000..9cfaa31a421266fafa60a1ee4bb2d45f1c47577c --- /dev/null +++ b/conf/lora/lora-is-this-charlie-parker.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3 + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3 diff --git a/conf/lora/lora.yml b/conf/lora/lora.yml index 6acadde3f8a361c6eb56bc93686beb3fc103be75..b901ea00a6008b92f25728d6d01a258c6aba5d1e 100644 --- a/conf/lora/lora.yml +++ b/conf/lora/lora.yml @@ -3,20 +3,20 @@ $include: fine_tune: True -train/AudioDataset.n_examples: 100000000 -val/AudioDataset.n_examples: 500 +train/AudioDataset.n_examples: 10000000 + +val/AudioDataset.n_examples: 10 NoamScheduler.warmup: 500 -batch_size: 6 +batch_size: 7 num_workers: 7 -save_iters: [10000, 20000, 30000, 40000, 50000, 100000] -sample_freq: 1000 -val_freq: 500 +epoch_length: 100 +save_audio_epochs: 10 AdamW.lr: 0.0001 # let's us organize sound classes into folders and choose from those sound classes uniformly AudioDataset.without_replacement: False -num_iters: 500000 \ No newline at end of file +max_epochs: 500 \ No newline at end of file diff --git a/conf/lora/underworld.yml b/conf/lora/underworld.yml new file mode 100644 index 0000000000000000000000000000000000000000..6fd1a6cf1e74220a2b51b1117afb373acda033a7 --- /dev/null +++ b/conf/lora/underworld.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/underworld.mp3 + +val/AudioLoader.sources: + - /media/CHONK/hugo/spotdl/subsets/underworld.mp3 diff --git a/conf/lora/xeno-canto/c2f.yml b/conf/lora/xeno-canto/c2f.yml new file mode 100644 index 0000000000000000000000000000000000000000..94f9906189f0b74b6c492bdd53fa56d58a0fa04d --- /dev/null +++ b/conf/lora/xeno-canto/c2f.yml @@ -0,0 +1,21 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/xeno-canto-2 + +val/AudioLoader.sources: + - /media/CHONK/hugo/xeno-canto-2 + + +VampNet.n_codebooks: 14 +VampNet.n_conditioning_codebooks: 4 + +VampNet.embedding_dim: 1280 +VampNet.n_layers: 16 +VampNet.n_heads: 20 + +AudioDataset.duration: 3.0 +AudioDataset.loudness_cutoff: -40.0 diff --git a/conf/lora/xeno-canto/coarse.yml b/conf/lora/xeno-canto/coarse.yml new file mode 100644 index 0000000000000000000000000000000000000000..223c8f0f8481f55ac1c33816ed79fe45b50f1495 --- /dev/null +++ b/conf/lora/xeno-canto/coarse.yml @@ -0,0 +1,10 @@ +$include: + - conf/lora/lora.yml + +fine_tune: True + +train/AudioLoader.sources: + - /media/CHONK/hugo/xeno-canto-2 + +val/AudioLoader.sources: + - /media/CHONK/hugo/xeno-canto-2 diff --git a/conf/vampnet-musdb-drums.yml b/conf/vampnet-musdb-drums.yml new file mode 100644 index 0000000000000000000000000000000000000000..010843d81ec9ac3c832b8e88f30af2f99a56ba99 --- /dev/null +++ b/conf/vampnet-musdb-drums.yml @@ -0,0 +1,22 @@ +$include: + - conf/vampnet.yml + +VampNet.embedding_dim: 512 +VampNet.n_layers: 12 +VampNet.n_heads: 8 + +AudioDataset.duration: 12.0 + +train/AudioDataset.n_examples: 10000000 +train/AudioLoader.sources: + - /data/musdb18hq/train/**/*drums.wav + + +val/AudioDataset.n_examples: 500 +val/AudioLoader.sources: + - /data/musdb18hq/test/**/*drums.wav + + +test/AudioDataset.n_examples: 1000 +test/AudioLoader.sources: + - /data/musdb18hq/test/**/*drums.wav diff --git a/conf/vampnet.yml b/conf/vampnet.yml index 7323925f800891ad3e08420bf130d327e21cebf6..d24df3fc1923eeb98f76f5747a52c3e83ef98795 100644 --- a/conf/vampnet.yml +++ b/conf/vampnet.yml @@ -1,17 +1,21 @@ -codec_ckpt: ./models/vampnet/codec.pth +codec_ckpt: ./models/spotdl/codec.pth save_path: ckpt - -num_iters: 1000000000 -save_iters: [10000, 50000, 100000, 300000, 500000] +max_epochs: 1000 +epoch_length: 1000 +save_audio_epochs: 2 val_idx: [0,1,2,3,4,5,6,7,8,9] -sample_freq: 10000 -val_freq: 1000 + +prefix_amt: 0.0 +suffix_amt: 0.0 +prefix_dropout: 0.1 +suffix_dropout: 0.1 batch_size: 8 num_workers: 10 # Optimization +detect_anomaly: false amp: false CrossEntropyLoss.label_smoothing: 0.1 @@ -21,6 +25,9 @@ AdamW.lr: 0.001 NoamScheduler.factor: 2.0 NoamScheduler.warmup: 10000 +PitchShift.shift_amount: [const, 0] +PitchShift.prob: 0.0 + VampNet.vocab_size: 1024 VampNet.n_codebooks: 4 VampNet.n_conditioning_codebooks: 0 @@ -32,7 +39,7 @@ VampNet.n_heads: 20 VampNet.flash_attn: false VampNet.dropout: 0.1 -AudioLoader.relative_path: "" +AudioLoader.relative_path: /data/ AudioDataset.loudness_cutoff: -30.0 AudioDataset.without_replacement: true AudioLoader.shuffle: true @@ -41,9 +48,12 @@ AudioDataset.duration: 10.0 train/AudioDataset.n_examples: 10000000 train/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/audio-train + - /data/spotdl/audio/train val/AudioDataset.n_examples: 2000 val/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/audio-val + - /data/spotdl/audio/val +test/AudioDataset.n_examples: 1000 +test/AudioLoader.sources: + - /data/spotdl/audio/test diff --git a/requirements.txt b/requirements.txt index 04be1a8f954a0c7d577aa413e538bc17f6c9d01e..4feab2a44cbfd4d65f5e7fdfbe2c5a3363790c6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,8 @@ torch argbind>=0.3.2 -numpy==1.23 +numpy==1.22 gradio loralib wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat lac @ git+https://github.com/hugofloresgarcia/lac.git -descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2 --e git+https://github.com/audacitorch/pyharp.git#egg=pyharp -torch_pitch_shift +audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git \ No newline at end of file diff --git a/scripts/exp/fine_tune.py b/scripts/exp/fine_tune.py index af82fcc7f260607a2efff9fad419271ad1a203d8..e2c6c3b768f585242705e5cdabeebe45ced557cf 100644 --- a/scripts/exp/fine_tune.py +++ b/scripts/exp/fine_tune.py @@ -35,7 +35,7 @@ def fine_tune(audio_files_or_folders: List[str], name: str): "AudioDataset.duration": 3.0, "AudioDataset.loudness_cutoff": -40.0, "save_path": f"./runs/{name}/c2f", - "fine_tune_checkpoint": "./models/vampnet/c2f.pth" + "fine_tune_checkpoint": "./models/spotdl/c2f.pth" } finetune_coarse_conf = { @@ -44,16 +44,17 @@ def fine_tune(audio_files_or_folders: List[str], name: str): "train/AudioLoader.sources": audio_files_or_folders, "val/AudioLoader.sources": audio_files_or_folders, "save_path": f"./runs/{name}/coarse", - "fine_tune_checkpoint": "./models/vampnet/coarse.pth" + "fine_tune_checkpoint": "./models/spotdl/coarse.pth" } interface_conf = { - "Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth", + "Interface.coarse_ckpt": f"./models/spotdl/coarse.pth", + "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth", - "Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth", - "Interface.wavebeat_ckpt": "./models/wavebeat.pth", + "Interface.coarse2fine_ckpt": f"./models/spotdl/c2f.pth", + "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth", - "Interface.codec_ckpt": "./models/vampnet/codec.pth", + "Interface.codec_ckpt": "./models/spotdl/codec.pth", "AudioLoader.sources": [audio_files_or_folders], } diff --git a/scripts/exp/train.py b/scripts/exp/train.py index d276bc29421a796346c027b711481921b3e51d6b..79251a529c9512b7bf8c2613e6ae173df21c5c61 100644 --- a/scripts/exp/train.py +++ b/scripts/exp/train.py @@ -1,9 +1,9 @@ import os -import sys +import subprocess +import time import warnings from pathlib import Path from typing import Optional -from dataclasses import dataclass import argbind import audiotools as at @@ -14,7 +14,7 @@ from audiotools.data import transforms from einops import rearrange from rich import pretty from rich.traceback import install -from torch.utils.tensorboard import SummaryWriter +from tensorboardX import SummaryWriter import vampnet from vampnet.modules.transformer import VampNet @@ -23,15 +23,6 @@ from vampnet import mask as pmask # from dac.model.dac import DAC from lac.model.lac import LAC as DAC -from audiotools.ml.decorators import ( - timer, Tracker, when -) - -import loralib as lora - -import torch._dynamo -torch._dynamo.config.verbose=True - # Enable cudnn autotuner to speed up training # (can be altered by the funcs.seed function) @@ -94,7 +85,11 @@ def build_datasets(args, sample_rate: int): ) with argbind.scope(args, "val"): val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform()) - return train_data, val_data + with argbind.scope(args, "test"): + test_data = AudioDataset( + AudioLoader(), sample_rate, transform=build_transform() + ) + return train_data, val_data, test_data def rand_float(shape, low, high, rng): @@ -105,392 +100,16 @@ def flip_coin(shape, p, rng): return rng.draw(shape)[:, 0] < p -def num_params_hook(o, p): - return o + f" {p/1e6:<.3f}M params." - - -def add_num_params_repr_hook(model): - import numpy as np - from functools import partial - - for n, m in model.named_modules(): - o = m.extra_repr() - p = sum([np.prod(p.size()) for p in m.parameters()]) - - setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p)) - - -def accuracy( - preds: torch.Tensor, - target: torch.Tensor, - top_k: int = 1, - ignore_index: Optional[int] = None, -) -> torch.Tensor: - # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class) - preds = rearrange(preds, "b p s -> (b s) p") - target = rearrange(target, "b s -> (b s)") - - # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index) - if ignore_index is not None: - # Create a mask for the ignored index - mask = target != ignore_index - # Apply the mask to the target and predictions - preds = preds[mask] - target = target[mask] - - # Get the top-k predicted classes and their indices - _, pred_indices = torch.topk(preds, k=top_k, dim=-1) - - # Determine if the true target is in the top-k predicted classes - correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1) - - # Calculate the accuracy - accuracy = torch.mean(correct.float()) - - return accuracy - -def _metrics(z_hat, r, target, flat_mask, output): - for r_range in [(0, 0.5), (0.5, 1.0)]: - unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX) - masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) - - assert target.shape[0] == r.shape[0] - # grab the indices of the r values that are in the range - r_idx = (r >= r_range[0]) & (r < r_range[1]) - - # grab the target and z_hat values that are in the range - r_unmasked_target = unmasked_target[r_idx] - r_masked_target = masked_target[r_idx] - r_z_hat = z_hat[r_idx] - - for topk in (1, 25): - s, e = r_range - tag = f"accuracy-{s}-{e}/top{topk}" - - output[f"{tag}/unmasked"] = accuracy( - preds=r_z_hat, - target=r_unmasked_target, - ignore_index=IGNORE_INDEX, - top_k=topk, - ) - output[f"{tag}/masked"] = accuracy( - preds=r_z_hat, - target=r_masked_target, - ignore_index=IGNORE_INDEX, - top_k=topk, - ) - - -@dataclass -class State: - model: VampNet - codec: DAC - - optimizer: AdamW - scheduler: NoamScheduler - criterion: CrossEntropyLoss - grad_clip_val: float - - rng: torch.quasirandom.SobolEngine - - train_data: AudioDataset - val_data: AudioDataset - - tracker: Tracker - - -@timer() -def train_loop(state: State, batch: dict, accel: Accelerator): - state.model.train() - batch = at.util.prepare_batch(batch, accel.device) - signal = apply_transform(state.train_data.transform, batch) - - output = {} - vn = accel.unwrap(state.model) - with accel.autocast(): - with torch.inference_mode(): - state.codec.to(accel.device) - z = state.codec.encode(signal.samples, signal.sample_rate)["codes"] - z = z[:, : vn.n_codebooks, :] - - n_batch = z.shape[0] - r = state.rng.draw(n_batch)[:, 0].to(accel.device) - - mask = pmask.random(z, r) - mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) - z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) - - z_mask_latent = vn.embedding.from_codes(z_mask, state.codec) - - dtype = torch.bfloat16 if accel.amp else None - with accel.autocast(dtype=dtype): - z_hat = state.model(z_mask_latent) - - target = codebook_flatten( - z[:, vn.n_conditioning_codebooks :, :], - ) - - flat_mask = codebook_flatten( - mask[:, vn.n_conditioning_codebooks :, :], - ) - - # replace target with ignore index for masked tokens - t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) - output["loss"] = state.criterion(z_hat, t_masked) - - _metrics( - r=r, - z_hat=z_hat, - target=target, - flat_mask=flat_mask, - output=output, - ) - - - accel.backward(output["loss"]) - - output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"] - output["other/batch_size"] = z.shape[0] - - - accel.scaler.unscale_(state.optimizer) - output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_( - state.model.parameters(), state.grad_clip_val - ) - - accel.step(state.optimizer) - state.optimizer.zero_grad() - - state.scheduler.step() - accel.update() - - - return {k: v for k, v in sorted(output.items())} - - -@timer() -@torch.no_grad() -def val_loop(state: State, batch: dict, accel: Accelerator): - state.model.eval() - state.codec.eval() - batch = at.util.prepare_batch(batch, accel.device) - signal = apply_transform(state.val_data.transform, batch) - - vn = accel.unwrap(state.model) - z = state.codec.encode(signal.samples, signal.sample_rate)["codes"] - z = z[:, : vn.n_codebooks, :] - - n_batch = z.shape[0] - r = state.rng.draw(n_batch)[:, 0].to(accel.device) - - mask = pmask.random(z, r) - mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) - z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) - - z_mask_latent = vn.embedding.from_codes(z_mask, state.codec) - - z_hat = state.model(z_mask_latent) - - target = codebook_flatten( - z[:, vn.n_conditioning_codebooks :, :], - ) - - flat_mask = codebook_flatten( - mask[:, vn.n_conditioning_codebooks :, :] - ) - - output = {} - # replace target with ignore index for masked tokens - t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) - output["loss"] = state.criterion(z_hat, t_masked) - - _metrics( - r=r, - z_hat=z_hat, - target=target, - flat_mask=flat_mask, - output=output, - ) - - return output - - -def validate(state, val_dataloader, accel): - for batch in val_dataloader: - output = val_loop(state, batch, accel) - # Consolidate state dicts if using ZeroRedundancyOptimizer - if hasattr(state.optimizer, "consolidate_state_dict"): - state.optimizer.consolidate_state_dict() - return output - - -def checkpoint(state, save_iters, save_path, fine_tune): - if accel.local_rank != 0: - state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}") - return - - metadata = {"logs": dict(state.tracker.history)} - - tags = ["latest"] - state.tracker.print(f"Saving to {str(Path('.').absolute())}") - - if state.tracker.step in save_iters: - tags.append(f"{state.tracker.step // 1000}k") - - if state.tracker.is_best("val", "loss"): - state.tracker.print(f"Best model so far") - tags.append("best") - - if fine_tune: - for tag in tags: - # save the lora model - (Path(save_path) / tag).mkdir(parents=True, exist_ok=True) - torch.save( - lora.lora_state_dict(accel.unwrap(state.model)), - f"{save_path}/{tag}/lora.pth" - ) - - for tag in tags: - model_extra = { - "optimizer.pth": state.optimizer.state_dict(), - "scheduler.pth": state.scheduler.state_dict(), - "tracker.pth": state.tracker.state_dict(), - "metadata.pth": metadata, - } - - accel.unwrap(state.model).metadata = metadata - accel.unwrap(state.model).save_to_folder( - f"{save_path}/{tag}", model_extra, package=False - ) - - -def save_sampled(state, z, writer): - num_samples = z.shape[0] - - for i in range(num_samples): - sampled = accel.unwrap(state.model).generate( - codec=state.codec, - time_steps=z.shape[-1], - start_tokens=z[i : i + 1], - ) - sampled.cpu().write_audio_to_tb( - f"sampled/{i}", - writer, - step=state.tracker.step, - plot_fn=None, - ) - - -def save_imputation(state, z, val_idx, writer): - n_prefix = int(z.shape[-1] * 0.25) - n_suffix = int(z.shape[-1] * 0.25) - - vn = accel.unwrap(state.model) - - mask = pmask.inpaint(z, n_prefix, n_suffix) - mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) - z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) - - imputed_noisy = vn.to_signal(z_mask, state.codec) - imputed_true = vn.to_signal(z, state.codec) - - imputed = [] - for i in range(len(z)): - imputed.append( - vn.generate( - codec=state.codec, - time_steps=z.shape[-1], - start_tokens=z[i][None, ...], - mask=mask[i][None, ...], - ) - ) - imputed = AudioSignal.batch(imputed) - - for i in range(len(val_idx)): - imputed_noisy[i].cpu().write_audio_to_tb( - f"inpainted_prompt/{i}", - writer, - step=state.tracker.step, - plot_fn=None, - ) - imputed[i].cpu().write_audio_to_tb( - f"inpainted_middle/{i}", - writer, - step=state.tracker.step, - plot_fn=None, - ) - imputed_true[i].cpu().write_audio_to_tb( - f"reconstructed/{i}", - writer, - step=state.tracker.step, - plot_fn=None, - ) - - -@torch.no_grad() -def save_samples(state: State, val_idx: int, writer: SummaryWriter): - state.model.eval() - state.codec.eval() - vn = accel.unwrap(state.model) - - batch = [state.val_data[i] for i in val_idx] - batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device) - - signal = apply_transform(state.val_data.transform, batch) - - z = state.codec.encode(signal.samples, signal.sample_rate)["codes"] - z = z[:, : vn.n_codebooks, :] - - r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device) - - - mask = pmask.random(z, r) - mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) - z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) - - z_mask_latent = vn.embedding.from_codes(z_mask, state.codec) - - z_hat = state.model(z_mask_latent) - - z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1) - z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks) - z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1) - - generated = vn.to_signal(z_pred, state.codec) - reconstructed = vn.to_signal(z, state.codec) - masked = vn.to_signal(z_mask.squeeze(1), state.codec) - - for i in range(generated.batch_size): - audio_dict = { - "original": signal[i], - "masked": masked[i], - "generated": generated[i], - "reconstructed": reconstructed[i], - } - for k, v in audio_dict.items(): - v.cpu().write_audio_to_tb( - f"onestep/_{i}.r={r[i]:0.2f}/{k}", - writer, - step=state.tracker.step, - plot_fn=None, - ) - - save_sampled(state=state, z=z, writer=writer) - save_imputation(state=state, z=z, val_idx=val_idx, writer=writer) - - - @argbind.bind(without_prefix=True) def load( args, accel: at.ml.Accelerator, - tracker: Tracker, save_path: str, resume: bool = False, tag: str = "latest", + load_weights: bool = False, fine_tune_checkpoint: Optional[str] = None, - grad_clip_val: float = 5.0, -) -> State: +): codec = DAC.load(args["codec_ckpt"], map_location="cpu") codec.eval() @@ -500,9 +119,8 @@ def load( kwargs = { "folder": f"{save_path}/{tag}", "map_location": "cpu", - "package": False, + "package": not load_weights, } - tracker.print(f"Loading checkpoint from {kwargs['folder']}") if (Path(kwargs["folder"]) / "vampnet").exists(): model, v_extra = VampNet.load_from_folder(**kwargs) else: @@ -513,14 +131,11 @@ def load( if args["fine_tune"]: assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint" - model = torch.compile( - VampNet.load(location=Path(fine_tune_checkpoint), - map_location="cpu", - ) - ) + model = VampNet.load(location=Path(fine_tune_checkpoint), map_location="cpu") + + model = VampNet() if model is None else model - model = torch.compile(VampNet()) if model is None else model model = accel.prepare_model(model) # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks @@ -532,57 +147,89 @@ def load( scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim) scheduler.step() + trainer_state = {"state_dict": None, "start_idx": 0} + if "optimizer.pth" in v_extra: optimizer.load_state_dict(v_extra["optimizer.pth"]) + if "scheduler.pth" in v_extra: scheduler.load_state_dict(v_extra["scheduler.pth"]) - if "tracker.pth" in v_extra: - tracker.load_state_dict(v_extra["tracker.pth"]) - - criterion = CrossEntropyLoss() + if "trainer.pth" in v_extra: + trainer_state = v_extra["trainer.pth"] - sample_rate = codec.sample_rate + return { + "model": model, + "codec": codec, + "optimizer": optimizer, + "scheduler": scheduler, + "trainer_state": trainer_state, + } - # a better rng for sampling from our schedule - rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"]) - # log a model summary w/ num params - if accel.local_rank == 0: - add_num_params_repr_hook(accel.unwrap(model)) - with open(f"{save_path}/model.txt", "w") as f: - f.write(repr(accel.unwrap(model))) - # load the datasets - train_data, val_data = build_datasets(args, sample_rate) - - return State( - tracker=tracker, - model=model, - codec=codec, - optimizer=optimizer, - scheduler=scheduler, - criterion=criterion, - rng=rng, - train_data=train_data, - val_data=val_data, - grad_clip_val=grad_clip_val, - ) +def num_params_hook(o, p): + return o + f" {p/1e6:<.3f}M params." + + +def add_num_params_repr_hook(model): + import numpy as np + from functools import partial + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + + setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p)) + + +def accuracy( + preds: torch.Tensor, + target: torch.Tensor, + top_k: int = 1, + ignore_index: Optional[int] = None, +) -> torch.Tensor: + # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class) + preds = rearrange(preds, "b p s -> (b s) p") + target = rearrange(target, "b s -> (b s)") + + # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index) + if ignore_index is not None: + # Create a mask for the ignored index + mask = target != ignore_index + # Apply the mask to the target and predictions + preds = preds[mask] + target = target[mask] + + # Get the top-k predicted classes and their indices + _, pred_indices = torch.topk(preds, k=top_k, dim=-1) + + # Determine if the true target is in the top-k predicted classes + correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1) + + # Calculate the accuracy + accuracy = torch.mean(correct.float()) + + return accuracy @argbind.bind(without_prefix=True) def train( args, accel: at.ml.Accelerator, - seed: int = 0, codec_ckpt: str = None, + seed: int = 0, save_path: str = "ckpt", - num_iters: int = int(1000e6), - save_iters: list = [10000, 50000, 100000, 300000, 500000,], - sample_freq: int = 10000, - val_freq: int = 1000, - batch_size: int = 12, + max_epochs: int = int(100e3), + epoch_length: int = 1000, + save_audio_epochs: int = 2, + save_epochs: list = [10, 50, 100, 200, 300, 400,], + batch_size: int = 48, + grad_acc_steps: int = 1, val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], num_workers: int = 10, + detect_anomaly: bool = False, + grad_clip_val: float = 5.0, fine_tune: bool = False, + quiet: bool = False, ): assert codec_ckpt is not None, "codec_ckpt is required" @@ -594,79 +241,376 @@ def train( writer = SummaryWriter(log_dir=f"{save_path}/logs/") argbind.dump_args(args, f"{save_path}/args.yml") - tracker = Tracker( - writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank - ) - # load the codec model - state: State = load( - args=args, - accel=accel, - tracker=tracker, - save_path=save_path) - print("initialized state.") + loaded = load(args, accel, save_path) + model = loaded["model"] + codec = loaded["codec"] + optimizer = loaded["optimizer"] + scheduler = loaded["scheduler"] + trainer_state = loaded["trainer_state"] + sample_rate = codec.sample_rate + + # a better rng for sampling from our schedule + rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=seed) + + # log a model summary w/ num params + if accel.local_rank == 0: + add_num_params_repr_hook(accel.unwrap(model)) + with open(f"{save_path}/model.txt", "w") as f: + f.write(repr(accel.unwrap(model))) + + # load the datasets + train_data, val_data, _ = build_datasets(args, sample_rate) train_dataloader = accel.prepare_dataloader( - state.train_data, - start_idx=state.tracker.step * batch_size, + train_data, + start_idx=trainer_state["start_idx"], num_workers=num_workers, batch_size=batch_size, - collate_fn=state.train_data.collate, + collate_fn=train_data.collate, ) val_dataloader = accel.prepare_dataloader( - state.val_data, + val_data, start_idx=0, num_workers=num_workers, batch_size=batch_size, - collate_fn=state.val_data.collate, - persistent_workers=num_workers > 0, + collate_fn=val_data.collate, ) - print("initialized dataloader.") - + criterion = CrossEntropyLoss() if fine_tune: - lora.mark_only_lora_as_trainable(state.model) - print("marked only lora as trainable.") + import loralib as lora + lora.mark_only_lora_as_trainable(model) + + + class Trainer(at.ml.BaseTrainer): + _last_grad_norm = 0.0 + + def _metrics(self, vn, z_hat, r, target, flat_mask, output): + for r_range in [(0, 0.5), (0.5, 1.0)]: + unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX) + masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) + + assert target.shape[0] == r.shape[0] + # grab the indices of the r values that are in the range + r_idx = (r >= r_range[0]) & (r < r_range[1]) + + # grab the target and z_hat values that are in the range + r_unmasked_target = unmasked_target[r_idx] + r_masked_target = masked_target[r_idx] + r_z_hat = z_hat[r_idx] + + for topk in (1, 25): + s, e = r_range + tag = f"accuracy-{s}-{e}/top{topk}" + + output[f"{tag}/unmasked"] = accuracy( + preds=r_z_hat, + target=r_unmasked_target, + ignore_index=IGNORE_INDEX, + top_k=topk, + ) + output[f"{tag}/masked"] = accuracy( + preds=r_z_hat, + target=r_masked_target, + ignore_index=IGNORE_INDEX, + top_k=topk, + ) + + def train_loop(self, engine, batch): + model.train() + batch = at.util.prepare_batch(batch, accel.device) + signal = apply_transform(train_data.transform, batch) + + output = {} + vn = accel.unwrap(model) + with accel.autocast(): + with torch.inference_mode(): + codec.to(accel.device) + z = codec.encode(signal.samples, signal.sample_rate)["codes"] + z = z[:, : vn.n_codebooks, :] + + n_batch = z.shape[0] + r = rng.draw(n_batch)[:, 0].to(accel.device) + + mask = pmask.random(z, r) + mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) + z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) + + z_mask_latent = vn.embedding.from_codes(z_mask, codec) + + dtype = torch.bfloat16 if accel.amp else None + with accel.autocast(dtype=dtype): + z_hat = model(z_mask_latent, r) + + target = codebook_flatten( + z[:, vn.n_conditioning_codebooks :, :], + ) + + flat_mask = codebook_flatten( + mask[:, vn.n_conditioning_codebooks :, :], + ) + + # replace target with ignore index for masked tokens + t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) + output["loss"] = criterion(z_hat, t_masked) + + self._metrics( + vn=vn, + r=r, + z_hat=z_hat, + target=target, + flat_mask=flat_mask, + output=output, + ) + + + accel.backward(output["loss"] / grad_acc_steps) + + output["other/learning_rate"] = optimizer.param_groups[0]["lr"] + output["other/batch_size"] = z.shape[0] + + if ( + (engine.state.iteration % grad_acc_steps == 0) + or (engine.state.iteration % epoch_length == 0) + or (engine.state.iteration % epoch_length == 1) + ): # (or we reached the end of the epoch) + accel.scaler.unscale_(optimizer) + output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_( + model.parameters(), grad_clip_val + ) + self._last_grad_norm = output["other/grad_norm"] + + accel.step(optimizer) + optimizer.zero_grad() + + scheduler.step() + accel.update() + else: + output["other/grad_norm"] = self._last_grad_norm + + return {k: v for k, v in sorted(output.items())} + + @torch.no_grad() + def val_loop(self, engine, batch): + model.eval() + codec.eval() + batch = at.util.prepare_batch(batch, accel.device) + signal = apply_transform(val_data.transform, batch) + + vn = accel.unwrap(model) + z = codec.encode(signal.samples, signal.sample_rate)["codes"] + z = z[:, : vn.n_codebooks, :] - # Wrap the functions so that they neatly track in TensorBoard + progress bars - # and only run when specific conditions are met. - global train_loop, val_loop, validate, save_samples, checkpoint + n_batch = z.shape[0] + r = rng.draw(n_batch)[:, 0].to(accel.device) - train_loop = tracker.log("train", "value", history=False)( - tracker.track("train", num_iters, completed=state.tracker.step)(train_loop) - ) - val_loop = tracker.track("val", len(val_dataloader))(val_loop) - validate = tracker.log("val", "mean")(validate) + mask = pmask.random(z, r) + mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) + z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) - save_samples = when(lambda: accel.local_rank == 0)(save_samples) - checkpoint = when(lambda: accel.local_rank == 0)(checkpoint) + z_mask_latent = vn.embedding.from_codes(z_mask, codec) - print("starting training loop.") - with tracker.live: - for tracker.step, batch in enumerate(train_dataloader, start=tracker.step): - train_loop(state, batch, accel) + z_hat = model(z_mask_latent, r) - last_iter = ( - tracker.step == num_iters - 1 if num_iters is not None else False + target = codebook_flatten( + z[:, vn.n_conditioning_codebooks :, :], ) - if tracker.step % sample_freq == 0 or last_iter: - save_samples(state, val_idx, writer) + flat_mask = codebook_flatten( + mask[:, vn.n_conditioning_codebooks :, :] + ) - if tracker.step % val_freq == 0 or last_iter: - validate(state, val_dataloader, accel) - checkpoint( - state=state, - save_iters=save_iters, - save_path=save_path, - fine_tune=fine_tune) + output = {} + # replace target with ignore index for masked tokens + t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) + output["loss"] = criterion(z_hat, t_masked) + + self._metrics( + vn=vn, + r=r, + z_hat=z_hat, + target=target, + flat_mask=flat_mask, + output=output, + ) - # Reset validation progress bar, print summary since last validation. - tracker.done("val", f"Iteration {tracker.step}") + return output - if last_iter: - break + def checkpoint(self, engine): + if accel.local_rank != 0: + print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}") + return + + metadata = {"logs": dict(engine.state.logs["epoch"])} + + if self.state.epoch % save_audio_epochs == 0: + self.save_samples() + + tags = ["latest"] + loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train" + self.print(f"Saving to {str(Path('.').absolute())}") + + if self.state.epoch in save_epochs: + tags.append(f"epoch={self.state.epoch}") + + if self.is_best(engine, loss_key): + self.print(f"Best model so far") + tags.append("best") + + if fine_tune: + for tag in tags: + # save the lora model + (Path(save_path) / tag).mkdir(parents=True, exist_ok=True) + torch.save( + lora.lora_state_dict(accel.unwrap(model)), + f"{save_path}/{tag}/lora.pth" + ) + + for tag in tags: + model_extra = { + "optimizer.pth": optimizer.state_dict(), + "scheduler.pth": scheduler.state_dict(), + "trainer.pth": { + "start_idx": self.state.iteration * batch_size, + "state_dict": self.state_dict(), + }, + "metadata.pth": metadata, + } + + accel.unwrap(model).metadata = metadata + accel.unwrap(model).save_to_folder( + f"{save_path}/{tag}", model_extra, + ) + + def save_sampled(self, z): + num_samples = z.shape[0] + + for i in range(num_samples): + sampled = accel.unwrap(model).generate( + codec=codec, + time_steps=z.shape[-1], + start_tokens=z[i : i + 1], + ) + sampled.cpu().write_audio_to_tb( + f"sampled/{i}", + self.writer, + step=self.state.epoch, + plot_fn=None, + ) + + + def save_imputation(self, z: torch.Tensor): + n_prefix = int(z.shape[-1] * 0.25) + n_suffix = int(z.shape[-1] * 0.25) + + vn = accel.unwrap(model) + + mask = pmask.inpaint(z, n_prefix, n_suffix) + mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) + z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) + + imputed_noisy = vn.to_signal(z_mask, codec) + imputed_true = vn.to_signal(z, codec) + + imputed = [] + for i in range(len(z)): + imputed.append( + vn.generate( + codec=codec, + time_steps=z.shape[-1], + start_tokens=z[i][None, ...], + mask=mask[i][None, ...], + ) + ) + imputed = AudioSignal.batch(imputed) + + for i in range(len(val_idx)): + imputed_noisy[i].cpu().write_audio_to_tb( + f"imputed_noisy/{i}", + self.writer, + step=self.state.epoch, + plot_fn=None, + ) + imputed[i].cpu().write_audio_to_tb( + f"imputed/{i}", + self.writer, + step=self.state.epoch, + plot_fn=None, + ) + imputed_true[i].cpu().write_audio_to_tb( + f"imputed_true/{i}", + self.writer, + step=self.state.epoch, + plot_fn=None, + ) + + @torch.no_grad() + def save_samples(self): + model.eval() + codec.eval() + vn = accel.unwrap(model) + + batch = [val_data[i] for i in val_idx] + batch = at.util.prepare_batch(val_data.collate(batch), accel.device) + + signal = apply_transform(val_data.transform, batch) + + z = codec.encode(signal.samples, signal.sample_rate)["codes"] + z = z[:, : vn.n_codebooks, :] + + r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device) + + + mask = pmask.random(z, r) + mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) + z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) + + z_mask_latent = vn.embedding.from_codes(z_mask, codec) + + z_hat = model(z_mask_latent, r) + + z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1) + z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks) + z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1) + + generated = vn.to_signal(z_pred, codec) + reconstructed = vn.to_signal(z, codec) + masked = vn.to_signal(z_mask.squeeze(1), codec) + + for i in range(generated.batch_size): + audio_dict = { + "original": signal[i], + "masked": masked[i], + "generated": generated[i], + "reconstructed": reconstructed[i], + } + for k, v in audio_dict.items(): + v.cpu().write_audio_to_tb( + f"samples/_{i}.r={r[i]:0.2f}/{k}", + self.writer, + step=self.state.epoch, + plot_fn=None, + ) + + self.save_sampled(z) + self.save_imputation(z) + + trainer = Trainer(writer=writer, quiet=quiet) + + if trainer_state["state_dict"] is not None: + trainer.load_state_dict(trainer_state["state_dict"]) + if hasattr(train_dataloader.sampler, "set_epoch"): + train_dataloader.sampler.set_epoch(trainer.trainer.state.epoch) + + trainer.run( + train_dataloader, + val_dataloader, + num_epochs=max_epochs, + epoch_length=epoch_length, + detect_anomaly=detect_anomaly, + ) if __name__ == "__main__": @@ -674,6 +618,4 @@ if __name__ == "__main__": args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0 with argbind.scope(args): with Accelerator() as accel: - if accel.local_rank != 0: - sys.tracebacklimit = 0 train(args, accel) diff --git a/scripts/utils/augment.py b/scripts/utils/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..e971bf8404fb328d556a56f1e997a211e355dbaa --- /dev/null +++ b/scripts/utils/augment.py @@ -0,0 +1,53 @@ +from pathlib import Path + +import audiotools as at +from audiotools import AudioSignal + +import argbind +import tqdm + + +from pedalboard import ( + Compressor, Gain, Chorus, LadderFilter, Phaser, Convolution, Reverb, Pedalboard +) +from pedalboard.io import AudioFile + +# Read in a whole file, resampling to our desired sample rate: +samplerate = 44100.0 +with AudioFile('guitar-input.wav').resampled_to(samplerate) as f: + audio = f.read(f.frames) + +# Make a pretty interesting sounding guitar pedalboard: +board = Pedalboard([ + Compressor(threshold_db=-50, ratio=25), + Gain(gain_db=30), + Chorus(), + LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900), + Phaser(), + Convolution("./guitar_amp.wav", 1.0), + Reverb(room_size=0.25), +]) + + +@argbind.bind(without_prefix=True) +def augment( + audio_folder: Path, + dest_folder: Path, + n_augmentations: int = 10, +): + """ + Augment a folder of audio files by applying audiotools and pedalboard transforms. + + The dest foler will contain a folder for each of the clean dataset's files. + Under each of these folders, there will be a clean file and many augmented files. + """ + + audio_files = at.util.find_audio(audio_folder) + + for audio_file in tqdm.tqdm(audio_files): + subtree = dest_folder / audio_file.relative_to(audio_folder).parent + subdir = subtree / audio_file.stem + subdir.mkdir(parents=True, exist_ok=True) + + # apply pedalboard transforms + for i in range(n_augmentations): diff --git a/scripts/utils/data/augment.py b/scripts/utils/data/augment.py deleted file mode 100644 index 679a4e4979658df843ee61a8698cc6d521fe24a8..0000000000000000000000000000000000000000 --- a/scripts/utils/data/augment.py +++ /dev/null @@ -1,67 +0,0 @@ -from pathlib import Path - -import audiotools as at -from audiotools import AudioSignal - -import argbind -import tqdm -import torch - - -from torch_pitch_shift import pitch_shift, get_fast_shifts -from torch_time_stretch import time_stretch, get_fast_stretches - -from audiotools.core.util import sample_from_dist - - -@argbind.bind(without_prefix=True) -def augment( - audio_folder: Path = None, - dest_folder: Path = None, - n_augmentations: int = 10, -): - """ - Augment a folder of audio files by applying audiotools and pedalboard transforms. - - The dest foler will contain a folder for each of the clean dataset's files. - Under each of these folders, there will be a clean file and many augmented files. - """ - assert audio_folder is not None - assert dest_folder is not None - audio_files = at.util.find_audio(audio_folder) - - for audio_file in tqdm.tqdm(audio_files): - subtree = dest_folder / audio_file.relative_to(audio_folder).parent - subdir = subtree / audio_file.stem - subdir.mkdir(parents=True, exist_ok=True) - - src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu") - - - for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))): - # apply pedalboard transforms - for j in range(n_augmentations): - # pitch shift between -7 and 7 semitones - import random - dst = chunk.clone() - dst.samples = pitch_shift( - dst.samples, - shift=random.choice(get_fast_shifts(src.sample_rate, - condition=lambda x: x >= 0.25 and x <= 1.0)), - sample_rate=src.sample_rate - ) - dst.samples = time_stretch( - dst.samples, - stretch=random.choice(get_fast_stretches(src.sample_rate, - condition=lambda x: x >= 0.667 and x <= 1.5, )), - sample_rate=src.sample_rate, - ) - - dst.cpu().write(subdir / f"{i}-{j}.wav") - - -if __name__ == "__main__": - args = argbind.parse_args() - - with argbind.scope(args): - augment() \ No newline at end of file diff --git a/scripts/utils/gtzan_embeddings.py b/scripts/utils/gtzan_embeddings.py deleted file mode 100644 index 78a6e318fbba98355fb48aa6ea1c74b0b83ff287..0000000000000000000000000000000000000000 --- a/scripts/utils/gtzan_embeddings.py +++ /dev/null @@ -1,263 +0,0 @@ -""" -TODO: train a linear probe -usage: - python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_gtzan /path/to/gtzan/genres_original --output_dir /path/to/output -""" -from pathlib import Path -from typing import List - -import audiotools as at -from audiotools import AudioSignal -import argbind -import torch -import numpy as np -import zipfile -import json - -from vampnet.interface import Interface -import tqdm - -# bind the Interface to argbind -Interface = argbind.bind(Interface) - -DEBUG = False - -def smart_plotly_export(fig, save_path): - img_format = save_path.split('.')[-1] - if img_format == 'html': - fig.write_html(save_path) - elif img_format == 'bytes': - return fig.to_image(format='png') - #TODO: come back and make this prettier - elif img_format == 'numpy': - import io - from PIL import Image - - def plotly_fig2array(fig): - #convert Plotly fig to an array - fig_bytes = fig.to_image(format="png", width=1200, height=700) - buf = io.BytesIO(fig_bytes) - img = Image.open(buf) - return np.asarray(img) - - return plotly_fig2array(fig) - elif img_format == 'jpeg' or 'png' or 'webp': - fig.write_image(save_path) - else: - raise ValueError("invalid image format") - -def dim_reduce(emb, labels, save_path, n_components=3, method='tsne', title=''): - """ - dimensionality reduction for visualization! - saves an html plotly figure to save_path - parameters: - emb (np.ndarray): the samples to be reduces with shape (samples, features) - labels (list): list of labels for embedding - save_path (str): path where u wanna save ur figure - method (str): umap, tsne, or pca - title (str): title for ur figure - returns: - proj (np.ndarray): projection vector with shape (samples, dimensions) - """ - import pandas as pd - import plotly.express as px - if method == 'umap': - reducer = umap.UMAP(n_components=n_components) - elif method == 'tsne': - from sklearn.manifold import TSNE - reducer = TSNE(n_components=n_components) - elif method == 'pca': - from sklearn.decomposition import PCA - reducer = PCA(n_components=n_components) - else: - raise ValueError - - proj = reducer.fit_transform(emb) - - if n_components == 2: - df = pd.DataFrame(dict( - x=proj[:, 0], - y=proj[:, 1], - instrument=labels - )) - fig = px.scatter(df, x='x', y='y', color='instrument', - title=title+f"_{method}") - - elif n_components == 3: - df = pd.DataFrame(dict( - x=proj[:, 0], - y=proj[:, 1], - z=proj[:, 2], - instrument=labels - )) - fig = px.scatter_3d(df, x='x', y='y', z='z', - color='instrument', - title=title) - else: - raise ValueError("cant plot more than 3 components") - - fig.update_traces(marker=dict(size=6, - line=dict(width=1, - color='DarkSlateGrey')), - selector=dict(mode='markers')) - - return smart_plotly_export(fig, save_path) - - - -# per JukeMIR, we want the emebddings from the middle layer? -def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10): - with torch.inference_mode(): - # preprocess the signal - sig = interface.preprocess(sig) - - # get the coarse vampnet model - vampnet = interface.coarse - - # get the tokens - z = interface.encode(sig)[:, :vampnet.n_codebooks, :] - z_latents = vampnet.embedding.from_codes(z, interface.codec) - - # do a forward pass through the model, get the embeddings - _z, embeddings = vampnet(z_latents, return_activations=True) - # print(f"got embeddings with shape {embeddings.shape}") - # [layer, batch, time, n_dims] - # [20, 1, 600ish, 768] - - - # squeeze batch dim (1 bc layer should be dim 0) - assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}" - embeddings = embeddings.squeeze(1) - - num_layers = embeddings.shape[0] - assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers" - - # do meanpooling over the time dimension - embeddings = embeddings.mean(dim=-2) - # [20, 768] - - # return the embeddings - return embeddings - -from dataclasses import dataclass, fields -@dataclass -class Embedding: - genre: str - filename: str - embedding: np.ndarray - - def save(self, path): - """Save the Embedding object to a given path as a zip file.""" - with zipfile.ZipFile(path, 'w') as archive: - - # Save numpy array - with archive.open('embedding.npy', 'w') as f: - np.save(f, self.embedding) - - # Save non-numpy data as json - non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'} - with archive.open('data.json', 'w') as f: - f.write(json.dumps(non_numpy_data).encode('utf-8')) - - @classmethod - def load(cls, path): - """Load the Embedding object from a given zip path.""" - with zipfile.ZipFile(path, 'r') as archive: - - # Load numpy array - with archive.open('embedding.npy') as f: - embedding = np.load(f) - - # Load non-numpy data from json - with archive.open('data.json') as f: - data = json.loads(f.read().decode('utf-8')) - - return cls(embedding=embedding, **data) - - -@argbind.bind(without_prefix=True) -def main( - path_to_gtzan: str = None, - cache_dir: str = "./.gtzan_emb_cache", - output_dir: str = "./gtzan_vampnet_embeddings", - layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19] -): - path_to_gtzan = Path(path_to_gtzan) - assert path_to_gtzan.exists(), f"{path_to_gtzan} does not exist" - - cache_dir = Path(cache_dir) - output_dir = Path(output_dir) - output_dir.mkdir(exist_ok=True, parents=True) - - # load our interface - # argbind will automatically load the default config, - interface = Interface() - - # gtzan should have a folder for each genre, so let's get the list of genres - genres = [Path(x).name for x in path_to_gtzan.iterdir() if x.is_dir()] - print(f"Found {len(genres)} genres") - print(f"genres: {genres}") - - # collect audio files, genres, and embeddings - data = [] - for genre in genres: - audio_files = list(at.util.find_audio(path_to_gtzan / genre)) - print(f"Found {len(audio_files)} audio files for genre {genre}") - - for audio_file in tqdm.tqdm(audio_files, desc=f"embedding genre {genre}"): - # check if we have a cached embedding for this file - cached_path = (cache_dir / f"{genre}_{audio_file.stem}.emb") - if cached_path.exists(): - # if so, load it - if DEBUG: - print(f"loading cached embedding for {cached_path.stem}") - embedding = Embedding.load(cached_path) - data.append(embedding) - else: - try: - sig = AudioSignal(audio_file) - except Exception as e: - print(f"failed to load {audio_file.name} with error {e}") - print(f"skipping {audio_file.name}") - continue - - # gets the embedding - emb = vampnet_embed(sig, interface).cpu().numpy() - - # create an embedding we can save/load - embedding = Embedding( - genre=genre, - filename=audio_file.name, - embedding=emb - ) - - # cache the embeddings - cached_path.parent.mkdir(exist_ok=True, parents=True) - embedding.save(cached_path) - - # now, let's do a dim reduction on the embeddings - # and visualize them. - - # collect a list of embeddings and labels - embeddings = [d.embedding for d in data] - labels = [d.genre for d in data] - - # convert the embeddings to a numpy array - embeddings = np.stack(embeddings) - - # do dimensionality reduction for each layer we're given - for layer in tqdm.tqdm(layers, desc="dim reduction"): - dim_reduce( - embeddings[:, layer, :], labels, - save_path=str(output_dir / f'vampnet-gtzan-layer={layer}.html'), - n_components=2, method='tsne', - title=f'vampnet-gtzan-layer={layer}' - ) - - - - -if __name__ == "__main__": - args = argbind.parse_args() - with argbind.scope(args): - main() \ No newline at end of file diff --git a/scripts/utils/data/maestro-reorg.py b/scripts/utils/maestro-reorg.py similarity index 100% rename from scripts/utils/data/maestro-reorg.py rename to scripts/utils/maestro-reorg.py diff --git a/scripts/utils/remove_quiet_files.py b/scripts/utils/remove_quiet_files.py deleted file mode 100644 index f557f1574da562203cbdd5334717a699e89196bb..0000000000000000000000000000000000000000 --- a/scripts/utils/remove_quiet_files.py +++ /dev/null @@ -1,29 +0,0 @@ -# removes files with loudness below 24db - -from pathlib import Path -import shutil -import audiotools as at -import argbind - -@argbind.bind(without_prefix=True) -def remove_quiet_files( - src_dir: Path = None, - dest_dir: Path = None, - min_loudness: float = -30, -): - # copy src to dest - dest_dir.mkdir(parents=True, exist_ok=True) - shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True) - - audio_files = at.util.find_audio(dest_dir) - for audio_file in audio_files: - sig = at.AudioSignal(audio_file) - if sig.loudness() < min_loudness: - audio_file.unlink() - print(f"removed {audio_file}") - -if __name__ == "__main__": - args = argbind.parse_args() - - with argbind.scope(args): - remove_quiet_files() \ No newline at end of file diff --git a/scripts/utils/split_long_audio_file.py b/scripts/utils/split_long_audio_file.py deleted file mode 100644 index 8648b2612ebd4f1344357222dff5b430525091c5..0000000000000000000000000000000000000000 --- a/scripts/utils/split_long_audio_file.py +++ /dev/null @@ -1,34 +0,0 @@ -from pathlib import Path -import argbind - -import audiotools as at -import tqdm - - -@argbind.bind(without_prefix=True) -def split_long_audio_file( - file: str = None, - max_chunk_size_s: int = 60*10 -): - file = Path(file) - output_dir = file.parent / file.stem - output_dir.mkdir() - - sig = at.AudioSignal(file) - - # split into chunks - for i, sig in tqdm.tqdm(enumerate(sig.windows( - window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2, - preprocess=True)) - ): - sig.write(output_dir / f"{i}.wav") - - print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}") - - return output_dir - -if __name__ == "__main__": - args = argbind.parse_args() - - with argbind.scope(args): - split_long_audio_file() \ No newline at end of file diff --git a/scripts/utils/xeno-canto-dl.py b/scripts/utils/xeno-canto-dl.py deleted file mode 100644 index 89acc822d291bd0a145b656fc62be5209a6f6cfc..0000000000000000000000000000000000000000 --- a/scripts/utils/xeno-canto-dl.py +++ /dev/null @@ -1,234 +0,0 @@ -from xenopy import Query - - -SPECIES = [ - "American Robin", - "Northern Cardinal", - "Mourning Dove", - "American Crow", - "Baltimore Oriole", - "Blue Jay", - "Eastern Bluebird", - "House Finch", - "American Goldfinch", - "House Sparrow", - "Song Sparrow", - "Tufted Titmouse", - "White-breasted Nuthatch", - "European Starling", - "American Redstart", - "Red-winged Blackbird", - "Brown-headed Cowbird", - "Common Grackle", - "Boat-tailed Grackle", - "Common Yellowthroat", - "Northern Mockingbird", - "Carolina Wren", - "Eastern Meadowlark", - "Chipping Sparrow", - "Tree Swallow", - "Barn Swallow", - "Cliff Swallow", - "Pine Siskin", - "Indigo Bunting", - "Eastern Towhee", - "Carolina Chickadee", - "Great Crested Flycatcher", - "Eastern Wood-Pewee", - "Ovenbird", - "Northern Flicker", - "Red-eyed Vireo", - "American Woodcock", - "Eastern Phoebe", - "Downy Woodpecker", - "Scarlet Tanager", - "Yellow Warbler", - "White-eyed Vireo", - "Common Loon", - "White-throated Sparrow", - "Yellow-throated Vireo", - "Great Blue Heron", - "Belted Kingfisher", - "Pied-billed Grebe", - "Wild Turkey", - "Wood Thrush", - "Rose-breasted Grosbeak", - "Field Sparrow", - "Hooded Warbler", - "Northern Parula", - "Chestnut-sided Warbler", - "Blue-winged Warbler", - "Red-bellied Woodpecker", - "Yellow-billed Cuckoo", - "Gray Catbird", - "Northern Saw-whet Owl", - "Osprey", - "Common Nighthawk", - "Broad-winged Hawk", - "Black-throated Green Warbler", - "Great Horned Owl", - "Common Raven", - "Barred Owl", - "Canada Warbler", - "Magnolia Warbler", - "Black-and-white Warbler", - "Eastern Kingbird", - "Swainson's Thrush", - "Worm-eating Warbler", - "Prairie Warbler", - "Baltimore Oriole", - "Black-throated Blue Warbler", - "Louisiana Waterthrush", - "Blackburnian Warbler", - "Black-capped Chickadee", - "Cerulean Warbler", - "Red-shouldered Hawk", - "Cooper's Hawk", - "Yellow-throated Warbler", - "Blue-headed Vireo", - "Blackpoll Warbler", - "Ruffed Grouse", - "Kentucky Warbler", - "Hermit Thrush", - "Cedar Waxwing", - "Eastern Screech-Owl", - "Northern Goshawk", - "Green Heron", - "Red-tailed Hawk", - "Black Vulture", - "Hairy Woodpecker", - "Golden-crowned Kinglet", - "Ruby-crowned Kinglet", - "Bicknell's Thrush", - "Blue-gray Gnatcatcher", - "Veery", - "Pileated Woodpecker", - "Purple Finch", - "White-crowned Sparrow", - "Snow Bunting", - "Pine Grosbeak", - "American Tree Sparrow", - "Dark-eyed Junco", - "Snowy Owl", - "White-winged Crossbill", - "Red Crossbill", - "Common Redpoll", - "Northern Shrike", - "Northern Harrier", - "Rough-legged Hawk", - "Long-eared Owl", - "Evening Grosbeak", - "Northern Pintail", - "American Black Duck", - "Mallard", - "Canvasback", - "Redhead", - "Ring-necked Duck", - "Greater Scaup", - "Lesser Scaup", - "Bufflehead", - "Common Goldeneye", - "Hooded Merganser", - "Common Merganser", - "Red-breasted Merganser", - "Ruddy Duck", - "Wood Duck", - "Gadwall", - "American Wigeon", - "Northern Shoveler", - "Green-winged Teal", - "Blue-winged Teal", - "Cinnamon Teal", - "Ringed Teal", - "Cape Teal", - "Northern Fulmar", - "Yellow-billed Loon", - "Red-throated Loon", - "Arctic Loon", - "Pacific Loon", - "Horned Grebe", - "Red-necked Grebe", - "Eared Grebe", - "Western Grebe", - "Clark's Grebe", - "Double-crested Cormorant", - "Pelagic Cormorant", - "Great Cormorant", - "American White Pelican", - "Brown Pelican", - "Brandt's Cormorant", - "Least Bittern", - "Great Egret", - "Snowy Egret", - "Little Blue Heron", - "Tricolored Heron", - "Reddish Egret", - "Black-crowned Night-Heron", - "Yellow-crowned Night-Heron", - "White Ibis", - "Glossy Ibis", - "Roseate Spoonbill", - "Wood Stork", - "Black-bellied Whistling-Duck", - "Fulvous Whistling-Duck", - "Greater White-fronted Goose", - "Snow Goose", - "Ross's Goose", - "Canada Goose", - "Brant", - "Mute Swan", - "Tundra Swan", - "Whooper Swan", - "Sandhill Crane", - "Black-necked Stilt", - "American Avocet", - "Northern Jacana", - "Greater Yellowlegs", - "Lesser Yellowlegs", - "Willet", - "Spotted Sandpiper", - "Upland Sandpiper", - "Whimbrel", - "Long-billed Curlew", - "Marbled Godwit", - "Ruddy Turnstone", - "Red Knot", - "Sanderling", - "Semipalmated Sandpiper", - "Western Sandpiper", - "Least Sandpiper", - "White-rumped Sandpiper", - "Baird's Sandpiper", - "Pectoral Sandpiper", - "Dunlin", - "Buff-breasted Sandpiper", - "Short-billed Dowitcher", - "Long-billed Dowitcher", - "Common Snipe", - "American Woodcock", - "Wilson's Phalarope", - "Red-necked Phalarope", - "Red Phalarope" -] - -from pathlib import Path - -def remove_spaces(s): - return s.replace(" ", "") - -for species in SPECIES: - if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists(): - continue - try: - q = Query( - name=species, q="A", length="10-30", - ) - - # retrieve metadata - metafiles = q.retrieve_meta(verbose=True) - # retrieve recordings - q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/") - - except: - print("Failed to download " + species) - continue \ No newline at end of file diff --git a/setup.py b/setup.py index 0b0097fa81ef9ba6774a08417ca29317b3374d26..2964e0f810f32dab3abc433912a2de128c081761 100644 --- a/setup.py +++ b/setup.py @@ -28,13 +28,12 @@ setup( install_requires=[ "torch", "argbind>=0.3.2", - "numpy==1.23", + "numpy==1.22", "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat", "lac @ git+https://github.com/hugofloresgarcia/lac.git", - "descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2", + "audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git", "gradio", + "tensorboardX", "loralib", - "torch_pitch_shift", - "madmom", ], ) diff --git a/vampnet/beats.py b/vampnet/beats.py index 2b03a4e3df705a059cd34e6e01a72752fc4d8a98..317496ef83d7b764fbbc51068c13170ce0c17e13 100644 --- a/vampnet/beats.py +++ b/vampnet/beats.py @@ -9,7 +9,6 @@ from typing import Tuple from typing import Union import librosa -import torch import numpy as np from audiotools import AudioSignal @@ -204,7 +203,7 @@ class WaveBeat(BeatTracker): def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"): from wavebeat.dstcn import dsTCNModel - model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device)) + model = dsTCNModel.load_from_checkpoint(ckpt_path) model.eval() self.device = device diff --git a/vampnet/interface.py b/vampnet/interface.py index 1099de3863c42827d10ff093a45720a428666a77..0a6e39182c9d91c1b76bcb18476f9c018a247543 100644 --- a/vampnet/interface.py +++ b/vampnet/interface.py @@ -22,7 +22,6 @@ def signal_concat( return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate) - def _load_model( ckpt: str, lora_ckpt: str = None, @@ -120,16 +119,17 @@ class Interface(torch.nn.Module): if coarse_ckpt is not None: self.coarse.to("cpu") state_dict = torch.load(coarse_ckpt, map_location="cpu") - print(f"loading coarse from {coarse_ckpt}") + self.coarse.load_state_dict(state_dict, strict=False) self.coarse.to(self.device) if c2f_ckpt is not None: self.c2f.to("cpu") state_dict = torch.load(c2f_ckpt, map_location="cpu") - print(f"loading c2f from {c2f_ckpt}") + self.c2f.load_state_dict(state_dict, strict=False) self.c2f.to(self.device) + def s2t(self, seconds: float): """seconds to tokens""" if isinstance(seconds, np.ndarray): @@ -193,8 +193,8 @@ class Interface(torch.nn.Module): def make_beat_mask(self, signal: AudioSignal, - before_beat_s: float = 0.0, - after_beat_s: float = 0.02, + before_beat_s: float = 0.1, + after_beat_s: float = 0.1, mask_downbeats: bool = True, mask_upbeats: bool = True, downbeat_downsample_factor: int = None, @@ -275,44 +275,34 @@ class Interface(torch.nn.Module): def coarse_to_fine( self, - z: torch.Tensor, - mask: torch.Tensor = None, + coarse_z: torch.Tensor, **kwargs ): assert self.c2f is not None, "No coarse2fine model loaded" - length = z.shape[-1] + length = coarse_z.shape[-1] chunk_len = self.s2t(self.c2f.chunk_size_s) - n_chunks = math.ceil(z.shape[-1] / chunk_len) + n_chunks = math.ceil(coarse_z.shape[-1] / chunk_len) # zero pad to chunk_len if length % chunk_len != 0: pad_len = chunk_len - (length % chunk_len) - z = torch.nn.functional.pad(z, (0, pad_len)) - mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None + coarse_z = torch.nn.functional.pad(coarse_z, (0, pad_len)) - n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1] + n_codebooks_to_append = self.c2f.n_codebooks - coarse_z.shape[1] if n_codebooks_to_append > 0: - z = torch.cat([ - z, - torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device) + coarse_z = torch.cat([ + coarse_z, + torch.zeros(coarse_z.shape[0], n_codebooks_to_append, coarse_z.shape[-1]).long().to(self.device) ], dim=1) - # set the mask to 0 for all conditioning codebooks - if mask is not None: - mask = mask.clone() - mask[:, :self.c2f.n_conditioning_codebooks, :] = 0 - fine_z = [] for i in range(n_chunks): - chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len] - mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None - + chunk = coarse_z[:, :, i * chunk_len : (i + 1) * chunk_len] chunk = self.c2f.generate( codec=self.codec, time_steps=chunk_len, start_tokens=chunk, return_signal=False, - mask=mask_chunk, **kwargs ) fine_z.append(chunk) @@ -347,12 +337,6 @@ class Interface(torch.nn.Module): **kwargs ) - # add the fine codes back in - c_vamp = torch.cat( - [c_vamp, z[:, self.coarse.n_codebooks :, :]], - dim=1 - ) - if return_mask: return c_vamp, cz_masked @@ -368,18 +352,17 @@ if __name__ == "__main__": at.util.seed(42) interface = Interface( - coarse_ckpt="./models/vampnet/coarse.pth", - coarse2fine_ckpt="./models/vampnet/c2f.pth", - codec_ckpt="./models/vampnet/codec.pth", + coarse_ckpt="./models/spotdl/coarse.pth", + coarse2fine_ckpt="./models/spotdl/c2f.pth", + codec_ckpt="./models/spotdl/codec.pth", device="cuda", wavebeat_ckpt="./models/wavebeat.pth" ) - sig = at.AudioSignal('assets/example.wav') + sig = at.AudioSignal.zeros(duration=10, sample_rate=44100) z = interface.encode(sig) - breakpoint() # mask = linear_random(z, 1.0) # mask = mask_and( @@ -391,14 +374,13 @@ if __name__ == "__main__": # ) # ) - # mask = interface.make_beat_mask( - # sig, 0.0, 0.075 - # ) + mask = interface.make_beat_mask( + sig, 0.0, 0.075 + ) # mask = dropout(mask, 0.0) # mask = codebook_unmask(mask, 0) - - mask = inpaint(z, n_prefix=100, n_suffix=100) + breakpoint() zv, mask_z = interface.coarse_vamp( z, mask=mask, @@ -407,16 +389,16 @@ if __name__ == "__main__": return_mask=True, gen_fn=interface.coarse.generate ) - use_coarse2fine = True if use_coarse2fine: - zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask) - breakpoint() + zv = interface.coarse_to_fine(zv, temperature=0.8) mask = interface.to_signal(mask_z).cpu() sig = interface.to_signal(zv).cpu() print("done") + sig.write("output3.wav") + mask.write("mask.wav") \ No newline at end of file diff --git a/vampnet/mask.py b/vampnet/mask.py index 3edb7b0d782b77abe6534f5eb9c2a920aa6c02c7..1302fd101925fca4403c9532561a0a8dc7b55265 100644 --- a/vampnet/mask.py +++ b/vampnet/mask.py @@ -191,47 +191,29 @@ def onset_mask( width: int = 1 ): import librosa - import madmom - from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor - import tempfile - import numpy as np - - with tempfile.NamedTemporaryFile(suffix='.wav') as f: - sig = sig.clone() - sig.write(f.name) - - proc = RNNOnsetProcessor(online=False) - onsetproc = OnsetPeakPickingProcessor(threshold=0.3, - fps=sig.sample_rate/interface.codec.hop_length) - - act = proc(f.name) - onset_times = onsetproc(act) - - # convert to indices for z array - onset_indices = librosa.time_to_frames(onset_times, sr=sig.sample_rate, hop_length=interface.codec.hop_length) - - if onset_indices.shape[0] == 0: - mask = empty_mask(z) - print(f"no onsets found, returning empty mask") - else: - torch.set_printoptions(threshold=1000) - print("onset indices: ", onset_indices) - print("onset times: ", onset_times) - - # create a mask, set onset - mask = torch.ones_like(z) - n_timesteps = z.shape[-1] - - for onset_index in onset_indices: - onset_index = min(onset_index, n_timesteps - 1) - onset_index = max(onset_index, 0) - mask[:, :, onset_index - width:onset_index + width] = 0.0 - - print(mask) + + onset_indices = librosa.onset.onset_detect( + y=sig.clone().to_mono().samples.cpu().numpy()[0, 0], + sr=sig.sample_rate, + hop_length=interface.codec.hop_length, + backtrack=True, + ) + + # create a mask, set onset + mask = torch.ones_like(z) + n_timesteps = z.shape[-1] + + for onset_index in onset_indices: + onset_index = min(onset_index, n_timesteps - 1) + onset_index = max(onset_index, 0) + mask[:, :, onset_index - width:onset_index + width] = 0.0 + + print(mask) return mask if __name__ == "__main__": - pass + torch.set_printoptions(threshold=10000) + diff --git a/vampnet/modules/__init__.py b/vampnet/modules/__init__.py index 3f4c8c226e42d022c60b620e8f21ccaf4e6a57bd..3481f32e0287faa9e79ba219f17d18529a4b57ac 100644 --- a/vampnet/modules/__init__.py +++ b/vampnet/modules/__init__.py @@ -2,5 +2,3 @@ import audiotools audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"] audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"] - -from .transformer import VampNet \ No newline at end of file diff --git a/vampnet/modules/transformer.py b/vampnet/modules/transformer.py index 0858644d363d50c9395b2fbf5177f7ad5659114b..debb483ff3e0d3ced5a81581a9bfe89e6a185863 100644 --- a/vampnet/modules/transformer.py +++ b/vampnet/modules/transformer.py @@ -410,9 +410,7 @@ class TransformerStack(nn.Module): def subsequent_mask(self, size): return torch.ones(1, size, size).tril().bool() - def forward(self, x, x_mask, cond=None, src=None, src_mask=None, - return_activations: bool = False - ): + def forward(self, x, x_mask, cond=None, src=None, src_mask=None): """Computes a full transformer stack Parameters ---------- @@ -439,8 +437,6 @@ class TransformerStack(nn.Module): encoder_decoder_position_bias = None # Compute transformer layers - if return_activations: - activations = [] for layer in self.layers: x, position_bias, encoder_decoder_position_bias = layer( x=x, @@ -451,15 +447,8 @@ class TransformerStack(nn.Module): position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, ) - if return_activations: - activations.append(x.detach()) - - out = self.norm(x) if self.norm is not None else x - if return_activations: - return out, torch.stack(activations) - else: - return out + return self.norm(x) if self.norm is not None else x class VampNet(at.ml.BaseModel): @@ -467,7 +456,7 @@ class VampNet(at.ml.BaseModel): self, n_heads: int = 20, n_layers: int = 16, - r_cond_dim: int = 0, + r_cond_dim: int = 64, n_codebooks: int = 9, n_conditioning_codebooks: int = 0, latent_dim: int = 8, @@ -478,7 +467,6 @@ class VampNet(at.ml.BaseModel): dropout: float = 0.1 ): super().__init__() - assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}" self.n_heads = n_heads self.n_layers = n_layers self.r_cond_dim = r_cond_dim @@ -525,25 +513,21 @@ class VampNet(at.ml.BaseModel): ), ) - def forward(self, x, return_activations: bool = False): + def forward(self, x, cond): x = self.embedding(x) x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1) - x = rearrange(x, "b d n -> b n d") - out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations) - if return_activations: - out, activations = out + cond = self.r_embed(cond) + x = rearrange(x, "b d n -> b n d") + out = self.transformer(x=x, x_mask=x_mask, cond=cond) out = rearrange(out, "b n d -> b d n") - out = self.classifier(out, None) # no cond here! + out = self.classifier(out, cond) out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks) - if return_activations: - return out, activations - else: - return out + return out def r_embed(self, r, max_positions=10000): if self.r_cond_dim > 0: @@ -594,23 +578,22 @@ class VampNet(at.ml.BaseModel): self, codec, time_steps: int = 300, - sampling_steps: int = 36, + sampling_steps: int = 24, start_tokens: Optional[torch.Tensor] = None, - sampling_temperature: float = 1.0, mask: Optional[torch.Tensor] = None, - mask_temperature: float = 10.5, + temperature: Union[float, Tuple[float, float]] = 2.5, typical_filtering=False, typical_mass=0.2, typical_min_tokens=1, - top_p=None, return_signal=True, - seed: int = None, - sample_cutoff: float = 1.0, ): - if seed is not None: - at.util.seed(seed) logging.debug(f"beginning generation with {sampling_steps} steps") + ##################### + # resolve temperature # + ##################### + assert isinstance(temperature, float) + logging.debug(f"temperature: {temperature}") ##################### @@ -662,6 +645,9 @@ class VampNet(at.ml.BaseModel): for i in range(sampling_steps): logging.debug(f"step {i} of {sampling_steps}") + # our current temperature + logging.debug(f"temperature: {temperature}") + # our current schedule step r = scalar_to_batch_tensor( (i + 1) / sampling_steps, @@ -676,24 +662,41 @@ class VampNet(at.ml.BaseModel): # infer from latents # NOTE: this collapses the codebook dimension into the sequence dimension - logits = self.forward(latents) # b, prob, seq + logits = self.forward(latents, r) # b, prob, seq logits = logits.permute(0, 2, 1) # b, seq, prob - b = logits.shape[0] + if typical_filtering: + typical_filter(logits, + typical_mass=typical_mass, + typical_min_tokens=typical_min_tokens + ) + logging.debug(f"permuted logits with shape: {logits.shape}") - sampled_z, selected_probs = sample_from_logits( - logits, sample=( - (i / sampling_steps) <= sample_cutoff - ), - temperature=sampling_temperature, - typical_filtering=typical_filtering, typical_mass=typical_mass, - typical_min_tokens=typical_min_tokens, - top_k=None, top_p=top_p, return_probs=True, - ) + # logits2probs + probs = torch.softmax(logits, dim=-1) + logging.debug(f"computed probs with shape: {probs.shape}") + + + # sample from logits with multinomial sampling + b = probs.shape[0] + probs = rearrange(probs, "b seq prob -> (b seq) prob") + + sampled_z = torch.multinomial(probs, 1).squeeze(-1) + + sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b) + probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b) logging.debug(f"sampled z with shape: {sampled_z.shape}") + # get the confidences: which tokens did we sample? + selected_probs = ( + torch.take_along_dim( + probs, sampled_z.long().unsqueeze(-1), + dim=-1 + ).squeeze(-1) + ) + # flatten z_masked and mask, so we can deal with the sampling logic # we'll unflatten them at the end of the loop for the next forward pass # remove conditioning codebooks, we'll add them back at the end @@ -730,7 +733,7 @@ class VampNet(at.ml.BaseModel): # get our new mask mask = mask_by_random_topk( - num_to_mask, selected_probs, mask_temperature * (1-r) + num_to_mask, selected_probs, temperature * (1-r) ) # update the mask @@ -763,97 +766,8 @@ class VampNet(at.ml.BaseModel): else: return sampled_z -def sample_from_logits( - logits, - sample: bool = True, - temperature: float = 1.0, - top_k: int = None, - top_p: float = None, - typical_filtering: bool = False, - typical_mass: float = 0.2, - typical_min_tokens: int = 1, - return_probs: bool = False - ): - """Convenience function to sample from a categorial distribution with input as - unnormalized logits. - - Parameters - ---------- - logits : Tensor[..., vocab_size] - config: SamplingConfig - The set of hyperparameters to be used for sampling - sample : bool, optional - Whether to perform multinomial sampling, by default True - temperature : float, optional - Scaling parameter when multinomial samping, by default 1.0 - top_k : int, optional - Restricts sampling to only `top_k` values acc. to probability, - by default None - top_p : float, optional - Restricts sampling to only those values with cumulative - probability = `top_p`, by default None - - Returns - ------- - Tensor[...] - Sampled tokens - """ - shp = logits.shape[:-1] - - if typical_filtering: - typical_filter(logits, - typical_mass=typical_mass, - typical_min_tokens=typical_min_tokens - ) - - # Apply top_k sampling - if top_k is not None: - v, _ = logits.topk(top_k) - logits[logits < v[..., [-1]]] = -float("inf") - - # Apply top_p (nucleus) sampling - if top_p is not None and top_p < 1.0: - v, sorted_indices = logits.sort(descending=True) - cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1) - - sorted_indices_to_remove = cumulative_probs > top_p - # Right shift indices_to_remove to keep 1st token over threshold - sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[ - ..., :-1 - ] - - # Compute indices_to_remove in unsorted array - indices_to_remove = sorted_indices_to_remove.scatter( - -1, sorted_indices, sorted_indices_to_remove - ) - - logits[indices_to_remove] = -float("inf") - - # Perform multinomial sampling after normalizing logits - probs = ( - F.softmax(logits / temperature, dim=-1) - if temperature > 0 - else logits.softmax(dim=-1) - ) - token = ( - probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp) - if sample - else logits.argmax(-1) - ) - - if return_probs: - token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1) - return token, token_probs - else: - return token - - -def mask_by_random_topk( - num_to_mask: int, - probs: torch.Tensor, - temperature: float = 1.0, - ): +def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0): """ Args: num_to_mask (int): number of tokens to mask @@ -866,8 +780,7 @@ def mask_by_random_topk( logging.debug(f"temperature: {temperature}") logging.debug("") - noise = gumbel_noise_like(probs) - confidence = torch.log(probs) + temperature * noise + confidence = torch.log(probs) + temperature * gumbel_noise_like(probs) logging.debug(f"confidence shape: {confidence.shape}") sorted_confidence, sorted_idx = confidence.sort(dim=-1) @@ -937,7 +850,7 @@ if __name__ == "__main__": z_mask_latent = torch.rand( batch_size, model.latent_dim * model.n_codebooks, seq_len ).to(device) - z_hat = model(z_mask_latent) + z_hat = model(z_mask_latent, r) pred = z_hat.argmax(dim=1) pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)