hugo flores garcia commited on
Commit
2b3cdf0
1 Parent(s): e612fff

gitattributes

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ coarse.pth filter=lfs diff=lfs merge=lfs -text
37
+ c2f.pth filter=lfs diff=lfs merge=lfs -text
38
+ wavebeat.pth filter=lfs diff=lfs merge=lfs -text
39
+ codec.pth filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/env.sh
108
+ venv/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ # Files created by experiments
131
+ output/
132
+ snapshot/
133
+ *.m4a
134
+ notebooks/scratch.ipynb
135
+ notebooks/inspect.ipynb
136
+ notebooks/effects.ipynb
137
+ notebooks/*.ipynb
138
+ notebooks/*.gif
139
+ notebooks/*.wav
140
+ notebooks/*.mp4
141
+ *runs/
142
+ boards/
143
+ samples/
144
+ *.ipynb
145
+
146
+ results.json
147
+ metrics.csv
148
+ mprofile_*
149
+ mem.png
150
+
151
+ results/
152
+ mprofile*
153
+ *.png
154
+ # do not ignore the test wav file
155
+ !tests/audio/short_test_audio.wav
156
+ !tests/audio/output.wav
157
+ */.DS_Store
158
+ .DS_Store
159
+ env.sh
160
+ _codebraid/
161
+ **/*.html
162
+ **/*.exec.md
163
+ flagged/
164
+ log.txt
165
+ ckpt/
166
+ .syncthing*
167
+ tests/assets/
168
+ archived/
169
+
170
+ scratch/
171
+
172
+ runs-archive
173
+ lyrebird-audiotools
174
+ lyrebird-audio-codec
175
+ samples-*/**
176
+
177
+ gradio-outputs/
178
+ samples*/
179
+ models-all/
180
+ models.zip
181
+ .git-old
182
+ conf/generated/*
183
+ runs*/
184
+
185
+
186
+ gtzan.zip
187
+ .gtzan_emb_cache
188
+ runs
189
+
190
+ data/
191
+ src/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/asottile/reorder_python_imports
3
+ rev: v2.5.0
4
+ hooks:
5
+ - id: reorder-python-imports
6
+ - repo: https://github.com/psf/black
7
+ rev: 23.1.0
8
+ hooks:
9
+ - id: black
10
+ language_version: python3
11
+ - repo: https://github.com/pre-commit/pre-commit-hooks
12
+ rev: v4.0.1
13
+ hooks:
14
+ - id: end-of-file-fixer
15
+ - id: trailing-whitespace
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Hugo Flores García and Prem Seetharaman
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # huggingface space exclusive
2
+ import os
3
+
4
+ # print("installing pyharp")
5
+ # os.system('pip install "pyharp@git+https://github.com/audacitorch/pyharp.git"')
6
+ # print("installing madmom")
7
+ os.system('pip install cython')
8
+ os.system('pip install madmom')
9
+
10
+ from pathlib import Path
11
+ from typing import Tuple
12
+ import yaml
13
+ import tempfile
14
+ import uuid
15
+ import shutil
16
+ from dataclasses import dataclass, asdict
17
+
18
+ import numpy as np
19
+ import audiotools as at
20
+ import argbind
21
+ import torch
22
+
23
+ import gradio as gr
24
+ from vampnet.interface import Interface
25
+ from vampnet import mask as pmask
26
+
27
+ from pyharp import ModelCard, build_endpoint
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ interface = Interface(
31
+ device=device,
32
+ coarse_ckpt="models/vampvat/coarse.pth",
33
+ coarse2fine_ckpt="models/vampvat/c2f.pth",
34
+ codec_ckpt="models/vampvat/codec.pth",
35
+ )
36
+
37
+ # populate the model choices with any interface.yml files in the generated confs
38
+ MODEL_CHOICES = {
39
+ "default": {
40
+ "Interface.coarse_ckpt": str(interface.coarse_path),
41
+ "Interface.coarse2fine_ckpt": str(interface.c2f_path),
42
+ "Interface.codec_ckpt": str(interface.codec_path),
43
+ }
44
+ }
45
+ generated_confs = Path("conf/generated")
46
+ for conf_file in generated_confs.glob("*/interface.yml"):
47
+ with open(conf_file) as f:
48
+ _conf = yaml.safe_load(f)
49
+ MODEL_CHOICES[conf_file.parent.name] = _conf
50
+
51
+
52
+
53
+ OUT_DIR = Path("gradio-outputs")
54
+ OUT_DIR.mkdir(exist_ok=True, parents=True)
55
+
56
+
57
+ def load_audio(file):
58
+ print(file)
59
+ filepath = file.name
60
+ sig = at.AudioSignal.salient_excerpt(
61
+ filepath,
62
+ duration=interface.coarse.chunk_size_s
63
+ )
64
+ sig = interface.preprocess(sig)
65
+
66
+ out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
67
+ out_dir.mkdir(parents=True, exist_ok=True)
68
+ sig.write(out_dir / "input.wav")
69
+ return sig.path_to_file
70
+
71
+
72
+ def load_example_audio():
73
+ return "./assets/example.wav"
74
+
75
+ from torch_pitch_shift import pitch_shift, get_fast_shifts
76
+ def shift_pitch(signal, interval: int):
77
+ signal.samples = pitch_shift(
78
+ signal.samples,
79
+ shift=interval,
80
+ sample_rate=signal.sample_rate
81
+ )
82
+ return signal
83
+
84
+ def _vamp(data, return_mask=False):
85
+ # remove any old files in the output directory (from previous runs)
86
+ shutil.rmtree(OUT_DIR)
87
+ OUT_DIR.mkdir()
88
+
89
+ out_dir = OUT_DIR / str(uuid.uuid4())
90
+ out_dir.mkdir()
91
+ sig = at.AudioSignal(data[input_audio])
92
+ sig = interface.preprocess(sig)
93
+
94
+ # reload the model if necessary
95
+ interface.reload(
96
+ coarse_ckpt=MODEL_CHOICES[data[model_choice]]["Interface.coarse_ckpt"],
97
+ c2f_ckpt=MODEL_CHOICES[data[model_choice]]["Interface.coarse2fine_ckpt"],
98
+ )
99
+
100
+ loudness = sig.loudness()
101
+ print(f"input loudness is {loudness}")
102
+
103
+ if data[pitch_shift_amt] != 0:
104
+ sig = shift_pitch(sig, data[pitch_shift_amt])
105
+
106
+ z = interface.encode(sig)
107
+
108
+ ncc = data[n_conditioning_codebooks]
109
+
110
+ # build the mask
111
+ mask = pmask.linear_random(z, data[rand_mask_intensity])
112
+ mask = pmask.mask_and(
113
+ mask, pmask.inpaint(
114
+ z,
115
+ interface.s2t(data[prefix_s]),
116
+ interface.s2t(data[suffix_s])
117
+ )
118
+ )
119
+ mask = pmask.mask_and(
120
+ mask, pmask.periodic_mask(
121
+ z,
122
+ data[periodic_p],
123
+ data[periodic_w],
124
+ random_roll=True
125
+ )
126
+ )
127
+ if data[onset_mask_width] > 0:
128
+ mask = pmask.mask_or(
129
+ mask, pmask.onset_mask(sig, z, interface, width=data[onset_mask_width])
130
+ )
131
+ if data[beat_mask_width] > 0:
132
+ beat_mask = interface.make_beat_mask(
133
+ sig,
134
+ after_beat_s=(data[beat_mask_width]/1000),
135
+ mask_upbeats=not data[beat_mask_downbeats],
136
+ )
137
+ mask = pmask.mask_and(mask, beat_mask)
138
+
139
+ # these should be the last two mask ops
140
+ mask = pmask.dropout(mask, data[dropout])
141
+ mask = pmask.codebook_unmask(mask, ncc)
142
+ mask = pmask.codebook_mask(mask, int(data[n_mask_codebooks]))
143
+
144
+ print(f"dropout {data[dropout]}")
145
+ print(f"masktemp {data[masktemp]}")
146
+ print(f"sampletemp {data[sampletemp]}")
147
+ print(f"top_p {data[top_p]}")
148
+ print(f"prefix_s {data[prefix_s]}")
149
+ print(f"suffix_s {data[suffix_s]}")
150
+ print(f"rand_mask_intensity {data[rand_mask_intensity]}")
151
+ print(f"num_steps {data[num_steps]}")
152
+ print(f"periodic_p {data[periodic_p]}")
153
+ print(f"periodic_w {data[periodic_w]}")
154
+ print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
155
+ print(f"use_coarse2fine {data[use_coarse2fine]}")
156
+ print(f"onset_mask_width {data[onset_mask_width]}")
157
+ print(f"beat_mask_width {data[beat_mask_width]}")
158
+ print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
159
+ print(f"stretch_factor {data[stretch_factor]}")
160
+ print(f"seed {data[seed]}")
161
+ print(f"pitch_shift_amt {data[pitch_shift_amt]}")
162
+ print(f"sample_cutoff {data[sample_cutoff]}")
163
+
164
+
165
+ _top_p = data[top_p] if data[top_p] > 0 else None
166
+ # save the mask as a txt file
167
+ np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
168
+
169
+ _seed = data[seed] if data[seed] > 0 else None
170
+ print(f"processing coarse...")
171
+ zv, mask_z = interface.coarse_vamp(
172
+ z,
173
+ mask=mask,
174
+ sampling_steps=data[num_steps],
175
+ mask_temperature=data[masktemp]*10,
176
+ sampling_temperature=data[sampletemp],
177
+ return_mask=True,
178
+ typical_filtering=data[typical_filtering],
179
+ typical_mass=data[typical_mass],
180
+ typical_min_tokens=data[typical_min_tokens],
181
+ top_p=_top_p,
182
+ gen_fn=interface.coarse.generate,
183
+ seed=_seed,
184
+ sample_cutoff=data[sample_cutoff],
185
+ )
186
+
187
+ if use_coarse2fine:
188
+ print(f"processing coarse to fine...")
189
+ zv = interface.coarse_to_fine(
190
+ zv,
191
+ mask_temperature=data[masktemp]*10,
192
+ sampling_temperature=data[sampletemp],
193
+ mask=mask,
194
+ sampling_steps=data[num_steps] // 2,
195
+ sample_cutoff=data[sample_cutoff],
196
+ seed=_seed,
197
+ )
198
+
199
+ sig = interface.to_signal(zv).cpu()
200
+ print("done")
201
+
202
+
203
+
204
+ sig.write(out_dir / "output.wav")
205
+
206
+ if return_mask:
207
+ mask = interface.to_signal(mask_z).cpu()
208
+ mask.write(out_dir / "mask.wav")
209
+ return sig.path_to_file, mask.path_to_file
210
+ else:
211
+ return sig.path_to_file
212
+
213
+ def vamp(data):
214
+ return _vamp(data, return_mask=True)
215
+
216
+ def api_vamp(data):
217
+ return _vamp(data, return_mask=False)
218
+
219
+ def save_vamp(data):
220
+ out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
221
+ out_dir.mkdir(parents=True, exist_ok=True)
222
+
223
+ sig_in = at.AudioSignal(data[input_audio])
224
+ sig_out = at.AudioSignal(data[output_audio])
225
+
226
+ sig_in.write(out_dir / "input.wav")
227
+ sig_out.write(out_dir / "output.wav")
228
+
229
+ _data = {
230
+ "masktemp": data[masktemp],
231
+ "sampletemp": data[sampletemp],
232
+ "top_p": data[top_p],
233
+ "prefix_s": data[prefix_s],
234
+ "suffix_s": data[suffix_s],
235
+ "rand_mask_intensity": data[rand_mask_intensity],
236
+ "num_steps": data[num_steps],
237
+ "notes": data[notes_text],
238
+ "periodic_period": data[periodic_p],
239
+ "periodic_width": data[periodic_w],
240
+ "n_conditioning_codebooks": data[n_conditioning_codebooks],
241
+ "use_coarse2fine": data[use_coarse2fine],
242
+ "stretch_factor": data[stretch_factor],
243
+ "seed": data[seed],
244
+ "samplecutoff": data[sample_cutoff],
245
+ }
246
+
247
+ # save with yaml
248
+ with open(out_dir / "data.yaml", "w") as f:
249
+ yaml.dump(_data, f)
250
+
251
+ import zipfile
252
+ zip_path = out_dir.with_suffix(".zip")
253
+ with zipfile.ZipFile(zip_path, "w") as zf:
254
+ for file in out_dir.iterdir():
255
+ zf.write(file, file.name)
256
+
257
+ return f"saved! your save code is {out_dir.stem}", zip_path
258
+
259
+
260
+ def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
261
+
262
+ out_dir = OUT_DIR / str(uuid.uuid4())
263
+ out_dir.mkdir()
264
+ sig = at.AudioSignal(_input_audio)
265
+ sig = interface.preprocess(sig)
266
+
267
+ z = interface.encode(sig)
268
+
269
+ # build the mask
270
+ mask = pmask.linear_random(z, 1.0)
271
+ if _beat_mask_width > 0:
272
+ beat_mask = interface.make_beat_mask(
273
+ sig,
274
+ after_beat_s=(_beat_mask_width/1000),
275
+ )
276
+ mask = pmask.mask_and(mask, beat_mask)
277
+
278
+ # save the mask as a txt file
279
+ zv, mask_z = interface.coarse_vamp(
280
+ z,
281
+ mask=mask,
282
+ sampling_temperature=_sampletemp,
283
+ return_mask=True,
284
+ gen_fn=interface.coarse.generate,
285
+ )
286
+
287
+
288
+ zv = interface.coarse_to_fine(
289
+ zv,
290
+ sampling_temperature=_sampletemp,
291
+ mask=mask,
292
+ )
293
+
294
+ sig = interface.to_signal(zv).cpu()
295
+ print("done")
296
+
297
+ sig.write(out_dir / "output.wav")
298
+
299
+ return sig.path_to_file
300
+
301
+ with gr.Blocks() as demo:
302
+
303
+ with gr.Row():
304
+ with gr.Column():
305
+ gr.Markdown("# VampNet Audio Vamping")
306
+ gr.Markdown("""## Description:
307
+ This is a demo of the VampNet, a generative audio model that transforms the input audio based on the chosen settings.
308
+ You can control the extent and nature of variation with a set of manual controls and presets.
309
+ Use this interface to experiment with different mask settings and explore the audio outputs.
310
+ """)
311
+
312
+ gr.Markdown("""
313
+ ## Instructions:
314
+ 1. You can start by uploading some audio, or by loading the example audio.
315
+ 2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
316
+ 3. Click the "generate (vamp)!!!" button to apply the vamp operation. Listen to the output audio.
317
+ 4. Optionally, you can add some notes and save the result.
318
+ 5. You can also use the output as the new input and continue experimenting!
319
+ """)
320
+ with gr.Row():
321
+ with gr.Column():
322
+
323
+
324
+ manual_audio_upload = gr.File(
325
+ label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
326
+ file_types=["audio"]
327
+ )
328
+ load_example_audio_button = gr.Button("or load example audio")
329
+
330
+ input_audio = gr.Audio(
331
+ label="input audio",
332
+ interactive=False,
333
+ type="filepath",
334
+ )
335
+
336
+ audio_mask = gr.Audio(
337
+ label="audio mask (listen to this to hear the mask hints)",
338
+ interactive=False,
339
+ type="filepath",
340
+ )
341
+
342
+ # connect widgets
343
+ load_example_audio_button.click(
344
+ fn=load_example_audio,
345
+ inputs=[],
346
+ outputs=[ input_audio]
347
+ )
348
+
349
+ manual_audio_upload.change(
350
+ fn=load_audio,
351
+ inputs=[manual_audio_upload],
352
+ outputs=[ input_audio]
353
+ )
354
+
355
+ # mask settings
356
+ with gr.Column():
357
+
358
+
359
+ presets = {
360
+ "unconditional": {
361
+ "periodic_p": 0,
362
+ "onset_mask_width": 0,
363
+ "beat_mask_width": 0,
364
+ "beat_mask_downbeats": False,
365
+ },
366
+ "slight periodic variation": {
367
+ "periodic_p": 5,
368
+ "onset_mask_width": 5,
369
+ "beat_mask_width": 0,
370
+ "beat_mask_downbeats": False,
371
+ },
372
+ "moderate periodic variation": {
373
+ "periodic_p": 13,
374
+ "onset_mask_width": 5,
375
+ "beat_mask_width": 0,
376
+ "beat_mask_downbeats": False,
377
+ },
378
+ "strong periodic variation": {
379
+ "periodic_p": 17,
380
+ "onset_mask_width": 5,
381
+ "beat_mask_width": 0,
382
+ "beat_mask_downbeats": False,
383
+ },
384
+ "very strong periodic variation": {
385
+ "periodic_p": 21,
386
+ "onset_mask_width": 5,
387
+ "beat_mask_width": 0,
388
+ "beat_mask_downbeats": False,
389
+ },
390
+ "beat-driven variation": {
391
+ "periodic_p": 0,
392
+ "onset_mask_width": 0,
393
+ "beat_mask_width": 50,
394
+ "beat_mask_downbeats": False,
395
+ },
396
+ "beat-driven variation (downbeats only)": {
397
+ "periodic_p": 0,
398
+ "onset_mask_width": 0,
399
+ "beat_mask_width": 50,
400
+ "beat_mask_downbeats": True,
401
+ },
402
+ "beat-driven variation (downbeats only, strong)": {
403
+ "periodic_p": 0,
404
+ "onset_mask_width": 0,
405
+ "beat_mask_width": 20,
406
+ "beat_mask_downbeats": True,
407
+ },
408
+ }
409
+
410
+ preset = gr.Dropdown(
411
+ label="preset",
412
+ choices=list(presets.keys()),
413
+ value="strong periodic variation",
414
+ )
415
+ load_preset_button = gr.Button("load_preset")
416
+
417
+ with gr.Accordion("manual controls", open=True):
418
+ periodic_p = gr.Slider(
419
+ label="periodic prompt (0 - unconditional, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
420
+ minimum=0,
421
+ maximum=128,
422
+ step=1,
423
+ value=3,
424
+ )
425
+
426
+
427
+ onset_mask_width = gr.Slider(
428
+ label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
429
+ minimum=0,
430
+ maximum=100,
431
+ step=1,
432
+ value=5,
433
+ )
434
+
435
+ beat_mask_width = gr.Slider(
436
+ label="beat prompt (ms)",
437
+ minimum=0,
438
+ maximum=200,
439
+ value=0,
440
+ )
441
+ beat_mask_downbeats = gr.Checkbox(
442
+ label="beat mask downbeats only?",
443
+ value=False
444
+ )
445
+
446
+ n_mask_codebooks = gr.Number(
447
+ label="first upper codebook level to mask",
448
+ value=9,
449
+ )
450
+
451
+
452
+ with gr.Accordion("extras ", open=False):
453
+ pitch_shift_amt = gr.Slider(
454
+ label="pitch shift amount (semitones)",
455
+ minimum=-12,
456
+ maximum=12,
457
+ step=1,
458
+ value=0,
459
+ )
460
+
461
+ rand_mask_intensity = gr.Slider(
462
+ label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
463
+ minimum=0.0,
464
+ maximum=1.0,
465
+ value=1.0
466
+ )
467
+
468
+ periodic_w = gr.Slider(
469
+ label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
470
+ minimum=1,
471
+ maximum=20,
472
+ step=1,
473
+ value=1,
474
+ )
475
+ n_conditioning_codebooks = gr.Number(
476
+ label="number of conditioning codebooks. probably 0",
477
+ value=0,
478
+ precision=0,
479
+ )
480
+
481
+ stretch_factor = gr.Slider(
482
+ label="time stretch factor",
483
+ minimum=0,
484
+ maximum=64,
485
+ step=1,
486
+ value=1,
487
+ )
488
+
489
+ preset_outputs = {
490
+ periodic_p,
491
+ onset_mask_width,
492
+ beat_mask_width,
493
+ beat_mask_downbeats,
494
+ }
495
+
496
+ def load_preset(_preset):
497
+ return tuple(presets[_preset].values())
498
+
499
+ load_preset_button.click(
500
+ fn=load_preset,
501
+ inputs=[preset],
502
+ outputs=preset_outputs
503
+ )
504
+
505
+
506
+ with gr.Accordion("prefix/suffix prompts", open=False):
507
+ prefix_s = gr.Slider(
508
+ label="prefix hint length (seconds)",
509
+ minimum=0.0,
510
+ maximum=10.0,
511
+ value=0.0
512
+ )
513
+ suffix_s = gr.Slider(
514
+ label="suffix hint length (seconds)",
515
+ minimum=0.0,
516
+ maximum=10.0,
517
+ value=0.0
518
+ )
519
+
520
+ masktemp = gr.Slider(
521
+ label="mask temperature",
522
+ minimum=0.0,
523
+ maximum=100.0,
524
+ value=1.5
525
+ )
526
+ sampletemp = gr.Slider(
527
+ label="sample temperature",
528
+ minimum=0.1,
529
+ maximum=10.0,
530
+ value=1.0,
531
+ step=0.001
532
+ )
533
+
534
+
535
+
536
+ with gr.Accordion("sampling settings", open=False):
537
+ top_p = gr.Slider(
538
+ label="top p (0.0 = off)",
539
+ minimum=0.0,
540
+ maximum=1.0,
541
+ value=0.9
542
+ )
543
+ typical_filtering = gr.Checkbox(
544
+ label="typical filtering ",
545
+ value=False
546
+ )
547
+ typical_mass = gr.Slider(
548
+ label="typical mass (should probably stay between 0.1 and 0.5)",
549
+ minimum=0.01,
550
+ maximum=0.99,
551
+ value=0.15
552
+ )
553
+ typical_min_tokens = gr.Slider(
554
+ label="typical min tokens (should probably stay between 1 and 256)",
555
+ minimum=1,
556
+ maximum=256,
557
+ step=1,
558
+ value=64
559
+ )
560
+ sample_cutoff = gr.Slider(
561
+ label="sample cutoff",
562
+ minimum=0.0,
563
+ maximum=1.0,
564
+ value=0.5,
565
+ step=0.01
566
+ )
567
+
568
+ use_coarse2fine = gr.Checkbox(
569
+ label="use coarse2fine",
570
+ value=True,
571
+ visible=False
572
+ )
573
+
574
+ num_steps = gr.Slider(
575
+ label="number of steps (should normally be between 12 and 36)",
576
+ minimum=1,
577
+ maximum=128,
578
+ step=1,
579
+ value=36
580
+ )
581
+
582
+ dropout = gr.Slider(
583
+ label="mask dropout",
584
+ minimum=0.0,
585
+ maximum=1.0,
586
+ step=0.01,
587
+ value=0.0
588
+ )
589
+
590
+
591
+ seed = gr.Number(
592
+ label="seed (0 for random)",
593
+ value=0,
594
+ precision=0,
595
+ )
596
+
597
+
598
+
599
+ # mask settings
600
+ with gr.Column():
601
+
602
+ model_choice = gr.Dropdown(
603
+ label="model choice",
604
+ choices=list(MODEL_CHOICES.keys()),
605
+ value="default",
606
+ visible=True
607
+ )
608
+
609
+ vamp_button = gr.Button("generate (vamp)!!!")
610
+ output_audio = gr.Audio(
611
+ label="output audio",
612
+ interactive=False,
613
+ type="filepath"
614
+ )
615
+
616
+ notes_text = gr.Textbox(
617
+ label="type any notes about the generated audio here",
618
+ value="",
619
+ interactive=True
620
+ )
621
+ save_button = gr.Button("save vamp")
622
+ download_file = gr.File(
623
+ label="vamp to download will appear here",
624
+ interactive=False
625
+ )
626
+ use_as_input_button = gr.Button("use output as input")
627
+
628
+ thank_you = gr.Markdown("")
629
+
630
+
631
+ _inputs = {
632
+ input_audio,
633
+ num_steps,
634
+ masktemp,
635
+ sampletemp,
636
+ top_p,
637
+ prefix_s, suffix_s,
638
+ rand_mask_intensity,
639
+ periodic_p, periodic_w,
640
+ n_conditioning_codebooks,
641
+ dropout,
642
+ use_coarse2fine,
643
+ stretch_factor,
644
+ onset_mask_width,
645
+ typical_filtering,
646
+ typical_mass,
647
+ typical_min_tokens,
648
+ beat_mask_width,
649
+ beat_mask_downbeats,
650
+ seed,
651
+ model_choice,
652
+ n_mask_codebooks,
653
+ pitch_shift_amt,
654
+ sample_cutoff
655
+ }
656
+
657
+ # connect widgets
658
+ vamp_button.click(
659
+ fn=vamp,
660
+ inputs=_inputs,
661
+ outputs=[output_audio, audio_mask],
662
+ )
663
+
664
+ api_vamp_button = gr.Button("api vamp", visible=False)
665
+ api_vamp_button.click(
666
+ fn=api_vamp,
667
+ inputs=_inputs,
668
+ outputs=[output_audio],
669
+ api_name="vamp"
670
+ )
671
+
672
+ use_as_input_button.click(
673
+ fn=lambda x: x,
674
+ inputs=[output_audio],
675
+ outputs=[input_audio]
676
+ )
677
+
678
+ save_button.click(
679
+ fn=save_vamp,
680
+ inputs=_inputs | {notes_text, output_audio},
681
+ outputs=[thank_you, download_file]
682
+ )
683
+
684
+ # harp stuff
685
+ harp_inputs = [
686
+ input_audio,
687
+ beat_mask_width,
688
+ sampletemp,
689
+ ]
690
+
691
+ build_endpoint(
692
+ inputs=harp_inputs,
693
+ output=output_audio,
694
+ process_fn=harp_vamp,
695
+ card=ModelCard(
696
+ name="vampnet",
697
+ description="Generate variations on music input, based on small prompts around the beat. NOTE: vampnet's has a maximum context length of 10 seconds. Please split all audio clips into 10 second chunks, or processing will result in an error. ",
698
+ author="Hugo Flores García",
699
+ tags=["music", "generative"]
700
+ ),
701
+ visible=False
702
+ )
703
+
704
+ demo.queue().launch()
assets/example.wav ADDED
Binary file (883 kB). View file
 
