diff --git a/.gitignore b/.gitignore index 685d6448d5a197ed8a434f72d6f2279552613c91..81d14bc405b5624e69bceda90e746519ff9930a5 100644 --- a/.gitignore +++ b/.gitignore @@ -182,3 +182,4 @@ audiotools/ descript-audio-codec/ # *.pth .git-old +conf/generated/* diff --git a/README.md b/README.md index 687fb086b0db5ec747e42728d8d25be07f51e7cb..231407a1cc1ca55dc8ea22d351de801f7dd069f4 100644 --- a/README.md +++ b/README.md @@ -7,12 +7,14 @@ sdk: gradio sdk_version: 3.36.1 app_file: app.py pinned: false -duplicated_from: hugggof/vampnet --- # VampNet -This repository contains recipes for training generative music models on top of the Lyrebird Audio Codec. +This repository contains recipes for training generative music models on top of the Descript Audio Codec. + +## try `unloop` +you can try vampnet in a co-creative looper called unloop. see this link: https://github.com/hugofloresgarcia/unloop # Setting up @@ -35,7 +37,7 @@ Config files are stored in the `conf/` folder. ### Licensing for Pretrained Models: The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). -Download the pretrained models from [this link](https://zenodo.org/record/8136545). Then, extract the models to the `models/` folder. +Download the pretrained models from [this link](https://zenodo.org/record/8136629). Then, extract the models to the `models/` folder. # Usage diff --git a/app.py b/app.py index dfa3ff80e64cacf3e395fe74855f2a3f7370156f..de30215d0e4c43a5731ec014ae27412cd9ee80fa 100644 --- a/app.py +++ b/app.py @@ -124,7 +124,7 @@ def _vamp(data, return_mask=False): ) if use_coarse2fine: - zv = interface.coarse_to_fine(zv, temperature=data[temp]) + zv = interface.coarse_to_fine(zv, temperature=data[temp], mask=mask) sig = interface.to_signal(zv).cpu() print("done") @@ -407,7 +407,8 @@ with gr.Blocks() as demo: use_coarse2fine = gr.Checkbox( label="use coarse2fine", - value=True + value=True, + visible=False ) num_steps = gr.Slider( diff --git a/conf/generated-v0/berta-goldman-speech/c2f.yml b/conf/generated-v0/berta-goldman-speech/c2f.yml deleted file mode 100644 index 0f5a4cd57e7a801121d7c77a62a0e8767b7fe61c..0000000000000000000000000000000000000000 --- a/conf/generated-v0/berta-goldman-speech/c2f.yml +++ /dev/null @@ -1,15 +0,0 @@ -$include: -- conf/lora/lora.yml -AudioDataset.duration: 3.0 -AudioDataset.loudness_cutoff: -40.0 -VampNet.embedding_dim: 1280 -VampNet.n_codebooks: 14 -VampNet.n_conditioning_codebooks: 4 -VampNet.n_heads: 20 -VampNet.n_layers: 16 -fine_tune: true -save_path: ./runs/berta-goldman-speech/c2f -train/AudioLoader.sources: -- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 -val/AudioLoader.sources: -- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 diff --git a/conf/generated-v0/berta-goldman-speech/coarse.yml b/conf/generated-v0/berta-goldman-speech/coarse.yml deleted file mode 100644 index 7c1207e9cfe83bac59f76fcf21068405cd6c9551..0000000000000000000000000000000000000000 --- a/conf/generated-v0/berta-goldman-speech/coarse.yml +++ /dev/null @@ -1,8 +0,0 @@ -$include: -- conf/lora/lora.yml -fine_tune: true -save_path: ./runs/berta-goldman-speech/coarse -train/AudioLoader.sources: -- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 -val/AudioLoader.sources: -- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 diff --git a/conf/generated-v0/berta-goldman-speech/interface.yml b/conf/generated-v0/berta-goldman-speech/interface.yml deleted file mode 100644 index d1ba35ec732a0148f3bced5542e27e85575c4d4e..0000000000000000000000000000000000000000 --- a/conf/generated-v0/berta-goldman-speech/interface.yml +++ /dev/null @@ -1,5 +0,0 @@ -AudioLoader.sources: -- /media/CHONK/hugo/Berta-Caceres-2015-Goldman-Speech.mp3 -Interface.coarse2fine_ckpt: ./runs/berta-goldman-speech/c2f/best/vampnet/weights.pth -Interface.coarse_ckpt: ./runs/berta-goldman-speech/coarse/best/vampnet/weights.pth -Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated-v0/gamelan-xeno-canto/c2f.yml b/conf/generated-v0/gamelan-xeno-canto/c2f.yml deleted file mode 100644 index 9e6fec4ddc7dd0a2e02d1be66cc7f6eafa669ed1..0000000000000000000000000000000000000000 --- a/conf/generated-v0/gamelan-xeno-canto/c2f.yml +++ /dev/null @@ -1,17 +0,0 @@ -$include: -- conf/lora/lora.yml -AudioDataset.duration: 3.0 -AudioDataset.loudness_cutoff: -40.0 -VampNet.embedding_dim: 1280 -VampNet.n_codebooks: 14 -VampNet.n_conditioning_codebooks: 4 -VampNet.n_heads: 20 -VampNet.n_layers: 16 -fine_tune: true -save_path: ./runs/gamelan-xeno-canto/c2f -train/AudioLoader.sources: -- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 -- /media/CHONK/hugo/loras/xeno-canto-2 -val/AudioLoader.sources: -- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 -- /media/CHONK/hugo/loras/xeno-canto-2 diff --git a/conf/generated-v0/gamelan-xeno-canto/coarse.yml b/conf/generated-v0/gamelan-xeno-canto/coarse.yml deleted file mode 100644 index 7e8d38e18d714cb08db7ed456939737404533c3e..0000000000000000000000000000000000000000 --- a/conf/generated-v0/gamelan-xeno-canto/coarse.yml +++ /dev/null @@ -1,10 +0,0 @@ -$include: -- conf/lora/lora.yml -fine_tune: true -save_path: ./runs/gamelan-xeno-canto/coarse -train/AudioLoader.sources: -- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 -- /media/CHONK/hugo/loras/xeno-canto-2 -val/AudioLoader.sources: -- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 -- /media/CHONK/hugo/loras/xeno-canto-2 diff --git a/conf/generated-v0/gamelan-xeno-canto/interface.yml b/conf/generated-v0/gamelan-xeno-canto/interface.yml deleted file mode 100644 index e567800477816ac1cc41719744c1ba40562e35b1..0000000000000000000000000000000000000000 --- a/conf/generated-v0/gamelan-xeno-canto/interface.yml +++ /dev/null @@ -1,6 +0,0 @@ -AudioLoader.sources: -- /media/CHONK/hugo/loras/Sound Tracker - Gamelan (Indonesia) [UEWCCSuHsuQ].mp3 -- /media/CHONK/hugo/loras/xeno-canto-2 -Interface.coarse2fine_ckpt: ./runs/gamelan-xeno-canto/c2f/best/vampnet/weights.pth -Interface.coarse_ckpt: ./runs/gamelan-xeno-canto/coarse/best/vampnet/weights.pth -Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated-v0/nasralla/c2f.yml b/conf/generated-v0/nasralla/c2f.yml deleted file mode 100644 index 9d9db7bed268c18f3ca4047dcde34dd18a5a2301..0000000000000000000000000000000000000000 --- a/conf/generated-v0/nasralla/c2f.yml +++ /dev/null @@ -1,15 +0,0 @@ -$include: -- conf/lora/lora.yml -AudioDataset.duration: 3.0 -AudioDataset.loudness_cutoff: -40.0 -VampNet.embedding_dim: 1280 -VampNet.n_codebooks: 14 -VampNet.n_conditioning_codebooks: 4 -VampNet.n_heads: 20 -VampNet.n_layers: 16 -fine_tune: true -save_path: ./runs/nasralla/c2f -train/AudioLoader.sources: -- /media/CHONK/hugo/nasralla -val/AudioLoader.sources: -- /media/CHONK/hugo/nasralla diff --git a/conf/generated-v0/nasralla/coarse.yml b/conf/generated-v0/nasralla/coarse.yml deleted file mode 100644 index 43a4d18c7f955e38200ded0d2a4fa0959ddb639e..0000000000000000000000000000000000000000 --- a/conf/generated-v0/nasralla/coarse.yml +++ /dev/null @@ -1,8 +0,0 @@ -$include: -- conf/lora/lora.yml -fine_tune: true -save_path: ./runs/nasralla/coarse -train/AudioLoader.sources: -- /media/CHONK/hugo/nasralla -val/AudioLoader.sources: -- /media/CHONK/hugo/nasralla diff --git a/conf/generated-v0/nasralla/interface.yml b/conf/generated-v0/nasralla/interface.yml deleted file mode 100644 index c93e872d1e4b66567812755882a996814794ad8f..0000000000000000000000000000000000000000 --- a/conf/generated-v0/nasralla/interface.yml +++ /dev/null @@ -1,5 +0,0 @@ -AudioLoader.sources: -- /media/CHONK/hugo/nasralla -Interface.coarse2fine_ckpt: ./runs/nasralla/c2f/best/vampnet/weights.pth -Interface.coarse_ckpt: ./runs/nasralla/coarse/best/vampnet/weights.pth -Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/breaks-steps/c2f.yml b/conf/generated/breaks-steps/c2f.yml deleted file mode 100644 index 49617a6d52de00a9bc7c82c6e820168076402fac..0000000000000000000000000000000000000000 --- a/conf/generated/breaks-steps/c2f.yml +++ /dev/null @@ -1,15 +0,0 @@ -$include: -- conf/lora/lora.yml -AudioDataset.duration: 3.0 -AudioDataset.loudness_cutoff: -40.0 -VampNet.embedding_dim: 1280 -VampNet.n_codebooks: 14 -VampNet.n_conditioning_codebooks: 4 -VampNet.n_heads: 20 -VampNet.n_layers: 16 -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/c2f.pth -save_path: ./runs/breaks-steps/c2f -train/AudioLoader.sources: &id001 -- /media/CHONK/hugo/breaks-steps -val/AudioLoader.sources: *id001 diff --git a/conf/generated/breaks-steps/coarse.yml b/conf/generated/breaks-steps/coarse.yml deleted file mode 100644 index 71d9b27fbc4aac7d407d3606e98c4eaca35e2d3f..0000000000000000000000000000000000000000 --- a/conf/generated/breaks-steps/coarse.yml +++ /dev/null @@ -1,8 +0,0 @@ -$include: -- conf/lora/lora.yml -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/coarse.pth -save_path: ./runs/breaks-steps/coarse -train/AudioLoader.sources: &id001 -- /media/CHONK/hugo/breaks-steps -val/AudioLoader.sources: *id001 diff --git a/conf/generated/breaks-steps/interface.yml b/conf/generated/breaks-steps/interface.yml deleted file mode 100644 index b4b5182c4a378884e1614d89bc39abdf78a4eaa2..0000000000000000000000000000000000000000 --- a/conf/generated/breaks-steps/interface.yml +++ /dev/null @@ -1,7 +0,0 @@ -AudioLoader.sources: -- - /media/CHONK/hugo/breaks-steps -Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth -Interface.coarse2fine_lora_ckpt: ./runs/breaks-steps/c2f/latest/lora.pth -Interface.coarse_ckpt: ./models/spotdl/coarse.pth -Interface.coarse_lora_ckpt: ./runs/breaks-steps/coarse/latest/lora.pth -Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/bulgarian-tv-choir/c2f.yml b/conf/generated/bulgarian-tv-choir/c2f.yml deleted file mode 100644 index 7bc54bf54bc8cc5c599a11f30c036822fa4b84c5..0000000000000000000000000000000000000000 --- a/conf/generated/bulgarian-tv-choir/c2f.yml +++ /dev/null @@ -1,15 +0,0 @@ -$include: -- conf/lora/lora.yml -AudioDataset.duration: 3.0 -AudioDataset.loudness_cutoff: -40.0 -VampNet.embedding_dim: 1280 -VampNet.n_codebooks: 14 -VampNet.n_conditioning_codebooks: 4 -VampNet.n_heads: 20 -VampNet.n_layers: 16 -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/c2f.pth -save_path: ./runs/bulgarian-tv-choir/c2f -train/AudioLoader.sources: &id001 -- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/ -val/AudioLoader.sources: *id001 diff --git a/conf/generated/bulgarian-tv-choir/coarse.yml b/conf/generated/bulgarian-tv-choir/coarse.yml deleted file mode 100644 index 06f27f140dbd8c6d6315aab0787435ff501f8958..0000000000000000000000000000000000000000 --- a/conf/generated/bulgarian-tv-choir/coarse.yml +++ /dev/null @@ -1,8 +0,0 @@ -$include: -- conf/lora/lora.yml -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/coarse.pth -save_path: ./runs/bulgarian-tv-choir/coarse -train/AudioLoader.sources: &id001 -- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/ -val/AudioLoader.sources: *id001 diff --git a/conf/generated/bulgarian-tv-choir/interface.yml b/conf/generated/bulgarian-tv-choir/interface.yml deleted file mode 100644 index b56e8d721adf99da361dadf423a669bb576478e1..0000000000000000000000000000000000000000 --- a/conf/generated/bulgarian-tv-choir/interface.yml +++ /dev/null @@ -1,7 +0,0 @@ -AudioLoader.sources: -- - /media/CHONK/hugo/loras/bulgarian-female-tv-choir/ -Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth -Interface.coarse2fine_lora_ckpt: ./runs/bulgarian-tv-choir/c2f/latest/lora.pth -Interface.coarse_ckpt: ./models/spotdl/coarse.pth -Interface.coarse_lora_ckpt: ./runs/bulgarian-tv-choir/coarse/latest/lora.pth -Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/dariacore/c2f.yml b/conf/generated/dariacore/c2f.yml deleted file mode 100644 index e8e52fc05be63fe891d3adf0c2115efd5e06ecef..0000000000000000000000000000000000000000 --- a/conf/generated/dariacore/c2f.yml +++ /dev/null @@ -1,15 +0,0 @@ -$include: -- conf/lora/lora.yml -AudioDataset.duration: 3.0 -AudioDataset.loudness_cutoff: -40.0 -VampNet.embedding_dim: 1280 -VampNet.n_codebooks: 14 -VampNet.n_conditioning_codebooks: 4 -VampNet.n_heads: 20 -VampNet.n_layers: 16 -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/c2f.pth -save_path: ./runs/dariacore/c2f -train/AudioLoader.sources: &id001 -- /media/CHONK/hugo/loras/dariacore -val/AudioLoader.sources: *id001 diff --git a/conf/generated/dariacore/coarse.yml b/conf/generated/dariacore/coarse.yml deleted file mode 100644 index 42044d7bbafbf890d6d6bc504beb49edf977c39b..0000000000000000000000000000000000000000 --- a/conf/generated/dariacore/coarse.yml +++ /dev/null @@ -1,8 +0,0 @@ -$include: -- conf/lora/lora.yml -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/coarse.pth -save_path: ./runs/dariacore/coarse -train/AudioLoader.sources: &id001 -- /media/CHONK/hugo/loras/dariacore -val/AudioLoader.sources: *id001 diff --git a/conf/generated/dariacore/interface.yml b/conf/generated/dariacore/interface.yml deleted file mode 100644 index 29342d2fe9d97f20d9521885869f1cca16d2aeba..0000000000000000000000000000000000000000 --- a/conf/generated/dariacore/interface.yml +++ /dev/null @@ -1,7 +0,0 @@ -AudioLoader.sources: -- - /media/CHONK/hugo/loras/dariacore -Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth -Interface.coarse2fine_lora_ckpt: ./runs/dariacore/c2f/latest/lora.pth -Interface.coarse_ckpt: ./models/spotdl/coarse.pth -Interface.coarse_lora_ckpt: ./runs/dariacore/coarse/latest/lora.pth -Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/musica-bolero-marimba/c2f.yml b/conf/generated/musica-bolero-marimba/c2f.yml deleted file mode 100644 index cd06c72814deaf9fd41d3dabc8e6046e050ad968..0000000000000000000000000000000000000000 --- a/conf/generated/musica-bolero-marimba/c2f.yml +++ /dev/null @@ -1,18 +0,0 @@ -$include: -- conf/lora/lora.yml -AudioDataset.duration: 3.0 -AudioDataset.loudness_cutoff: -40.0 -VampNet.embedding_dim: 1280 -VampNet.n_codebooks: 14 -VampNet.n_conditioning_codebooks: 4 -VampNet.n_heads: 20 -VampNet.n_layers: 16 -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/c2f.pth -save_path: ./runs/musica-bolero-marimba/c2f -train/AudioLoader.sources: -- /media/CHONK/hugo/loras/boleros -- /media/CHONK/hugo/loras/marimba-honduras -val/AudioLoader.sources: -- /media/CHONK/hugo/loras/boleros -- /media/CHONK/hugo/loras/marimba-honduras diff --git a/conf/generated/musica-bolero-marimba/coarse.yml b/conf/generated/musica-bolero-marimba/coarse.yml deleted file mode 100644 index a3e1c0ee8e8593528cb389fb84c56894727cfca5..0000000000000000000000000000000000000000 --- a/conf/generated/musica-bolero-marimba/coarse.yml +++ /dev/null @@ -1,11 +0,0 @@ -$include: -- conf/lora/lora.yml -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/coarse.pth -save_path: ./runs/musica-bolero-marimba/coarse -train/AudioLoader.sources: -- /media/CHONK/hugo/loras/boleros -- /media/CHONK/hugo/loras/marimba-honduras -val/AudioLoader.sources: -- /media/CHONK/hugo/loras/boleros -- /media/CHONK/hugo/loras/marimba-honduras diff --git a/conf/generated/musica-bolero-marimba/interface.yml b/conf/generated/musica-bolero-marimba/interface.yml deleted file mode 100644 index 08b42e3120a3cedbb5aafb9a39ca879d8958127a..0000000000000000000000000000000000000000 --- a/conf/generated/musica-bolero-marimba/interface.yml +++ /dev/null @@ -1,8 +0,0 @@ -AudioLoader.sources: -- /media/CHONK/hugo/loras/boleros -- /media/CHONK/hugo/loras/marimba-honduras -Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth -Interface.coarse2fine_lora_ckpt: ./runs/musica-bolero-marimba/c2f/latest/lora.pth -Interface.coarse_ckpt: ./models/spotdl/coarse.pth -Interface.coarse_lora_ckpt: ./runs/musica-bolero-marimba/coarse/latest/lora.pth -Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/panchos/c2f.yml b/conf/generated/panchos/c2f.yml deleted file mode 100644 index 4efd6fb4caf409382929dcf61d40ed37e3773eac..0000000000000000000000000000000000000000 --- a/conf/generated/panchos/c2f.yml +++ /dev/null @@ -1,15 +0,0 @@ -$include: -- conf/lora/lora.yml -AudioDataset.duration: 3.0 -AudioDataset.loudness_cutoff: -40.0 -VampNet.embedding_dim: 1280 -VampNet.n_codebooks: 14 -VampNet.n_conditioning_codebooks: 4 -VampNet.n_heads: 20 -VampNet.n_layers: 16 -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/c2f.pth -save_path: ./runs/panchos/c2f -train/AudioLoader.sources: &id001 -- /media/CHONK/hugo/loras/panchos/ -val/AudioLoader.sources: *id001 diff --git a/conf/generated/panchos/coarse.yml b/conf/generated/panchos/coarse.yml deleted file mode 100644 index c4f21a3f4deb58cd6b98680e82d59ad32098542e..0000000000000000000000000000000000000000 --- a/conf/generated/panchos/coarse.yml +++ /dev/null @@ -1,8 +0,0 @@ -$include: -- conf/lora/lora.yml -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/coarse.pth -save_path: ./runs/panchos/coarse -train/AudioLoader.sources: &id001 -- /media/CHONK/hugo/loras/panchos/ -val/AudioLoader.sources: *id001 diff --git a/conf/generated/panchos/interface.yml b/conf/generated/panchos/interface.yml deleted file mode 100644 index 8bae11c225a0fa49c27efdfc808a63d53c21755a..0000000000000000000000000000000000000000 --- a/conf/generated/panchos/interface.yml +++ /dev/null @@ -1,7 +0,0 @@ -AudioLoader.sources: -- - /media/CHONK/hugo/loras/panchos/ -Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth -Interface.coarse2fine_lora_ckpt: ./runs/panchos/c2f/latest/lora.pth -Interface.coarse_ckpt: ./models/spotdl/coarse.pth -Interface.coarse_lora_ckpt: ./runs/panchos/coarse/latest/lora.pth -Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/titi-monkey/c2f.yml b/conf/generated/titi-monkey/c2f.yml deleted file mode 100644 index 456912ab1589eee1dfe6c5768e70ede4e455c828..0000000000000000000000000000000000000000 --- a/conf/generated/titi-monkey/c2f.yml +++ /dev/null @@ -1,15 +0,0 @@ -$include: -- conf/lora/lora.yml -AudioDataset.duration: 3.0 -AudioDataset.loudness_cutoff: -40.0 -VampNet.embedding_dim: 1280 -VampNet.n_codebooks: 14 -VampNet.n_conditioning_codebooks: 4 -VampNet.n_heads: 20 -VampNet.n_layers: 16 -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/c2f.pth -save_path: ./runs/titi-monkey/c2f -train/AudioLoader.sources: &id001 -- /media/CHONK/hugo/loras/titi-monkey.mp3 -val/AudioLoader.sources: *id001 diff --git a/conf/generated/titi-monkey/coarse.yml b/conf/generated/titi-monkey/coarse.yml deleted file mode 100644 index c2af934aa5aff33c26ae95a2d7a46eb19f9b7194..0000000000000000000000000000000000000000 --- a/conf/generated/titi-monkey/coarse.yml +++ /dev/null @@ -1,8 +0,0 @@ -$include: -- conf/lora/lora.yml -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/coarse.pth -save_path: ./runs/titi-monkey/coarse -train/AudioLoader.sources: &id001 -- /media/CHONK/hugo/loras/titi-monkey.mp3 -val/AudioLoader.sources: *id001 diff --git a/conf/generated/titi-monkey/interface.yml b/conf/generated/titi-monkey/interface.yml deleted file mode 100644 index cbc4ffad24c7c3b34e930aff08404955348b49a2..0000000000000000000000000000000000000000 --- a/conf/generated/titi-monkey/interface.yml +++ /dev/null @@ -1,7 +0,0 @@ -AudioLoader.sources: -- - /media/CHONK/hugo/loras/titi-monkey.mp3 -Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth -Interface.coarse2fine_lora_ckpt: ./runs/titi-monkey/c2f/latest/lora.pth -Interface.coarse_ckpt: ./models/spotdl/coarse.pth -Interface.coarse_lora_ckpt: ./runs/titi-monkey/coarse/latest/lora.pth -Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/generated/xeno-canto/c2f.yml b/conf/generated/xeno-canto/c2f.yml deleted file mode 100644 index 251b0e361ee15d01f7715608480cb3d5e9fdb122..0000000000000000000000000000000000000000 --- a/conf/generated/xeno-canto/c2f.yml +++ /dev/null @@ -1,15 +0,0 @@ -$include: -- conf/lora/lora.yml -AudioDataset.duration: 3.0 -AudioDataset.loudness_cutoff: -40.0 -VampNet.embedding_dim: 1280 -VampNet.n_codebooks: 14 -VampNet.n_conditioning_codebooks: 4 -VampNet.n_heads: 20 -VampNet.n_layers: 16 -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/c2f.pth -save_path: ./runs/xeno-canto/c2f -train/AudioLoader.sources: &id001 -- /media/CHONK/hugo/loras/xeno-canto-2/ -val/AudioLoader.sources: *id001 diff --git a/conf/generated/xeno-canto/coarse.yml b/conf/generated/xeno-canto/coarse.yml deleted file mode 100644 index ea151dbb64ff13982b0004685901da2b58c8e596..0000000000000000000000000000000000000000 --- a/conf/generated/xeno-canto/coarse.yml +++ /dev/null @@ -1,8 +0,0 @@ -$include: -- conf/lora/lora.yml -fine_tune: true -fine_tune_checkpoint: ./models/spotdl/coarse.pth -save_path: ./runs/xeno-canto/coarse -train/AudioLoader.sources: &id001 -- /media/CHONK/hugo/loras/xeno-canto-2/ -val/AudioLoader.sources: *id001 diff --git a/conf/generated/xeno-canto/interface.yml b/conf/generated/xeno-canto/interface.yml deleted file mode 100644 index 1a8b1420f142cef024471073e674cd9db59ffad0..0000000000000000000000000000000000000000 --- a/conf/generated/xeno-canto/interface.yml +++ /dev/null @@ -1,7 +0,0 @@ -AudioLoader.sources: -- - /media/CHONK/hugo/loras/xeno-canto-2/ -Interface.coarse2fine_ckpt: ./mod els/spotdl/c2f.pth -Interface.coarse2fine_lora_ckpt: ./runs/xeno-canto/c2f/latest/lora.pth -Interface.coarse_ckpt: ./models/spotdl/coarse.pth -Interface.coarse_lora_ckpt: ./runs/xeno-canto/coarse/latest/lora.pth -Interface.codec_ckpt: ./models/spotdl/codec.pth diff --git a/conf/lora/birds.yml b/conf/lora/birds.yml deleted file mode 100644 index de413ec0dec4f974e664923c9319861a1c957e87..0000000000000000000000000000000000000000 --- a/conf/lora/birds.yml +++ /dev/null @@ -1,10 +0,0 @@ -$include: - - conf/lora/lora.yml - -fine_tune: True - -train/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/birds - -val/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/birds diff --git a/conf/lora/birdss.yml b/conf/lora/birdss.yml deleted file mode 100644 index 3526de67d24e296de2cc0a7d2e5ebbc18245a6c8..0000000000000000000000000000000000000000 --- a/conf/lora/birdss.yml +++ /dev/null @@ -1,12 +0,0 @@ -$include: - - conf/lora/lora.yml - -fine_tune: True - -train/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/birds - - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/ - -val/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/birds - - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/ diff --git a/conf/lora/constructions.yml b/conf/lora/constructions.yml deleted file mode 100644 index f513b4898e06339fa0d0b4af24e98fdf5289094a..0000000000000000000000000000000000000000 --- a/conf/lora/constructions.yml +++ /dev/null @@ -1,10 +0,0 @@ -$include: - - conf/lora/lora.yml - -fine_tune: True - -train/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3 - -val/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/constructions/third.mp3 diff --git a/conf/lora/ella-baila-sola.yml b/conf/lora/ella-baila-sola.yml deleted file mode 100644 index 24eeada8013ea0d56d7d6474db52a48c3fd43bc1..0000000000000000000000000000000000000000 --- a/conf/lora/ella-baila-sola.yml +++ /dev/null @@ -1,10 +0,0 @@ -$include: - - conf/lora/lora.yml - -fine_tune: True - -train/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3 - -val/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/ella-baila-sola.mp3 diff --git a/conf/lora/gas-station.yml b/conf/lora/gas-station.yml deleted file mode 100644 index 4369f9203232fa3dcfd21667f3e55d0d0fda108e..0000000000000000000000000000000000000000 --- a/conf/lora/gas-station.yml +++ /dev/null @@ -1,10 +0,0 @@ -$include: - - conf/lora/lora.yml - -fine_tune: True - -train/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3 - -val/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3 diff --git a/conf/lora/lora-is-this-charlie-parker.yml b/conf/lora/lora-is-this-charlie-parker.yml deleted file mode 100644 index 9cfaa31a421266fafa60a1ee4bb2d45f1c47577c..0000000000000000000000000000000000000000 --- a/conf/lora/lora-is-this-charlie-parker.yml +++ /dev/null @@ -1,10 +0,0 @@ -$include: - - conf/lora/lora.yml - -fine_tune: True - -train/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3 - -val/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/this-is-charlie-parker/Charlie Parker - Donna Lee.mp3 diff --git a/conf/lora/lora.yml b/conf/lora/lora.yml index b901ea00a6008b92f25728d6d01a258c6aba5d1e..c6abe7e0bddac557ea3885309f3425877541cfe9 100644 --- a/conf/lora/lora.yml +++ b/conf/lora/lora.yml @@ -3,20 +3,18 @@ $include: fine_tune: True -train/AudioDataset.n_examples: 10000000 - -val/AudioDataset.n_examples: 10 +train/AudioDataset.n_examples: 100000000 +val/AudioDataset.n_examples: 100 NoamScheduler.warmup: 500 batch_size: 7 num_workers: 7 -epoch_length: 100 -save_audio_epochs: 10 +save_iters: [100000, 200000, 300000, 4000000, 500000] AdamW.lr: 0.0001 # let's us organize sound classes into folders and choose from those sound classes uniformly AudioDataset.without_replacement: False -max_epochs: 500 \ No newline at end of file +num_iters: 500000 \ No newline at end of file diff --git a/conf/lora/underworld.yml b/conf/lora/underworld.yml deleted file mode 100644 index 6fd1a6cf1e74220a2b51b1117afb373acda033a7..0000000000000000000000000000000000000000 --- a/conf/lora/underworld.yml +++ /dev/null @@ -1,10 +0,0 @@ -$include: - - conf/lora/lora.yml - -fine_tune: True - -train/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/underworld.mp3 - -val/AudioLoader.sources: - - /media/CHONK/hugo/spotdl/subsets/underworld.mp3 diff --git a/conf/lora/xeno-canto/c2f.yml b/conf/lora/xeno-canto/c2f.yml deleted file mode 100644 index 94f9906189f0b74b6c492bdd53fa56d58a0fa04d..0000000000000000000000000000000000000000 --- a/conf/lora/xeno-canto/c2f.yml +++ /dev/null @@ -1,21 +0,0 @@ -$include: - - conf/lora/lora.yml - -fine_tune: True - -train/AudioLoader.sources: - - /media/CHONK/hugo/xeno-canto-2 - -val/AudioLoader.sources: - - /media/CHONK/hugo/xeno-canto-2 - - -VampNet.n_codebooks: 14 -VampNet.n_conditioning_codebooks: 4 - -VampNet.embedding_dim: 1280 -VampNet.n_layers: 16 -VampNet.n_heads: 20 - -AudioDataset.duration: 3.0 -AudioDataset.loudness_cutoff: -40.0 diff --git a/conf/lora/xeno-canto/coarse.yml b/conf/lora/xeno-canto/coarse.yml deleted file mode 100644 index 223c8f0f8481f55ac1c33816ed79fe45b50f1495..0000000000000000000000000000000000000000 --- a/conf/lora/xeno-canto/coarse.yml +++ /dev/null @@ -1,10 +0,0 @@ -$include: - - conf/lora/lora.yml - -fine_tune: True - -train/AudioLoader.sources: - - /media/CHONK/hugo/xeno-canto-2 - -val/AudioLoader.sources: - - /media/CHONK/hugo/xeno-canto-2 diff --git a/conf/vampnet-musdb-drums.yml b/conf/vampnet-musdb-drums.yml deleted file mode 100644 index 010843d81ec9ac3c832b8e88f30af2f99a56ba99..0000000000000000000000000000000000000000 --- a/conf/vampnet-musdb-drums.yml +++ /dev/null @@ -1,22 +0,0 @@ -$include: - - conf/vampnet.yml - -VampNet.embedding_dim: 512 -VampNet.n_layers: 12 -VampNet.n_heads: 8 - -AudioDataset.duration: 12.0 - -train/AudioDataset.n_examples: 10000000 -train/AudioLoader.sources: - - /data/musdb18hq/train/**/*drums.wav - - -val/AudioDataset.n_examples: 500 -val/AudioLoader.sources: - - /data/musdb18hq/test/**/*drums.wav - - -test/AudioDataset.n_examples: 1000 -test/AudioLoader.sources: - - /data/musdb18hq/test/**/*drums.wav diff --git a/conf/vampnet.yml b/conf/vampnet.yml index d24df3fc1923eeb98f76f5747a52c3e83ef98795..6157577f3435c86e302bee1aa62f48d855128490 100644 --- a/conf/vampnet.yml +++ b/conf/vampnet.yml @@ -1,21 +1,17 @@ -codec_ckpt: ./models/spotdl/codec.pth +codec_ckpt: ./models/vampnet/codec.pth save_path: ckpt -max_epochs: 1000 -epoch_length: 1000 -save_audio_epochs: 2 -val_idx: [0,1,2,3,4,5,6,7,8,9] -prefix_amt: 0.0 -suffix_amt: 0.0 -prefix_dropout: 0.1 -suffix_dropout: 0.1 +num_iters: 1000000000 +save_iters: [10000, 50000, 100000, 300000, 500000] +val_idx: [0,1,2,3,4,5,6,7,8,9] +sample_freq: 10000 +val_freq: 1000 batch_size: 8 num_workers: 10 # Optimization -detect_anomaly: false amp: false CrossEntropyLoss.label_smoothing: 0.1 @@ -25,9 +21,6 @@ AdamW.lr: 0.001 NoamScheduler.factor: 2.0 NoamScheduler.warmup: 10000 -PitchShift.shift_amount: [const, 0] -PitchShift.prob: 0.0 - VampNet.vocab_size: 1024 VampNet.n_codebooks: 4 VampNet.n_conditioning_codebooks: 0 @@ -48,12 +41,9 @@ AudioDataset.duration: 10.0 train/AudioDataset.n_examples: 10000000 train/AudioLoader.sources: - - /data/spotdl/audio/train + - /media/CHONK/hugo/spotdl/audio-train val/AudioDataset.n_examples: 2000 val/AudioLoader.sources: - - /data/spotdl/audio/val + - /media/CHONK/hugo/spotdl/audio-val -test/AudioDataset.n_examples: 1000 -test/AudioLoader.sources: - - /data/spotdl/audio/test diff --git a/scripts/exp/fine_tune.py b/scripts/exp/fine_tune.py index e2c6c3b768f585242705e5cdabeebe45ced557cf..d3145378c574ee293c96e4973ec6f33ee3cb8713 100644 --- a/scripts/exp/fine_tune.py +++ b/scripts/exp/fine_tune.py @@ -35,7 +35,7 @@ def fine_tune(audio_files_or_folders: List[str], name: str): "AudioDataset.duration": 3.0, "AudioDataset.loudness_cutoff": -40.0, "save_path": f"./runs/{name}/c2f", - "fine_tune_checkpoint": "./models/spotdl/c2f.pth" + "fine_tune_checkpoint": "./models/vampnet/c2f.pth" } finetune_coarse_conf = { @@ -44,17 +44,17 @@ def fine_tune(audio_files_or_folders: List[str], name: str): "train/AudioLoader.sources": audio_files_or_folders, "val/AudioLoader.sources": audio_files_or_folders, "save_path": f"./runs/{name}/coarse", - "fine_tune_checkpoint": "./models/spotdl/coarse.pth" + "fine_tune_checkpoint": "./models/vampnet/coarse.pth" } interface_conf = { - "Interface.coarse_ckpt": f"./models/spotdl/coarse.pth", + "Interface.coarse_ckpt": f"./models/vampnet/coarse.pth", "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth", - "Interface.coarse2fine_ckpt": f"./models/spotdl/c2f.pth", + "Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth", "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth", - "Interface.codec_ckpt": "./models/spotdl/codec.pth", + "Interface.codec_ckpt": "./models/vampnet/codec.pth", "AudioLoader.sources": [audio_files_or_folders], } diff --git a/scripts/exp/train.py b/scripts/exp/train.py index 79251a529c9512b7bf8c2613e6ae173df21c5c61..68ddd88221710645c2e500203aae64dfd8d09257 100644 --- a/scripts/exp/train.py +++ b/scripts/exp/train.py @@ -1,9 +1,9 @@ import os -import subprocess -import time +import sys import warnings from pathlib import Path from typing import Optional +from dataclasses import dataclass import argbind import audiotools as at @@ -23,6 +23,12 @@ from vampnet import mask as pmask # from dac.model.dac import DAC from lac.model.lac import LAC as DAC +from audiotools.ml.decorators import ( + timer, Tracker, when +) + +import loralib as lora + # Enable cudnn autotuner to speed up training # (can be altered by the funcs.seed function) @@ -85,11 +91,7 @@ def build_datasets(args, sample_rate: int): ) with argbind.scope(args, "val"): val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform()) - with argbind.scope(args, "test"): - test_data = AudioDataset( - AudioLoader(), sample_rate, transform=build_transform() - ) - return train_data, val_data, test_data + return train_data, val_data def rand_float(shape, low, high, rng): @@ -100,16 +102,393 @@ def flip_coin(shape, p, rng): return rng.draw(shape)[:, 0] < p +def num_params_hook(o, p): + return o + f" {p/1e6:<.3f}M params." + + +def add_num_params_repr_hook(model): + import numpy as np + from functools import partial + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + + setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p)) + + +def accuracy( + preds: torch.Tensor, + target: torch.Tensor, + top_k: int = 1, + ignore_index: Optional[int] = None, +) -> torch.Tensor: + # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class) + preds = rearrange(preds, "b p s -> (b s) p") + target = rearrange(target, "b s -> (b s)") + + # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index) + if ignore_index is not None: + # Create a mask for the ignored index + mask = target != ignore_index + # Apply the mask to the target and predictions + preds = preds[mask] + target = target[mask] + + # Get the top-k predicted classes and their indices + _, pred_indices = torch.topk(preds, k=top_k, dim=-1) + + # Determine if the true target is in the top-k predicted classes + correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1) + + # Calculate the accuracy + accuracy = torch.mean(correct.float()) + + return accuracy + +def _metrics(z_hat, r, target, flat_mask, output): + for r_range in [(0, 0.5), (0.5, 1.0)]: + unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX) + masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) + + assert target.shape[0] == r.shape[0] + # grab the indices of the r values that are in the range + r_idx = (r >= r_range[0]) & (r < r_range[1]) + + # grab the target and z_hat values that are in the range + r_unmasked_target = unmasked_target[r_idx] + r_masked_target = masked_target[r_idx] + r_z_hat = z_hat[r_idx] + + for topk in (1, 25): + s, e = r_range + tag = f"accuracy-{s}-{e}/top{topk}" + + output[f"{tag}/unmasked"] = accuracy( + preds=r_z_hat, + target=r_unmasked_target, + ignore_index=IGNORE_INDEX, + top_k=topk, + ) + output[f"{tag}/masked"] = accuracy( + preds=r_z_hat, + target=r_masked_target, + ignore_index=IGNORE_INDEX, + top_k=topk, + ) + + +@dataclass +class State: + model: VampNet + codec: DAC + + optimizer: AdamW + scheduler: NoamScheduler + criterion: CrossEntropyLoss + grad_clip_val: float + + rng: torch.quasirandom.SobolEngine + + train_data: AudioDataset + val_data: AudioDataset + + tracker: Tracker + + +@timer() +def train_loop(state: State, batch: dict, accel: Accelerator): + state.model.train() + batch = at.util.prepare_batch(batch, accel.device) + signal = apply_transform(state.train_data.transform, batch) + + output = {} + vn = accel.unwrap(state.model) + with accel.autocast(): + with torch.inference_mode(): + state.codec.to(accel.device) + z = state.codec.encode(signal.samples, signal.sample_rate)["codes"] + z = z[:, : vn.n_codebooks, :] + + n_batch = z.shape[0] + r = state.rng.draw(n_batch)[:, 0].to(accel.device) + + mask = pmask.random(z, r) + mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) + z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) + + z_mask_latent = vn.embedding.from_codes(z_mask, state.codec) + + dtype = torch.bfloat16 if accel.amp else None + with accel.autocast(dtype=dtype): + z_hat = state.model(z_mask_latent, r) + + target = codebook_flatten( + z[:, vn.n_conditioning_codebooks :, :], + ) + + flat_mask = codebook_flatten( + mask[:, vn.n_conditioning_codebooks :, :], + ) + + # replace target with ignore index for masked tokens + t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) + output["loss"] = state.criterion(z_hat, t_masked) + + _metrics( + r=r, + z_hat=z_hat, + target=target, + flat_mask=flat_mask, + output=output, + ) + + + accel.backward(output["loss"]) + + output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"] + output["other/batch_size"] = z.shape[0] + + + accel.scaler.unscale_(state.optimizer) + output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_( + state.model.parameters(), state.grad_clip_val + ) + + accel.step(state.optimizer) + state.optimizer.zero_grad() + + state.scheduler.step() + accel.update() + + + return {k: v for k, v in sorted(output.items())} + + +@timer() +@torch.no_grad() +def val_loop(state: State, batch: dict, accel: Accelerator): + state.model.eval() + state.codec.eval() + batch = at.util.prepare_batch(batch, accel.device) + signal = apply_transform(state.val_data.transform, batch) + + vn = accel.unwrap(state.model) + z = state.codec.encode(signal.samples, signal.sample_rate)["codes"] + z = z[:, : vn.n_codebooks, :] + + n_batch = z.shape[0] + r = state.rng.draw(n_batch)[:, 0].to(accel.device) + + mask = pmask.random(z, r) + mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) + z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) + + z_mask_latent = vn.embedding.from_codes(z_mask, state.codec) + + z_hat = state.model(z_mask_latent, r) + + target = codebook_flatten( + z[:, vn.n_conditioning_codebooks :, :], + ) + + flat_mask = codebook_flatten( + mask[:, vn.n_conditioning_codebooks :, :] + ) + + output = {} + # replace target with ignore index for masked tokens + t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) + output["loss"] = state.criterion(z_hat, t_masked) + + _metrics( + r=r, + z_hat=z_hat, + target=target, + flat_mask=flat_mask, + output=output, + ) + + return output + + +def validate(state, val_dataloader, accel): + for batch in val_dataloader: + output = val_loop(state, batch, accel) + # Consolidate state dicts if using ZeroRedundancyOptimizer + if hasattr(state.optimizer, "consolidate_state_dict"): + state.optimizer.consolidate_state_dict() + return output + + +def checkpoint(state, save_iters, save_path, fine_tune): + if accel.local_rank != 0: + state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}") + return + + metadata = {"logs": dict(state.tracker.history)} + + tags = ["latest"] + state.tracker.print(f"Saving to {str(Path('.').absolute())}") + + if state.tracker.step in save_iters: + tags.append(f"{state.tracker.step // 1000}k") + + if state.tracker.is_best("val", "loss"): + state.tracker.print(f"Best model so far") + tags.append("best") + + if fine_tune: + for tag in tags: + # save the lora model + (Path(save_path) / tag).mkdir(parents=True, exist_ok=True) + torch.save( + lora.lora_state_dict(accel.unwrap(state.model)), + f"{save_path}/{tag}/lora.pth" + ) + + for tag in tags: + model_extra = { + "optimizer.pth": state.optimizer.state_dict(), + "scheduler.pth": state.scheduler.state_dict(), + "tracker.pth": state.tracker.state_dict(), + "metadata.pth": metadata, + } + + accel.unwrap(state.model).metadata = metadata + accel.unwrap(state.model).save_to_folder( + f"{save_path}/{tag}", model_extra, package=False + ) + + +def save_sampled(state, z, writer): + num_samples = z.shape[0] + + for i in range(num_samples): + sampled = accel.unwrap(state.model).generate( + codec=state.codec, + time_steps=z.shape[-1], + start_tokens=z[i : i + 1], + ) + sampled.cpu().write_audio_to_tb( + f"sampled/{i}", + writer, + step=state.tracker.step, + plot_fn=None, + ) + + +def save_imputation(state, z, val_idx, writer): + n_prefix = int(z.shape[-1] * 0.25) + n_suffix = int(z.shape[-1] * 0.25) + + vn = accel.unwrap(state.model) + + mask = pmask.inpaint(z, n_prefix, n_suffix) + mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) + z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) + + imputed_noisy = vn.to_signal(z_mask, state.codec) + imputed_true = vn.to_signal(z, state.codec) + + imputed = [] + for i in range(len(z)): + imputed.append( + vn.generate( + codec=state.codec, + time_steps=z.shape[-1], + start_tokens=z[i][None, ...], + mask=mask[i][None, ...], + ) + ) + imputed = AudioSignal.batch(imputed) + + for i in range(len(val_idx)): + imputed_noisy[i].cpu().write_audio_to_tb( + f"imputed_noisy/{i}", + writer, + step=state.tracker.step, + plot_fn=None, + ) + imputed[i].cpu().write_audio_to_tb( + f"imputed/{i}", + writer, + step=state.tracker.step, + plot_fn=None, + ) + imputed_true[i].cpu().write_audio_to_tb( + f"imputed_true/{i}", + writer, + step=state.tracker.step, + plot_fn=None, + ) + + +@torch.no_grad() +def save_samples(state: State, val_idx: int, writer: SummaryWriter): + state.model.eval() + state.codec.eval() + vn = accel.unwrap(state.model) + + batch = [state.val_data[i] for i in val_idx] + batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device) + + signal = apply_transform(state.val_data.transform, batch) + + z = state.codec.encode(signal.samples, signal.sample_rate)["codes"] + z = z[:, : vn.n_codebooks, :] + + r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device) + + + mask = pmask.random(z, r) + mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) + z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) + + z_mask_latent = vn.embedding.from_codes(z_mask, state.codec) + + z_hat = state.model(z_mask_latent, r) + + z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1) + z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks) + z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1) + + generated = vn.to_signal(z_pred, state.codec) + reconstructed = vn.to_signal(z, state.codec) + masked = vn.to_signal(z_mask.squeeze(1), state.codec) + + for i in range(generated.batch_size): + audio_dict = { + "original": signal[i], + "masked": masked[i], + "generated": generated[i], + "reconstructed": reconstructed[i], + } + for k, v in audio_dict.items(): + v.cpu().write_audio_to_tb( + f"samples/_{i}.r={r[i]:0.2f}/{k}", + writer, + step=state.tracker.step, + plot_fn=None, + ) + + save_sampled(state=state, z=z, writer=writer) + save_imputation(state=state, z=z, val_idx=val_idx, writer=writer) + + + @argbind.bind(without_prefix=True) def load( args, accel: at.ml.Accelerator, + tracker: Tracker, save_path: str, resume: bool = False, tag: str = "latest", load_weights: bool = False, fine_tune_checkpoint: Optional[str] = None, -): + grad_clip_val: float = 5.0, +) -> State: codec = DAC.load(args["codec_ckpt"], map_location="cpu") codec.eval() @@ -121,6 +500,7 @@ def load( "map_location": "cpu", "package": not load_weights, } + tracker.print(f"Loading checkpoint from {kwargs['folder']}") if (Path(kwargs["folder"]) / "vampnet").exists(): model, v_extra = VampNet.load_from_folder(**kwargs) else: @@ -147,89 +527,57 @@ def load( scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim) scheduler.step() - trainer_state = {"state_dict": None, "start_idx": 0} - if "optimizer.pth" in v_extra: optimizer.load_state_dict(v_extra["optimizer.pth"]) - if "scheduler.pth" in v_extra: scheduler.load_state_dict(v_extra["scheduler.pth"]) - if "trainer.pth" in v_extra: - trainer_state = v_extra["trainer.pth"] - - return { - "model": model, - "codec": codec, - "optimizer": optimizer, - "scheduler": scheduler, - "trainer_state": trainer_state, - } - - - -def num_params_hook(o, p): - return o + f" {p/1e6:<.3f}M params." - - -def add_num_params_repr_hook(model): - import numpy as np - from functools import partial - - for n, m in model.named_modules(): - o = m.extra_repr() - p = sum([np.prod(p.size()) for p in m.parameters()]) - - setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p)) - - -def accuracy( - preds: torch.Tensor, - target: torch.Tensor, - top_k: int = 1, - ignore_index: Optional[int] = None, -) -> torch.Tensor: - # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class) - preds = rearrange(preds, "b p s -> (b s) p") - target = rearrange(target, "b s -> (b s)") - - # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index) - if ignore_index is not None: - # Create a mask for the ignored index - mask = target != ignore_index - # Apply the mask to the target and predictions - preds = preds[mask] - target = target[mask] + if "tracker.pth" in v_extra: + tracker.load_state_dict(v_extra["tracker.pth"]) + + criterion = CrossEntropyLoss() - # Get the top-k predicted classes and their indices - _, pred_indices = torch.topk(preds, k=top_k, dim=-1) + sample_rate = codec.sample_rate - # Determine if the true target is in the top-k predicted classes - correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1) + # a better rng for sampling from our schedule + rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"]) - # Calculate the accuracy - accuracy = torch.mean(correct.float()) + # log a model summary w/ num params + if accel.local_rank == 0: + add_num_params_repr_hook(accel.unwrap(model)) + with open(f"{save_path}/model.txt", "w") as f: + f.write(repr(accel.unwrap(model))) - return accuracy + # load the datasets + train_data, val_data = build_datasets(args, sample_rate) + + return State( + tracker=tracker, + model=model, + codec=codec, + optimizer=optimizer, + scheduler=scheduler, + criterion=criterion, + rng=rng, + train_data=train_data, + val_data=val_data, + grad_clip_val=grad_clip_val, + ) @argbind.bind(without_prefix=True) def train( args, accel: at.ml.Accelerator, - codec_ckpt: str = None, seed: int = 0, + codec_ckpt: str = None, save_path: str = "ckpt", - max_epochs: int = int(100e3), - epoch_length: int = 1000, - save_audio_epochs: int = 2, - save_epochs: list = [10, 50, 100, 200, 300, 400,], - batch_size: int = 48, - grad_acc_steps: int = 1, + num_iters: int = int(1000e6), + save_iters: list = [10000, 50000, 100000, 300000, 500000,], + sample_freq: int = 10000, + val_freq: int = 1000, + batch_size: int = 12, val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], num_workers: int = 10, - detect_anomaly: bool = False, - grad_clip_val: float = 5.0, fine_tune: bool = False, - quiet: bool = False, ): assert codec_ckpt is not None, "codec_ckpt is required" @@ -241,376 +589,76 @@ def train( writer = SummaryWriter(log_dir=f"{save_path}/logs/") argbind.dump_args(args, f"{save_path}/args.yml") - # load the codec model - loaded = load(args, accel, save_path) - model = loaded["model"] - codec = loaded["codec"] - optimizer = loaded["optimizer"] - scheduler = loaded["scheduler"] - trainer_state = loaded["trainer_state"] - - sample_rate = codec.sample_rate + tracker = Tracker( + writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank + ) - # a better rng for sampling from our schedule - rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=seed) + # load the codec model + state: State = load( + args=args, + accel=accel, + tracker=tracker, + save_path=save_path) - # log a model summary w/ num params - if accel.local_rank == 0: - add_num_params_repr_hook(accel.unwrap(model)) - with open(f"{save_path}/model.txt", "w") as f: - f.write(repr(accel.unwrap(model))) - # load the datasets - train_data, val_data, _ = build_datasets(args, sample_rate) train_dataloader = accel.prepare_dataloader( - train_data, - start_idx=trainer_state["start_idx"], + state.train_data, + start_idx=state.tracker.step * batch_size, num_workers=num_workers, batch_size=batch_size, - collate_fn=train_data.collate, + collate_fn=state.train_data.collate, ) val_dataloader = accel.prepare_dataloader( - val_data, + state.val_data, start_idx=0, num_workers=num_workers, batch_size=batch_size, - collate_fn=val_data.collate, + collate_fn=state.val_data.collate, + persistent_workers=True, ) - criterion = CrossEntropyLoss() + if fine_tune: - import loralib as lora - lora.mark_only_lora_as_trainable(model) - - - class Trainer(at.ml.BaseTrainer): - _last_grad_norm = 0.0 - - def _metrics(self, vn, z_hat, r, target, flat_mask, output): - for r_range in [(0, 0.5), (0.5, 1.0)]: - unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX) - masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) - - assert target.shape[0] == r.shape[0] - # grab the indices of the r values that are in the range - r_idx = (r >= r_range[0]) & (r < r_range[1]) - - # grab the target and z_hat values that are in the range - r_unmasked_target = unmasked_target[r_idx] - r_masked_target = masked_target[r_idx] - r_z_hat = z_hat[r_idx] - - for topk in (1, 25): - s, e = r_range - tag = f"accuracy-{s}-{e}/top{topk}" - - output[f"{tag}/unmasked"] = accuracy( - preds=r_z_hat, - target=r_unmasked_target, - ignore_index=IGNORE_INDEX, - top_k=topk, - ) - output[f"{tag}/masked"] = accuracy( - preds=r_z_hat, - target=r_masked_target, - ignore_index=IGNORE_INDEX, - top_k=topk, - ) - - def train_loop(self, engine, batch): - model.train() - batch = at.util.prepare_batch(batch, accel.device) - signal = apply_transform(train_data.transform, batch) - - output = {} - vn = accel.unwrap(model) - with accel.autocast(): - with torch.inference_mode(): - codec.to(accel.device) - z = codec.encode(signal.samples, signal.sample_rate)["codes"] - z = z[:, : vn.n_codebooks, :] - - n_batch = z.shape[0] - r = rng.draw(n_batch)[:, 0].to(accel.device) - - mask = pmask.random(z, r) - mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) - z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) - - z_mask_latent = vn.embedding.from_codes(z_mask, codec) - - dtype = torch.bfloat16 if accel.amp else None - with accel.autocast(dtype=dtype): - z_hat = model(z_mask_latent, r) - - target = codebook_flatten( - z[:, vn.n_conditioning_codebooks :, :], - ) - - flat_mask = codebook_flatten( - mask[:, vn.n_conditioning_codebooks :, :], - ) - - # replace target with ignore index for masked tokens - t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) - output["loss"] = criterion(z_hat, t_masked) - - self._metrics( - vn=vn, - r=r, - z_hat=z_hat, - target=target, - flat_mask=flat_mask, - output=output, - ) - - - accel.backward(output["loss"] / grad_acc_steps) - - output["other/learning_rate"] = optimizer.param_groups[0]["lr"] - output["other/batch_size"] = z.shape[0] - - if ( - (engine.state.iteration % grad_acc_steps == 0) - or (engine.state.iteration % epoch_length == 0) - or (engine.state.iteration % epoch_length == 1) - ): # (or we reached the end of the epoch) - accel.scaler.unscale_(optimizer) - output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_( - model.parameters(), grad_clip_val - ) - self._last_grad_norm = output["other/grad_norm"] - - accel.step(optimizer) - optimizer.zero_grad() - - scheduler.step() - accel.update() - else: - output["other/grad_norm"] = self._last_grad_norm - - return {k: v for k, v in sorted(output.items())} - - @torch.no_grad() - def val_loop(self, engine, batch): - model.eval() - codec.eval() - batch = at.util.prepare_batch(batch, accel.device) - signal = apply_transform(val_data.transform, batch) - - vn = accel.unwrap(model) - z = codec.encode(signal.samples, signal.sample_rate)["codes"] - z = z[:, : vn.n_codebooks, :] - - n_batch = z.shape[0] - r = rng.draw(n_batch)[:, 0].to(accel.device) - - mask = pmask.random(z, r) - mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) - z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) + lora.mark_only_lora_as_trainable(state.model) - z_mask_latent = vn.embedding.from_codes(z_mask, codec) + # Wrap the functions so that they neatly track in TensorBoard + progress bars + # and only run when specific conditions are met. + global train_loop, val_loop, validate, save_samples, checkpoint - z_hat = model(z_mask_latent, r) + train_loop = tracker.log("train", "value", history=False)( + tracker.track("train", num_iters, completed=state.tracker.step)(train_loop) + ) + val_loop = tracker.track("val", len(val_dataloader))(val_loop) + validate = tracker.log("val", "mean")(validate) - target = codebook_flatten( - z[:, vn.n_conditioning_codebooks :, :], - ) + save_samples = when(lambda: accel.local_rank == 0)(save_samples) + checkpoint = when(lambda: accel.local_rank == 0)(checkpoint) - flat_mask = codebook_flatten( - mask[:, vn.n_conditioning_codebooks :, :] - ) + with tracker.live: + for tracker.step, batch in enumerate(train_dataloader, start=tracker.step): + train_loop(state, batch, accel) - output = {} - # replace target with ignore index for masked tokens - t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX) - output["loss"] = criterion(z_hat, t_masked) - - self._metrics( - vn=vn, - r=r, - z_hat=z_hat, - target=target, - flat_mask=flat_mask, - output=output, + last_iter = ( + tracker.step == num_iters - 1 if num_iters is not None else False ) - return output + if tracker.step % sample_freq == 0 or last_iter: + save_samples(state, val_idx, writer) - def checkpoint(self, engine): - if accel.local_rank != 0: - print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}") - return - - metadata = {"logs": dict(engine.state.logs["epoch"])} - - if self.state.epoch % save_audio_epochs == 0: - self.save_samples() - - tags = ["latest"] - loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train" - self.print(f"Saving to {str(Path('.').absolute())}") - - if self.state.epoch in save_epochs: - tags.append(f"epoch={self.state.epoch}") - - if self.is_best(engine, loss_key): - self.print(f"Best model so far") - tags.append("best") - - if fine_tune: - for tag in tags: - # save the lora model - (Path(save_path) / tag).mkdir(parents=True, exist_ok=True) - torch.save( - lora.lora_state_dict(accel.unwrap(model)), - f"{save_path}/{tag}/lora.pth" - ) - - for tag in tags: - model_extra = { - "optimizer.pth": optimizer.state_dict(), - "scheduler.pth": scheduler.state_dict(), - "trainer.pth": { - "start_idx": self.state.iteration * batch_size, - "state_dict": self.state_dict(), - }, - "metadata.pth": metadata, - } - - accel.unwrap(model).metadata = metadata - accel.unwrap(model).save_to_folder( - f"{save_path}/{tag}", model_extra, - ) - - def save_sampled(self, z): - num_samples = z.shape[0] - - for i in range(num_samples): - sampled = accel.unwrap(model).generate( - codec=codec, - time_steps=z.shape[-1], - start_tokens=z[i : i + 1], - ) - sampled.cpu().write_audio_to_tb( - f"sampled/{i}", - self.writer, - step=self.state.epoch, - plot_fn=None, - ) - - - def save_imputation(self, z: torch.Tensor): - n_prefix = int(z.shape[-1] * 0.25) - n_suffix = int(z.shape[-1] * 0.25) - - vn = accel.unwrap(model) - - mask = pmask.inpaint(z, n_prefix, n_suffix) - mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) - z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) - - imputed_noisy = vn.to_signal(z_mask, codec) - imputed_true = vn.to_signal(z, codec) - - imputed = [] - for i in range(len(z)): - imputed.append( - vn.generate( - codec=codec, - time_steps=z.shape[-1], - start_tokens=z[i][None, ...], - mask=mask[i][None, ...], - ) - ) - imputed = AudioSignal.batch(imputed) - - for i in range(len(val_idx)): - imputed_noisy[i].cpu().write_audio_to_tb( - f"imputed_noisy/{i}", - self.writer, - step=self.state.epoch, - plot_fn=None, - ) - imputed[i].cpu().write_audio_to_tb( - f"imputed/{i}", - self.writer, - step=self.state.epoch, - plot_fn=None, - ) - imputed_true[i].cpu().write_audio_to_tb( - f"imputed_true/{i}", - self.writer, - step=self.state.epoch, - plot_fn=None, - ) - - @torch.no_grad() - def save_samples(self): - model.eval() - codec.eval() - vn = accel.unwrap(model) - - batch = [val_data[i] for i in val_idx] - batch = at.util.prepare_batch(val_data.collate(batch), accel.device) - - signal = apply_transform(val_data.transform, batch) - - z = codec.encode(signal.samples, signal.sample_rate)["codes"] - z = z[:, : vn.n_codebooks, :] - - r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device) - - - mask = pmask.random(z, r) - mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks) - z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token) - - z_mask_latent = vn.embedding.from_codes(z_mask, codec) - - z_hat = model(z_mask_latent, r) + if tracker.step % val_freq == 0 or last_iter: + validate(state, val_dataloader, accel) + checkpoint( + state=state, + save_iters=save_iters, + save_path=save_path, + fine_tune=fine_tune) - z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1) - z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks) - z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1) + # Reset validation progress bar, print summary since last validation. + tracker.done("val", f"Iteration {tracker.step}") - generated = vn.to_signal(z_pred, codec) - reconstructed = vn.to_signal(z, codec) - masked = vn.to_signal(z_mask.squeeze(1), codec) - - for i in range(generated.batch_size): - audio_dict = { - "original": signal[i], - "masked": masked[i], - "generated": generated[i], - "reconstructed": reconstructed[i], - } - for k, v in audio_dict.items(): - v.cpu().write_audio_to_tb( - f"samples/_{i}.r={r[i]:0.2f}/{k}", - self.writer, - step=self.state.epoch, - plot_fn=None, - ) - - self.save_sampled(z) - self.save_imputation(z) - - trainer = Trainer(writer=writer, quiet=quiet) - - if trainer_state["state_dict"] is not None: - trainer.load_state_dict(trainer_state["state_dict"]) - if hasattr(train_dataloader.sampler, "set_epoch"): - train_dataloader.sampler.set_epoch(trainer.trainer.state.epoch) - - trainer.run( - train_dataloader, - val_dataloader, - num_epochs=max_epochs, - epoch_length=epoch_length, - detect_anomaly=detect_anomaly, - ) + if last_iter: + break if __name__ == "__main__": @@ -618,4 +666,6 @@ if __name__ == "__main__": args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0 with argbind.scope(args): with Accelerator() as accel: + if accel.local_rank != 0: + sys.tracebacklimit = 0 train(args, accel) diff --git a/setup.py b/setup.py index 2964e0f810f32dab3abc433912a2de128c081761..fb4d211908e8c6e89a4a59b4cb99a0508fbbfc80 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ setup( "numpy==1.22", "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat", "lac @ git+https://github.com/hugofloresgarcia/lac.git", - "audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git", + "descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2", "gradio", "tensorboardX", "loralib", diff --git a/vampnet/beats.py b/vampnet/beats.py index 317496ef83d7b764fbbc51068c13170ce0c17e13..2b03a4e3df705a059cd34e6e01a72752fc4d8a98 100644 --- a/vampnet/beats.py +++ b/vampnet/beats.py @@ -9,6 +9,7 @@ from typing import Tuple from typing import Union import librosa +import torch import numpy as np from audiotools import AudioSignal @@ -203,7 +204,7 @@ class WaveBeat(BeatTracker): def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"): from wavebeat.dstcn import dsTCNModel - model = dsTCNModel.load_from_checkpoint(ckpt_path) + model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device)) model.eval() self.device = device diff --git a/vampnet/interface.py b/vampnet/interface.py index 0a6e39182c9d91c1b76bcb18476f9c018a247543..39e313e949d1721df04ab112f3fb40bebba37b61 100644 --- a/vampnet/interface.py +++ b/vampnet/interface.py @@ -22,6 +22,7 @@ def signal_concat( return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate) + def _load_model( ckpt: str, lora_ckpt: str = None, @@ -64,7 +65,7 @@ class Interface(torch.nn.Module): ): super().__init__() assert codec_ckpt is not None, "must provide a codec checkpoint" - self.codec = DAC.load(Path(codec_ckpt)) + self.codec = DAC.load(codec_ckpt) self.codec.eval() self.codec.to(device) @@ -275,34 +276,44 @@ class Interface(torch.nn.Module): def coarse_to_fine( self, - coarse_z: torch.Tensor, + z: torch.Tensor, + mask: torch.Tensor = None, **kwargs ): assert self.c2f is not None, "No coarse2fine model loaded" - length = coarse_z.shape[-1] + length = z.shape[-1] chunk_len = self.s2t(self.c2f.chunk_size_s) - n_chunks = math.ceil(coarse_z.shape[-1] / chunk_len) + n_chunks = math.ceil(z.shape[-1] / chunk_len) # zero pad to chunk_len if length % chunk_len != 0: pad_len = chunk_len - (length % chunk_len) - coarse_z = torch.nn.functional.pad(coarse_z, (0, pad_len)) + z = torch.nn.functional.pad(z, (0, pad_len)) + mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None - n_codebooks_to_append = self.c2f.n_codebooks - coarse_z.shape[1] + n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1] if n_codebooks_to_append > 0: - coarse_z = torch.cat([ - coarse_z, - torch.zeros(coarse_z.shape[0], n_codebooks_to_append, coarse_z.shape[-1]).long().to(self.device) + z = torch.cat([ + z, + torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device) ], dim=1) + # set the mask to 0 for all conditioning codebooks + if mask is not None: + mask = mask.clone() + mask[:, :self.c2f.n_conditioning_codebooks, :] = 0 + fine_z = [] for i in range(n_chunks): - chunk = coarse_z[:, :, i * chunk_len : (i + 1) * chunk_len] + chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len] + mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None + chunk = self.c2f.generate( codec=self.codec, time_steps=chunk_len, start_tokens=chunk, return_signal=False, + mask=mask_chunk, **kwargs ) fine_z.append(chunk) @@ -337,6 +348,12 @@ class Interface(torch.nn.Module): **kwargs ) + # add the fine codes back in + c_vamp = torch.cat( + [c_vamp, z[:, self.coarse.n_codebooks :, :]], + dim=1 + ) + if return_mask: return c_vamp, cz_masked @@ -352,17 +369,18 @@ if __name__ == "__main__": at.util.seed(42) interface = Interface( - coarse_ckpt="./models/spotdl/coarse.pth", - coarse2fine_ckpt="./models/spotdl/c2f.pth", - codec_ckpt="./models/spotdl/codec.pth", + coarse_ckpt="./models/vampnet/coarse.pth", + coarse2fine_ckpt="./models/vampnet/c2f.pth", + codec_ckpt="./models/vampnet/codec.pth", device="cuda", wavebeat_ckpt="./models/wavebeat.pth" ) - sig = at.AudioSignal.zeros(duration=10, sample_rate=44100) + sig = at.AudioSignal('assets/example.wav') z = interface.encode(sig) + breakpoint() # mask = linear_random(z, 1.0) # mask = mask_and( @@ -374,13 +392,14 @@ if __name__ == "__main__": # ) # ) - mask = interface.make_beat_mask( - sig, 0.0, 0.075 - ) + # mask = interface.make_beat_mask( + # sig, 0.0, 0.075 + # ) # mask = dropout(mask, 0.0) # mask = codebook_unmask(mask, 0) + + mask = inpaint(z, n_prefix=100, n_suffix=100) - breakpoint() zv, mask_z = interface.coarse_vamp( z, mask=mask, @@ -389,16 +408,16 @@ if __name__ == "__main__": return_mask=True, gen_fn=interface.coarse.generate ) + use_coarse2fine = True if use_coarse2fine: - zv = interface.coarse_to_fine(zv, temperature=0.8) + zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask) + breakpoint() mask = interface.to_signal(mask_z).cpu() sig = interface.to_signal(zv).cpu() print("done") - sig.write("output3.wav") - mask.write("mask.wav") \ No newline at end of file diff --git a/vampnet/modules/__init__.py b/vampnet/modules/__init__.py index 3481f32e0287faa9e79ba219f17d18529a4b57ac..3f4c8c226e42d022c60b620e8f21ccaf4e6a57bd 100644 --- a/vampnet/modules/__init__.py +++ b/vampnet/modules/__init__.py @@ -2,3 +2,5 @@ import audiotools audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"] audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"] + +from .transformer import VampNet \ No newline at end of file