conf/c2f.yml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/vampnet.yml
3
+
4
+ VampNet.n_codebooks: 14
5
+ VampNet.n_conditioning_codebooks: 4
6
+
7
+ VampNet.embedding_dim: 1280
8
+ VampNet.n_layers: 16
9
+ VampNet.n_heads: 20
10
+
11
+ AudioDataset.duration: 3.0
12
+
13
+
14
+ AudioDataset.loudness_cutoff: -40.0
conf/interface.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Interface.coarse_ckpt: ./models/vampnet/coarse.pth
2
+ Interface.coarse2fine_ckpt: ./models/vampnet/c2f.pth
3
+ Interface.codec_ckpt: ./models/vampnet/codec.pth
4
+ Interface.coarse_chunk_size_s: 10
5
+ Interface.coarse2fine_chunk_size_s: 3
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
7
+
8
+ # AudioLoader.sources:
9
+ # - /media/CHONK/null
10
+
conf/lora/lora.yml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/vampnet.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioDataset.n_examples: 100000000
7
+ val/AudioDataset.n_examples: 500
8
+
9
+
10
+ NoamScheduler.warmup: 500
11
+
12
+ batch_size: 6
13
+ num_workers: 7
14
+ save_iters: [2000, 4000, 10000,20000, 40000, 100000]
15
+ sample_freq: 2000
16
+ val_freq: 1000
17
+
18
+ AdamW.lr: 0.0001
19
+
20
+ # let's us organize sound classes into folders and choose from those sound classes uniformly
21
+ AudioDataset.without_replacement: False
22
+ num_iters: 500000
conf/salad_bowl.yml ADDED
File without changes
conf/vampnet.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ codec_ckpt: ./models/vampnet/codec.pth
3
+ save_path: ckpt
4
+
5
+ num_iters: 1000000000
6
+ save_iters: [10000, 50000, 100000, 300000, 500000]
7
+ val_idx: [0,1,2,3,4,5,6,7,8,9]
8
+ sample_freq: 10000
9
+ val_freq: 1000
10
+
11
+ batch_size: 8
12
+ num_workers: 10
13
+
14
+ # Optimization
15
+ amp: false
16
+
17
+ CrossEntropyLoss.label_smoothing: 0.1
18
+
19
+ AdamW.lr: 0.001
20
+
21
+ NoamScheduler.factor: 2.0
22
+ NoamScheduler.warmup: 10000
23
+
24
+ VampNet.vocab_size: 1024
25
+ VampNet.n_codebooks: 4
26
+ VampNet.n_conditioning_codebooks: 0
27
+ VampNet.r_cond_dim: 0
28
+ VampNet.noise_mode: mask
29
+ VampNet.embedding_dim: 1280
30
+ VampNet.n_layers: 20
31
+ VampNet.n_heads: 20
32
+ VampNet.flash_attn: false
33
+ VampNet.dropout: 0.1
34
+
35
+ AudioLoader.relative_path: ""
36
+ AudioDataset.loudness_cutoff: -30.0
37
+ AudioDataset.without_replacement: true
38
+ AudioLoader.shuffle: true
39
+
40
+ AudioDataset.duration: 10.0
41
+
42
+ train/AudioDataset.n_examples: 10000000
43
+ train/AudioLoader.sources:
44
+ - /media/CHONK/hugo/spotdl/audio-train
45
+
46
+ val/AudioDataset.n_examples: 2000
47
+ val/AudioLoader.sources:
48
+ - /media/CHONK/hugo/spotdl/audio-val
49
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ argbind>=0.3.2
3
+ numpy==1.23
4
+ gradio
5
+ loralib
6
+ wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
7
+ lac @ git+https://github.com/hugofloresgarcia/lac.git
8
+ descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2
9
+ -e git+https://github.com/audacitorch/pyharp.git#egg=pyharp
10
+ torch_pitch_shift
scripts/exp/eval.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ from functools import partial
4
+
5
+ from frechet_audio_distance import FrechetAudioDistance
6
+ import pandas
7
+ import argbind
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ import audiotools
12
+ from audiotools import AudioSignal
13
+
14
+ @argbind.bind(without_prefix=True)
15
+ def eval(
16
+ exp_dir: str = None,
17
+ baseline_key: str = "baseline",
18
+ audio_ext: str = ".wav",
19
+ ):
20
+ assert exp_dir is not None
21
+ exp_dir = Path(exp_dir)
22
+ assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
23
+
24
+ # set up our metrics
25
+ # sisdr_loss = audiotools.metrics.distance.SISDRLoss()
26
+ # stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
27
+ mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
28
+ frechet = FrechetAudioDistance(
29
+ use_pca=False,
30
+ use_activation=False,
31
+ verbose=True,
32
+ audio_load_worker=4,
33
+ )
34
+ frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ # figure out what conditions we have
37
+ conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
38
+
39
+ assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
40
+ conditions.remove(baseline_key)
41
+
42
+ print(f"Found {len(conditions)} conditions in {exp_dir}")
43
+ print(f"conditions: {conditions}")
44
+
45
+ baseline_dir = exp_dir / baseline_key
46
+ baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
47
+
48
+ metrics = []
49
+ for condition in tqdm(conditions):
50
+ cond_dir = exp_dir / condition
51
+ cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
52
+
53
+ print(f"computing fad for {baseline_dir} and {cond_dir}")
54
+ frechet_score = frechet.score(baseline_dir, cond_dir)
55
+
56
+ # make sure we have the same number of files
57
+ num_files = min(len(baseline_files), len(cond_files))
58
+ baseline_files = baseline_files[:num_files]
59
+ cond_files = cond_files[:num_files]
60
+ assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
61
+
62
+ def process(baseline_file, cond_file):
63
+ # make sure the files match (same name)
64
+ assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
65
+
66
+ # load the files
67
+ baseline_sig = AudioSignal(str(baseline_file))
68
+ cond_sig = AudioSignal(str(cond_file))
69
+
70
+ cond_sig.resample(baseline_sig.sample_rate)
71
+ cond_sig.truncate_samples(baseline_sig.length)
72
+
73
+ # if our condition is inpainting, we need to trim the conditioning off
74
+ if "inpaint" in condition:
75
+ ctx_amt = float(condition.split("_")[-1])
76
+ ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
77
+ print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
78
+ cond_sig.trim(ctx_samples, ctx_samples)
79
+ baseline_sig.trim(ctx_samples, ctx_samples)
80
+
81
+ return {
82
+ # "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
83
+ # "stft": stft_loss(baseline_sig, cond_sig).item(),
84
+ "mel": mel_loss(baseline_sig, cond_sig).item(),
85
+ "frechet": frechet_score,
86
+ # "visqol": vsq,
87
+ "condition": condition,
88
+ "file": baseline_file.stem,
89
+ }
90
+
91
+ print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
92
+ metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))
93
+
94
+ metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
95
+
96
+
97
+ for mk in metric_keys:
98
+ stat = pandas.DataFrame(metrics)
99
+ stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
100
+ stat.to_csv(exp_dir / f"stats-{mk}.csv")
101
+
102
+ df = pandas.DataFrame(metrics)
103
+ df.to_csv(exp_dir / "metrics-all.csv", index=False)
104
+
105
+
106
+ if __name__ == "__main__":
107
+ args = argbind.parse_args()
108
+
109
+ with argbind.scope(args):
110
+ eval()
scripts/exp/experiment.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import random
3
+ from typing import List
4
+ import tempfile
5
+ import subprocess
6
+
7
+ import argbind
8
+ from tqdm import tqdm
9
+ import torch
10
+
11
+ from vampnet.interface import Interface
12
+ from vampnet import mask as pmask
13
+ import audiotools as at
14
+
15
+ Interface: Interface = argbind.bind(Interface)
16
+
17
+
18
+
19
+ def calculate_bitrate(
20
+ interface, num_codebooks,
21
+ downsample_factor
22
+ ):
23
+ bit_width = 10
24
+ sr = interface.codec.sample_rate
25
+ hop = interface.codec.hop_size
26
+ rate = (sr / hop) * ((bit_width * num_codebooks) / downsample_factor)
27
+ return rate
28
+
29
+ def baseline(sig, interface):
30
+ return interface.preprocess(sig)
31
+
32
+ def reconstructed(sig, interface):
33
+ return interface.to_signal(
34
+ interface.encode(sig)
35
+ )
36
+
37
+ def coarse2fine(sig, interface):
38
+ z = interface.encode(sig)
39
+ z = z[:, :interface.c2f.n_conditioning_codebooks, :]
40
+
41
+ z = interface.coarse_to_fine(z)
42
+ return interface.to_signal(z)
43
+
44
+ class CoarseCond:
45
+
46
+ def __init__(self, num_conditioning_codebooks, downsample_factor):
47
+ self.num_conditioning_codebooks = num_conditioning_codebooks
48
+ self.downsample_factor = downsample_factor
49
+
50
+ def __call__(self, sig, interface):
51
+ z = interface.encode(sig)
52
+ mask = pmask.full_mask(z)
53
+ mask = pmask.codebook_unmask(mask, self.num_conditioning_codebooks)
54
+ mask = pmask.periodic_mask(mask, self.downsample_factor)
55
+
56
+ zv = interface.coarse_vamp(z, mask)
57
+ zv = interface.coarse_to_fine(zv)
58
+ return interface.to_signal(zv)
59
+
60
+ def opus(sig, interface, bitrate=128):
61
+ sig = interface.preprocess(sig)
62
+
63
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
64
+ sig.write(f.name)
65
+
66
+ opus_name = Path(f.name).with_suffix(".opus")
67
+ # convert to opus
68
+ cmd = [
69
+ "ffmpeg", "-y", "-i", f.name,
70
+ "-c:a", "libopus",
71
+ "-b:a", f"{bitrate}",
72
+ opus_name
73
+ ]
74
+ subprocess.run(cmd, check=True)
75
+
76
+ # convert back to wav
77
+ output_name = Path(f"{f.name}-opus").with_suffix(".wav")
78
+ cmd = [
79
+ "ffmpeg", "-y", "-i", opus_name,
80
+ output_name
81
+ ]
82
+
83
+ subprocess.run(cmd, check=True)
84
+
85
+ sig = at.AudioSignal(
86
+ output_name,
87
+ sample_rate=sig.sample_rate
88
+ )
89
+ return sig
90
+
91
+ def mask_ratio_1_step(ratio=1.0):
92
+ def wrapper(sig, interface):
93
+ z = interface.encode(sig)
94
+ mask = pmask.linear_random(z, ratio)
95
+ zv = interface.coarse_vamp(
96
+ z,
97
+ mask,
98
+ sampling_steps=1,
99
+ )
100
+
101
+ return interface.to_signal(zv)
102
+ return wrapper
103
+
104
+ def num_sampling_steps(num_steps=1):
105
+ def wrapper(sig, interface: Interface):
106
+ z = interface.encode(sig)
107
+ mask = pmask.periodic_mask(z, 16)
108
+ zv = interface.coarse_vamp(
109
+ z,
110
+ mask,
111
+ sampling_steps=num_steps,
112
+ )
113
+
114
+ zv = interface.coarse_to_fine(zv)
115
+ return interface.to_signal(zv)
116
+ return wrapper
117
+
118
+ def beat_mask(ctx_time):
119
+ def wrapper(sig, interface):
120
+ beat_mask = interface.make_beat_mask(
121
+ sig,
122
+ before_beat_s=ctx_time/2,
123
+ after_beat_s=ctx_time/2,
124
+ invert=True
125
+ )
126
+
127
+ z = interface.encode(sig)
128
+
129
+ zv = interface.coarse_vamp(
130
+ z, beat_mask
131
+ )
132
+
133
+ zv = interface.coarse_to_fine(zv)
134
+ return interface.to_signal(zv)
135
+ return wrapper
136
+
137
+ def inpaint(ctx_time):
138
+ def wrapper(sig, interface: Interface):
139
+ z = interface.encode(sig)
140
+ mask = pmask.inpaint(z, interface.s2t(ctx_time), interface.s2t(ctx_time))
141
+
142
+ zv = interface.coarse_vamp(z, mask)
143
+ zv = interface.coarse_to_fine(zv)
144
+
145
+ return interface.to_signal(zv)
146
+ return wrapper
147
+
148
+ def token_noise(noise_amt):
149
+ def wrapper(sig, interface: Interface):
150
+ z = interface.encode(sig)
151
+ mask = pmask.random(z, noise_amt)
152
+ z = torch.where(
153
+ mask,
154
+ torch.randint_like(z, 0, interface.coarse.vocab_size),
155
+ z
156
+ )
157
+ return interface.to_signal(z)
158
+ return wrapper
159
+
160
+ EXP_REGISTRY = {}
161
+
162
+ EXP_REGISTRY["gen-compression"] = {
163
+ "baseline": baseline,
164
+ "reconstructed": reconstructed,
165
+ "coarse2fine": coarse2fine,
166
+ **{
167
+ f"{n}_codebooks_downsampled_{x}x": CoarseCond(num_conditioning_codebooks=n, downsample_factor=x)
168
+ for (n, x) in (
169
+ (1, 1), # 1 codebook, no downsampling
170
+ (4, 4), # 4 codebooks, downsampled 4x
171
+ (4, 16), # 4 codebooks, downsampled 16x
172
+ (4, 32), # 4 codebooks, downsampled 16x
173
+ )
174
+ },
175
+ **{
176
+ f"token_noise_{x}": mask_ratio_1_step(ratio=x)
177
+ for x in [0.25, 0.5, 0.75]
178
+ },
179
+
180
+ }
181
+
182
+
183
+ EXP_REGISTRY["sampling-steps"] = {
184
+ # "codec": reconstructed,
185
+ **{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 36, 64, 72]},
186
+ }
187
+
188
+
189
+ EXP_REGISTRY["musical-sampling"] = {
190
+ **{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
191
+ **{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
192
+ }
193
+
194
+ @argbind.bind(without_prefix=True)
195
+ def main(
196
+ sources=[
197
+ "/media/CHONK/hugo/spotdl/val",
198
+ ],
199
+ output_dir: str = "./samples",
200
+ max_excerpts: int = 2000,
201
+ exp_type: str = "gen-compression",
202
+ seed: int = 0,
203
+ ext: str = [".mp3"],
204
+ ):
205
+ at.util.seed(seed)
206
+ interface = Interface()
207
+
208
+ output_dir = Path(output_dir)
209
+ output_dir.mkdir(exist_ok=True, parents=True)
210
+
211
+ from audiotools.data.datasets import AudioLoader, AudioDataset
212
+
213
+ loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext)
214
+ dataset = AudioDataset(loader,
215
+ sample_rate=interface.codec.sample_rate,
216
+ duration=interface.coarse.chunk_size_s,
217
+ n_examples=max_excerpts,
218
+ without_replacement=True,
219
+ )
220
+
221
+ if exp_type in EXP_REGISTRY:
222
+ SAMPLE_CONDS = EXP_REGISTRY[exp_type]
223
+ else:
224
+ raise ValueError(f"Unknown exp_type {exp_type}")
225
+
226
+
227
+ indices = list(range(max_excerpts))
228
+ random.shuffle(indices)
229
+ for i in tqdm(indices):
230
+ # if all our files are already there, skip
231
+ done = []
232
+ for name in SAMPLE_CONDS:
233
+ o_dir = Path(output_dir) / name
234
+ done.append((o_dir / f"{i}.wav").exists())
235
+ if all(done):
236
+ continue
237
+
238
+ sig = dataset[i]["signal"]
239
+ results = {
240
+ name: cond(sig, interface).cpu()
241
+ for name, cond in SAMPLE_CONDS.items()
242
+ }
243
+
244
+ for name, sig in results.items():
245
+ o_dir = Path(output_dir) / name
246
+ o_dir.mkdir(exist_ok=True, parents=True)
247
+
248
+ sig.write(o_dir / f"{i}.wav")
249
+
250
+ if __name__ == "__main__":
251
+ args = argbind.parse_args()
252
+
253
+ with argbind.scope(args):
254
+ main()
scripts/exp/fine_tune.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argbind
2
+ from pathlib import Path
3
+ import yaml
4
+ from typing import List
5
+
6
+
7
+
8
+
9
+ """example output: (yaml)
10
+
11
+ """
12
+
13
+ @argbind.bind(without_prefix=True, positional=True)
14
+ def fine_tune(audio_files_or_folders: List[str], name: str):
15
+
16
+ conf_dir = Path("conf")
17
+ assert conf_dir.exists(), "conf directory not found. are you in the vampnet directory?"
18
+
19
+ conf_dir = conf_dir / "generated"
20
+ conf_dir.mkdir(exist_ok=True)
21
+
22
+ finetune_dir = conf_dir / name
23
+ finetune_dir.mkdir(exist_ok=True)
24
+
25
+ finetune_c2f_conf = {
26
+ "$include": ["conf/lora/lora.yml"],
27
+ "fine_tune": True,
28
+ "train/AudioLoader.sources": audio_files_or_folders,
29
+ "val/AudioLoader.sources": audio_files_or_folders,
30
+ "VampNet.n_codebooks": 14,
31
+ "VampNet.n_conditioning_codebooks": 4,
32
+ "VampNet.embedding_dim": 1280,
33
+ "VampNet.n_layers": 16,
34
+ "VampNet.n_heads": 20,
35
+ "AudioDataset.duration": 3.0,
36
+ "AudioDataset.loudness_cutoff": -40.0,
37
+ "save_path": str(finetune_dir / "ckpt/c2f"),
38
+ "fine_tune_checkpoint": "./models/vampnet/c2f.pth"
39
+ }
40
+
41
+ finetune_coarse_conf = {
42
+ "$include": ["conf/lora/lora.yml"],
43
+ "fine_tune": True,
44
+ "train/AudioLoader.sources": audio_files_or_folders,
45
+ "val/AudioLoader.sources": audio_files_or_folders,
46
+ "save_path": str(finetune_dir / "ckpt/coarse"),
47
+ "fine_tune_checkpoint": "./models/vampnet/coarse.pth"
48
+ }
49
+
50
+ interface_conf = {
51
+ "Interface.coarse_ckpt": f"{finetune_dir}/ckpt/coarse/latest/vampnet/weights.pth",
52
+
53
+ "Interface.coarse2fine_ckpt": f"{finetune_dir}/ckpt/c2f/latest/vampnet/weights.pth",
54
+ "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
55
+
56
+ "Interface.codec_ckpt": "./models/vampnet/codec.pth",
57
+ "AudioLoader.sources": [audio_files_or_folders],
58
+ }
59
+
60
+ # save the confs
61
+ with open(finetune_dir / "c2f.yml", "w") as f:
62
+ yaml.dump(finetune_c2f_conf, f)
63
+
64
+ with open(finetune_dir / "coarse.yml", "w") as f:
65
+ yaml.dump(finetune_coarse_conf, f)
66
+
67
+ with open(finetune_dir / "interface.yml", "w") as f:
68
+ yaml.dump(interface_conf, f)
69
+
70
+
71
+ print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml` ")
72
+
73
+ if __name__ == "__main__":
74
+ args = argbind.parse_args()
75
+
76
+ with argbind.scope(args):
77
+ fine_tune()
78
+
79
+
80
+
81
+
scripts/exp/train.py ADDED
@@ -0,0 +1,686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ from dataclasses import dataclass
7
+
8
+ import argbind
9
+ import audiotools as at
10
+ import torch
11
+ import torch.nn as nn
12
+ from audiotools import AudioSignal
13
+ from audiotools.data import transforms as tfm
14
+ from einops import rearrange
15
+ from rich import pretty
16
+ from rich.traceback import install
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ import vampnet
20
+ from vampnet.modules.transformer import VampNet
21
+ from vampnet.util import codebook_unflatten, codebook_flatten
22
+ from vampnet import mask as pmask
23
+ # from dac.model.dac import DAC
24
+ from lac.model.lac import LAC as DAC
25
+
26
+ from audiotools.ml.decorators import (
27
+ timer, Tracker, when
28
+ )
29
+
30
+ import loralib as lora
31
+
32
+ import torch._dynamo
33
+ torch._dynamo.config.verbose=True
34
+
35
+
36
+ # Enable cudnn autotuner to speed up training
37
+ # (can be altered by the funcs.seed function)
38
+ torch.backends.cudnn.benchmark = bool(int(os.getenv("CUDNN_BENCHMARK", 1)))
39
+ # Uncomment to trade memory for speed.
40
+
41
+ # Install to make things look nice
42
+ warnings.filterwarnings("ignore", category=UserWarning)
43
+ pretty.install()
44
+ install()
45
+
46
+ # optim
47
+ Accelerator = argbind.bind(at.ml.Accelerator, without_prefix=True)
48
+ CrossEntropyLoss = argbind.bind(nn.CrossEntropyLoss)
49
+ AdamW = argbind.bind(torch.optim.AdamW)
50
+ NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler)
51
+
52
+ # transforms
53
+ filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
54
+ "BaseTransform",
55
+ "Compose",
56
+ "Choose",
57
+ ]
58
+
59
+ # model
60
+ VampNet = argbind.bind(VampNet)
61
+
62
+
63
+ # data
64
+ AudioLoader = argbind.bind(at.datasets.AudioLoader)
65
+ AudioDataset = argbind.bind(at.datasets.AudioDataset, "train", "val")
66
+
67
+ IGNORE_INDEX = -100
68
+
69
+
70
+ @argbind.bind("train", "val", without_prefix=True)
71
+ def build_transform():
72
+ transform = tfm.Compose(
73
+ tfm.VolumeNorm(("const", -24)),
74
+ # tfm.PitchShift(),
75
+ tfm.RescaleAudio(),
76
+ )
77
+ return transform
78
+
79
+
80
+ @torch.no_grad()
81
+ def apply_transform(transform_fn, batch):
82
+ sig: AudioSignal = batch["signal"]
83
+ kwargs = batch["transform_args"]
84
+
85
+ sig: AudioSignal = transform_fn(sig.clone(), **kwargs)
86
+ return sig
87
+
88
+
89
+ def build_datasets(args, sample_rate: int):
90
+ with argbind.scope(args, "train"):
91
+ train_data = AudioDataset(
92
+ AudioLoader(), sample_rate, transform=build_transform()
93
+ )
94
+ with argbind.scope(args, "val"):
95
+ val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
96
+ return train_data, val_data
97
+
98
+
99
+ def rand_float(shape, low, high, rng):
100
+ return rng.draw(shape)[:, 0] * (high - low) + low
101
+
102
+
103
+ def flip_coin(shape, p, rng):
104
+ return rng.draw(shape)[:, 0] < p
105
+
106
+
107
+ def num_params_hook(o, p):
108
+ return o + f" {p/1e6:<.3f}M params."
109
+
110
+
111
+ def add_num_params_repr_hook(model):
112
+ import numpy as np
113
+ from functools import partial
114
+
115
+ for n, m in model.named_modules():
116
+ o = m.extra_repr()
117
+ p = sum([np.prod(p.size()) for p in m.parameters()])
118
+
119
+ setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
120
+
121
+
122
+ def accuracy(
123
+ preds: torch.Tensor,
124
+ target: torch.Tensor,
125
+ top_k: int = 1,
126
+ ignore_index: Optional[int] = None,
127
+ ) -> torch.Tensor:
128
+ # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
129
+ preds = rearrange(preds, "b p s -> (b s) p")
130
+ target = rearrange(target, "b s -> (b s)")
131
+
132
+ # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
133
+ if ignore_index is not None:
134
+ # Create a mask for the ignored index
135
+ mask = target != ignore_index
136
+ # Apply the mask to the target and predictions
137
+ preds = preds[mask]
138
+ target = target[mask]
139
+
140
+ # Get the top-k predicted classes and their indices
141
+ _, pred_indices = torch.topk(preds, k=top_k, dim=-1)
142
+
143
+ # Determine if the true target is in the top-k predicted classes
144
+ correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
145
+
146
+ # Calculate the accuracy
147
+ accuracy = torch.mean(correct.float())
148
+
149
+ return accuracy
150
+
151
+ def _metrics(z_hat, r, target, flat_mask, output):
152
+ for r_range in [(0, 0.5), (0.5, 1.0)]:
153
+ unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
154
+ masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
155
+
156
+ assert target.shape[0] == r.shape[0]
157
+ # grab the indices of the r values that are in the range
158
+ r_idx = (r >= r_range[0]) & (r < r_range[1])
159
+
160
+ # grab the target and z_hat values that are in the range
161
+ r_unmasked_target = unmasked_target[r_idx]
162
+ r_masked_target = masked_target[r_idx]
163
+ r_z_hat = z_hat[r_idx]
164
+
165
+ for topk in (1, 25):
166
+ s, e = r_range
167
+ tag = f"accuracy-{s}-{e}/top{topk}"
168
+
169
+ output[f"{tag}/unmasked"] = accuracy(
170
+ preds=r_z_hat,
171
+ target=r_unmasked_target,
172
+ ignore_index=IGNORE_INDEX,
173
+ top_k=topk,
174
+ )
175
+ output[f"{tag}/masked"] = accuracy(
176
+ preds=r_z_hat,
177
+ target=r_masked_target,
178
+ ignore_index=IGNORE_INDEX,
179
+ top_k=topk,
180
+ )
181
+
182
+
183
+ @dataclass
184
+ class State:
185
+ model: VampNet
186
+ codec: DAC
187
+
188
+ optimizer: AdamW
189
+ scheduler: NoamScheduler
190
+ criterion: CrossEntropyLoss
191
+ grad_clip_val: float
192
+
193
+ rng: torch.quasirandom.SobolEngine
194
+
195
+ train_data: AudioDataset
196
+ val_data: AudioDataset
197
+
198
+ tracker: Tracker
199
+
200
+
201
+ @timer()
202
+ def train_loop(state: State, batch: dict, accel: Accelerator):
203
+ state.model.train()
204
+ batch = at.util.prepare_batch(batch, accel.device)
205
+ signal = apply_transform(state.train_data.transform, batch)
206
+
207
+ output = {}
208
+ vn = accel.unwrap(state.model)
209
+ with accel.autocast():
210
+ with torch.inference_mode():
211
+ state.codec.to(accel.device)
212
+ z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
213
+ z = z[:, : vn.n_codebooks, :]
214
+
215
+ n_batch = z.shape[0]
216
+ r = state.rng.draw(n_batch)[:, 0].to(accel.device)
217
+
218
+ mask = pmask.random(z, r)
219
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
220
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
221
+
222
+ z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
223
+
224
+ dtype = torch.bfloat16 if accel.amp else None
225
+ with accel.autocast(dtype=dtype):
226
+ z_hat = state.model(z_mask_latent)
227
+
228
+ target = codebook_flatten(
229
+ z[:, vn.n_conditioning_codebooks :, :],
230
+ )
231
+
232
+ flat_mask = codebook_flatten(
233
+ mask[:, vn.n_conditioning_codebooks :, :],
234
+ )
235
+
236
+ # replace target with ignore index for masked tokens
237
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
238
+ output["loss"] = state.criterion(z_hat, t_masked)
239
+
240
+ _metrics(
241
+ r=r,
242
+ z_hat=z_hat,
243
+ target=target,
244
+ flat_mask=flat_mask,
245
+ output=output,
246
+ )
247
+
248
+
249
+ accel.backward(output["loss"])
250
+
251
+ output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
252
+ output["other/batch_size"] = z.shape[0]
253
+
254
+
255
+ accel.scaler.unscale_(state.optimizer)
256
+ output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
257
+ state.model.parameters(), state.grad_clip_val
258
+ )
259
+
260
+ accel.step(state.optimizer)
261
+ state.optimizer.zero_grad()
262
+
263
+ state.scheduler.step()
264
+ accel.update()
265
+
266
+
267
+ return {k: v for k, v in sorted(output.items())}
268
+
269
+
270
+ @timer()
271
+ @torch.no_grad()
272
+ def val_loop(state: State, batch: dict, accel: Accelerator):
273
+ state.model.eval()
274
+ state.codec.eval()
275
+ batch = at.util.prepare_batch(batch, accel.device)
276
+ signal = apply_transform(state.val_data.transform, batch)
277
+
278
+ vn = accel.unwrap(state.model)
279
+ z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
280
+ z = z[:, : vn.n_codebooks, :]
281
+
282
+ n_batch = z.shape[0]
283
+ r = state.rng.draw(n_batch)[:, 0].to(accel.device)
284
+
285
+ mask = pmask.random(z, r)
286
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
287
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
288
+
289
+ z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
290
+
291
+ z_hat = state.model(z_mask_latent)
292
+
293
+ target = codebook_flatten(
294
+ z[:, vn.n_conditioning_codebooks :, :],
295
+ )
296
+
297
+ flat_mask = codebook_flatten(
298
+ mask[:, vn.n_conditioning_codebooks :, :]
299
+ )
300
+
301
+ output = {}
302
+ # replace target with ignore index for masked tokens
303
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
304
+ output["loss"] = state.criterion(z_hat, t_masked)
305
+
306
+ _metrics(
307
+ r=r,
308
+ z_hat=z_hat,
309
+ target=target,
310
+ flat_mask=flat_mask,
311
+ output=output,
312
+ )
313
+
314
+ return output
315
+
316
+
317
+ def validate(state, val_dataloader, accel):
318
+ for batch in val_dataloader:
319
+ output = val_loop(state, batch, accel)
320
+ # Consolidate state dicts if using ZeroRedundancyOptimizer
321
+ if hasattr(state.optimizer, "consolidate_state_dict"):
322
+ state.optimizer.consolidate_state_dict()
323
+ return output
324
+
325
+
326
+ def checkpoint(state, save_iters, save_path, fine_tune):
327
+ if accel.local_rank != 0:
328
+ state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
329
+ return
330
+
331
+ metadata = {"logs": dict(state.tracker.history)}
332
+
333
+ tags = ["latest"]
334
+ state.tracker.print(f"Saving to {str(Path('.').absolute())}")
335
+
336
+ if state.tracker.step in save_iters:
337
+ tags.append(f"{state.tracker.step // 1000}k")
338
+
339
+ if state.tracker.is_best("val", "loss"):
340
+ state.tracker.print(f"Best model so far")
341
+ tags.append("best")
342
+
343
+ if fine_tune:
344
+ for tag in tags:
345
+ # save the lora model
346
+ (Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
347
+ torch.save(
348
+ lora.lora_state_dict(accel.unwrap(state.model)),
349
+ f"{save_path}/{tag}/lora.pth"
350
+ )
351
+
352
+ for tag in tags:
353
+ model_extra = {
354
+ "optimizer.pth": state.optimizer.state_dict(),
355
+ "scheduler.pth": state.scheduler.state_dict(),
356
+ "tracker.pth": state.tracker.state_dict(),
357
+ "metadata.pth": metadata,
358
+ }
359
+
360
+ accel.unwrap(state.model).metadata = metadata
361
+ accel.unwrap(state.model).save_to_folder(
362
+ f"{save_path}/{tag}", model_extra, package=False
363
+ )
364
+
365
+
366
+ def save_sampled(state, z, writer):
367
+ num_samples = z.shape[0]
368
+
369
+ for i in range(num_samples):
370
+ sampled = accel.unwrap(state.model).generate(
371
+ codec=state.codec,
372
+ time_steps=z.shape[-1],
373
+ start_tokens=z[i : i + 1],
374
+ )
375
+ sampled.cpu().write_audio_to_tb(
376
+ f"sampled/{i}",
377
+ writer,
378
+ step=state.tracker.step,
379
+ plot_fn=None,
380
+ )
381
+
382
+
383
+ def save_imputation(state, z, val_idx, writer):
384
+ n_prefix = int(z.shape[-1] * 0.25)
385
+ n_suffix = int(z.shape[-1] * 0.25)
386
+
387
+ vn = accel.unwrap(state.model)
388
+
389
+ mask = pmask.inpaint(z, n_prefix, n_suffix)
390
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
391
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
392
+
393
+ imputed_noisy = vn.to_signal(z_mask, state.codec)
394
+ imputed_true = vn.to_signal(z, state.codec)
395
+
396
+ imputed = []
397
+ for i in range(len(z)):
398
+ imputed.append(
399
+ vn.generate(
400
+ codec=state.codec,
401
+ time_steps=z.shape[-1],
402
+ start_tokens=z[i][None, ...],
403
+ mask=mask[i][None, ...],
404
+ )
405
+ )
406
+ imputed = AudioSignal.batch(imputed)
407
+
408
+ for i in range(len(val_idx)):
409
+ imputed_noisy[i].cpu().write_audio_to_tb(
410
+ f"inpainted_prompt/{i}",
411
+ writer,
412
+ step=state.tracker.step,
413
+ plot_fn=None,
414
+ )
415
+ imputed[i].cpu().write_audio_to_tb(
416
+ f"inpainted_middle/{i}",
417
+ writer,
418
+ step=state.tracker.step,
419
+ plot_fn=None,
420
+ )
421
+ imputed_true[i].cpu().write_audio_to_tb(
422
+ f"reconstructed/{i}",
423
+ writer,
424
+ step=state.tracker.step,
425
+ plot_fn=None,
426
+ )
427
+
428
+
429
+ @torch.no_grad()
430
+ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
431
+ state.model.eval()
432
+ state.codec.eval()
433
+ vn = accel.unwrap(state.model)
434
+
435
+ batch = [state.val_data[i] for i in val_idx]
436
+ batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device)
437
+
438
+ signal = apply_transform(state.val_data.transform, batch)
439
+
440
+ z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
441
+ z = z[:, : vn.n_codebooks, :]
442
+
443
+ r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
444
+
445
+
446
+ mask = pmask.random(z, r)
447
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
448
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
449
+
450
+ z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
451
+
452
+ z_hat = state.model(z_mask_latent)
453
+
454
+ z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
455
+ z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
456
+ z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
457
+
458
+ generated = vn.to_signal(z_pred, state.codec)
459
+ reconstructed = vn.to_signal(z, state.codec)
460
+ masked = vn.to_signal(z_mask.squeeze(1), state.codec)
461
+
462
+ for i in range(generated.batch_size):
463
+ audio_dict = {
464
+ "original": signal[i],
465
+ "masked": masked[i],
466
+ "generated": generated[i],
467
+ "reconstructed": reconstructed[i],
468
+ }
469
+ for k, v in audio_dict.items():
470
+ v.cpu().write_audio_to_tb(
471
+ f"onestep/_{i}.r={r[i]:0.2f}/{k}",
472
+ writer,
473
+ step=state.tracker.step,
474
+ plot_fn=None,
475
+ )
476
+
477
+ save_sampled(state=state, z=z, writer=writer)
478
+ save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
479
+
480
+
481
+
482
+ @argbind.bind(without_prefix=True)
483
+ def load(
484
+ args,
485
+ accel: at.ml.Accelerator,
486
+ tracker: Tracker,
487
+ save_path: str,
488
+ resume: bool = False,
489
+ tag: str = "latest",
490
+ fine_tune_checkpoint: Optional[str] = None,
491
+ grad_clip_val: float = 5.0,
492
+ ) -> State:
493
+ codec = DAC.load(args["codec_ckpt"], map_location="cpu")
494
+ codec.eval()
495
+
496
+ model, v_extra = None, {}
497
+
498
+ if args["fine_tune"]:
499
+ assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
500
+ model = torch.compile(
501
+ VampNet.load(location=Path(fine_tune_checkpoint),
502
+ map_location="cpu",
503
+ )
504
+ )
505
+
506
+ if resume:
507
+ kwargs = {
508
+ "folder": f"{save_path}/{tag}",
509
+ "map_location": "cpu",
510
+ "package": False,
511
+ }
512
+ tracker.print(f"Loading checkpoint from {kwargs['folder']}")
513
+ if (Path(kwargs["folder"]) / "vampnet").exists():
514
+ model, v_extra = VampNet.load_from_folder(**kwargs)
515
+ else:
516
+ raise ValueError(
517
+ f"Could not find a VampNet checkpoint in {kwargs['folder']}"
518
+ )
519
+
520
+
521
+
522
+
523
+ model = torch.compile(VampNet()) if model is None else model
524
+ model = accel.prepare_model(model)
525
+
526
+ # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
527
+ assert (
528
+ accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size
529
+ )
530
+
531
+
532
+ if accel.world_size > 1:
533
+ from torch.distributed.optim import ZeroRedundancyOptimizer
534
+ optimizer = ZeroRedundancyOptimizer(model.parameters(), AdamW)
535
+ print(f"OPTIMIZER LR is {optimizer.param_groups[0]['lr']}")
536
+ else:
537
+ optimizer = AdamW(model.parameters())
538
+
539
+ scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
540
+ scheduler.step()
541
+
542
+ if "optimizer.pth" in v_extra:
543
+ optimizer.load_state_dict(v_extra["optimizer.pth"])
544
+ scheduler.load_state_dict(v_extra["scheduler.pth"])
545
+ if "tracker.pth" in v_extra:
546
+ tracker.load_state_dict(v_extra["tracker.pth"])
547
+
548
+ criterion = CrossEntropyLoss()
549
+
550
+ sample_rate = codec.sample_rate
551
+
552
+ # a better rng for sampling from our schedule
553
+ rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
554
+
555
+ # log a model summary w/ num params
556
+ if accel.local_rank == 0:
557
+ add_num_params_repr_hook(accel.unwrap(model))
558
+ with open(f"{save_path}/model.txt", "w") as f:
559
+ f.write(repr(accel.unwrap(model)))
560
+
561
+ # load the datasets
562
+ train_data, val_data = build_datasets(args, sample_rate)
563
+
564
+ return State(
565
+ tracker=tracker,
566
+ model=model,
567
+ codec=codec,
568
+ optimizer=optimizer,
569
+ scheduler=scheduler,
570
+ criterion=criterion,
571
+ rng=rng,
572
+ train_data=train_data,
573
+ val_data=val_data,
574
+ grad_clip_val=grad_clip_val,
575
+ )
576
+
577
+
578
+ @argbind.bind(without_prefix=True)
579
+ def train(
580
+ args,
581
+ accel: at.ml.Accelerator,
582
+ seed: int = 0,
583
+ codec_ckpt: str = None,
584
+ save_path: str = "ckpt",
585
+ num_iters: int = int(1000e6),
586
+ save_iters: list = [10000, 50000, 100000, 300000, 500000,],
587
+ sample_freq: int = 10000,
588
+ val_freq: int = 1000,
589
+ batch_size: int = 12,
590
+ val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
591
+ num_workers: int = 10,
592
+ fine_tune: bool = False,
593
+ ):
594
+ assert codec_ckpt is not None, "codec_ckpt is required"
595
+
596
+ seed = seed + accel.local_rank
597
+ at.util.seed(seed)
598
+ writer = None
599
+
600
+ if accel.local_rank == 0:
601
+ writer = SummaryWriter(log_dir=f"{save_path}/logs/")
602
+ argbind.dump_args(args, f"{save_path}/args.yml")
603
+
604
+ tracker = Tracker(
605
+ writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank
606
+ )
607
+
608
+ # load the codec model
609
+ state: State = load(
610
+ args=args,
611
+ accel=accel,
612
+ tracker=tracker,
613
+ save_path=save_path)
614
+ print("initialized state.")
615
+
616
+ train_dataloader = accel.prepare_dataloader(
617
+ state.train_data,
618
+ start_idx=state.tracker.step * batch_size,
619
+ num_workers=num_workers,
620
+ batch_size=batch_size,
621
+ collate_fn=state.train_data.collate,
622
+ )
623
+ val_dataloader = accel.prepare_dataloader(
624
+ state.val_data,
625
+ start_idx=0,
626
+ num_workers=num_workers,
627
+ batch_size=batch_size,
628
+ collate_fn=state.val_data.collate,
629
+ persistent_workers=num_workers > 0,
630
+ )
631
+ print("initialized dataloader.")
632
+
633
+
634
+
635
+ if fine_tune:
636
+ lora.mark_only_lora_as_trainable(state.model)
637
+ print("marked only lora as trainable.")
638
+
639
+ # Wrap the functions so that they neatly track in TensorBoard + progress bars
640
+ # and only run when specific conditions are met.
641
+ global train_loop, val_loop, validate, save_samples, checkpoint
642
+
643
+ train_loop = tracker.log("train", "value", history=False)(
644
+ tracker.track("train", num_iters, completed=state.tracker.step)(train_loop)
645
+ )
646
+ val_loop = tracker.track("val", len(val_dataloader))(val_loop)
647
+ validate = tracker.log("val", "mean")(validate)
648
+
649
+ save_samples = when(lambda: accel.local_rank == 0)(save_samples)
650
+ checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
651
+
652
+ print("starting training loop.")
653
+ with tracker.live:
654
+ for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
655
+ train_loop(state, batch, accel)
656
+
657
+ last_iter = (
658
+ tracker.step == num_iters - 1 if num_iters is not None else False
659
+ )
660
+
661
+ if tracker.step % sample_freq == 0 or last_iter:
662
+ save_samples(state, val_idx, writer)
663
+
664
+ if tracker.step % val_freq == 0 or last_iter:
665
+ validate(state, val_dataloader, accel)
666
+ checkpoint(
667
+ state=state,
668
+ save_iters=save_iters,
669
+ save_path=save_path,
670
+ fine_tune=fine_tune)
671
+
672
+ # Reset validation progress bar, print summary since last validation.
673
+ tracker.done("val", f"Iteration {tracker.step}")
674
+
675
+ if last_iter:
676
+ break
677
+
678
+
679
+ if __name__ == "__main__":
680
+ args = argbind.parse_args()
681
+ args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
682
+ with argbind.scope(args):
683
+ with Accelerator() as accel:
684
+ if accel.local_rank != 0:
685
+ sys.tracebacklimit = 0
686
+ train(args, accel)
scripts/utils/README.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scripts
2
+
3
+ ## process_zip.py
4
+
5
+ Some requirements that may not be installed in the docker image:
6
+ * argbind
7
+ * wav2wav (pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git or `pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git@<branchname>`)
8
+
9
+ ### zip folder structure
10
+
11
+ The zip folder should have the following internal structure:
12
+
13
+ ```
14
+ base_folder/
15
+ test_case_1/
16
+ before.wav
17
+ test_case_2/
18
+ before.wav
19
+ ...
20
+ test_case_n/
21
+ before.wav
22
+ ```
23
+
24
+ Note: There can be issues with the output zip if the input zip folder structure is too deep or too shallow. IF you want/need to use a zip file with a different folder structure, adjust this:
25
+ https://github.com/descriptinc/lyrebird-wav2wav/blob/136c923ce19df03876a515ca0ed83854710cfa30/scripts/utils/process_zip.py#L28
26
+
27
+ ### Execution
28
+ `python process_zip.py <path/to/zip> -tag <string>`
scripts/utils/gtzan_embeddings.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO: train a linear probe
3
+ usage:
4
+ python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_gtzan /path/to/gtzan/genres_original --output_dir /path/to/output
5
+ """
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ import audiotools as at
10
+ from audiotools import AudioSignal
11
+ import argbind
12
+ import torch
13
+ import numpy as np
14
+ import zipfile
15
+ import json
16
+
17
+ from vampnet.interface import Interface
18
+ import tqdm
19
+
20
+ # bind the Interface to argbind
21
+ Interface = argbind.bind(Interface)
22
+
23
+ DEBUG = False
24
+
25
+ def smart_plotly_export(fig, save_path):
26
+ img_format = save_path.split('.')[-1]
27
+ if img_format == 'html':
28
+ fig.write_html(save_path)
29
+ elif img_format == 'bytes':
30
+ return fig.to_image(format='png')
31
+ #TODO: come back and make this prettier
32
+ elif img_format == 'numpy':
33
+ import io
34
+ from PIL import Image
35
+
36
+ def plotly_fig2array(fig):
37
+ #convert Plotly fig to an array
38
+ fig_bytes = fig.to_image(format="png", width=1200, height=700)
39
+ buf = io.BytesIO(fig_bytes)
40
+ img = Image.open(buf)
41
+ return np.asarray(img)
42
+
43
+ return plotly_fig2array(fig)
44
+ elif img_format == 'jpeg' or 'png' or 'webp':
45
+ fig.write_image(save_path)
46
+ else:
47
+ raise ValueError("invalid image format")
48
+
49
+ def dim_reduce(emb, labels, save_path, n_components=3, method='tsne', title=''):
50
+ """
51
+ dimensionality reduction for visualization!
52
+ saves an html plotly figure to save_path
53
+ parameters:
54
+ emb (np.ndarray): the samples to be reduces with shape (samples, features)
55
+ labels (list): list of labels for embedding
56
+ save_path (str): path where u wanna save ur figure
57
+ method (str): umap, tsne, or pca
58
+ title (str): title for ur figure
59
+ returns:
60
+ proj (np.ndarray): projection vector with shape (samples, dimensions)
61
+ """
62
+ import pandas as pd
63
+ import plotly.express as px
64
+ if method == 'umap':
65
+ from umap import UMAP
66
+ reducer = umap.UMAP(n_components=n_components)
67
+ elif method == 'tsne':
68
+ from sklearn.manifold import TSNE
69
+ reducer = TSNE(n_components=n_components)
70
+ elif method == 'pca':
71
+ from sklearn.decomposition import PCA
72
+ reducer = PCA(n_components=n_components)
73
+ else:
74
+ raise ValueError
75
+
76
+ proj = reducer.fit_transform(emb)
77
+
78
+ if n_components == 2:
79
+ df = pd.DataFrame(dict(
80
+ x=proj[:, 0],
81
+ y=proj[:, 1],
82
+ instrument=labels
83
+ ))
84
+ fig = px.scatter(df, x='x', y='y', color='instrument',
85
+ title=title+f"_{method}")
86
+
87
+ elif n_components == 3:
88
+ df = pd.DataFrame(dict(
89
+ x=proj[:, 0],
90
+ y=proj[:, 1],
91
+ z=proj[:, 2],
92
+ instrument=labels
93
+ ))
94
+ fig = px.scatter_3d(df, x='x', y='y', z='z',
95
+ color='instrument',
96
+ title=title)
97
+ else:
98
+ raise ValueError("cant plot more than 3 components")
99
+
100
+ fig.update_traces(marker=dict(size=6,
101
+ line=dict(width=1,
102
+ color='DarkSlateGrey')),
103
+ selector=dict(mode='markers'))
104
+
105
+ return smart_plotly_export(fig, save_path)
106
+
107
+
108
+
109
+ # per JukeMIR, we want the emebddings from the middle layer?
110
+ def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
111
+ with torch.inference_mode():
112
+ # preprocess the signal
113
+ sig = interface.preprocess(sig)
114
+
115
+ # get the coarse vampnet model
116
+ vampnet = interface.coarse
117
+
118
+ # get the tokens
119
+ z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
120
+ z_latents = vampnet.embedding.from_codes(z, interface.codec)
121
+
122
+ # do a forward pass through the model, get the embeddings
123
+ _z, embeddings = vampnet(z_latents, return_activations=True)
124
+ # print(f"got embeddings with shape {embeddings.shape}")
125
+ # [layer, batch, time, n_dims]
126
+ # [20, 1, 600ish, 768]
127
+
128
+
129
+ # squeeze batch dim (1 bc layer should be dim 0)
130
+ assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
131
+ embeddings = embeddings.squeeze(1)
132
+
133
+ num_layers = embeddings.shape[0]
134
+ assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
135
+
136
+ # do meanpooling over the time dimension
137
+ embeddings = embeddings.mean(dim=-2)
138
+ # [20, 768]
139
+
140
+ # return the embeddings
141
+ return embeddings
142
+
143
+ from dataclasses import dataclass, fields
144
+ @dataclass
145
+ class Embedding:
146
+ genre: str
147
+ filename: str
148
+ embedding: np.ndarray
149
+
150
+ def save(self, path):
151
+ """Save the Embedding object to a given path as a zip file."""
152
+ with zipfile.ZipFile(path, 'w') as archive:
153
+
154
+ # Save numpy array
155
+ with archive.open('embedding.npy', 'w') as f:
156
+ np.save(f, self.embedding)
157
+
158
+ # Save non-numpy data as json
159
+ non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
160
+ with archive.open('data.json', 'w') as f:
161
+ f.write(json.dumps(non_numpy_data).encode('utf-8'))
162
+
163
+ @classmethod
164
+ def load(cls, path):
165
+ """Load the Embedding object from a given zip path."""
166
+ with zipfile.ZipFile(path, 'r') as archive:
167
+
168
+ # Load numpy array
169
+ with archive.open('embedding.npy') as f:
170
+ embedding = np.load(f)
171
+
172
+ # Load non-numpy data from json
173
+ with archive.open('data.json') as f:
174
+ data = json.loads(f.read().decode('utf-8'))
175
+
176
+ return cls(embedding=embedding, **data)
177
+
178
+
179
+ @argbind.bind(without_prefix=True)
180
+ def main(
181
+ path_to_gtzan: str = None,
182
+ cache_dir: str = "./.gtzan_emb_cache",
183
+ output_dir: str = "./gtzan_vampnet_embeddings",
184
+ layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
185
+ ):
186
+ path_to_gtzan = Path(path_to_gtzan)
187
+ assert path_to_gtzan.exists(), f"{path_to_gtzan} does not exist"
188
+
189
+ cache_dir = Path(cache_dir)
190
+ output_dir = Path(output_dir)
191
+ output_dir.mkdir(exist_ok=True, parents=True)
192
+
193
+ # load our interface
194
+ # argbind will automatically load the default config,
195
+ interface = Interface()
196
+
197
+ # gtzan should have a folder for each genre, so let's get the list of genres
198
+ genres = [Path(x).name for x in path_to_gtzan.iterdir() if x.is_dir()]
199
+ print(f"Found {len(genres)} genres")
200
+ print(f"genres: {genres}")
201
+
202
+ # collect audio files, genres, and embeddings
203
+ data = []
204
+ for genre in genres:
205
+ audio_files = list(at.util.find_audio(path_to_gtzan / genre))
206
+ print(f"Found {len(audio_files)} audio files for genre {genre}")
207
+
208
+ for audio_file in tqdm.tqdm(audio_files, desc=f"embedding genre {genre}"):
209
+ # check if we have a cached embedding for this file
210
+ cached_path = (cache_dir / f"{genre}_{audio_file.stem}.emb")
211
+ if cached_path.exists():
212
+ # if so, load it
213
+ if DEBUG:
214
+ print(f"loading cached embedding for {cached_path.stem}")
215
+ embedding = Embedding.load(cached_path)
216
+ else:
217
+ try:
218
+ sig = AudioSignal(audio_file)
219
+ except Exception as e:
220
+ print(f"failed to load {audio_file.name} with error {e}")
221
+ print(f"skipping {audio_file.name}")
222
+ continue
223
+
224
+ # gets the embedding
225
+ emb = vampnet_embed(sig, interface).cpu().numpy()
226
+
227
+ # create an embedding we can save/load
228
+ embedding = Embedding(
229
+ genre=genre,
230
+ filename=audio_file.name,
231
+ embedding=emb
232
+ )
233
+
234
+ # cache the embeddings
235
+ cached_path.parent.mkdir(exist_ok=True, parents=True)
236
+ embedding.save(cached_path)
237
+ data.append(embedding)
238
+
239
+ # now, let's do a dim reduction on the embeddings
240
+ # and visualize them.
241
+
242
+ # collect a list of embeddings and labels
243
+ embeddings = [d.embedding for d in data]
244
+ labels = [d.genre for d in data]
245
+
246
+ # convert the embeddings to a numpy array
247
+ embeddings = np.stack(embeddings)
248
+
249
+ # do dimensionality reduction for each layer we're given
250
+ for layer in tqdm.tqdm(layers, desc="dim reduction"):
251
+ dim_reduce(
252
+ embeddings[:, layer, :], labels,
253
+ save_path=str(output_dir / f'vampnet-gtzan-layer={layer}.html'),
254
+ n_components=2, method='tsne',
255
+ title=f'vampnet-gtzan-layer={layer}'
256
+ )
257
+
258
+
259
+
260
+
261
+ if __name__ == "__main__":
262
+ args = argbind.parse_args()
263
+ with argbind.scope(args):
264
+ main()
scripts/utils/plots.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+ from pandas.api.types import CategoricalDtype
4
+
5
+ def plot_metrics(metrics, condition_to_latex, title, color_palette):
6
+ # Add a new column to your dataframe with the latex representation
7
+ metrics['condition_latex'] = metrics['condition'].map(condition_to_latex)
8
+
9
+ # Order condition_latex as per the condition_to_latex dictionary
10
+ cat_type = CategoricalDtype(categories=condition_to_latex.values(), ordered=True)
11
+ metrics['condition_latex'] = metrics['condition_latex'].astype(cat_type)
12
+
13
+ # Compute mean and std for each condition for each metric
14
+ grouped = metrics.groupby('condition_latex')[['mel', 'frechet']].agg(['mean', 'std'])
15
+
16
+ fig, axs = plt.subplots(2, 1, figsize=(7, 5.25))
17
+
18
+ # Set the main title for the figure
19
+ fig.suptitle(title, fontsize=16)
20
+
21
+ # Get color for each bar in the plot
22
+ bar_colors = [color_palette[condition] for condition in grouped.index]
23
+
24
+ # Plot mel
25
+ sns.boxplot(x='condition_latex', y='mel', data=metrics, ax=axs[0], palette=color_palette, showfliers=False)
26
+ axs[0].set_ylabel('Mel Spectrogram Loss \u2190')
27
+ axs[0].set_xlabel('') # Remove x-axis label
28
+ axs[0].set_xticklabels(grouped.index, rotation=0, ha='center')
29
+
30
+ # Plot frechet
31
+ axs[1].bar(grouped.index, grouped['frechet']['mean'], yerr=grouped['frechet']['std'], color=bar_colors)
32
+ axs[1].set_ylabel('FAD \u2190')
33
+ axs[1].set_xlabel('') # Remove x-axis label
34
+ axs[1].set_xticklabels(grouped.index, rotation=0, ha='center')
35
+
36
+ # Adjust the space between plots
37
+ plt.subplots_adjust(hspace=0.1)
38
+
39
+ # Remove any unnecessary space around the plot
40
+ plt.tight_layout(rect=[0, 0, 1, 0.96])
41
+
42
+ # Reduce the space between suptitle and the plot
43
+ plt.subplots_adjust(top=0.92)
scripts/utils/remove_quiet_files.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # removes files with loudness below 24db
2
+
3
+ from pathlib import Path
4
+ import shutil
5
+ import audiotools as at
6
+ import argbind
7
+
8
+ @argbind.bind(without_prefix=True)
9
+ def remove_quiet_files(
10
+ src_dir: Path = None,
11
+ dest_dir: Path = None,
12
+ min_loudness: float = -30,
13
+ ):
14
+ # copy src to dest
15
+ dest_dir.mkdir(parents=True, exist_ok=True)
16
+ shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
17
+
18
+ audio_files = at.util.find_audio(dest_dir)
19
+ for audio_file in audio_files:
20
+ sig = at.AudioSignal(audio_file)
21
+ if sig.loudness() < min_loudness:
22
+ audio_file.unlink()
23
+ print(f"removed {audio_file}")
24
+
25
+ if __name__ == "__main__":
26
+ args = argbind.parse_args()
27
+
28
+ with argbind.scope(args):
29
+ remove_quiet_files()
scripts/utils/split.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import random
3
+ import shutil
4
+ import os
5
+ import json
6
+
7
+ import argbind
8
+ from tqdm import tqdm
9
+ from tqdm.contrib.concurrent import thread_map
10
+
11
+ from audiotools.core import util
12
+
13
+
14
+ @argbind.bind(without_prefix=True)
15
+ def train_test_split(
16
+ audio_folder: str = ".",
17
+ test_size: float = 0.2,
18
+ seed: int = 42,
19
+ pattern: str = "**/*.mp3",
20
+ ):
21
+ print(f"finding audio")
22
+
23
+ audio_folder = Path(audio_folder)
24
+ audio_files = list(tqdm(audio_folder.glob(pattern)))
25
+ print(f"found {len(audio_files)} audio files")
26
+
27
+ # split according to test_size
28
+ n_test = int(len(audio_files) * test_size)
29
+ n_train = len(audio_files) - n_test
30
+
31
+ # shuffle
32
+ random.seed(seed)
33
+ random.shuffle(audio_files)
34
+
35
+ train_files = audio_files[:n_train]
36
+ test_files = audio_files[n_train:]
37
+
38
+
39
+ print(f"Train files: {len(train_files)}")
40
+ print(f"Test files: {len(test_files)}")
41
+ continue_ = input("Continue [yn]? ") or "n"
42
+
43
+ if continue_ != "y":
44
+ return
45
+
46
+ for split, files in (
47
+ ("train", train_files), ("test", test_files)
48
+ ):
49
+ for file in tqdm(files):
50
+ out_file = audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name
51
+ out_file.parent.mkdir(exist_ok=True, parents=True)
52
+ os.symlink(file, out_file)
53
+
54
+ # save split as json
55
+ with open(Path(audio_folder) / f"{split}.json", "w") as f:
56
+ json.dump([str(f) for f in files], f)
57
+
58
+
59
+
60
+ if __name__ == "__main__":
61
+ args = argbind.parse_args()
62
+
63
+ with argbind.scope(args):
64
+ train_test_split()
scripts/utils/split_long_audio_file.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import argbind
3
+
4
+ import audiotools as at
5
+ import tqdm
6
+
7
+
8
+ @argbind.bind(without_prefix=True)
9
+ def split_long_audio_file(
10
+ file: str = None,
11
+ max_chunk_size_s: int = 60*10
12
+ ):
13
+ file = Path(file)
14
+ output_dir = file.parent / file.stem
15
+ output_dir.mkdir()
16
+
17
+ sig = at.AudioSignal(file)
18
+
19
+ # split into chunks
20
+ for i, sig in tqdm.tqdm(enumerate(sig.windows(
21
+ window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
22
+ preprocess=True))
23
+ ):
24
+ sig.write(output_dir / f"{i}.wav")
25
+
26
+ print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
27
+
28
+ return output_dir
29
+
30
+ if __name__ == "__main__":
31
+ args = argbind.parse_args()
32
+
33
+ with argbind.scope(args):
34
+ split_long_audio_file()
scripts/utils/stage.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+ import argbind
6
+ import rich
7
+ from audiotools.ml import Experiment
8
+
9
+
10
+ @argbind.bind(without_prefix=True)
11
+ def run(
12
+ run_dir: str = os.getenv("PATH_TO_RUNS", "runs"),
13
+ name: str = None,
14
+ recent: bool = False,
15
+ ):
16
+ if recent:
17
+ paths = sorted(Path(run_dir).iterdir(), key=os.path.getmtime)
18
+ paths = [p.name for p in paths if p.is_dir()]
19
+ if paths:
20
+ name = paths[-1]
21
+
22
+ with Experiment(run_dir, name) as exp:
23
+ exp.snapshot()
24
+ rich.print(f"Created a snapshot of {exp.parent_directory} at {exp.exp_dir}")
25
+
26
+
27
+ if __name__ == "__main__":
28
+ args = argbind.parse_args()
29
+ with argbind.scope(args):
30
+ run()
scripts/utils/visualize_embeddings.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO: train a linear probe
3
+ usage:
4
+ python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_audio /path/to/audio/labels --output_dir /path/to/output
5
+ """
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ import audiotools as at
10
+ from audiotools import AudioSignal
11
+ import argbind
12
+ import torch
13
+ import numpy as np
14
+ import zipfile
15
+ import json
16
+
17
+ from vampnet.interface import Interface
18
+ import tqdm
19
+
20
+ # bind the Interface to argbind
21
+ Interface = argbind.bind(Interface)
22
+
23
+ DEBUG = False
24
+
25
+
26
+ def smart_plotly_export(fig, save_path: Path):
27
+ img_format = save_path.suffix[1:]
28
+ if img_format == "html":
29
+ fig.write_html(save_path)
30
+ elif img_format == 'bytes':
31
+ return fig.to_image(format='png')
32
+ #TODO: come back and make this prettier
33
+ elif img_format == 'numpy':
34
+ import io
35
+ from PIL import Image
36
+
37
+ def plotly_fig2array(fig):
38
+ #convert Plotly fig to an array
39
+ fig_bytes = fig.to_image(format="png", width=1200, height=700)
40
+ buf = io.BytesIO(fig_bytes)
41
+ img = Image.open(buf)
42
+ return np.asarray(img)
43
+
44
+ return plotly_fig2array(fig)
45
+ elif img_format == 'jpeg' or 'png' or 'webp':
46
+ fig.write_image(save_path)
47
+ else:
48
+ raise ValueError("invalid image format")
49
+
50
+
51
+ def dim_reduce(annotated_embeddings, layer, output_dir, n_components=3, method="tsne"):
52
+ """
53
+ dimensionality reduction for visualization!
54
+ saves an html plotly figure to save_path
55
+ parameters:
56
+ annotated_embeddings (list): the annotated enmbeddings to be reduced; embeddings have shape (samples, features)
57
+ labels (list): list of labels for embedding
58
+ save_path (str): path where u wanna save ur figure
59
+ method (str): umap, tsne, or pca
60
+ title (str): title for ur figure
61
+ returns:
62
+ proj (np.ndarray): projection vector with shape (samples, dimensions)
63
+ """
64
+ import pandas as pd
65
+ import plotly.express as px
66
+
67
+ fig_name = f"vampnet-embeddings-layer={layer}"
68
+ fig_title = f"{fig_name}_{method}"
69
+ save_path = (output_dir / fig_name).with_suffix(".html")
70
+
71
+ if method == "umap":
72
+ from umap import UMAP
73
+ reducer = umap.UMAP(n_components=n_components)
74
+ elif method == "tsne":
75
+ from sklearn.manifold import TSNE
76
+
77
+ reducer = TSNE(n_components=n_components)
78
+ elif method == "pca":
79
+ from sklearn.decomposition import PCA
80
+
81
+ reducer = PCA(n_components=n_components)
82
+ else:
83
+ raise ValueError(f"invalid method: {method}")
84
+
85
+ labels = [emb.label for emb in annotated_embeddings]
86
+ names = [emb.filename for emb in annotated_embeddings]
87
+ embs = [emb.embedding for emb in annotated_embeddings]
88
+ embs_at_layer = np.stack(embs)[:, layer, :]
89
+ projs = reducer.fit_transform(embs_at_layer)
90
+
91
+ df = pd.DataFrame(
92
+ {
93
+ "label": labels,
94
+ "name": names,
95
+ "x": projs[:, 0],
96
+ "y": projs[:, 1],
97
+ }
98
+ )
99
+ if n_components == 2:
100
+ fig = px.scatter(
101
+ df, x="x", y="y", color="label", hover_name="name", title=fig_title,
102
+ )
103
+
104
+ elif n_components == 3:
105
+ df['z'] = projs[:, 2]
106
+ fig = px.scatter_3d(
107
+ df, x="x", y="y", z="z", color="label", hover_name="name", title=fig_title
108
+ )
109
+ else:
110
+ raise ValueError(f"can't plot {n_components} components")
111
+
112
+ fig.update_traces(
113
+ marker=dict(size=6, line=dict(width=1, color="DarkSlateGrey")),
114
+ selector=dict(mode="markers"),
115
+ )
116
+
117
+ return smart_plotly_export(fig, save_path)
118
+
119
+
120
+
121
+ # per JukeMIR, we want the emebddings from the middle layer?
122
+ def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
123
+ with torch.inference_mode():
124
+ # preprocess the signal
125
+ sig = interface.preprocess(sig)
126
+
127
+ # get the coarse vampnet model
128
+ vampnet = interface.coarse
129
+
130
+ # get the tokens
131
+ z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
132
+ z_latents = vampnet.embedding.from_codes(z, interface.codec)
133
+
134
+ # do a forward pass through the model, get the embeddings
135
+ _z, embeddings = vampnet(z_latents, return_activations=True)
136
+ # print(f"got embeddings with shape {embeddings.shape}")
137
+ # [layer, batch, time, n_dims]
138
+ # [20, 1, 600ish, 768]
139
+
140
+
141
+ # squeeze batch dim (1 bc layer should be dim 0)
142
+ assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
143
+ embeddings = embeddings.squeeze(1)
144
+
145
+ num_layers = embeddings.shape[0]
146
+ assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
147
+
148
+ # do meanpooling over the time dimension
149
+ embeddings = embeddings.mean(dim=-2)
150
+ # [20, 768]
151
+
152
+ # return the embeddings
153
+ return embeddings
154
+
155
+ from dataclasses import dataclass, fields
156
+ @dataclass
157
+ class AnnotatedEmbedding:
158
+ label: str
159
+ filename: str
160
+ embedding: np.ndarray
161
+
162
+ def save(self, path):
163
+ """Save the Embedding object to a given path as a zip file."""
164
+ with zipfile.ZipFile(path, 'w') as archive:
165
+
166
+ # Save numpy array
167
+ with archive.open('embedding.npy', 'w') as f:
168
+ np.save(f, self.embedding)
169
+
170
+ # Save non-numpy data as json
171
+ non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
172
+ with archive.open('data.json', 'w') as f:
173
+ f.write(json.dumps(non_numpy_data).encode('utf-8'))
174
+
175
+ @classmethod
176
+ def load(cls, path):
177
+ """Load the Embedding object from a given zip path."""
178
+ with zipfile.ZipFile(path, 'r') as archive:
179
+
180
+ # Load numpy array
181
+ with archive.open('embedding.npy') as f:
182
+ embedding = np.load(f)
183
+
184
+ # Load non-numpy data from json
185
+ with archive.open('data.json') as f:
186
+ data = json.loads(f.read().decode('utf-8'))
187
+
188
+ return cls(embedding=embedding, **data)
189
+
190
+
191
+ @argbind.bind(without_prefix=True)
192
+ def main(
193
+ path_to_audio: str = None,
194
+ cache_dir: str = "./.emb_cache",
195
+ output_dir: str = "./vampnet_embeddings",
196
+ layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
197
+ method: str = "tsne",
198
+ n_components: int = 2,
199
+ ):
200
+ path_to_audio = Path(path_to_audio)
201
+ assert path_to_audio.exists(), f"{path_to_audio} does not exist"
202
+
203
+ cache_dir = Path(cache_dir)
204
+ output_dir = Path(output_dir)
205
+ output_dir.mkdir(exist_ok=True, parents=True)
206
+
207
+ # load our interface
208
+ # argbind will automatically load the default config,
209
+ interface = Interface()
210
+
211
+ # we expect path_to_audio to consist of a folder for each label, so let's get the list of labels
212
+ labels = [Path(x).name for x in path_to_audio.iterdir() if x.is_dir()]
213
+ print(f"Found {len(labels)} labels")
214
+ print(f"labels: {labels}")
215
+
216
+ # collect audio files, labels, and embeddings
217
+ annotated_embeddings = []
218
+ for label in labels:
219
+ audio_files = list(at.util.find_audio(path_to_audio / label))
220
+ print(f"Found {len(audio_files)} audio files for label {label}")
221
+
222
+ for audio_file in tqdm.tqdm(audio_files, desc=f"embedding label {label}"):
223
+ # check if we have a cached embedding for this file
224
+ cached_path = cache_dir / f"{label}_{audio_file.stem}.emb"
225
+ if cached_path.exists():
226
+ # if so, load it
227
+ if DEBUG:
228
+ print(f"loading cached embedding for {cached_path.stem}")
229
+ embedding = AnnotatedEmbedding.load(cached_path)
230
+ else:
231
+ try:
232
+ sig = AudioSignal(audio_file)
233
+ except Exception as e:
234
+ print(f"failed to load {audio_file.name} with error {e}")
235
+ print(f"skipping {audio_file.name}")
236
+ continue
237
+
238
+ # gets the embedding
239
+ emb = vampnet_embed(sig, interface).cpu().numpy()
240
+
241
+ # create an embedding we can save/load
242
+ embedding = AnnotatedEmbedding(
243
+ label=label, filename=audio_file.name, embedding=emb
244
+ )
245
+
246
+ # cache the embeddings
247
+ cached_path.parent.mkdir(exist_ok=True, parents=True)
248
+ embedding.save(cached_path)
249
+ annotated_embeddings.append(embedding)
250
+
251
+ # now, let's do a dim reduction on the embeddings and visualize them.
252
+ for layer in tqdm.tqdm(layers, desc="dim reduction"):
253
+ dim_reduce(
254
+ annotated_embeddings,
255
+ layer,
256
+ output_dir=output_dir,
257
+ n_components=n_components,
258
+ method=method,
259
+ )
260
+
261
+
262
+ if __name__ == "__main__":
263
+ args = argbind.parse_args()
264
+ with argbind.scope(args):
265
+ main()
scripts/utils/xeno-canto-dl.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from xenopy import Query
2
+
3
+
4
+ SPECIES = [
5
+ "American Robin",
6
+ "Northern Cardinal",
7
+ "Mourning Dove",
8
+ "American Crow",
9
+ "Baltimore Oriole",
10
+ "Blue Jay",
11
+ "Eastern Bluebird",
12
+ "House Finch",
13
+ "American Goldfinch",
14
+ "House Sparrow",
15
+ "Song Sparrow",
16
+ "Tufted Titmouse",
17
+ "White-breasted Nuthatch",
18
+ "European Starling",
19
+ "American Redstart",
20
+ "Red-winged Blackbird",
21
+ "Brown-headed Cowbird",
22
+ "Common Grackle",
23
+ "Boat-tailed Grackle",
24
+ "Common Yellowthroat",
25
+ "Northern Mockingbird",
26
+ "Carolina Wren",
27
+ "Eastern Meadowlark",
28
+ "Chipping Sparrow",
29
+ "Tree Swallow",
30
+ "Barn Swallow",
31
+ "Cliff Swallow",
32
+ "Pine Siskin",
33
+ "Indigo Bunting",
34
+ "Eastern Towhee",
35
+ "Carolina Chickadee",
36
+ "Great Crested Flycatcher",
37
+ "Eastern Wood-Pewee",
38
+ "Ovenbird",
39
+ "Northern Flicker",
40
+ "Red-eyed Vireo",
41
+ "American Woodcock",
42
+ "Eastern Phoebe",
43
+ "Downy Woodpecker",
44
+ "Scarlet Tanager",
45
+ "Yellow Warbler",
46
+ "White-eyed Vireo",
47
+ "Common Loon",
48
+ "White-throated Sparrow",
49
+ "Yellow-throated Vireo",
50
+ "Great Blue Heron",
51
+ "Belted Kingfisher",
52
+ "Pied-billed Grebe",
53
+ "Wild Turkey",
54
+ "Wood Thrush",
55
+ "Rose-breasted Grosbeak",
56
+ "Field Sparrow",
57
+ "Hooded Warbler",
58
+ "Northern Parula",
59
+ "Chestnut-sided Warbler",
60
+ "Blue-winged Warbler",
61
+ "Red-bellied Woodpecker",
62
+ "Yellow-billed Cuckoo",
63
+ "Gray Catbird",
64
+ "Northern Saw-whet Owl",
65
+ "Osprey",
66
+ "Common Nighthawk",
67
+ "Broad-winged Hawk",
68
+ "Black-throated Green Warbler",
69
+ "Great Horned Owl",
70
+ "Common Raven",
71
+ "Barred Owl",
72
+ "Canada Warbler",
73
+ "Magnolia Warbler",
74
+ "Black-and-white Warbler",
75
+ "Eastern Kingbird",
76
+ "Swainson's Thrush",
77
+ "Worm-eating Warbler",
78
+ "Prairie Warbler",
79
+ "Baltimore Oriole",
80
+ "Black-throated Blue Warbler",
81
+ "Louisiana Waterthrush",
82
+ "Blackburnian Warbler",
83
+ "Black-capped Chickadee",
84
+ "Cerulean Warbler",
85
+ "Red-shouldered Hawk",
86
+ "Cooper's Hawk",
87
+ "Yellow-throated Warbler",
88
+ "Blue-headed Vireo",
89
+ "Blackpoll Warbler",
90
+ "Ruffed Grouse",
91
+ "Kentucky Warbler",
92
+ "Hermit Thrush",
93
+ "Cedar Waxwing",
94
+ "Eastern Screech-Owl",
95
+ "Northern Goshawk",
96
+ "Green Heron",
97
+ "Red-tailed Hawk",
98
+ "Black Vulture",
99
+ "Hairy Woodpecker",
100
+ "Golden-crowned Kinglet",
101
+ "Ruby-crowned Kinglet",
102
+ "Bicknell's Thrush",
103
+ "Blue-gray Gnatcatcher",
104
+ "Veery",
105
+ "Pileated Woodpecker",
106
+ "Purple Finch",
107
+ "White-crowned Sparrow",
108
+ "Snow Bunting",
109
+ "Pine Grosbeak",
110
+ "American Tree Sparrow",
111
+ "Dark-eyed Junco",
112
+ "Snowy Owl",
113
+ "White-winged Crossbill",
114
+ "Red Crossbill",
115
+ "Common Redpoll",
116
+ "Northern Shrike",
117
+ "Northern Harrier",
118
+ "Rough-legged Hawk",
119
+ "Long-eared Owl",
120
+ "Evening Grosbeak",
121
+ "Northern Pintail",
122
+ "American Black Duck",
123
+ "Mallard",
124
+ "Canvasback",
125
+ "Redhead",
126
+ "Ring-necked Duck",
127
+ "Greater Scaup",
128
+ "Lesser Scaup",
129
+ "Bufflehead",
130
+ "Common Goldeneye",
131
+ "Hooded Merganser",
132
+ "Common Merganser",
133
+ "Red-breasted Merganser",
134
+ "Ruddy Duck",
135
+ "Wood Duck",
136
+ "Gadwall",
137
+ "American Wigeon",
138
+ "Northern Shoveler",
139
+ "Green-winged Teal",
140
+ "Blue-winged Teal",
141
+ "Cinnamon Teal",
142
+ "Ringed Teal",
143
+ "Cape Teal",
144
+ "Northern Fulmar",
145
+ "Yellow-billed Loon",
146
+ "Red-throated Loon",
147
+ "Arctic Loon",
148
+ "Pacific Loon",
149
+ "Horned Grebe",
150
+ "Red-necked Grebe",
151
+ "Eared Grebe",
152
+ "Western Grebe",
153
+ "Clark's Grebe",
154
+ "Double-crested Cormorant",
155
+ "Pelagic Cormorant",
156
+ "Great Cormorant",
157
+ "American White Pelican",
158
+ "Brown Pelican",
159
+ "Brandt's Cormorant",
160
+ "Least Bittern",
161
+ "Great Egret",
162
+ "Snowy Egret",
163
+ "Little Blue Heron",
164
+ "Tricolored Heron",
165
+ "Reddish Egret",
166
+ "Black-crowned Night-Heron",
167
+ "Yellow-crowned Night-Heron",
168
+ "White Ibis",
169
+ "Glossy Ibis",
170
+ "Roseate Spoonbill",
171
+ "Wood Stork",
172
+ "Black-bellied Whistling-Duck",
173
+ "Fulvous Whistling-Duck",
174
+ "Greater White-fronted Goose",
175
+ "Snow Goose",
176
+ "Ross's Goose",
177
+ "Canada Goose",
178
+ "Brant",
179
+ "Mute Swan",
180
+ "Tundra Swan",
181
+ "Whooper Swan",
182
+ "Sandhill Crane",
183
+ "Black-necked Stilt",
184
+ "American Avocet",
185
+ "Northern Jacana",
186
+ "Greater Yellowlegs",
187
+ "Lesser Yellowlegs",
188
+ "Willet",
189
+ "Spotted Sandpiper",
190
+ "Upland Sandpiper",
191
+ "Whimbrel",
192
+ "Long-billed Curlew",
193
+ "Marbled Godwit",
194
+ "Ruddy Turnstone",
195
+ "Red Knot",
196
+ "Sanderling",
197
+ "Semipalmated Sandpiper",
198
+ "Western Sandpiper",
199
+ "Least Sandpiper",
200
+ "White-rumped Sandpiper",
201
+ "Baird's Sandpiper",
202
+ "Pectoral Sandpiper",
203
+ "Dunlin",
204
+ "Buff-breasted Sandpiper",
205
+ "Short-billed Dowitcher",
206
+ "Long-billed Dowitcher",
207
+ "Common Snipe",
208
+ "American Woodcock",
209
+ "Wilson's Phalarope",
210
+ "Red-necked Phalarope",
211
+ "Red Phalarope"
212
+ ]
213
+
214
+ from pathlib import Path
215
+
216
+ def remove_spaces(s):
217
+ return s.replace(" ", "")
218
+
219
+ for species in SPECIES:
220
+ if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
221
+ continue
222
+ try:
223
+ q = Query(
224
+ name=species, q="A", length="10-30",
225
+ )
226
+
227
+ # retrieve metadata
228
+ metafiles = q.retrieve_meta(verbose=True)
229
+ # retrieve recordings
230
+ q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
231
+
232
+ except:
233
+ print("Failed to download " + species)
234
+ continue
setup.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages
2
+ from setuptools import setup
3
+
4
+ with open("README.md") as f:
5
+ long_description = f.read()
6
+
7
+ setup(
8
+ name="vampnet",
9
+ version="0.0.1",
10
+ classifiers=[
11
+ "Intended Audience :: Developers",
12
+ "Natural Language :: English",
13
+ "Programming Language :: Python :: 3.7",
14
+ "Topic :: Artistic Software",
15
+ "Topic :: Multimedia",
16
+ "Topic :: Multimedia :: Sound/Audio",
17
+ "Topic :: Multimedia :: Sound/Audio :: Editors",
18
+ "Topic :: Software Development :: Libraries",
19
+ ],
20
+ description="Generative Music Modeling.",
21
+ long_description=long_description,
22
+ long_description_content_type="text/markdown",
23
+ author="Hugo Flores García, Prem Seetharaman",
24
+ author_email="hfgacrcia@descript.com",
25
+ url="https://github.com/hugofloresgarcia/vampnet",
26
+ license="MIT",
27
+ packages=find_packages(),
28
+ install_requires=[
29
+ "torch",
30
+ "argbind>=0.3.2",
31
+ "numpy==1.23",
32
+ "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
33
+ "lac @ git+https://github.com/hugofloresgarcia/lac.git",
34
+ "descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
35
+ "gradio",
36
+ "loralib",
37
+ "torch_pitch_shift",
38
+ "madmom",
39
+ "pyharp @ git+https://github.com/audacitorch/pyharp.git",
40
+ "plotly",
41
+ "umap_learn",
42
+ ],
43
+ )
vampnet/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from . import modules
3
+ from . import scheduler
4
+ from .interface import Interface
5
+
6
+ __version__ = "0.0.1"
vampnet/beats.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any
7
+ from typing import List
8
+ from typing import Tuple
9
+ from typing import Union
10
+
11
+ import librosa
12
+ import torch
13
+ import numpy as np
14
+ from audiotools import AudioSignal
15
+
16
+
17
+ logging.basicConfig(level=logging.INFO)
18
+
19
+ ###################
20
+ # beat sync utils #
21
+ ###################
22
+
23
+ AGGREGATOR_REGISTRY = {
24
+ "mean": np.mean,
25
+ "median": np.median,
26
+ "max": np.max,
27
+ "min": np.min,
28
+ }
29
+
30
+
31
+ def list_aggregators() -> list:
32
+ return list(AGGREGATOR_REGISTRY.keys())
33
+
34
+
35
+ @dataclass
36
+ class TimeSegment:
37
+ start: float
38
+ end: float
39
+
40
+ @property
41
+ def duration(self):
42
+ return self.end - self.start
43
+
44
+ def __str__(self) -> str:
45
+ return f"{self.start} - {self.end}"
46
+
47
+ def find_overlapping_segment(
48
+ self, segments: List["TimeSegment"]
49
+ ) -> Union["TimeSegment", None]:
50
+ """Find the first segment that overlaps with this segment, or None if no segment overlaps"""
51
+ for s in segments:
52
+ if s.start <= self.start and s.end >= self.end:
53
+ return s
54
+ return None
55
+
56
+
57
+ def mkdir(path: Union[Path, str]) -> Path:
58
+ p = Path(path)
59
+ p.mkdir(parents=True, exist_ok=True)
60
+ return p
61
+
62
+
63
+
64
+ ###################
65
+ # beat data #
66
+ ###################
67
+ @dataclass
68
+ class BeatSegment(TimeSegment):
69
+ downbeat: bool = False # if there's a downbeat on the start_time
70
+
71
+
72
+ class Beats:
73
+ def __init__(self, beat_times, downbeat_times):
74
+ if isinstance(beat_times, np.ndarray):
75
+ beat_times = beat_times.tolist()
76
+ if isinstance(downbeat_times, np.ndarray):
77
+ downbeat_times = downbeat_times.tolist()
78
+ self._beat_times = beat_times
79
+ self._downbeat_times = downbeat_times
80
+ self._use_downbeats = False
81
+
82
+ def use_downbeats(self, use_downbeats: bool = True):
83
+ """use downbeats instead of beats when calling beat_times"""
84
+ self._use_downbeats = use_downbeats
85
+
86
+ def beat_segments(self, signal: AudioSignal) -> List[BeatSegment]:
87
+ """
88
+ segments a song into time segments corresponding to beats.
89
+ the first segment starts at 0 and ends at the first beat time.
90
+ the last segment starts at the last beat time and ends at the end of the song.
91
+ """
92
+ beat_times = self._beat_times.copy()
93
+ downbeat_times = self._downbeat_times
94
+ beat_times.insert(0, 0)
95
+ beat_times.append(signal.signal_duration)
96
+
97
+ downbeat_ids = np.intersect1d(beat_times, downbeat_times, return_indices=True)[
98
+ 1
99
+ ]
100
+ is_downbeat = [
101
+ True if i in downbeat_ids else False for i in range(len(beat_times))
102
+ ]
103
+ segments = [
104
+ BeatSegment(start_time, end_time, downbeat)
105
+ for start_time, end_time, downbeat in zip(
106
+ beat_times[:-1], beat_times[1:], is_downbeat
107
+ )
108
+ ]
109
+ return segments
110
+
111
+ def get_beats(self) -> np.ndarray:
112
+ """returns an array of beat times, in seconds
113
+ if downbeats is True, returns an array of downbeat times, in seconds
114
+ """
115
+ return np.array(
116
+ self._downbeat_times if self._use_downbeats else self._beat_times
117
+ )
118
+
119
+ @property
120
+ def beat_times(self) -> np.ndarray:
121
+ """return beat times"""
122
+ return np.array(self._beat_times)
123
+
124
+ @property
125
+ def downbeat_times(self) -> np.ndarray:
126
+ """return downbeat times"""
127
+ return np.array(self._downbeat_times)
128
+
129
+ def beat_times_to_feature_frames(
130
+ self, signal: AudioSignal, features: np.ndarray
131
+ ) -> np.ndarray:
132
+ """convert beat times to frames, given an array of time-varying features"""
133
+ beat_times = self.get_beats()
134
+ beat_frames = (
135
+ beat_times * signal.sample_rate / signal.signal_length * features.shape[-1]
136
+ ).astype(np.int64)
137
+ return beat_frames
138
+
139
+ def sync_features(
140
+ self, feature_frames: np.ndarray, features: np.ndarray, aggregate="median"
141
+ ) -> np.ndarray:
142
+ """sync features to beats"""
143
+ if aggregate not in AGGREGATOR_REGISTRY:
144
+ raise ValueError(f"unknown aggregation method {aggregate}")
145
+
146
+ return librosa.util.sync(
147
+ features, feature_frames, aggregate=AGGREGATOR_REGISTRY[aggregate]
148
+ )
149
+
150
+ def to_json(self) -> dict:
151
+ """return beats and downbeats as json"""
152
+ return {
153
+ "beats": self._beat_times,
154
+ "downbeats": self._downbeat_times,
155
+ "use_downbeats": self._use_downbeats,
156
+ }
157
+
158
+ @classmethod
159
+ def from_dict(cls, data: dict):
160
+ """load beats and downbeats from json"""
161
+ inst = cls(data["beats"], data["downbeats"])
162
+ inst.use_downbeats(data["use_downbeats"])
163
+ return inst
164
+
165
+ def save(self, output_dir: Path):
166
+ """save beats and downbeats to json"""
167
+ mkdir(output_dir)
168
+ with open(output_dir / "beats.json", "w") as f:
169
+ json.dump(self.to_json(), f)
170
+
171
+ @classmethod
172
+ def load(cls, input_dir: Path):
173
+ """load beats and downbeats from json"""
174
+ beats_file = Path(input_dir) / "beats.json"
175
+ with open(beats_file, "r") as f:
176
+ data = json.load(f)
177
+ return cls.from_dict(data)
178
+
179
+
180
+ ###################
181
+ # beat tracking #
182
+ ###################
183
+
184
+
185
+ class BeatTracker:
186
+ def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
187
+ """extract beats from an audio signal"""
188
+ raise NotImplementedError
189
+
190
+ def __call__(self, signal: AudioSignal) -> Beats:
191
+ """extract beats from an audio signal
192
+ NOTE: if the first beat (and/or downbeat) is detected within the first 100ms of the audio,
193
+ it is discarded. This is to avoid empty bins with no beat synced features in the first beat.
194
+ Args:
195
+ signal (AudioSignal): signal to beat track
196
+ Returns:
197
+ Tuple[np.ndarray, np.ndarray]: beats and downbeats
198
+ """
199
+ beats, downbeats = self.extract_beats(signal)
200
+ return Beats(beats, downbeats)
201
+
202
+
203
+ class WaveBeat(BeatTracker):
204
+ def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
205
+ from wavebeat.dstcn import dsTCNModel
206
+
207
+ model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device))
208
+ model.eval()
209
+
210
+ self.device = device
211
+ self.model = model
212
+
213
+ def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
214
+ """returns beat and downbeat times, in seconds"""
215
+ # extract beats
216
+ beats, downbeats = self.model.predict_beats_from_array(
217
+ audio=signal.audio_data.squeeze(0),
218
+ sr=signal.sample_rate,
219
+ use_gpu=self.device != "cpu",
220
+ )
221
+
222
+ return beats, downbeats
223
+
224
+
225
+ class MadmomBeats(BeatTracker):
226
+ def __init__(self):
227
+ raise NotImplementedError
228
+
229
+ def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
230
+ """returns beat and downbeat times, in seconds"""
231
+ pass
232
+
233
+
234
+ BEAT_TRACKER_REGISTRY = {
235
+ "wavebeat": WaveBeat,
236
+ "madmom": MadmomBeats,
237
+ }
238
+
239
+
240
+ def list_beat_trackers() -> list:
241
+ return list(BEAT_TRACKER_REGISTRY.keys())
242
+
243
+
244
+ def load_beat_tracker(beat_tracker: str, **kwargs) -> BeatTracker:
245
+ if beat_tracker not in BEAT_TRACKER_REGISTRY:
246
+ raise ValueError(
247
+ f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}"
248
+ )
249
+
250
+ return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs)
vampnet/interface.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import math
4
+
5
+ import torch
6
+ import numpy as np
7
+ from audiotools import AudioSignal
8
+ import tqdm
9
+
10
+ from .modules.transformer import VampNet
11
+ from .beats import WaveBeat
12
+ from .mask import *
13
+
14
+ # from dac.model.dac import DAC
15
+ from lac.model.lac import LAC as DAC
16
+
17
+
18
+ def signal_concat(
19
+ audio_signals: list,
20
+ ):
21
+ audio_data = torch.cat([x.audio_data for x in audio_signals], dim=-1)
22
+
23
+ return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
24
+
25
+
26
+ def _load_model(
27
+ ckpt: str,
28
+ lora_ckpt: str = None,
29
+ device: str = "cpu",
30
+ chunk_size_s: int = 10,
31
+ ):
32
+ # we need to set strict to False if the model has lora weights to add later
33
+ model = VampNet.load(location=Path(ckpt), map_location="cpu", strict=False)
34
+
35
+ # load lora weights if needed
36
+ if lora_ckpt is not None:
37
+ if not Path(lora_ckpt).exists():
38
+ should_cont = input(
39
+ f"lora checkpoint {lora_ckpt} does not exist. continue? (y/n) "
40
+ )
41
+ if should_cont != "y":
42
+ raise Exception("aborting")
43
+ else:
44
+ model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False)
45
+
46
+ model.to(device)
47
+ model.eval()
48
+ model.chunk_size_s = chunk_size_s
49
+ return model
50
+
51
+
52
+
53
+ class Interface(torch.nn.Module):
54
+ def __init__(
55
+ self,
56
+ coarse_ckpt: str = None,
57
+ coarse_lora_ckpt: str = None,
58
+ coarse2fine_ckpt: str = None,
59
+ coarse2fine_lora_ckpt: str = None,
60
+ codec_ckpt: str = None,
61
+ wavebeat_ckpt: str = None,
62
+ device: str = "cpu",
63
+ coarse_chunk_size_s: int = 10,
64
+ coarse2fine_chunk_size_s: int = 3,
65
+ ):
66
+ super().__init__()
67
+ assert codec_ckpt is not None, "must provide a codec checkpoint"
68
+ self.codec = DAC.load(Path(codec_ckpt))
69
+ self.codec.eval()
70
+ self.codec.to(device)
71
+ self.codec_path = Path(codec_ckpt)
72
+
73
+ assert coarse_ckpt is not None, "must provide a coarse checkpoint"
74
+ self.coarse = _load_model(
75
+ ckpt=coarse_ckpt,
76
+ lora_ckpt=coarse_lora_ckpt,
77
+ device=device,
78
+ chunk_size_s=coarse_chunk_size_s,
79
+ )
80
+ self.coarse_path = Path(coarse_ckpt)
81
+
82
+ # check if we have a coarse2fine ckpt
83
+ if coarse2fine_ckpt is not None:
84
+ self.c2f_path = Path(coarse2fine_ckpt)
85
+ self.c2f = _load_model(
86
+ ckpt=coarse2fine_ckpt,
87
+ lora_ckpt=coarse2fine_lora_ckpt,
88
+ device=device,
89
+ chunk_size_s=coarse2fine_chunk_size_s,
90
+ )
91
+ else:
92
+ self.c2f_path = None
93
+ self.c2f = None
94
+
95
+ if wavebeat_ckpt is not None:
96
+ print(f"loading wavebeat from {wavebeat_ckpt}")
97
+ self.beat_tracker = WaveBeat(wavebeat_ckpt)
98
+ self.beat_tracker.model.to(device)
99
+ else:
100
+ self.beat_tracker = None
101
+
102
+ self.device = device
103
+
104
+ def reload(
105
+ self,
106
+ coarse_ckpt: str = None,
107
+ c2f_ckpt: str = None,
108
+ ):
109
+ if coarse_ckpt is not None:
110
+ # check if we already loaded, if so, don't reload
111
+ if self.coarse_path == Path(coarse_ckpt):
112
+ print(f"already loaded {coarse_ckpt}")
113
+ return
114
+ self.coarse = _load_model(
115
+ ckpt=coarse_ckpt,
116
+ device=self.device,
117
+ chunk_size_s=self.coarse.chunk_size_s,
118
+ )
119
+ self.coarse_path = Path(coarse_ckpt)
120
+ print(f"loaded {coarse_ckpt}")
121
+
122
+ if c2f_ckpt is not None:
123
+ if self.c2f_path == Path(c2f_ckpt):
124
+ print(f"already loaded {c2f_ckpt}")
125
+ return
126
+ self.c2f = _load_model(
127
+ ckpt=c2f_ckpt,
128
+ device=self.device,
129
+ chunk_size_s=self.c2f.chunk_size_s,
130
+ )
131
+ self.c2f_path = Path(c2f_ckpt)
132
+ print(f"loaded {c2f_ckpt}")
133
+
134
+ def s2t(self, seconds: float):
135
+ """seconds to tokens"""
136
+ if isinstance(seconds, np.ndarray):
137
+ return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
138
+ else:
139
+ return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
140
+
141
+ def s2t2s(self, seconds: float):
142
+ """seconds to tokens to seconds"""
143
+ return self.t2s(self.s2t(seconds))
144
+
145
+ def t2s(self, tokens: int):
146
+ """tokens to seconds"""
147
+ return tokens * self.codec.hop_length / self.codec.sample_rate
148
+
149
+ def to(self, device):
150
+ self.device = device
151
+ self.coarse.to(device)
152
+ self.codec.to(device)
153
+
154
+ if self.c2f is not None:
155
+ self.c2f.to(device)
156
+
157
+ if self.beat_tracker is not None:
158
+ self.beat_tracker.model.to(device)
159
+ return self
160
+
161
+ def to_signal(self, z: torch.Tensor):
162
+ return self.coarse.to_signal(z, self.codec)
163
+
164
+ def preprocess(self, signal: AudioSignal):
165
+ signal = (
166
+ signal.clone()
167
+ .resample(self.codec.sample_rate)
168
+ .to_mono()
169
+ .normalize(-24)
170
+ .ensure_max_of_audio(1.0)
171
+ )
172
+ return signal
173
+
174
+ @torch.inference_mode()
175
+ def encode(self, signal: AudioSignal):
176
+ signal = self.preprocess(signal).to(self.device)
177
+ z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
178
+ return z
179
+
180
+ def snap_to_beats(
181
+ self,
182
+ signal: AudioSignal
183
+ ):
184
+ assert hasattr(self, "beat_tracker"), "No beat tracker loaded"
185
+ beats, downbeats = self.beat_tracker.extract_beats(signal)
186
+
187
+ # trim the signa around the first beat time
188
+ samples_begin = int(beats[0] * signal.sample_rate )
189
+ samples_end = int(beats[-1] * signal.sample_rate)
190
+ print(beats[0])
191
+ signal = signal.clone().trim(samples_begin, signal.length - samples_end)
192
+
193
+ return signal
194
+
195
+ def make_beat_mask(self,
196
+ signal: AudioSignal,
197
+ before_beat_s: float = 0.0,
198
+ after_beat_s: float = 0.02,
199
+ mask_downbeats: bool = True,
200
+ mask_upbeats: bool = True,
201
+ downbeat_downsample_factor: int = None,
202
+ beat_downsample_factor: int = None,
203
+ dropout: float = 0.0,
204
+ invert: bool = True,
205
+ ):
206
+ """make a beat synced mask. that is, make a mask that
207
+ places 1s at and around the beat, and 0s everywhere else.
208
+ """
209
+ assert self.beat_tracker is not None, "No beat tracker loaded"
210
+
211
+ # get the beat times
212
+ beats, downbeats = self.beat_tracker.extract_beats(signal)
213
+
214
+ # get the beat indices in z
215
+ beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
216
+
217
+ # remove downbeats from beats
218
+ beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))]
219
+ beats_z = beats_z.tolist()
220
+ downbeats_z = downbeats_z.tolist()
221
+
222
+ # make the mask
223
+ seq_len = self.s2t(signal.duration)
224
+ mask = torch.zeros(seq_len, device=self.device)
225
+
226
+ mask_b4 = self.s2t(before_beat_s)
227
+ mask_after = self.s2t(after_beat_s)
228
+
229
+ if beat_downsample_factor is not None:
230
+ if beat_downsample_factor < 1:
231
+ raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
232
+ else:
233
+ beat_downsample_factor = 1
234
+
235
+ if downbeat_downsample_factor is not None:
236
+ if downbeat_downsample_factor < 1:
237
+ raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
238
+ else:
239
+ downbeat_downsample_factor = 1
240
+
241
+ beats_z = beats_z[::beat_downsample_factor]
242
+ downbeats_z = downbeats_z[::downbeat_downsample_factor]
243
+ print(f"beats_z: {len(beats_z)}")
244
+ print(f"downbeats_z: {len(downbeats_z)}")
245
+
246
+ if mask_upbeats:
247
+ for beat_idx in beats_z:
248
+ _slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
249
+ num_steps = mask[_slice[0]:_slice[1]].shape[0]
250
+ _m = torch.ones(num_steps, device=self.device)
251
+ _m_mask = torch.bernoulli(_m * (1 - dropout))
252
+ _m = _m * _m_mask.long()
253
+
254
+ mask[_slice[0]:_slice[1]] = _m
255
+
256
+ if mask_downbeats:
257
+ for downbeat_idx in downbeats_z:
258
+ _slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
259
+ num_steps = mask[_slice[0]:_slice[1]].shape[0]
260
+ _m = torch.ones(num_steps, device=self.device)
261
+ _m_mask = torch.bernoulli(_m * (1 - dropout))
262
+ _m = _m * _m_mask.long()
263
+
264
+ mask[_slice[0]:_slice[1]] = _m
265
+
266
+ mask = mask.clamp(0, 1)
267
+ if invert:
268
+ mask = 1 - mask
269
+
270
+ mask = mask[None, None, :].bool().long()
271
+ if self.c2f is not None:
272
+ mask = mask.repeat(1, self.c2f.n_codebooks, 1)
273
+ else:
274
+ mask = mask.repeat(1, self.coarse.n_codebooks, 1)
275
+ return mask
276
+
277
+ def coarse_to_fine(
278
+ self,
279
+ z: torch.Tensor,
280
+ mask: torch.Tensor = None,
281
+ **kwargs
282
+ ):
283
+ assert self.c2f is not None, "No coarse2fine model loaded"
284
+ length = z.shape[-1]
285
+ chunk_len = self.s2t(self.c2f.chunk_size_s)
286
+ n_chunks = math.ceil(z.shape[-1] / chunk_len)
287
+
288
+ # zero pad to chunk_len
289
+ if length % chunk_len != 0:
290
+ pad_len = chunk_len - (length % chunk_len)
291
+ z = torch.nn.functional.pad(z, (0, pad_len))
292
+ mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None
293
+
294
+ n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
295
+ if n_codebooks_to_append > 0:
296
+ z = torch.cat([
297
+ z,
298
+ torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
299
+ ], dim=1)
300
+
301
+ # set the mask to 0 for all conditioning codebooks
302
+ if mask is not None:
303
+ mask = mask.clone()
304
+ mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
305
+
306
+ fine_z = []
307
+ for i in range(n_chunks):
308
+ chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
309
+ mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None
310
+
311
+ chunk = self.c2f.generate(
312
+ codec=self.codec,
313
+ time_steps=chunk_len,
314
+ start_tokens=chunk,
315
+ return_signal=False,
316
+ mask=mask_chunk,
317
+ **kwargs
318
+ )
319
+ fine_z.append(chunk)
320
+
321
+ fine_z = torch.cat(fine_z, dim=-1)
322
+ return fine_z[:, :, :length].clone()
323
+
324
+ def coarse_vamp(
325
+ self,
326
+ z,
327
+ mask,
328
+ return_mask=False,
329
+ gen_fn=None,
330
+ **kwargs
331
+ ):
332
+ # coarse z
333
+ cz = z[:, : self.coarse.n_codebooks, :].clone()
334
+ assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
335
+
336
+ mask = mask[:, : self.coarse.n_codebooks, :]
337
+
338
+ cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token)
339
+ cz_masked = cz_masked[:, : self.coarse.n_codebooks, :]
340
+
341
+ gen_fn = gen_fn or self.coarse.generate
342
+ c_vamp = gen_fn(
343
+ codec=self.codec,
344
+ time_steps=cz.shape[-1],
345
+ start_tokens=cz,
346
+ mask=mask,
347
+ return_signal=False,
348
+ **kwargs
349
+ )
350
+
351
+ # add the fine codes back in
352
+ c_vamp = torch.cat(
353
+ [c_vamp, z[:, self.coarse.n_codebooks :, :]],
354
+ dim=1
355
+ )
356
+
357
+ if return_mask:
358
+ return c_vamp, cz_masked
359
+
360
+ return c_vamp
361
+
362
+ # def chunked_coarse_vamp(
363
+ # self,
364
+ # z,
365
+ # mask,
366
+ # return_mask=False,
367
+ # gen_fn=None,
368
+ # **kwargs
369
+ # )
370
+
371
+
372
+ if __name__ == "__main__":
373
+ import audiotools as at
374
+ import logging
375
+ logger = logging.getLogger()
376
+ logger.setLevel(logging.INFO)
377
+ torch.set_printoptions(threshold=10000)
378
+ at.util.seed(42)
379
+
380
+ interface = Interface(
381
+ coarse_ckpt="./models/vampnet/coarse.pth",
382
+ coarse2fine_ckpt="./models/vampnet/c2f.pth",
383
+ codec_ckpt="./models/vampnet/codec.pth",
384
+ device="cuda",
385
+ wavebeat_ckpt="./models/wavebeat.pth"
386
+ )
387
+
388
+
389
+ sig = at.AudioSignal('assets/example.wav')
390
+
391
+ z = interface.encode(sig)
392
+ breakpoint()
393
+
394
+ # mask = linear_random(z, 1.0)
395
+ # mask = mask_and(
396
+ # mask, periodic_mask(
397
+ # z,
398
+ # 32,
399
+ # 1,
400
+ # random_roll=True
401
+ # )
402
+ # )
403
+
404
+ # mask = interface.make_beat_mask(
405
+ # sig, 0.0, 0.075
406
+ # )
407
+ # mask = dropout(mask, 0.0)
408
+ # mask = codebook_unmask(mask, 0)
409
+
410
+ mask = inpaint(z, n_prefix=100, n_suffix=100)
411
+
412
+ zv, mask_z = interface.coarse_vamp(
413
+ z,
414
+ mask=mask,
415
+ sampling_steps=36,
416
+ temperature=8.0,
417
+ return_mask=True,
418
+ gen_fn=interface.coarse.generate
419
+ )
420
+
421
+
422
+ use_coarse2fine = True
423
+ if use_coarse2fine:
424
+ zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask)
425
+ breakpoint()
426
+
427
+ mask = interface.to_signal(mask_z).cpu()
428
+
429
+ sig = interface.to_signal(zv).cpu()
430
+ print("done")
431
+
432
+
vampnet/mask.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from audiotools import AudioSignal
5
+
6
+ from .util import scalar_to_batch_tensor
7
+
8
+ def _gamma(r):
9
+ return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
10
+
11
+ def _invgamma(y):
12
+ if not torch.is_tensor(y):
13
+ y = torch.tensor(y)[None]
14
+ return 2 * y.acos() / torch.pi
15
+
16
+ def full_mask(x: torch.Tensor):
17
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
18
+ return torch.ones_like(x).long()
19
+
20
+ def empty_mask(x: torch.Tensor):
21
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
22
+ return torch.zeros_like(x).long()
23
+
24
+ def apply_mask(
25
+ x: torch.Tensor,
26
+ mask: torch.Tensor,
27
+ mask_token: int
28
+ ):
29
+ assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
30
+ assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
31
+ assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
32
+ assert ~torch.any(mask > 1), "mask must be binary"
33
+ assert ~torch.any(mask < 0), "mask must be binary"
34
+
35
+ fill_x = torch.full_like(x, mask_token)
36
+ x = x * (1 - mask) + fill_x * mask
37
+
38
+ return x, mask
39
+
40
+ def random(
41
+ x: torch.Tensor,
42
+ r: torch.Tensor
43
+ ):
44
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
45
+ if not isinstance(r, torch.Tensor):
46
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
47
+
48
+ r = _gamma(r)[:, None, None]
49
+ probs = torch.ones_like(x) * r
50
+
51
+ mask = torch.bernoulli(probs)
52
+ mask = mask.round().long()
53
+
54
+ return mask
55
+
56
+ def linear_random(
57
+ x: torch.Tensor,
58
+ r: torch.Tensor,
59
+ ):
60
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
61
+ if not isinstance(r, torch.Tensor):
62
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
63
+
64
+ probs = torch.ones_like(x).to(x.device).float()
65
+ # expand to batch and codebook dims
66
+ probs = probs.expand(x.shape[0], x.shape[1], -1)
67
+ probs = probs * r
68
+
69
+ mask = torch.bernoulli(probs)
70
+ mask = mask.round().long()
71
+
72
+ return mask
73
+
74
+ def inpaint(x: torch.Tensor,
75
+ n_prefix,
76
+ n_suffix,
77
+ ):
78
+ assert n_prefix is not None
79
+ assert n_suffix is not None
80
+
81
+ mask = full_mask(x)
82
+
83
+ # if we have a prefix or suffix, set their mask prob to 0
84
+ if n_prefix > 0:
85
+ if not isinstance(n_prefix, torch.Tensor):
86
+ n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
87
+ for i, n in enumerate(n_prefix):
88
+ if n > 0:
89
+ mask[i, :, :n] = 0.0
90
+ if n_suffix > 0:
91
+ if not isinstance(n_suffix, torch.Tensor):
92
+ n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
93
+ for i, n in enumerate(n_suffix):
94
+ if n > 0:
95
+ mask[i, :, -n:] = 0.0
96
+
97
+
98
+ return mask
99
+
100
+ def periodic_mask(x: torch.Tensor,
101
+ period: int, width: int = 1,
102
+ random_roll=False,
103
+ ):
104
+ mask = full_mask(x)
105
+ if period == 0:
106
+ return mask
107
+
108
+ if not isinstance(period, torch.Tensor):
109
+ period = scalar_to_batch_tensor(period, x.shape[0])
110
+ for i, factor in enumerate(period):
111
+ if factor == 0:
112
+ continue
113
+ for j in range(mask.shape[-1]):
114
+ if j % factor == 0:
115
+ # figure out how wide the mask should be
116
+ j_start = max(0, j - width // 2 )
117
+ j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1
118
+ # flip a coin for each position in the mask
119
+ j_mask = torch.bernoulli(torch.ones(j_end - j_start))
120
+ assert torch.all(j_mask == 1)
121
+ j_fill = torch.ones_like(j_mask) * (1 - j_mask)
122
+ assert torch.all(j_fill == 0)
123
+ # fill
124
+ mask[i, :, j_start:j_end] = j_fill
125
+ if random_roll:
126
+ # add a random offset to the mask
127
+ offset = torch.randint(0, period[0], (1,))
128
+ mask = torch.roll(mask, offset.item(), dims=-1)
129
+
130
+ return mask
131
+
132
+ def codebook_unmask(
133
+ mask: torch.Tensor,
134
+ n_conditioning_codebooks: int
135
+ ):
136
+ if n_conditioning_codebooks == None:
137
+ return mask
138
+ # if we have any conditioning codebooks, set their mask to 0
139
+ mask = mask.clone()
140
+ mask[:, :n_conditioning_codebooks, :] = 0
141
+ return mask
142
+
143
+ def codebook_mask(mask: torch.Tensor, start: int):
144
+ mask = mask.clone()
145
+ mask[:, start:, :] = 1
146
+ return mask
147
+
148
+ def mask_and(
149
+ mask1: torch.Tensor,
150
+ mask2: torch.Tensor
151
+ ):
152
+ assert mask1.shape == mask2.shape, "masks must be same shape"
153
+ return torch.min(mask1, mask2)
154
+
155
+ def dropout(
156
+ mask: torch.Tensor,
157
+ p: float,
158
+ ):
159
+ assert 0 <= p <= 1, "p must be between 0 and 1"
160
+ assert mask.max() <= 1, "mask must be binary"
161
+ assert mask.min() >= 0, "mask must be binary"
162
+ mask = (~mask.bool()).float()
163
+ mask = torch.bernoulli(mask * (1 - p))
164
+ mask = ~mask.round().bool()
165
+ return mask.long()
166
+
167
+ def mask_or(
168
+ mask1: torch.Tensor,
169
+ mask2: torch.Tensor
170
+ ):
171
+ assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
172
+ assert mask1.max() <= 1, "mask1 must be binary"
173
+ assert mask2.max() <= 1, "mask2 must be binary"
174
+ assert mask1.min() >= 0, "mask1 must be binary"
175
+ assert mask2.min() >= 0, "mask2 must be binary"
176
+ return (mask1 + mask2).clamp(0, 1)
177
+
178
+ def time_stretch_mask(
179
+ x: torch.Tensor,
180
+ stretch_factor: int,
181
+ ):
182
+ assert stretch_factor >= 1, "stretch factor must be >= 1"
183
+ c_seq_len = x.shape[-1]
184
+ x = x.repeat_interleave(stretch_factor, dim=-1)
185
+
186
+ # trim cz to the original length
187
+ x = x[:, :, :c_seq_len]
188
+
189
+ mask = periodic_mask(x, stretch_factor, width=1)
190
+ return mask
191
+
192
+ def onset_mask(
193
+ sig: AudioSignal,
194
+ z: torch.Tensor,
195
+ interface,
196
+ width: int = 1
197
+ ):
198
+ import librosa
199
+ import madmom
200
+ from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor
201
+ import tempfile
202
+ import numpy as np
203
+
204
+ with tempfile.NamedTemporaryFile(suffix='.wav') as f:
205
+ sig = sig.clone()
206
+ sig.write(f.name)
207
+
208
+ proc = RNNOnsetProcessor(online=False)
209
+ onsetproc = OnsetPeakPickingProcessor(threshold=0.3,
210
+ fps=sig.sample_rate/interface.codec.hop_length)
211
+
212
+ act = proc(f.name)
213
+ onset_times = onsetproc(act)
214
+
215
+ # convert to indices for z array
216
+ onset_indices = librosa.time_to_frames(onset_times, sr=sig.sample_rate, hop_length=interface.codec.hop_length)
217
+
218
+ if onset_indices.shape[0] == 0:
219
+ mask = empty_mask(z)
220
+ print(f"no onsets found, returning empty mask")
221
+ else:
222
+ torch.set_printoptions(threshold=1000)
223
+ print("onset indices: ", onset_indices)
224
+ print("onset times: ", onset_times)
225
+
226
+ # create a mask, set onset
227
+ mask = torch.ones_like(z)
228
+ n_timesteps = z.shape[-1]
229
+
230
+ for onset_index in onset_indices:
231
+ onset_index = min(onset_index, n_timesteps - 1)
232
+ onset_index = max(onset_index, 0)
233
+ mask[:, :, onset_index - width:onset_index + width] = 0.0
234
+
235
+ print(mask)
236
+
237
+ return mask
238
+
239
+
240
+
241
+ if __name__ == "__main__":
242
+ pass
vampnet/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import audiotools
2
+
3
+ audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
4
+ audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
5
+
6
+ from .transformer import VampNet
vampnet/modules/activations.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ class NewGELU(nn.Module):
10
+ """
11
+ Implementation of the GELU activation function currently in Google BERT repo
12
+ (identical to OpenAI GPT). Also see the Gaussian Error Linear Units
13
+ paper: https://arxiv.org/abs/1606.08415
14
+ """
15
+
16
+ def forward(self, x):
17
+ return (
18
+ 0.5
19
+ * x
20
+ * (
21
+ 1.0
22
+ + torch.tanh(
23
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
24
+ )
25
+ )
26
+ )
27
+
28
+ class GatedGELU(nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.gelu = NewGELU()
32
+
33
+ def forward(self, x, dim: int = -1):
34
+ p1, p2 = x.chunk(2, dim=dim)
35
+ return p1 * self.gelu(p2)
36
+
37
+ class Snake1d(nn.Module):
38
+ def __init__(self, channels):
39
+ super().__init__()
40
+ self.alpha = nn.Parameter(torch.ones(channels))
41
+
42
+ def forward(self, x):
43
+ return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2)
44
+
45
+ def get_activation(name: str = "relu"):
46
+ if name == "relu":
47
+ return nn.ReLU
48
+ elif name == "gelu":
49
+ return NewGELU
50
+ elif name == "geglu":
51
+ return GatedGELU
52
+ elif name == "snake":
53
+ return Snake1d
54
+ else:
55
+ raise ValueError(f"Unrecognized activation {name}")
vampnet/modules/layers.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Optional
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch.nn.utils import weight_norm
10
+
11
+ # Scripting this brings model speed up 1.4x
12
+ @torch.jit.script
13
+ def snake(x, alpha):
14
+ shape = x.shape
15
+ x = x.reshape(shape[0], shape[1], -1)
16
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
17
+ x = x.reshape(shape)
18
+ return x
19
+
20
+
21
+ class Snake1d(nn.Module):
22
+ def __init__(self, channels):
23
+ super().__init__()
24
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
25
+
26
+ def forward(self, x):
27
+ return snake(x, self.alpha)
28
+
29
+
30
+ def num_params(model):
31
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
32
+
33
+
34
+ def recurse_children(module, fn):
35
+ for child in module.children():
36
+ if isinstance(child, nn.ModuleList):
37
+ for c in child:
38
+ yield recurse_children(c, fn)
39
+ if isinstance(child, nn.ModuleDict):
40
+ for c in child.values():
41
+ yield recurse_children(c, fn)
42
+
43
+ yield recurse_children(child, fn)
44
+ yield fn(child)
45
+
46
+
47
+ def WNConv1d(*args, **kwargs):
48
+ return weight_norm(nn.Conv1d(*args, **kwargs))
49
+
50
+
51
+ def WNConvTranspose1d(*args, **kwargs):
52
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
53
+
54
+
55
+ class SequentialWithFiLM(nn.Module):
56
+ """
57
+ handy wrapper for nn.Sequential that allows FiLM layers to be
58
+ inserted in between other layers.
59
+ """
60
+
61
+ def __init__(self, *layers):
62
+ super().__init__()
63
+ self.layers = nn.ModuleList(layers)
64
+
65
+ @staticmethod
66
+ def has_film(module):
67
+ mod_has_film = any(
68
+ [res for res in recurse_children(module, lambda c: isinstance(c, FiLM))]
69
+ )
70
+ return mod_has_film
71
+
72
+ def forward(self, x, cond):
73
+ for layer in self.layers:
74
+ if self.has_film(layer):
75
+ x = layer(x, cond)
76
+ else:
77
+ x = layer(x)
78
+ return x
79
+
80
+
81
+ class FiLM(nn.Module):
82
+ def __init__(self, input_dim: int, output_dim: int):
83
+ super().__init__()
84
+
85
+ self.input_dim = input_dim
86
+ self.output_dim = output_dim
87
+
88
+ if input_dim > 0:
89
+ self.beta = nn.Linear(input_dim, output_dim)
90
+ self.gamma = nn.Linear(input_dim, output_dim)
91
+
92
+ def forward(self, x, r):
93
+ if self.input_dim == 0:
94
+ return x
95
+ else:
96
+ beta, gamma = self.beta(r), self.gamma(r)
97
+ beta, gamma = (
98
+ beta.view(x.size(0), self.output_dim, 1),
99
+ gamma.view(x.size(0), self.output_dim, 1),
100
+ )
101
+ x = x * (gamma + 1) + beta
102
+ return x
103
+
104
+
105
+ class CodebookEmbedding(nn.Module):
106
+ def __init__(
107
+ self,
108
+ vocab_size: int,
109
+ latent_dim: int,
110
+ n_codebooks: int,
111
+ emb_dim: int,
112
+ special_tokens: Optional[Tuple[str]] = None,
113
+ ):
114
+ super().__init__()
115
+ self.n_codebooks = n_codebooks
116
+ self.emb_dim = emb_dim
117
+ self.latent_dim = latent_dim
118
+ self.vocab_size = vocab_size
119
+
120
+ if special_tokens is not None:
121
+ for tkn in special_tokens:
122
+ self.special = nn.ParameterDict(
123
+ {
124
+ tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim))
125
+ for tkn in special_tokens
126
+ }
127
+ )
128
+ self.special_idxs = {
129
+ tkn: i + vocab_size for i, tkn in enumerate(special_tokens)
130
+ }
131
+
132
+ self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
133
+
134
+ def from_codes(self, codes: torch.Tensor, codec):
135
+ """
136
+ get a sequence of continuous embeddings from a sequence of discrete codes.
137
+ unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
138
+ necessary for the language model, like <MASK>.
139
+ """
140
+ n_codebooks = codes.shape[1]
141
+ latent = []
142
+ for i in range(n_codebooks):
143
+ c = codes[:, i, :]
144
+
145
+ lookup_table = codec.quantizer.quantizers[i].codebook.weight
146
+ if hasattr(self, "special"):
147
+ special_lookup = torch.cat(
148
+ [self.special[tkn][i : i + 1] for tkn in self.special], dim=0
149
+ )
150
+ lookup_table = torch.cat([lookup_table, special_lookup], dim=0)
151
+
152
+ l = F.embedding(c, lookup_table).transpose(1, 2)
153
+ latent.append(l)
154
+
155
+ latent = torch.cat(latent, dim=1)
156
+ return latent
157
+
158
+ def forward(self, latents: torch.Tensor):
159
+ """
160
+ project a sequence of latents to a sequence of embeddings
161
+ """
162
+ x = self.out_proj(latents)
163
+ return x
164
+
vampnet/modules/transformer.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import logging
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ import loralib as lora
11
+ import audiotools as at
12
+
13
+ from .activations import get_activation
14
+ from .layers import CodebookEmbedding
15
+ from .layers import FiLM
16
+ from .layers import SequentialWithFiLM
17
+ from .layers import WNConv1d
18
+ from ..util import scalar_to_batch_tensor, codebook_flatten, codebook_unflatten
19
+ from ..mask import _gamma
20
+
21
+ LORA_R = 8
22
+
23
+ # def log(t, eps=1e-20):
24
+ # return torch.log(t + eps)
25
+
26
+
27
+ def gumbel_noise_like(t):
28
+ noise = torch.zeros_like(t).uniform_(1e-20, 1)
29
+ return -torch.log(-torch.log(noise))
30
+
31
+
32
+ def gumbel_sample(t, temperature=1.0, dim=-1):
33
+ return ((t / max(temperature, 1e-10)) + gumbel_noise_like(t)).argmax(dim=dim)
34
+
35
+
36
+ class RMSNorm(nn.Module):
37
+ def __init__(self, hidden_size: int, eps=1e-6):
38
+ super().__init__()
39
+ self.weight = nn.Parameter(torch.ones(hidden_size))
40
+ self.var_eps = eps
41
+
42
+ def forward(self, x):
43
+ """Returns root mean square normalized version of input `x`
44
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known
45
+ # as Root Mean Square Layer Normalization https://arxiv.org/abs/1910.07467
46
+ # thus varience is calculated w/o mean and there is no bias
47
+ Parameters
48
+ ----------
49
+ x : Tensor[B x T x D]
50
+ Returns
51
+ -------
52
+ Tensor[B x T x D]
53
+ """
54
+ var = x.pow(2).mean(-1, keepdim=True)
55
+ x = x * torch.rsqrt(var + self.var_eps)
56
+
57
+ return self.weight * x
58
+
59
+
60
+ class FeedForward(nn.Module):
61
+ def __init__(
62
+ self, d_model: int = 512, dropout: float = 0.1, activation: str = "geglu"
63
+ ):
64
+ super().__init__()
65
+ factor = 2 if activation == "geglu" else 1
66
+ self.w_1 = lora.Linear(d_model, d_model * 4, bias=False, r=LORA_R)
67
+ self.w_2 = lora.Linear(d_model * 4 // factor, d_model, bias=False, r=LORA_R)
68
+ self.drop = nn.Dropout(dropout)
69
+ self.act = get_activation(activation)()
70
+
71
+ def forward(self, x):
72
+ """Computes position-wise feed-forward layer
73
+ Parameters
74
+ ----------
75
+ x : Tensor[B x T x D]
76
+ Returns
77
+ -------
78
+ Tensor[B x T x D]
79
+ """
80
+ x = self.w_1(x)
81
+ x = self.act(x)
82
+ x = self.drop(x)
83
+ x = self.w_2(x)
84
+ return x
85
+
86
+
87
+ class MultiHeadRelativeAttention(nn.Module):
88
+ def __init__(
89
+ self,
90
+ n_head: int = 8,
91
+ d_model: int = 512,
92
+ dropout: float = 0.1,
93
+ bidirectional: bool = True,
94
+ has_relative_attention_bias: bool = True,
95
+ attention_num_buckets: int = 32,
96
+ attention_max_distance: int = 128,
97
+ ):
98
+ super().__init__()
99
+ d_head = d_model // n_head
100
+ self.n_head = n_head
101
+ self.d_head = d_head
102
+ self.bidirectional = bidirectional
103
+ self.has_relative_attention_bias = has_relative_attention_bias
104
+ self.attention_num_buckets = attention_num_buckets
105
+ self.attention_max_distance = attention_max_distance
106
+
107
+ # Create linear query, key, value projections
108
+ self.w_qs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
109
+ self.w_ks = nn.Linear(d_model, d_model, bias=False)
110
+ self.w_vs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
111
+
112
+ # Create linear final output projection
113
+ self.fc = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
114
+
115
+ # Dropout for attention output weights
116
+ self.dropout = nn.Dropout(dropout)
117
+
118
+ # Create relative positional embeddings (if turned on)
119
+ if has_relative_attention_bias:
120
+ self.relative_attention_bias = nn.Embedding(attention_num_buckets, n_head)
121
+
122
+ def _relative_position_bucket(self, relative_position):
123
+ """Converts unbounded relative position into bounded set of buckets
124
+ with half "exact" buckets (1 position = 1 bucket) and half "log-spaced"
125
+ buckets
126
+ Parameters
127
+ ----------
128
+ relative_position : Tensor[T_q x T_kv]
129
+ Relative positions between queries and key_value items
130
+ Returns
131
+ -------
132
+ Tensor[T_q x T_kv]
133
+ Input relative positions converted into buckets
134
+ """
135
+ relative_buckets = 0
136
+ num_buckets = self.attention_num_buckets
137
+ max_distance = self.attention_max_distance
138
+
139
+ # Convert relative position for (-inf, inf) to [0, inf]
140
+ # Negative relative positions correspond to past
141
+ # Positive relative positions correspond to future
142
+ if self.bidirectional:
143
+ # use half buckets for each side (past / future)
144
+ num_buckets //= 2
145
+
146
+ # Shift the position positions by `num_buckets` to wrap around
147
+ # negative positions
148
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
149
+ relative_position = torch.abs(relative_position)
150
+ else:
151
+ # If not bidirectional, ignore positive positions and wrap
152
+ # negative positions to positive
153
+ relative_position = -torch.min(
154
+ relative_position, torch.zeros_like(relative_position)
155
+ )
156
+
157
+ # Allocate half of the buckets are for exact increments in positions
158
+ max_exact = num_buckets // 2
159
+ is_small = relative_position < max_exact
160
+
161
+ # The other half of the buckets are for logarithmically bigger bins in
162
+ # positions up to `max_distance`
163
+ relative_postion_if_large = max_exact + (
164
+ torch.log(relative_position.float() / max_exact)
165
+ / math.log(max_distance / max_exact)
166
+ * (num_buckets - max_exact)
167
+ ).to(torch.long)
168
+
169
+ # Clip the max relative position to `num_buckets - 1`
170
+ relative_postion_if_large = torch.min(
171
+ relative_postion_if_large,
172
+ torch.full_like(relative_postion_if_large, num_buckets - 1),
173
+ )
174
+
175
+ # Choose relative buckets based on small or large positions
176
+ relative_buckets += torch.where(
177
+ is_small, relative_position, relative_postion_if_large
178
+ )
179
+
180
+ return relative_buckets
181
+
182
+ def compute_bias(self, query_length, key_length):
183
+ """Computes a position bias scalar for each index in query_length x key_length
184
+ Parameters
185
+ ----------
186
+ query_length : int
187
+ key_length : int
188
+ Returns
189
+ -------
190
+ Tensor[heads x 1 x T_q x T_kv]
191
+ Position bias to be applied on attention logits
192
+ """
193
+
194
+ query_position = torch.arange(query_length, dtype=torch.long)[:, None]
195
+ key_position = torch.arange(key_length, dtype=torch.long)[None, :]
196
+ relative_position = key_position - query_position
197
+
198
+ # Convert relative position to buckets
199
+ relative_position_bucket = self._relative_position_bucket(relative_position)
200
+ relative_position_bucket = relative_position_bucket.to(
201
+ self.relative_attention_bias.weight.device
202
+ )
203
+
204
+ # Index attention bias values
205
+ values = self.relative_attention_bias(relative_position_bucket)
206
+ values = rearrange(values, "q k h -> h 1 q k")
207
+
208
+ return values
209
+
210
+ def forward(self, q, k, v, mask=None, position_bias=None):
211
+ """Computes attention over (keys, values) for every timestep in query
212
+ Parameters
213
+ ----------
214
+ q : Tensor[B x T_q x d_model]
215
+ Query vectors
216
+ k : Tensor[B x T_kv x d_model]
217
+ Key vectors to compute attention over
218
+ v : Tensor[B x T_kv x d_model]
219
+ Value vectors corresponding to the keys
220
+ mask : Tensor[B x T_q x T_kv], optional
221
+ position_bias: Tensor[head x 1 x T_q x T_kv]
222
+ Returns
223
+ -------
224
+ Tensor[B x T_q x d_model]
225
+ Outputs after attending (key, value) using queries
226
+ """
227
+ # Compute query, key, value projections
228
+ q = rearrange(self.w_qs(q), "b l (head k) -> head b l k", head=self.n_head)
229
+ k = rearrange(self.w_ks(k), "b t (head k) -> head b t k", head=self.n_head)
230
+ v = rearrange(self.w_vs(v), "b t (head k) -> head b t k", head=self.n_head)
231
+
232
+ # Compute attention matrix
233
+ attn = torch.einsum("hblk,hbtk->hblt", [q, k]) / np.sqrt(q.shape[-1])
234
+
235
+ # Add relative position bias to attention scores
236
+ if position_bias is None:
237
+ if self.has_relative_attention_bias:
238
+ position_bias = self.compute_bias(q.size(-2), k.size(-2))
239
+ else:
240
+ position_bias = torch.zeros_like(attn)
241
+ attn += position_bias
242
+
243
+ # Apply mask to attention scores to prevent looking up invalid locations
244
+ if mask is not None:
245
+ attn = attn.masked_fill(mask[None] == 0, -1e9)
246
+
247
+ # Normalize attention scores and add dropout
248
+ attn = torch.softmax(attn, dim=3)
249
+ attn = self.dropout(attn)
250
+
251
+ # Compute attended outputs (product of attention matrix and values)
252
+ output = torch.einsum("hblt,hbtv->hblv", [attn, v])
253
+ output = rearrange(output, "head b l v -> b l (head v)")
254
+ output = self.fc(output)
255
+
256
+ return output, position_bias
257
+
258
+
259
+ class TransformerLayer(nn.Module):
260
+ def __init__(
261
+ self,
262
+ d_model: int = 512,
263
+ d_cond: int = 64,
264
+ n_heads: int = 8,
265
+ bidirectional: bool = True,
266
+ is_decoder: bool = False,
267
+ has_relative_attention_bias: bool = False,
268
+ flash_attn: bool = False,
269
+ dropout: float = 0.1,
270
+ ):
271
+ super().__init__()
272
+ # Store args
273
+ self.is_decoder = is_decoder
274
+
275
+ # Create self-attention layer
276
+ self.norm_1 = RMSNorm(d_model)
277
+ self.film_1 = FiLM(d_cond, d_model)
278
+ self.flash_attn = flash_attn
279
+
280
+ if flash_attn:
281
+ from flash_attn.flash_attention import FlashMHA
282
+ self.self_attn = FlashMHA(
283
+ embed_dim=d_model,
284
+ num_heads=n_heads,
285
+ attention_dropout=dropout,
286
+ causal=False,
287
+ )
288
+ else:
289
+ self.self_attn = MultiHeadRelativeAttention(
290
+ n_heads, d_model, dropout, bidirectional, has_relative_attention_bias
291
+ )
292
+
293
+ # (Optional) Create cross-attention layer
294
+ if is_decoder:
295
+ self.norm_2 = RMSNorm(d_model)
296
+ self.film_2 = FiLM(d_cond, d_model)
297
+ self.cross_attn = MultiHeadRelativeAttention(
298
+ n_heads,
299
+ d_model,
300
+ dropout,
301
+ bidirectional=True,
302
+ has_relative_attention_bias=False,
303
+ )
304
+
305
+ # Create last feed-forward layer
306
+ self.norm_3 = RMSNorm(d_model)
307
+ self.film_3 = FiLM(d_cond, d_model)
308
+ self.feed_forward = FeedForward(d_model=d_model, dropout=dropout)
309
+
310
+ # Create dropout
311
+ self.dropout = nn.Dropout(dropout)
312
+
313
+ def forward(
314
+ self,
315
+ x,
316
+ x_mask,
317
+ cond,
318
+ src=None,
319
+ src_mask=None,
320
+ position_bias=None,
321
+ encoder_decoder_position_bias=None,
322
+ ):
323
+ """Computes one transformer layer consisting of self attention, (op) cross attention
324
+ and feedforward layer
325
+ Parameters
326
+ ----------
327
+ x : Tensor[B x T_q x D]
328
+ x_mask : Tensor[B x T_q]
329
+ src : Tensor[B x T_kv x D], optional
330
+ src_mask : Tensor[B x T_kv x D], optional
331
+ position_bias : Tensor[heads x B x T_q x T_q], optional
332
+ Relative position bias for self attention layer
333
+ encoder_decoder_position_bias : Tensor[heads x B x T_q x T_kv], optional
334
+ Relative position bias for cross attention layer
335
+ Returns
336
+ -------
337
+ Tensor[B x T_q x D]
338
+ """
339
+ y = self.norm_1(x)
340
+ y = self.film_1(y.permute(0, 2, 1), cond).permute(0, 2, 1)
341
+ if self.flash_attn:
342
+ with torch.autocast(y.device.type, dtype=torch.bfloat16):
343
+ y = self.self_attn(y)[0]
344
+ else:
345
+ y, position_bias = self.self_attn(y, y, y, x_mask, position_bias)
346
+ x = x + self.dropout(y)
347
+
348
+ if self.is_decoder:
349
+ y = self.norm_2(x)
350
+ y = self.film_2(y.permute(0, 2, 1), cond).permute(0, 2, 1)
351
+ y, encoder_decoder_position_bias = self.cross_attn(
352
+ y, src, src, src_mask, encoder_decoder_position_bias
353
+ )
354
+ x = x + self.dropout(y)
355
+
356
+ y = self.norm_3(x)
357
+ y = self.film_3(
358
+ y.permute(
359
+ 0,
360
+ 2,
361
+ 1,
362
+ ),
363
+ cond,
364
+ ).permute(0, 2, 1)
365
+ y = self.feed_forward(y)
366
+ x = x + self.dropout(y)
367
+
368
+ return x, position_bias, encoder_decoder_position_bias
369
+
370
+
371
+ class TransformerStack(nn.Module):
372
+ def __init__(
373
+ self,
374
+ d_model: int = 512,
375
+ d_cond: int = 64,
376
+ n_heads: int = 8,
377
+ n_layers: int = 8,
378
+ last_layer: bool = True,
379
+ bidirectional: bool = True,
380
+ flash_attn: bool = False,
381
+ is_decoder: bool = False,
382
+ dropout: float = 0.1,
383
+ ):
384
+ super().__init__()
385
+ # Store args
386
+ self.bidirectional = bidirectional
387
+ self.is_decoder = is_decoder
388
+
389
+ # Create transformer layers
390
+ # In T5, relative attention bias is shared by all layers in the stack
391
+ self.layers = nn.ModuleList(
392
+ [
393
+ TransformerLayer(
394
+ d_model,
395
+ d_cond,
396
+ n_heads,
397
+ bidirectional,
398
+ is_decoder,
399
+ has_relative_attention_bias=True if (i == 0) else False,
400
+ flash_attn=flash_attn,
401
+ dropout=dropout,
402
+ )
403
+ for i in range(n_layers)
404
+ ]
405
+ )
406
+
407
+ # Perform last normalization
408
+ self.norm = RMSNorm(d_model) if last_layer else None
409
+
410
+ def subsequent_mask(self, size):
411
+ return torch.ones(1, size, size).tril().bool()
412
+
413
+ def forward(self, x, x_mask, cond=None, src=None, src_mask=None,
414
+ return_activations: bool = False
415
+ ):
416
+ """Computes a full transformer stack
417
+ Parameters
418
+ ----------
419
+ x : Tensor[B x T_q x D]
420
+ x_mask : Tensor[B x T_q]
421
+ src : Tensor[B x T_kv x D], optional
422
+ src_mask : Tensor[B x T_kv], optional
423
+ Returns
424
+ -------
425
+ Tensor[B x T_q x D]
426
+ """
427
+
428
+ # Convert `src_mask` to (B x T_q x T_kv) shape for cross attention masking
429
+ if self.is_decoder:
430
+ src_mask = x_mask.unsqueeze(-1) * src_mask.unsqueeze(-2)
431
+
432
+ # Convert `x_mask` to (B x T_q x T_q) shape for self attention masking
433
+ x_mask = x_mask.unsqueeze(-2)
434
+ if not self.bidirectional:
435
+ x_mask = x_mask * self.subsequent_mask(x.size(1)).to(x_mask.device)
436
+
437
+ # Initialize position biases
438
+ position_bias = None
439
+ encoder_decoder_position_bias = None
440
+
441
+ # Compute transformer layers
442
+ if return_activations:
443
+ activations = []
444
+ for layer in self.layers:
445
+ x, position_bias, encoder_decoder_position_bias = layer(
446
+ x=x,
447
+ x_mask=x_mask,
448
+ cond=cond,
449
+ src=src,
450
+ src_mask=src_mask,
451
+ position_bias=position_bias,
452
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
453
+ )
454
+ if return_activations:
455
+ activations.append(x.detach())
456
+
457
+
458
+ out = self.norm(x) if self.norm is not None else x
459
+ if return_activations:
460
+ return out, torch.stack(activations)
461
+ else:
462
+ return out
463
+
464
+
465
+ class VampNet(at.ml.BaseModel):
466
+ def __init__(
467
+ self,
468
+ n_heads: int = 20,
469
+ n_layers: int = 16,
470
+ r_cond_dim: int = 0,
471
+ n_codebooks: int = 9,
472
+ n_conditioning_codebooks: int = 0,
473
+ latent_dim: int = 8,
474
+ embedding_dim: int = 1280,
475
+ vocab_size: int = 1024,
476
+ flash_attn: bool = True,
477
+ noise_mode: str = "mask",
478
+ dropout: float = 0.1
479
+ ):
480
+ super().__init__()
481
+ assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
482
+ self.n_heads = n_heads
483
+ self.n_layers = n_layers
484
+ self.r_cond_dim = r_cond_dim
485
+ self.n_codebooks = n_codebooks
486
+ self.n_conditioning_codebooks = n_conditioning_codebooks
487
+ self.embedding_dim = embedding_dim
488
+ self.vocab_size = vocab_size
489
+ self.latent_dim = latent_dim
490
+ self.flash_attn = flash_attn
491
+ self.noise_mode = noise_mode
492
+
493
+ assert self.noise_mode == "mask", "deprecated"
494
+
495
+ self.embedding = CodebookEmbedding(
496
+ latent_dim=latent_dim,
497
+ n_codebooks=n_codebooks,
498
+ vocab_size=vocab_size,
499
+ emb_dim=embedding_dim,
500
+ special_tokens=["MASK"],
501
+ )
502
+ self.mask_token = self.embedding.special_idxs["MASK"]
503
+
504
+ self.transformer = TransformerStack(
505
+ d_model=embedding_dim,
506
+ d_cond=r_cond_dim,
507
+ n_heads=n_heads,
508
+ n_layers=n_layers,
509
+ last_layer=True,
510
+ bidirectional=True,
511
+ flash_attn=flash_attn,
512
+ is_decoder=False,
513
+ dropout=dropout,
514
+ )
515
+
516
+ # Add final conv layer
517
+ self.n_predict_codebooks = n_codebooks - n_conditioning_codebooks
518
+ self.classifier = SequentialWithFiLM(
519
+ WNConv1d(
520
+ embedding_dim,
521
+ vocab_size * self.n_predict_codebooks,
522
+ kernel_size=1,
523
+ padding="same",
524
+ # groups=self.n_predict_codebooks,
525
+ ),
526
+ )
527
+
528
+ def forward(self, x, return_activations: bool = False):
529
+ x = self.embedding(x)
530
+ x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
531
+
532
+ x = rearrange(x, "b d n -> b n d")
533
+ out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
534
+ if return_activations:
535
+ out, activations = out
536
+
537
+ out = rearrange(out, "b n d -> b d n")
538
+
539
+ out = self.classifier(out, None) # no cond here!
540
+
541
+ out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
542
+
543
+ if return_activations:
544
+ return out, activations
545
+ else:
546
+ return out
547
+
548
+ def r_embed(self, r, max_positions=10000):
549
+ if self.r_cond_dim > 0:
550
+ dtype = r.dtype
551
+
552
+ r = _gamma(r) * max_positions
553
+ half_dim = self.r_cond_dim // 2
554
+
555
+ emb = math.log(max_positions) / (half_dim - 1)
556
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
557
+
558
+ emb = r[:, None] * emb[None, :]
559
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
560
+
561
+ if self.r_cond_dim % 2 == 1: # zero pad
562
+ emb = nn.functional.pad(emb, (0, 1), mode="constant")
563
+
564
+ return emb.to(dtype)
565
+ else:
566
+ return r
567
+
568
+ @torch.no_grad()
569
+ def to_signal(self, z, codec):
570
+ """
571
+ convert a sequence of latents to a signal.
572
+ """
573
+ assert z.ndim == 3
574
+
575
+ signal = at.AudioSignal(
576
+ codec.decode(
577
+ codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
578
+ )["audio"],
579
+ codec.sample_rate,
580
+ )
581
+
582
+ # find where the mask token is and replace it with silence in the audio
583
+ for tstep in range(z.shape[-1]):
584
+ if torch.any(z[:, :, tstep] == self.mask_token):
585
+ sample_idx_0 = tstep * codec.hop_length
586
+ sample_idx_1 = sample_idx_0 + codec.hop_length
587
+ signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
588
+
589
+ return signal
590
+
591
+
592
+ @torch.no_grad()
593
+ def generate(
594
+ self,
595
+ codec,
596
+ time_steps: int = 300,
597
+ sampling_steps: int = 36,
598
+ start_tokens: Optional[torch.Tensor] = None,
599
+ sampling_temperature: float = 1.0,
600
+ mask: Optional[torch.Tensor] = None,
601
+ mask_temperature: float = 10.5,
602
+ typical_filtering=False,
603
+ typical_mass=0.2,
604
+ typical_min_tokens=1,
605
+ top_p=None,
606
+ return_signal=True,
607
+ seed: int = None,
608
+ sample_cutoff: float = 1.0,
609
+ ):
610
+ if seed is not None:
611
+ at.util.seed(seed)
612
+ logging.debug(f"beginning generation with {sampling_steps} steps")
613
+
614
+
615
+
616
+ #####################
617
+ # resolve initial z #
618
+ #####################
619
+ z = start_tokens
620
+
621
+ if z is None:
622
+ z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
623
+ self.device
624
+ )
625
+
626
+ logging.debug(f"created z with shape {z.shape}")
627
+
628
+
629
+ #################
630
+ # resolve mask #
631
+ #################
632
+
633
+ if mask is None:
634
+ mask = torch.ones_like(z).to(self.device).int()
635
+ mask[:, : self.n_conditioning_codebooks, :] = 0.0
636
+ if mask.ndim == 2:
637
+ mask = mask[:, None, :].repeat(1, z.shape[1], 1)
638
+ # init_mask = mask.clone()
639
+
640
+ logging.debug(f"created mask with shape {mask.shape}")
641
+
642
+
643
+ ###########
644
+ # set up #
645
+ ##########
646
+ # apply the mask to z
647
+ z_masked = z.masked_fill(mask.bool(), self.mask_token)
648
+ # logging.debug(f"z_masked: {z_masked}")
649
+
650
+ # how many mask tokens to begin with?
651
+ num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
652
+ logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
653
+
654
+ # how many codebooks are we inferring vs conditioning on?
655
+ n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
656
+ logging.debug(f"n infer codebooks: {n_infer_codebooks}")
657
+
658
+ #################
659
+ # begin sampling #
660
+ #################
661
+
662
+ for i in range(sampling_steps):
663
+ logging.debug(f"step {i} of {sampling_steps}")
664
+
665
+ # our current schedule step
666
+ r = scalar_to_batch_tensor(
667
+ (i + 1) / sampling_steps,
668
+ z.shape[0]
669
+ ).to(z.device)
670
+ logging.debug(f"r: {r}")
671
+
672
+ # get latents
673
+ latents = self.embedding.from_codes(z_masked, codec)
674
+ logging.debug(f"computed latents with shape: {latents.shape}")
675
+
676
+
677
+ # infer from latents
678
+ # NOTE: this collapses the codebook dimension into the sequence dimension
679
+ logits = self.forward(latents) # b, prob, seq
680
+ logits = logits.permute(0, 2, 1) # b, seq, prob
681
+ b = logits.shape[0]
682
+
683
+ logging.debug(f"permuted logits with shape: {logits.shape}")
684
+
685
+ sampled_z, selected_probs = sample_from_logits(
686
+ logits, sample=(
687
+ (i / sampling_steps) <= sample_cutoff
688
+ ),
689
+ temperature=sampling_temperature,
690
+ typical_filtering=typical_filtering, typical_mass=typical_mass,
691
+ typical_min_tokens=typical_min_tokens,
692
+ top_k=None, top_p=top_p, return_probs=True,
693
+ )
694
+
695
+ logging.debug(f"sampled z with shape: {sampled_z.shape}")
696
+
697
+ # flatten z_masked and mask, so we can deal with the sampling logic
698
+ # we'll unflatten them at the end of the loop for the next forward pass
699
+ # remove conditioning codebooks, we'll add them back at the end
700
+ z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
701
+
702
+ mask = (z_masked == self.mask_token).int()
703
+
704
+ # update the mask, remove conditioning codebooks from the mask
705
+ logging.debug(f"updated mask with shape: {mask.shape}")
706
+ # add z back into sampled z where the mask was false
707
+ sampled_z = torch.where(
708
+ mask.bool(), sampled_z, z_masked
709
+ )
710
+ logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
711
+
712
+ # ignore any tokens that weren't masked
713
+ selected_probs = torch.where(
714
+ mask.bool(), selected_probs, torch.inf
715
+ )
716
+
717
+ # get the num tokens to mask, according to the schedule
718
+ num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
719
+ logging.debug(f"num to mask: {num_to_mask}")
720
+
721
+ if i != (sampling_steps - 1):
722
+ num_to_mask = torch.maximum(
723
+ torch.tensor(1),
724
+ torch.minimum(
725
+ mask.sum(dim=-1, keepdim=True) - 1,
726
+ num_to_mask
727
+ )
728
+ )
729
+
730
+
731
+ # get our new mask
732
+ mask = mask_by_random_topk(
733
+ num_to_mask, selected_probs, mask_temperature * (1-r)
734
+ )
735
+
736
+ # update the mask
737
+ z_masked = torch.where(
738
+ mask.bool(), self.mask_token, sampled_z
739
+ )
740
+ logging.debug(f"updated z_masked with shape: {z_masked.shape}")
741
+
742
+ z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
743
+ mask = codebook_unflatten(mask, n_infer_codebooks)
744
+ logging.debug(f"unflattened z_masked with shape: {z_masked.shape}")
745
+
746
+ # add conditioning codebooks back to z_masked
747
+ z_masked = torch.cat(
748
+ (z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
749
+ )
750
+ logging.debug(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}")
751
+
752
+
753
+ # add conditioning codebooks back to sampled_z
754
+ sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
755
+ sampled_z = torch.cat(
756
+ (z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
757
+ )
758
+
759
+ logging.debug(f"finished sampling")
760
+
761
+ if return_signal:
762
+ return self.to_signal(sampled_z, codec)
763
+ else:
764
+ return sampled_z
765
+
766
+ def sample_from_logits(
767
+ logits,
768
+ sample: bool = True,
769
+ temperature: float = 1.0,
770
+ top_k: int = None,
771
+ top_p: float = None,
772
+ typical_filtering: bool = False,
773
+ typical_mass: float = 0.2,
774
+ typical_min_tokens: int = 1,
775
+ return_probs: bool = False
776
+ ):
777
+ """Convenience function to sample from a categorial distribution with input as
778
+ unnormalized logits.
779
+
780
+ Parameters
781
+ ----------
782
+ logits : Tensor[..., vocab_size]
783
+ config: SamplingConfig
784
+ The set of hyperparameters to be used for sampling
785
+ sample : bool, optional
786
+ Whether to perform multinomial sampling, by default True
787
+ temperature : float, optional
788
+ Scaling parameter when multinomial samping, by default 1.0
789
+ top_k : int, optional
790
+ Restricts sampling to only `top_k` values acc. to probability,
791
+ by default None
792
+ top_p : float, optional
793
+ Restricts sampling to only those values with cumulative
794
+ probability = `top_p`, by default None
795
+
796
+ Returns
797
+ -------
798
+ Tensor[...]
799
+ Sampled tokens
800
+ """
801
+ shp = logits.shape[:-1]
802
+
803
+ if typical_filtering:
804
+ typical_filter(logits,
805
+ typical_mass=typical_mass,
806
+ typical_min_tokens=typical_min_tokens
807
+ )
808
+
809
+ # Apply top_k sampling
810
+ if top_k is not None:
811
+ v, _ = logits.topk(top_k)
812
+ logits[logits < v[..., [-1]]] = -float("inf")
813
+
814
+ # Apply top_p (nucleus) sampling
815
+ if top_p is not None and top_p < 1.0:
816
+ v, sorted_indices = logits.sort(descending=True)
817
+ cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
818
+
819
+ sorted_indices_to_remove = cumulative_probs > top_p
820
+ # Right shift indices_to_remove to keep 1st token over threshold
821
+ sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
822
+ ..., :-1
823
+ ]
824
+
825
+ # Compute indices_to_remove in unsorted array
826
+ indices_to_remove = sorted_indices_to_remove.scatter(
827
+ -1, sorted_indices, sorted_indices_to_remove
828
+ )
829
+
830
+ logits[indices_to_remove] = -float("inf")
831
+
832
+ # Perform multinomial sampling after normalizing logits
833
+ probs = (
834
+ F.softmax(logits / temperature, dim=-1)
835
+ if temperature > 0
836
+ else logits.softmax(dim=-1)
837
+ )
838
+ token = (
839
+ probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
840
+ if sample
841
+ else logits.argmax(-1)
842
+ )
843
+
844
+ if return_probs:
845
+ token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
846
+ return token, token_probs
847
+ else:
848
+ return token
849
+
850
+
851
+
852
+ def mask_by_random_topk(
853
+ num_to_mask: int,
854
+ probs: torch.Tensor,
855
+ temperature: float = 1.0,
856
+ ):
857
+ """
858
+ Args:
859
+ num_to_mask (int): number of tokens to mask
860
+ probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
861
+ temperature (float, optional): temperature. Defaults to 1.0.
862
+ """
863
+ logging.debug(f"masking by random topk")
864
+ logging.debug(f"num to mask: {num_to_mask}")
865
+ logging.debug(f"probs shape: {probs.shape}")
866
+ logging.debug(f"temperature: {temperature}")
867
+ logging.debug("")
868
+
869
+ noise = gumbel_noise_like(probs)
870
+ confidence = torch.log(probs) + temperature * noise
871
+ logging.debug(f"confidence shape: {confidence.shape}")
872
+
873
+ sorted_confidence, sorted_idx = confidence.sort(dim=-1)
874
+ logging.debug(f"sorted confidence shape: {sorted_confidence.shape}")
875
+ logging.debug(f"sorted idx shape: {sorted_idx.shape}")
876
+
877
+ # get the cut off threshold, given the mask length
878
+ cut_off = torch.take_along_dim(
879
+ sorted_confidence, num_to_mask, axis=-1
880
+ )
881
+ logging.debug(f"cut off shape: {cut_off.shape}")
882
+
883
+ # mask out the tokens
884
+ mask = confidence < cut_off
885
+ logging.debug(f"mask shape: {mask.shape}")
886
+
887
+ return mask
888
+
889
+ def typical_filter(
890
+ logits,
891
+ typical_mass: float = 0.95,
892
+ typical_min_tokens: int = 1,):
893
+ nb, nt, _ = logits.shape
894
+ x_flat = rearrange(logits, "b t l -> (b t ) l")
895
+ x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
896
+ x_flat_norm_p = torch.exp(x_flat_norm)
897
+ entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
898
+
899
+ c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
900
+ c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
901
+ x_flat_cumsum = (
902
+ x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
903
+ )
904
+
905
+ last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
906
+ sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
907
+ 1, last_ind.view(-1, 1)
908
+ )
909
+ if typical_min_tokens > 1:
910
+ sorted_indices_to_remove[..., :typical_min_tokens] = 0
911
+ indices_to_remove = sorted_indices_to_remove.scatter(
912
+ 1, x_flat_indices, sorted_indices_to_remove
913
+ )
914
+ x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
915
+ logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
916
+ return logits
917
+
918
+
919
+ if __name__ == "__main__":
920
+ # import argbind
921
+ from .layers import num_params
922
+
923
+ VampNet = argbind.bind(VampNet)
924
+
925
+ @argbind.bind(without_prefix=True)
926
+ def try_model(device: str = "cuda", batch_size: int = 2, seq_len_s: float = 10.0):
927
+ seq_len = int(32000 / 512 * seq_len_s)
928
+
929
+ model = VampNet().to(device)
930
+
931
+ z = torch.randint(
932
+ 0, model.vocab_size, size=(batch_size, model.n_codebooks, seq_len)
933
+ ).to(device)
934
+
935
+ r = torch.zeros(batch_size).to(device)
936
+
937
+ z_mask_latent = torch.rand(
938
+ batch_size, model.latent_dim * model.n_codebooks, seq_len
939
+ ).to(device)
940
+ z_hat = model(z_mask_latent)
941
+
942
+ pred = z_hat.argmax(dim=1)
943
+ pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
944
+
945
+ print(f"model has {num_params(model)/1e6:<.3f}M parameters")
946
+ print(f"prediction has shape {pred.shape}")
947
+ breakpoint()
948
+
949
+ args = argbind.parse_args()
950
+ with argbind.scope(args):
951
+ try_model()
952
+
953
+
vampnet/scheduler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import List
3
+
4
+ import torch
5
+
6
+ class NoamScheduler:
7
+ """OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf
8
+ Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
9
+ """
10
+
11
+ def __init__(
12
+ self,
13
+ optimizer: torch.optim.Optimizer,
14
+ d_model: int = 512,
15
+ factor: float = 1.0,
16
+ warmup: int = 4000,
17
+ ):
18
+ # Store hparams
19
+ self.warmup = warmup
20
+ self.factor = factor
21
+ self.d_model = d_model
22
+
23
+ # Initialize variables `lr` and `steps`
24
+ self.lr = None
25
+ self.steps = 0
26
+
27
+ # Store the optimizer
28
+ self.optimizer = optimizer
29
+
30
+ def state_dict(self):
31
+ return {
32
+ key: value for key, value in self.__dict__.items() if key != "optimizer"
33
+ }
34
+
35
+ def load_state_dict(self, state_dict):
36
+ self.__dict__.update(state_dict)
37
+
38
+ def step(self):
39
+ self.steps += 1
40
+ self.lr = self.factor * (
41
+ self.d_model ** (-0.5)
42
+ * min(self.steps ** (-0.5), self.steps * self.warmup ** (-1.5))
43
+ )
44
+
45
+ for p in self.optimizer.param_groups:
46
+ p["lr"] = self.lr
47
+
vampnet/util.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ def scalar_to_batch_tensor(x, batch_size):
7
+ return torch.tensor(x).repeat(batch_size)
8
+
9
+
10
+ def parallelize(
11
+ fn,
12
+ *iterables,
13
+ parallel: str = "thread_map",
14
+ **kwargs
15
+ ):
16
+ if parallel == "thread_map":
17
+ from tqdm.contrib.concurrent import thread_map
18
+ return thread_map(
19
+ fn,
20
+ *iterables,
21
+ **kwargs
22
+ )
23
+ elif parallel == "process_map":
24
+ from tqdm.contrib.concurrent import process_map
25
+ return process_map(
26
+ fn,
27
+ *iterables,
28
+ **kwargs
29
+ )
30
+ elif parallel == "single":
31
+ return [fn(x) for x in tqdm.tqdm(*iterables)]
32
+ else:
33
+ raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")
34
+
35
+ def codebook_flatten(tokens: torch.Tensor):
36
+ """
37
+ flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
38
+ """
39
+ return rearrange(tokens, "b c t -> b (t c)")
40
+
41
+ def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
42
+ """
43
+ unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
44
+ """
45
+ tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c)
46
+ return tokens