.gitignore CHANGED
@@ -175,14 +175,11 @@ lyrebird-audio-codec
175
  samples-*/**
176
 
177
  gradio-outputs/
178
- models/
179
  samples*/
180
  models-all/
181
  models.zip
 
 
 
182
  .git-old
183
  conf/generated/*
184
- runs*/
185
-
186
-
187
- gtzan.zip
188
- .gtzan_emb_cache
 
175
  samples-*/**
176
 
177
  gradio-outputs/
 
178
  samples*/
179
  models-all/
180
  models.zip
181
+ audiotools/
182
+ descript-audio-codec/
183
+ # *.pth
184
  .git-old
185
  conf/generated/*
 
 
 
 
 
README.md CHANGED
@@ -7,7 +7,6 @@ sdk: gradio
7
  sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
- python_version: 3.9
11
  ---
12
 
13
  # VampNet
@@ -19,15 +18,7 @@ you can try vampnet in a co-creative looper called unloop. see this link: https:
19
 
20
  # Setting up
21
 
22
- **Requires Python 3.9**.
23
-
24
- you'll need a Python 3.9 environment to run VampNet. This is due to a [known issue with madmom](https://github.com/hugofloresgarcia/vampnet/issues/15).
25
-
26
- (for example, using conda)
27
- ```bash
28
- conda create -n vampnet python=3.9
29
- conda activate vampnet
30
- ```
31
 
32
 
33
  install VampNet
@@ -100,7 +91,7 @@ python scripts/exp/train.py --args.load conf/<fine_tune_name>/c2f.yml
100
 
101
  launch the interface:
102
  ```bash
103
- python app.py --args.load conf/generated/<fine_tune_name>/interface.yml
104
  ```
105
 
106
 
 
7
  sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  # VampNet
 
18
 
19
  # Setting up
20
 
21
+ Requires Python 3.9 or later.
 
 
 
 
 
 
 
 
22
 
23
 
24
  install VampNet
 
91
 
92
  launch the interface:
93
  ```bash
94
+ python demo.py --args.load conf/generated/<fine_tune_name>/interface.yml
95
  ```
96
 
97
 
app.py CHANGED
@@ -1,12 +1,3 @@
1
- # huggingface space exclusive
2
- import os
3
-
4
- # print("installing pyharp")
5
- # os.system('pip install "pyharp@git+https://github.com/audacitorch/pyharp.git"')
6
- # print("installing madmom")
7
- os.system('pip install cython')
8
- os.system('pip install madmom')
9
-
10
  from pathlib import Path
11
  from typing import Tuple
12
  import yaml
@@ -24,38 +15,27 @@ import gradio as gr
24
  from vampnet.interface import Interface
25
  from vampnet import mask as pmask
26
 
27
- from pyharp import ModelCard, build_endpoint
28
-
29
-
30
-
31
- # loader = AudioLoader()
32
  # AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
33
 
34
- conf = argbind.parse_args()
35
-
36
-
37
- from torch_pitch_shift import pitch_shift, get_fast_shifts
38
- def shift_pitch(signal, interval: int):
39
- signal.samples = pitch_shift(
40
- signal.samples,
41
- shift=interval,
42
- sample_rate=signal.sample_rate
43
- )
44
- return signal
45
-
46
- def load_interface():
47
- interface = Interface(
48
- coarse_ckpt="./models/vampnet/coarse.pth",
49
- coarse2fine_ckpt="./models/vampnet/c2f.pth",
50
- codec_ckpt="./models/vampnet/codec.pth",
51
- wavebeat_ckpt="./models/wavebeat.pth",
52
- device="cuda" if torch.cuda.is_available() else "cpu",
53
- )
54
- return interface
55
-
56
 
57
- interface = load_interface()
 
58
 
 
 
 
 
 
 
 
59
 
60
  OUT_DIR = Path("gradio-outputs")
61
  OUT_DIR.mkdir(exist_ok=True, parents=True)
@@ -70,7 +50,7 @@ def load_audio(file):
70
  )
71
  sig = interface.preprocess(sig)
72
 
73
- out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
74
  out_dir.mkdir(parents=True, exist_ok=True)
75
  sig.write(out_dir / "input.wav")
76
  return sig.path_to_file
@@ -88,10 +68,6 @@ def _vamp(data, return_mask=False):
88
  out_dir = OUT_DIR / str(uuid.uuid4())
89
  out_dir.mkdir()
90
  sig = at.AudioSignal(data[input_audio])
91
- sig = interface.preprocess(sig)
92
-
93
- if data[pitch_shift_amt] != 0:
94
- sig = shift_pitch(sig, data[pitch_shift_amt])
95
 
96
  z = interface.encode(sig)
97
 
@@ -131,58 +107,24 @@ def _vamp(data, return_mask=False):
131
  mask = pmask.codebook_unmask(mask, ncc)
132
 
133
 
134
- print(f"dropout {data[dropout]}")
135
- print(f"masktemp {data[masktemp]}")
136
- print(f"sampletemp {data[sampletemp]}")
137
- print(f"top_p {data[top_p]}")
138
- print(f"prefix_s {data[prefix_s]}")
139
- print(f"suffix_s {data[suffix_s]}")
140
- print(f"rand_mask_intensity {data[rand_mask_intensity]}")
141
- print(f"num_steps {data[num_steps]}")
142
- print(f"periodic_p {data[periodic_p]}")
143
- print(f"periodic_w {data[periodic_w]}")
144
- print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
145
- print(f"use_coarse2fine {data[use_coarse2fine]}")
146
- print(f"onset_mask_width {data[onset_mask_width]}")
147
- print(f"beat_mask_width {data[beat_mask_width]}")
148
- print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
149
- print(f"stretch_factor {data[stretch_factor]}")
150
- print(f"seed {data[seed]}")
151
- print(f"pitch_shift_amt {data[pitch_shift_amt]}")
152
- print(f"sample_cutoff {data[sample_cutoff]}")
153
-
154
-
155
- _top_p = data[top_p] if data[top_p] > 0 else None
156
  # save the mask as a txt file
157
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
158
 
159
- _seed = data[seed] if data[seed] > 0 else None
160
  zv, mask_z = interface.coarse_vamp(
161
  z,
162
  mask=mask,
163
  sampling_steps=data[num_steps],
164
- mask_temperature=data[masktemp]*10,
165
- sampling_temperature=data[sampletemp],
166
  return_mask=True,
167
  typical_filtering=data[typical_filtering],
168
  typical_mass=data[typical_mass],
169
  typical_min_tokens=data[typical_min_tokens],
170
- top_p=_top_p,
171
  gen_fn=interface.coarse.generate,
172
- seed=_seed,
173
- sample_cutoff=data[sample_cutoff],
174
  )
175
 
176
  if use_coarse2fine:
177
- zv = interface.coarse_to_fine(
178
- zv,
179
- mask_temperature=data[masktemp]*10,
180
- sampling_temperature=data[sampletemp],
181
- mask=mask,
182
- sampling_steps=data[num_steps] // 2,
183
- sample_cutoff=data[sample_cutoff],
184
- seed=_seed,
185
- )
186
 
187
  sig = interface.to_signal(zv).cpu()
188
  print("done")
@@ -215,9 +157,7 @@ def save_vamp(data):
215
  sig_out.write(out_dir / "output.wav")
216
 
217
  _data = {
218
- "masktemp": data[masktemp],
219
- "sampletemp": data[sampletemp],
220
- "top_p": data[top_p],
221
  "prefix_s": data[prefix_s],
222
  "suffix_s": data[suffix_s],
223
  "rand_mask_intensity": data[rand_mask_intensity],
@@ -228,8 +168,6 @@ def save_vamp(data):
228
  "n_conditioning_codebooks": data[n_conditioning_codebooks],
229
  "use_coarse2fine": data[use_coarse2fine],
230
  "stretch_factor": data[stretch_factor],
231
- "seed": data[seed],
232
- "samplecutoff": data[sample_cutoff],
233
  }
234
 
235
  # save with yaml
@@ -245,54 +183,13 @@ def save_vamp(data):
245
  return f"saved! your save code is {out_dir.stem}", zip_path
246
 
247
 
248
- def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
249
-
250
- out_dir = OUT_DIR / str(uuid.uuid4())
251
- out_dir.mkdir()
252
- sig = at.AudioSignal(_input_audio)
253
- sig = interface.preprocess(sig)
254
-
255
- z = interface.encode(sig)
256
-
257
- # build the mask
258
- mask = pmask.linear_random(z, 1.0)
259
- if _beat_mask_width > 0:
260
- beat_mask = interface.make_beat_mask(
261
- sig,
262
- after_beat_s=(_beat_mask_width/1000),
263
- )
264
- mask = pmask.mask_and(mask, beat_mask)
265
-
266
- # save the mask as a txt file
267
- zv, mask_z = interface.coarse_vamp(
268
- z,
269
- mask=mask,
270
- sampling_temperature=_sampletemp,
271
- return_mask=True,
272
- gen_fn=interface.coarse.generate,
273
- )
274
-
275
-
276
- zv = interface.coarse_to_fine(
277
- zv,
278
- sampling_temperature=_sampletemp,
279
- mask=mask,
280
- )
281
-
282
- sig = interface.to_signal(zv).cpu()
283
- print("done")
284
-
285
- sig.write(out_dir / "output.wav")
286
-
287
- return sig.path_to_file
288
-
289
  with gr.Blocks() as demo:
290
 
291
  with gr.Row():
292
  with gr.Column():
293
- gr.Markdown("# VampNet Audio Vamping")
294
  gr.Markdown("""## Description:
295
- This is a demo of the VampNet, a generative audio model that transforms the input audio based on the chosen settings.
296
  You can control the extent and nature of variation with a set of manual controls and presets.
297
  Use this interface to experiment with different mask settings and explore the audio outputs.
298
  """)
@@ -300,8 +197,8 @@ with gr.Blocks() as demo:
300
  gr.Markdown("""
301
  ## Instructions:
302
  1. You can start by uploading some audio, or by loading the example audio.
303
- 2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
304
- 3. Click the "generate (vamp)!!!" button to apply the vamp operation. Listen to the output audio.
305
  4. Optionally, you can add some notes and save the result.
306
  5. You can also use the output as the new input and continue experimenting!
307
  """)
@@ -352,25 +249,19 @@ with gr.Blocks() as demo:
352
  "beat_mask_downbeats": False,
353
  },
354
  "slight periodic variation": {
355
- "periodic_p": 5,
356
- "onset_mask_width": 5,
357
- "beat_mask_width": 0,
358
- "beat_mask_downbeats": False,
359
- },
360
- "moderate periodic variation": {
361
- "periodic_p": 13,
362
- "onset_mask_width": 5,
363
  "beat_mask_width": 0,
364
  "beat_mask_downbeats": False,
365
  },
366
  "strong periodic variation": {
367
- "periodic_p": 17,
368
  "onset_mask_width": 5,
369
  "beat_mask_width": 0,
370
  "beat_mask_downbeats": False,
371
  },
372
  "very strong periodic variation": {
373
- "periodic_p": 21,
374
  "onset_mask_width": 5,
375
  "beat_mask_width": 0,
376
  "beat_mask_downbeats": False,
@@ -378,15 +269,9 @@ with gr.Blocks() as demo:
378
  "beat-driven variation": {
379
  "periodic_p": 0,
380
  "onset_mask_width": 0,
381
- "beat_mask_width": 50,
382
  "beat_mask_downbeats": False,
383
  },
384
- "beat-driven variation (downbeats only)": {
385
- "periodic_p": 0,
386
- "onset_mask_width": 0,
387
- "beat_mask_width": 50,
388
- "beat_mask_downbeats": True,
389
- },
390
  "beat-driven variation (downbeats only, strong)": {
391
  "periodic_p": 0,
392
  "onset_mask_width": 0,
@@ -408,20 +293,20 @@ with gr.Blocks() as demo:
408
  minimum=0,
409
  maximum=128,
410
  step=1,
411
- value=3,
412
  )
413
 
414
 
415
  onset_mask_width = gr.Slider(
416
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
417
  minimum=0,
418
- maximum=100,
419
  step=1,
420
  value=5,
421
  )
422
 
423
  beat_mask_width = gr.Slider(
424
- label="beat prompt (ms)",
425
  minimum=0,
426
  maximum=200,
427
  value=0,
@@ -433,14 +318,6 @@ with gr.Blocks() as demo:
433
 
434
 
435
  with gr.Accordion("extras ", open=False):
436
- pitch_shift_amt = gr.Slider(
437
- label="pitch shift amount (semitones)",
438
- minimum=-12,
439
- maximum=12,
440
- step=1,
441
- value=0,
442
- )
443
-
444
  rand_mask_intensity = gr.Slider(
445
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
446
  minimum=0.0,
@@ -500,34 +377,21 @@ with gr.Blocks() as demo:
500
  value=0.0
501
  )
502
 
503
- masktemp = gr.Slider(
504
- label="mask temperature",
505
  minimum=0.0,
506
- maximum=100.0,
507
- value=1.5
508
- )
509
- sampletemp = gr.Slider(
510
- label="sample temperature",
511
- minimum=0.1,
512
  maximum=10.0,
513
- value=1.0,
514
- step=0.001
515
  )
516
-
517
 
518
 
519
  with gr.Accordion("sampling settings", open=False):
520
- top_p = gr.Slider(
521
- label="top p (0.0 = off)",
522
- minimum=0.0,
523
- maximum=1.0,
524
- value=0.9
525
- )
526
  typical_filtering = gr.Checkbox(
527
  label="typical filtering ",
528
  value=False
529
  )
530
- typical_mass = gr.Slider(
531
  label="typical mass (should probably stay between 0.1 and 0.5)",
532
  minimum=0.01,
533
  maximum=0.99,
@@ -540,13 +404,6 @@ with gr.Blocks() as demo:
540
  step=1,
541
  value=64
542
  )
543
- sample_cutoff = gr.Slider(
544
- label="sample cutoff",
545
- minimum=0.0,
546
- maximum=1.0,
547
- value=0.5,
548
- step=0.01
549
- )
550
 
551
  use_coarse2fine = gr.Checkbox(
552
  label="use coarse2fine",
@@ -571,24 +428,8 @@ with gr.Blocks() as demo:
571
  )
572
 
573
 
574
- seed = gr.Number(
575
- label="seed (0 for random)",
576
- value=0,
577
- precision=0,
578
- )
579
-
580
-
581
-
582
  # mask settings
583
  with gr.Column():
584
-
585
- # lora_choice = gr.Dropdown(
586
- # label="lora choice",
587
- # choices=list(loras.keys()),
588
- # value=LORA_NONE,
589
- # visible=False
590
- # )
591
-
592
  vamp_button = gr.Button("generate (vamp)!!!")
593
  output_audio = gr.Audio(
594
  label="output audio",
@@ -614,9 +455,7 @@ with gr.Blocks() as demo:
614
  _inputs = {
615
  input_audio,
616
  num_steps,
617
- masktemp,
618
- sampletemp,
619
- top_p,
620
  prefix_s, suffix_s,
621
  rand_mask_intensity,
622
  periodic_p, periodic_w,
@@ -629,11 +468,7 @@ with gr.Blocks() as demo:
629
  typical_mass,
630
  typical_min_tokens,
631
  beat_mask_width,
632
- beat_mask_downbeats,
633
- seed,
634
- # lora_choice,
635
- pitch_shift_amt,
636
- sample_cutoff
637
  }
638
 
639
  # connect widgets
@@ -663,24 +498,4 @@ with gr.Blocks() as demo:
663
  outputs=[thank_you, download_file]
664
  )
665
 
666
- # harp stuff
667
- harp_inputs = [
668
- input_audio,
669
- beat_mask_width,
670
- sampletemp,
671
- ]
672
-
673
- build_endpoint(
674
- inputs=harp_inputs,
675
- output=output_audio,
676
- process_fn=harp_vamp,
677
- card=ModelCard(
678
- name="vampnet",
679
- description="Generate variations on music input, based on small prompts around the beat. NOTE: vampnet's has a maximum context length of 10 seconds. Please split all audio clips into 10 second chunks, or processing will result in an error. ",
680
- author="Hugo Flores García",
681
- tags=["music", "generative"]
682
- ),
683
- visible=False
684
- )
685
-
686
- demo.launch()
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
  from typing import Tuple
3
  import yaml
 
15
  from vampnet.interface import Interface
16
  from vampnet import mask as pmask
17
 
18
+ # Interface = argbind.bind(Interface)
 
 
 
 
19
  # AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
20
 
21
+ interface = Interface(
22
+ coarse_ckpt="./models/vampnet/coarse.pth",
23
+ coarse2fine_ckpt="./models/vampnet/c2f.pth",
24
+ codec_ckpt="./models/vampnet/codec.pth",
25
+ wavebeat_ckpt="./models/wavebeat.pth",
26
+ device="cuda" if torch.cuda.is_available() else "cpu",
27
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # loader = AudioLoader()
30
+ print(f"interface device is {interface.device}")
31
 
32
+ # dataset = at.data.datasets.AudioDataset(
33
+ # loader,
34
+ # sample_rate=interface.codec.sample_rate,
35
+ # duration=interface.coarse.chunk_size_s,
36
+ # n_examples=5000,
37
+ # without_replacement=True,
38
+ # )
39
 
40
  OUT_DIR = Path("gradio-outputs")
41
  OUT_DIR.mkdir(exist_ok=True, parents=True)
 
50
  )
51
  sig = interface.preprocess(sig)
52
 
53
+ out_dir = OUT_DIR / str(uuid.uuid4())
54
  out_dir.mkdir(parents=True, exist_ok=True)
55
  sig.write(out_dir / "input.wav")
56
  return sig.path_to_file
 
68
  out_dir = OUT_DIR / str(uuid.uuid4())
69
  out_dir.mkdir()
70
  sig = at.AudioSignal(data[input_audio])
 
 
 
 
71
 
72
  z = interface.encode(sig)
73
 
 
107
  mask = pmask.codebook_unmask(mask, ncc)
108
 
109
 
110
+ print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[temp]}, use coarse2fine {data[use_coarse2fine]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # save the mask as a txt file
112
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
113
 
 
114
  zv, mask_z = interface.coarse_vamp(
115
  z,
116
  mask=mask,
117
  sampling_steps=data[num_steps],
118
+ temperature=float(data[temp]*10),
 
119
  return_mask=True,
120
  typical_filtering=data[typical_filtering],
121
  typical_mass=data[typical_mass],
122
  typical_min_tokens=data[typical_min_tokens],
 
123
  gen_fn=interface.coarse.generate,
 
 
124
  )
125
 
126
  if use_coarse2fine:
127
+ zv = interface.coarse_to_fine(zv, temperature=data[temp], mask=mask)
 
 
 
 
 
 
 
 
128
 
129
  sig = interface.to_signal(zv).cpu()
130
  print("done")
 
157
  sig_out.write(out_dir / "output.wav")
158
 
159
  _data = {
160
+ "temp": data[temp],
 
 
161
  "prefix_s": data[prefix_s],
162
  "suffix_s": data[suffix_s],
163
  "rand_mask_intensity": data[rand_mask_intensity],
 
168
  "n_conditioning_codebooks": data[n_conditioning_codebooks],
169
  "use_coarse2fine": data[use_coarse2fine],
170
  "stretch_factor": data[stretch_factor],
 
 
171
  }
172
 
173
  # save with yaml
 
183
  return f"saved! your save code is {out_dir.stem}", zip_path
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  with gr.Blocks() as demo:
187
 
188
  with gr.Row():
189
  with gr.Column():
190
+ gr.Markdown("# VampNet")
191
  gr.Markdown("""## Description:
192
+ This is a demo of VampNet, a masked generative music model capable of doing music variations.
193
  You can control the extent and nature of variation with a set of manual controls and presets.
194
  Use this interface to experiment with different mask settings and explore the audio outputs.
195
  """)
 
197
  gr.Markdown("""
198
  ## Instructions:
199
  1. You can start by uploading some audio, or by loading the example audio.
200
+ 2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings. Click the load preset button.
201
+ 3. Click the "generate (vamp)!!!" button to generate audio. Listen to the output audio, and the masked audio to hear the mask hints.
202
  4. Optionally, you can add some notes and save the result.
203
  5. You can also use the output as the new input and continue experimenting!
204
  """)
 
249
  "beat_mask_downbeats": False,
250
  },
251
  "slight periodic variation": {
252
+ "periodic_p": 7,
253
+ "onset_mask_width": 0,
 
 
 
 
 
 
254
  "beat_mask_width": 0,
255
  "beat_mask_downbeats": False,
256
  },
257
  "strong periodic variation": {
258
+ "periodic_p": 13,
259
  "onset_mask_width": 5,
260
  "beat_mask_width": 0,
261
  "beat_mask_downbeats": False,
262
  },
263
  "very strong periodic variation": {
264
+ "periodic_p": 17,
265
  "onset_mask_width": 5,
266
  "beat_mask_width": 0,
267
  "beat_mask_downbeats": False,
 
269
  "beat-driven variation": {
270
  "periodic_p": 0,
271
  "onset_mask_width": 0,
272
+ "beat_mask_width": 20,
273
  "beat_mask_downbeats": False,
274
  },
 
 
 
 
 
 
275
  "beat-driven variation (downbeats only, strong)": {
276
  "periodic_p": 0,
277
  "onset_mask_width": 0,
 
293
  minimum=0,
294
  maximum=128,
295
  step=1,
296
+ value=13,
297
  )
298
 
299
 
300
  onset_mask_width = gr.Slider(
301
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
302
  minimum=0,
303
+ maximum=20,
304
  step=1,
305
  value=5,
306
  )
307
 
308
  beat_mask_width = gr.Slider(
309
+ label="beat mask width (in milliseconds)",
310
  minimum=0,
311
  maximum=200,
312
  value=0,
 
318
 
319
 
320
  with gr.Accordion("extras ", open=False):
 
 
 
 
 
 
 
 
321
  rand_mask_intensity = gr.Slider(
322
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
323
  minimum=0.0,
 
377
  value=0.0
378
  )
379
 
380
+ temp = gr.Slider(
381
+ label="temperature",
382
  minimum=0.0,
 
 
 
 
 
 
383
  maximum=10.0,
384
+ value=1.8
 
385
  )
386
+
387
 
388
 
389
  with gr.Accordion("sampling settings", open=False):
 
 
 
 
 
 
390
  typical_filtering = gr.Checkbox(
391
  label="typical filtering ",
392
  value=False
393
  )
394
+ typical_mass = gr.Slider(
395
  label="typical mass (should probably stay between 0.1 and 0.5)",
396
  minimum=0.01,
397
  maximum=0.99,
 
404
  step=1,
405
  value=64
406
  )
 
 
 
 
 
 
 
407
 
408
  use_coarse2fine = gr.Checkbox(
409
  label="use coarse2fine",
 
428
  )
429
 
430
 
 
 
 
 
 
 
 
 
431
  # mask settings
432
  with gr.Column():
 
 
 
 
 
 
 
 
433
  vamp_button = gr.Button("generate (vamp)!!!")
434
  output_audio = gr.Audio(
435
  label="output audio",
 
455
  _inputs = {
456
  input_audio,
457
  num_steps,
458
+ temp,
 
 
459
  prefix_s, suffix_s,
460
  rand_mask_intensity,
461
  periodic_p, periodic_w,
 
468
  typical_mass,
469
  typical_min_tokens,
470
  beat_mask_width,
471
+ beat_mask_downbeats
 
 
 
 
472
  }
473
 
474
  # connect widgets
 
498
  outputs=[thank_you, download_file]
499
  )
500
 
501
+ demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
conf/lora/lora.yml CHANGED
@@ -4,16 +4,14 @@ $include:
4
  fine_tune: True
5
 
6
  train/AudioDataset.n_examples: 100000000
7
- val/AudioDataset.n_examples: 500
8
 
9
 
10
  NoamScheduler.warmup: 500
11
 
12
- batch_size: 6
13
  num_workers: 7
14
- save_iters: [10000, 20000, 30000, 40000, 50000, 100000]
15
- sample_freq: 1000
16
- val_freq: 500
17
 
18
  AdamW.lr: 0.0001
19
 
 
4
  fine_tune: True
5
 
6
  train/AudioDataset.n_examples: 100000000
7
+ val/AudioDataset.n_examples: 100
8
 
9
 
10
  NoamScheduler.warmup: 500
11
 
12
+ batch_size: 7
13
  num_workers: 7
14
+ save_iters: [100000, 200000, 300000, 4000000, 500000]
 
 
15
 
16
  AdamW.lr: 0.0001
17
 
conf/vampnet.yml CHANGED
@@ -32,7 +32,7 @@ VampNet.n_heads: 20
32
  VampNet.flash_attn: false
33
  VampNet.dropout: 0.1
34
 
35
- AudioLoader.relative_path: ""
36
  AudioDataset.loudness_cutoff: -30.0
37
  AudioDataset.without_replacement: true
38
  AudioLoader.shuffle: true
 
32
  VampNet.flash_attn: false
33
  VampNet.dropout: 0.1
34
 
35
+ AudioLoader.relative_path: /data/
36
  AudioDataset.loudness_cutoff: -30.0
37
  AudioDataset.without_replacement: true
38
  AudioLoader.shuffle: true
requirements.txt CHANGED
@@ -1,10 +1,8 @@
1
  torch
2
  argbind>=0.3.2
3
- numpy==1.23
4
  gradio
5
  loralib
6
  wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
7
  lac @ git+https://github.com/hugofloresgarcia/lac.git
8
- descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2
9
- -e git+https://github.com/audacitorch/pyharp.git#egg=pyharp
10
- torch_pitch_shift
 
1
  torch
2
  argbind>=0.3.2
3
+ numpy==1.22
4
  gradio
5
  loralib
6
  wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
7
  lac @ git+https://github.com/hugofloresgarcia/lac.git
8
+ audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
 
 
scripts/exp/fine_tune.py CHANGED
@@ -48,10 +48,11 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
48
  }
49
 
50
  interface_conf = {
51
- "Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
 
52
 
53
- "Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
54
- "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
55
 
56
  "Interface.codec_ckpt": "./models/vampnet/codec.pth",
57
  "AudioLoader.sources": [audio_files_or_folders],
 
48
  }
49
 
50
  interface_conf = {
51
+ "Interface.coarse_ckpt": f"./models/vampnet/coarse.pth",
52
+ "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
53
 
54
+ "Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
55
+ "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
56
 
57
  "Interface.codec_ckpt": "./models/vampnet/codec.pth",
58
  "AudioLoader.sources": [audio_files_or_folders],
scripts/exp/train.py CHANGED
@@ -14,7 +14,7 @@ from audiotools.data import transforms
14
  from einops import rearrange
15
  from rich import pretty
16
  from rich.traceback import install
17
- from torch.utils.tensorboard import SummaryWriter
18
 
19
  import vampnet
20
  from vampnet.modules.transformer import VampNet
@@ -29,9 +29,6 @@ from audiotools.ml.decorators import (
29
 
30
  import loralib as lora
31
 
32
- import torch._dynamo
33
- torch._dynamo.config.verbose=True
34
-
35
 
36
  # Enable cudnn autotuner to speed up training
37
  # (can be altered by the funcs.seed function)
@@ -224,7 +221,7 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
224
 
225
  dtype = torch.bfloat16 if accel.amp else None
226
  with accel.autocast(dtype=dtype):
227
- z_hat = state.model(z_mask_latent)
228
 
229
  target = codebook_flatten(
230
  z[:, vn.n_conditioning_codebooks :, :],
@@ -289,7 +286,7 @@ def val_loop(state: State, batch: dict, accel: Accelerator):
289
 
290
  z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
291
 
292
- z_hat = state.model(z_mask_latent)
293
 
294
  target = codebook_flatten(
295
  z[:, vn.n_conditioning_codebooks :, :],
@@ -408,19 +405,19 @@ def save_imputation(state, z, val_idx, writer):
408
 
409
  for i in range(len(val_idx)):
410
  imputed_noisy[i].cpu().write_audio_to_tb(
411
- f"inpainted_prompt/{i}",
412
  writer,
413
  step=state.tracker.step,
414
  plot_fn=None,
415
  )
416
  imputed[i].cpu().write_audio_to_tb(
417
- f"inpainted_middle/{i}",
418
  writer,
419
  step=state.tracker.step,
420
  plot_fn=None,
421
  )
422
  imputed_true[i].cpu().write_audio_to_tb(
423
- f"reconstructed/{i}",
424
  writer,
425
  step=state.tracker.step,
426
  plot_fn=None,
@@ -450,7 +447,7 @@ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
450
 
451
  z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
452
 
453
- z_hat = state.model(z_mask_latent)
454
 
455
  z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
456
  z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
@@ -469,7 +466,7 @@ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
469
  }
470
  for k, v in audio_dict.items():
471
  v.cpu().write_audio_to_tb(
472
- f"onestep/_{i}.r={r[i]:0.2f}/{k}",
473
  writer,
474
  step=state.tracker.step,
475
  plot_fn=None,
@@ -488,6 +485,7 @@ def load(
488
  save_path: str,
489
  resume: bool = False,
490
  tag: str = "latest",
 
491
  fine_tune_checkpoint: Optional[str] = None,
492
  grad_clip_val: float = 5.0,
493
  ) -> State:
@@ -500,7 +498,7 @@ def load(
500
  kwargs = {
501
  "folder": f"{save_path}/{tag}",
502
  "map_location": "cpu",
503
- "package": False,
504
  }
505
  tracker.print(f"Loading checkpoint from {kwargs['folder']}")
506
  if (Path(kwargs["folder"]) / "vampnet").exists():
@@ -513,14 +511,11 @@ def load(
513
 
514
  if args["fine_tune"]:
515
  assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
516
- model = torch.compile(
517
- VampNet.load(location=Path(fine_tune_checkpoint),
518
- map_location="cpu",
519
- )
520
- )
521
 
522
 
523
- model = torch.compile(VampNet()) if model is None else model
 
524
  model = accel.prepare_model(model)
525
 
526
  # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
@@ -604,7 +599,7 @@ def train(
604
  accel=accel,
605
  tracker=tracker,
606
  save_path=save_path)
607
- print("initialized state.")
608
 
609
  train_dataloader = accel.prepare_dataloader(
610
  state.train_data,
@@ -619,15 +614,13 @@ def train(
619
  num_workers=num_workers,
620
  batch_size=batch_size,
621
  collate_fn=state.val_data.collate,
622
- persistent_workers=num_workers > 0,
623
  )
624
- print("initialized dataloader.")
625
 
626
 
627
 
628
  if fine_tune:
629
  lora.mark_only_lora_as_trainable(state.model)
630
- print("marked only lora as trainable.")
631
 
632
  # Wrap the functions so that they neatly track in TensorBoard + progress bars
633
  # and only run when specific conditions are met.
@@ -642,7 +635,6 @@ def train(
642
  save_samples = when(lambda: accel.local_rank == 0)(save_samples)
643
  checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
644
 
645
- print("starting training loop.")
646
  with tracker.live:
647
  for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
648
  train_loop(state, batch, accel)
 
14
  from einops import rearrange
15
  from rich import pretty
16
  from rich.traceback import install
17
+ from tensorboardX import SummaryWriter
18
 
19
  import vampnet
20
  from vampnet.modules.transformer import VampNet
 
29
 
30
  import loralib as lora
31
 
 
 
 
32
 
33
  # Enable cudnn autotuner to speed up training
34
  # (can be altered by the funcs.seed function)
 
221
 
222
  dtype = torch.bfloat16 if accel.amp else None
223
  with accel.autocast(dtype=dtype):
224
+ z_hat = state.model(z_mask_latent, r)
225
 
226
  target = codebook_flatten(
227
  z[:, vn.n_conditioning_codebooks :, :],
 
286
 
287
  z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
288
 
289
+ z_hat = state.model(z_mask_latent, r)
290
 
291
  target = codebook_flatten(
292
  z[:, vn.n_conditioning_codebooks :, :],
 
405
 
406
  for i in range(len(val_idx)):
407
  imputed_noisy[i].cpu().write_audio_to_tb(
408
+ f"imputed_noisy/{i}",
409
  writer,
410
  step=state.tracker.step,
411
  plot_fn=None,
412
  )
413
  imputed[i].cpu().write_audio_to_tb(
414
+ f"imputed/{i}",
415
  writer,
416
  step=state.tracker.step,
417
  plot_fn=None,
418
  )
419
  imputed_true[i].cpu().write_audio_to_tb(
420
+ f"imputed_true/{i}",
421
  writer,
422
  step=state.tracker.step,
423
  plot_fn=None,
 
447
 
448
  z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
449
 
450
+ z_hat = state.model(z_mask_latent, r)
451
 
452
  z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
453
  z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
 
466
  }
467
  for k, v in audio_dict.items():
468
  v.cpu().write_audio_to_tb(
469
+ f"samples/_{i}.r={r[i]:0.2f}/{k}",
470
  writer,
471
  step=state.tracker.step,
472
  plot_fn=None,
 
485
  save_path: str,
486
  resume: bool = False,
487
  tag: str = "latest",
488
+ load_weights: bool = False,
489
  fine_tune_checkpoint: Optional[str] = None,
490
  grad_clip_val: float = 5.0,
491
  ) -> State:
 
498
  kwargs = {
499
  "folder": f"{save_path}/{tag}",
500
  "map_location": "cpu",
501
+ "package": not load_weights,
502
  }
503
  tracker.print(f"Loading checkpoint from {kwargs['folder']}")
504
  if (Path(kwargs["folder"]) / "vampnet").exists():
 
511
 
512
  if args["fine_tune"]:
513
  assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
514
+ model = VampNet.load(location=Path(fine_tune_checkpoint), map_location="cpu")
 
 
 
 
515
 
516
 
517
+ model = VampNet() if model is None else model
518
+
519
  model = accel.prepare_model(model)
520
 
521
  # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
 
599
  accel=accel,
600
  tracker=tracker,
601
  save_path=save_path)
602
+
603
 
604
  train_dataloader = accel.prepare_dataloader(
605
  state.train_data,
 
614
  num_workers=num_workers,
615
  batch_size=batch_size,
616
  collate_fn=state.val_data.collate,
617
+ persistent_workers=True,
618
  )
 
619
 
620
 
621
 
622
  if fine_tune:
623
  lora.mark_only_lora_as_trainable(state.model)
 
624
 
625
  # Wrap the functions so that they neatly track in TensorBoard + progress bars
626
  # and only run when specific conditions are met.
 
635
  save_samples = when(lambda: accel.local_rank == 0)(save_samples)
636
  checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
637
 
 
638
  with tracker.live:
639
  for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
640
  train_loop(state, batch, accel)
scripts/utils/{data/augment.py → augment.py} RENAMED
@@ -5,19 +5,34 @@ from audiotools import AudioSignal
5
 
6
  import argbind
7
  import tqdm
8
- import torch
9
 
10
 
11
- from torch_pitch_shift import pitch_shift, get_fast_shifts
12
- from torch_time_stretch import time_stretch, get_fast_stretches
 
 
13
 
14
- from audiotools.core.util import sample_from_dist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  @argbind.bind(without_prefix=True)
18
  def augment(
19
- audio_folder: Path = None,
20
- dest_folder: Path = None,
21
  n_augmentations: int = 10,
22
  ):
23
  """
@@ -26,8 +41,7 @@ def augment(
26
  The dest foler will contain a folder for each of the clean dataset's files.
27
  Under each of these folders, there will be a clean file and many augmented files.
28
  """
29
- assert audio_folder is not None
30
- assert dest_folder is not None
31
  audio_files = at.util.find_audio(audio_folder)
32
 
33
  for audio_file in tqdm.tqdm(audio_files):
@@ -35,33 +49,5 @@ def augment(
35
  subdir = subtree / audio_file.stem
36
  subdir.mkdir(parents=True, exist_ok=True)
37
 
38
- src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
39
-
40
-
41
- for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
42
- # apply pedalboard transforms
43
- for j in range(n_augmentations):
44
- # pitch shift between -7 and 7 semitones
45
- import random
46
- dst = chunk.clone()
47
- dst.samples = pitch_shift(
48
- dst.samples,
49
- shift=random.choice(get_fast_shifts(src.sample_rate,
50
- condition=lambda x: x >= 0.25 and x <= 1.0)),
51
- sample_rate=src.sample_rate
52
- )
53
- dst.samples = time_stretch(
54
- dst.samples,
55
- stretch=random.choice(get_fast_stretches(src.sample_rate,
56
- condition=lambda x: x >= 0.667 and x <= 1.5, )),
57
- sample_rate=src.sample_rate,
58
- )
59
-
60
- dst.cpu().write(subdir / f"{i}-{j}.wav")
61
-
62
-
63
- if __name__ == "__main__":
64
- args = argbind.parse_args()
65
-
66
- with argbind.scope(args):
67
- augment()
 
5
 
6
  import argbind
7
  import tqdm
 
8
 
9
 
10
+ from pedalboard import (
11
+ Compressor, Gain, Chorus, LadderFilter, Phaser, Convolution, Reverb, Pedalboard
12
+ )
13
+ from pedalboard.io import AudioFile
14
 
15
+ # Read in a whole file, resampling to our desired sample rate:
16
+ samplerate = 44100.0
17
+ with AudioFile('guitar-input.wav').resampled_to(samplerate) as f:
18
+ audio = f.read(f.frames)
19
+
20
+ # Make a pretty interesting sounding guitar pedalboard:
21
+ board = Pedalboard([
22
+ Compressor(threshold_db=-50, ratio=25),
23
+ Gain(gain_db=30),
24
+ Chorus(),
25
+ LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900),
26
+ Phaser(),
27
+ Convolution("./guitar_amp.wav", 1.0),
28
+ Reverb(room_size=0.25),
29
+ ])
30
 
31
 
32
  @argbind.bind(without_prefix=True)
33
  def augment(
34
+ audio_folder: Path,
35
+ dest_folder: Path,
36
  n_augmentations: int = 10,
37
  ):
38
  """
 
41
  The dest foler will contain a folder for each of the clean dataset's files.
42
  Under each of these folders, there will be a clean file and many augmented files.
43
  """
44
+
 
45
  audio_files = at.util.find_audio(audio_folder)
46
 
47
  for audio_file in tqdm.tqdm(audio_files):
 
49
  subdir = subtree / audio_file.stem
50
  subdir.mkdir(parents=True, exist_ok=True)
51
 
52
+ # apply pedalboard transforms
53
+ for i in range(n_augmentations):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/utils/gtzan_embeddings.py DELETED
@@ -1,263 +0,0 @@
1
- """
2
- TODO: train a linear probe
3
- usage:
4
- python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_gtzan /path/to/gtzan/genres_original --output_dir /path/to/output
5
- """
6
- from pathlib import Path
7
- from typing import List
8
-
9
- import audiotools as at
10
- from audiotools import AudioSignal
11
- import argbind
12
- import torch
13
- import numpy as np
14
- import zipfile
15
- import json
16
-
17
- from vampnet.interface import Interface
18
- import tqdm
19
-
20
- # bind the Interface to argbind
21
- Interface = argbind.bind(Interface)
22
-
23
- DEBUG = False
24
-
25
- def smart_plotly_export(fig, save_path):
26
- img_format = save_path.split('.')[-1]
27
- if img_format == 'html':
28
- fig.write_html(save_path)
29
- elif img_format == 'bytes':
30
- return fig.to_image(format='png')
31
- #TODO: come back and make this prettier
32
- elif img_format == 'numpy':
33
- import io
34
- from PIL import Image
35
-
36
- def plotly_fig2array(fig):
37
- #convert Plotly fig to an array
38
- fig_bytes = fig.to_image(format="png", width=1200, height=700)
39
- buf = io.BytesIO(fig_bytes)
40
- img = Image.open(buf)
41
- return np.asarray(img)
42
-
43
- return plotly_fig2array(fig)
44
- elif img_format == 'jpeg' or 'png' or 'webp':
45
- fig.write_image(save_path)
46
- else:
47
- raise ValueError("invalid image format")
48
-
49
- def dim_reduce(emb, labels, save_path, n_components=3, method='tsne', title=''):
50
- """
51
- dimensionality reduction for visualization!
52
- saves an html plotly figure to save_path
53
- parameters:
54
- emb (np.ndarray): the samples to be reduces with shape (samples, features)
55
- labels (list): list of labels for embedding
56
- save_path (str): path where u wanna save ur figure
57
- method (str): umap, tsne, or pca
58
- title (str): title for ur figure
59
- returns:
60
- proj (np.ndarray): projection vector with shape (samples, dimensions)
61
- """
62
- import pandas as pd
63
- import plotly.express as px
64
- if method == 'umap':
65
- reducer = umap.UMAP(n_components=n_components)
66
- elif method == 'tsne':
67
- from sklearn.manifold import TSNE
68
- reducer = TSNE(n_components=n_components)
69
- elif method == 'pca':
70
- from sklearn.decomposition import PCA
71
- reducer = PCA(n_components=n_components)
72
- else:
73
- raise ValueError
74
-
75
- proj = reducer.fit_transform(emb)
76
-
77
- if n_components == 2:
78
- df = pd.DataFrame(dict(
79
- x=proj[:, 0],
80
- y=proj[:, 1],
81
- instrument=labels
82
- ))
83
- fig = px.scatter(df, x='x', y='y', color='instrument',
84
- title=title+f"_{method}")
85
-
86
- elif n_components == 3:
87
- df = pd.DataFrame(dict(
88
- x=proj[:, 0],
89
- y=proj[:, 1],
90
- z=proj[:, 2],
91
- instrument=labels
92
- ))
93
- fig = px.scatter_3d(df, x='x', y='y', z='z',
94
- color='instrument',
95
- title=title)
96
- else:
97
- raise ValueError("cant plot more than 3 components")
98
-
99
- fig.update_traces(marker=dict(size=6,
100
- line=dict(width=1,
101
- color='DarkSlateGrey')),
102
- selector=dict(mode='markers'))
103
-
104
- return smart_plotly_export(fig, save_path)
105
-
106
-
107
-
108
- # per JukeMIR, we want the emebddings from the middle layer?
109
- def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
110
- with torch.inference_mode():
111
- # preprocess the signal
112
- sig = interface.preprocess(sig)
113
-
114
- # get the coarse vampnet model
115
- vampnet = interface.coarse
116
-
117
- # get the tokens
118
- z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
119
- z_latents = vampnet.embedding.from_codes(z, interface.codec)
120
-
121
- # do a forward pass through the model, get the embeddings
122
- _z, embeddings = vampnet(z_latents, return_activations=True)
123
- # print(f"got embeddings with shape {embeddings.shape}")
124
- # [layer, batch, time, n_dims]
125
- # [20, 1, 600ish, 768]
126
-
127
-
128
- # squeeze batch dim (1 bc layer should be dim 0)
129
- assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
130
- embeddings = embeddings.squeeze(1)
131
-
132
- num_layers = embeddings.shape[0]
133
- assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
134
-
135
- # do meanpooling over the time dimension
136
- embeddings = embeddings.mean(dim=-2)
137
- # [20, 768]
138
-
139
- # return the embeddings
140
- return embeddings
141
-
142
- from dataclasses import dataclass, fields
143
- @dataclass
144
- class Embedding:
145
- genre: str
146
- filename: str
147
- embedding: np.ndarray
148
-
149
- def save(self, path):
150
- """Save the Embedding object to a given path as a zip file."""
151
- with zipfile.ZipFile(path, 'w') as archive:
152
-
153
- # Save numpy array
154
- with archive.open('embedding.npy', 'w') as f:
155
- np.save(f, self.embedding)
156
-
157
- # Save non-numpy data as json
158
- non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
159
- with archive.open('data.json', 'w') as f:
160
- f.write(json.dumps(non_numpy_data).encode('utf-8'))
161
-
162
- @classmethod
163
- def load(cls, path):
164
- """Load the Embedding object from a given zip path."""
165
- with zipfile.ZipFile(path, 'r') as archive:
166
-
167
- # Load numpy array
168
- with archive.open('embedding.npy') as f:
169
- embedding = np.load(f)
170
-
171
- # Load non-numpy data from json
172
- with archive.open('data.json') as f:
173
- data = json.loads(f.read().decode('utf-8'))
174
-
175
- return cls(embedding=embedding, **data)
176
-
177
-
178
- @argbind.bind(without_prefix=True)
179
- def main(
180
- path_to_gtzan: str = None,
181
- cache_dir: str = "./.gtzan_emb_cache",
182
- output_dir: str = "./gtzan_vampnet_embeddings",
183
- layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
184
- ):
185
- path_to_gtzan = Path(path_to_gtzan)
186
- assert path_to_gtzan.exists(), f"{path_to_gtzan} does not exist"
187
-
188
- cache_dir = Path(cache_dir)
189
- output_dir = Path(output_dir)
190
- output_dir.mkdir(exist_ok=True, parents=True)
191
-
192
- # load our interface
193
- # argbind will automatically load the default config,
194
- interface = Interface()
195
-
196
- # gtzan should have a folder for each genre, so let's get the list of genres
197
- genres = [Path(x).name for x in path_to_gtzan.iterdir() if x.is_dir()]
198
- print(f"Found {len(genres)} genres")
199
- print(f"genres: {genres}")
200
-
201
- # collect audio files, genres, and embeddings
202
- data = []
203
- for genre in genres:
204
- audio_files = list(at.util.find_audio(path_to_gtzan / genre))
205
- print(f"Found {len(audio_files)} audio files for genre {genre}")
206
-
207
- for audio_file in tqdm.tqdm(audio_files, desc=f"embedding genre {genre}"):
208
- # check if we have a cached embedding for this file
209
- cached_path = (cache_dir / f"{genre}_{audio_file.stem}.emb")
210
- if cached_path.exists():
211
- # if so, load it
212
- if DEBUG:
213
- print(f"loading cached embedding for {cached_path.stem}")
214
- embedding = Embedding.load(cached_path)
215
- data.append(embedding)
216
- else:
217
- try:
218
- sig = AudioSignal(audio_file)
219
- except Exception as e:
220
- print(f"failed to load {audio_file.name} with error {e}")
221
- print(f"skipping {audio_file.name}")
222
- continue
223
-
224
- # gets the embedding
225
- emb = vampnet_embed(sig, interface).cpu().numpy()
226
-
227
- # create an embedding we can save/load
228
- embedding = Embedding(
229
- genre=genre,
230
- filename=audio_file.name,
231
- embedding=emb
232
- )
233
-
234
- # cache the embeddings
235
- cached_path.parent.mkdir(exist_ok=True, parents=True)
236
- embedding.save(cached_path)
237
-
238
- # now, let's do a dim reduction on the embeddings
239
- # and visualize them.
240
-
241
- # collect a list of embeddings and labels
242
- embeddings = [d.embedding for d in data]
243
- labels = [d.genre for d in data]
244
-
245
- # convert the embeddings to a numpy array
246
- embeddings = np.stack(embeddings)
247
-
248
- # do dimensionality reduction for each layer we're given
249
- for layer in tqdm.tqdm(layers, desc="dim reduction"):
250
- dim_reduce(
251
- embeddings[:, layer, :], labels,
252
- save_path=str(output_dir / f'vampnet-gtzan-layer={layer}.html'),
253
- n_components=2, method='tsne',
254
- title=f'vampnet-gtzan-layer={layer}'
255
- )
256
-
257
-
258
-
259
-
260
- if __name__ == "__main__":
261
- args = argbind.parse_args()
262
- with argbind.scope(args):
263
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/utils/{data/maestro-reorg.py → maestro-reorg.py} RENAMED
File without changes
scripts/utils/remove_quiet_files.py DELETED
@@ -1,29 +0,0 @@
1
- # removes files with loudness below 24db
2
-
3
- from pathlib import Path
4
- import shutil
5
- import audiotools as at
6
- import argbind
7
-
8
- @argbind.bind(without_prefix=True)
9
- def remove_quiet_files(
10
- src_dir: Path = None,
11
- dest_dir: Path = None,
12
- min_loudness: float = -30,
13
- ):
14
- # copy src to dest
15
- dest_dir.mkdir(parents=True, exist_ok=True)
16
- shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
17
-
18
- audio_files = at.util.find_audio(dest_dir)
19
- for audio_file in audio_files:
20
- sig = at.AudioSignal(audio_file)
21
- if sig.loudness() < min_loudness:
22
- audio_file.unlink()
23
- print(f"removed {audio_file}")
24
-
25
- if __name__ == "__main__":
26
- args = argbind.parse_args()
27
-
28
- with argbind.scope(args):
29
- remove_quiet_files()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/utils/split_long_audio_file.py DELETED
@@ -1,34 +0,0 @@
1
- from pathlib import Path
2
- import argbind
3
-
4
- import audiotools as at
5
- import tqdm
6
-
7
-
8
- @argbind.bind(without_prefix=True)
9
- def split_long_audio_file(
10
- file: str = None,
11
- max_chunk_size_s: int = 60*10
12
- ):
13
- file = Path(file)
14
- output_dir = file.parent / file.stem
15
- output_dir.mkdir()
16
-
17
- sig = at.AudioSignal(file)
18
-
19
- # split into chunks
20
- for i, sig in tqdm.tqdm(enumerate(sig.windows(
21
- window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
22
- preprocess=True))
23
- ):
24
- sig.write(output_dir / f"{i}.wav")
25
-
26
- print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
27
-
28
- return output_dir
29
-
30
- if __name__ == "__main__":
31
- args = argbind.parse_args()
32
-
33
- with argbind.scope(args):
34
- split_long_audio_file()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/utils/xeno-canto-dl.py DELETED
@@ -1,234 +0,0 @@
1
- from xenopy import Query
2
-
3
-
4
- SPECIES = [
5
- "American Robin",
6
- "Northern Cardinal",
7
- "Mourning Dove",
8
- "American Crow",
9
- "Baltimore Oriole",
10
- "Blue Jay",
11
- "Eastern Bluebird",
12
- "House Finch",
13
- "American Goldfinch",
14
- "House Sparrow",
15
- "Song Sparrow",
16
- "Tufted Titmouse",
17
- "White-breasted Nuthatch",
18
- "European Starling",
19
- "American Redstart",
20
- "Red-winged Blackbird",
21
- "Brown-headed Cowbird",
22
- "Common Grackle",
23
- "Boat-tailed Grackle",
24
- "Common Yellowthroat",
25
- "Northern Mockingbird",
26
- "Carolina Wren",
27
- "Eastern Meadowlark",
28
- "Chipping Sparrow",
29
- "Tree Swallow",
30
- "Barn Swallow",
31
- "Cliff Swallow",
32
- "Pine Siskin",
33
- "Indigo Bunting",
34
- "Eastern Towhee",
35
- "Carolina Chickadee",
36
- "Great Crested Flycatcher",
37
- "Eastern Wood-Pewee",
38
- "Ovenbird",
39
- "Northern Flicker",
40
- "Red-eyed Vireo",
41
- "American Woodcock",
42
- "Eastern Phoebe",
43
- "Downy Woodpecker",
44
- "Scarlet Tanager",
45
- "Yellow Warbler",
46
- "White-eyed Vireo",
47
- "Common Loon",
48
- "White-throated Sparrow",
49
- "Yellow-throated Vireo",
50
- "Great Blue Heron",
51
- "Belted Kingfisher",
52
- "Pied-billed Grebe",
53
- "Wild Turkey",
54
- "Wood Thrush",
55
- "Rose-breasted Grosbeak",
56
- "Field Sparrow",
57
- "Hooded Warbler",
58
- "Northern Parula",
59
- "Chestnut-sided Warbler",
60
- "Blue-winged Warbler",
61
- "Red-bellied Woodpecker",
62
- "Yellow-billed Cuckoo",
63
- "Gray Catbird",
64
- "Northern Saw-whet Owl",
65
- "Osprey",
66
- "Common Nighthawk",
67
- "Broad-winged Hawk",
68
- "Black-throated Green Warbler",
69
- "Great Horned Owl",
70
- "Common Raven",
71
- "Barred Owl",
72
- "Canada Warbler",
73
- "Magnolia Warbler",
74
- "Black-and-white Warbler",
75
- "Eastern Kingbird",
76
- "Swainson's Thrush",
77
- "Worm-eating Warbler",
78
- "Prairie Warbler",
79
- "Baltimore Oriole",
80
- "Black-throated Blue Warbler",
81
- "Louisiana Waterthrush",
82
- "Blackburnian Warbler",
83
- "Black-capped Chickadee",
84
- "Cerulean Warbler",
85
- "Red-shouldered Hawk",
86
- "Cooper's Hawk",
87
- "Yellow-throated Warbler",
88
- "Blue-headed Vireo",
89
- "Blackpoll Warbler",
90
- "Ruffed Grouse",
91
- "Kentucky Warbler",
92
- "Hermit Thrush",
93
- "Cedar Waxwing",
94
- "Eastern Screech-Owl",
95
- "Northern Goshawk",
96
- "Green Heron",
97
- "Red-tailed Hawk",
98
- "Black Vulture",
99
- "Hairy Woodpecker",
100
- "Golden-crowned Kinglet",
101
- "Ruby-crowned Kinglet",
102
- "Bicknell's Thrush",
103
- "Blue-gray Gnatcatcher",
104
- "Veery",
105
- "Pileated Woodpecker",
106
- "Purple Finch",
107
- "White-crowned Sparrow",
108
- "Snow Bunting",
109
- "Pine Grosbeak",
110
- "American Tree Sparrow",
111
- "Dark-eyed Junco",
112
- "Snowy Owl",
113
- "White-winged Crossbill",
114
- "Red Crossbill",
115
- "Common Redpoll",
116
- "Northern Shrike",
117
- "Northern Harrier",
118
- "Rough-legged Hawk",
119
- "Long-eared Owl",
120
- "Evening Grosbeak",
121
- "Northern Pintail",
122
- "American Black Duck",
123
- "Mallard",
124
- "Canvasback",
125
- "Redhead",
126
- "Ring-necked Duck",
127
- "Greater Scaup",
128
- "Lesser Scaup",
129
- "Bufflehead",
130
- "Common Goldeneye",
131
- "Hooded Merganser",
132
- "Common Merganser",
133
- "Red-breasted Merganser",
134
- "Ruddy Duck",
135
- "Wood Duck",
136
- "Gadwall",
137
- "American Wigeon",
138
- "Northern Shoveler",
139
- "Green-winged Teal",
140
- "Blue-winged Teal",
141
- "Cinnamon Teal",
142
- "Ringed Teal",
143
- "Cape Teal",
144
- "Northern Fulmar",
145
- "Yellow-billed Loon",
146
- "Red-throated Loon",
147
- "Arctic Loon",
148
- "Pacific Loon",
149
- "Horned Grebe",
150
- "Red-necked Grebe",
151
- "Eared Grebe",
152
- "Western Grebe",
153
- "Clark's Grebe",
154
- "Double-crested Cormorant",
155
- "Pelagic Cormorant",
156
- "Great Cormorant",
157
- "American White Pelican",
158
- "Brown Pelican",
159
- "Brandt's Cormorant",
160
- "Least Bittern",
161
- "Great Egret",
162
- "Snowy Egret",
163
- "Little Blue Heron",
164
- "Tricolored Heron",
165
- "Reddish Egret",
166
- "Black-crowned Night-Heron",
167
- "Yellow-crowned Night-Heron",
168
- "White Ibis",
169
- "Glossy Ibis",
170
- "Roseate Spoonbill",
171
- "Wood Stork",
172
- "Black-bellied Whistling-Duck",
173
- "Fulvous Whistling-Duck",
174
- "Greater White-fronted Goose",
175
- "Snow Goose",
176
- "Ross's Goose",
177
- "Canada Goose",
178
- "Brant",
179
- "Mute Swan",
180
- "Tundra Swan",
181
- "Whooper Swan",
182
- "Sandhill Crane",
183
- "Black-necked Stilt",
184
- "American Avocet",
185
- "Northern Jacana",
186
- "Greater Yellowlegs",
187
- "Lesser Yellowlegs",
188
- "Willet",
189
- "Spotted Sandpiper",
190
- "Upland Sandpiper",
191
- "Whimbrel",
192
- "Long-billed Curlew",
193
- "Marbled Godwit",
194
- "Ruddy Turnstone",
195
- "Red Knot",
196
- "Sanderling",
197
- "Semipalmated Sandpiper",
198
- "Western Sandpiper",
199
- "Least Sandpiper",
200
- "White-rumped Sandpiper",
201
- "Baird's Sandpiper",
202
- "Pectoral Sandpiper",
203
- "Dunlin",
204
- "Buff-breasted Sandpiper",
205
- "Short-billed Dowitcher",
206
- "Long-billed Dowitcher",
207
- "Common Snipe",
208
- "American Woodcock",
209
- "Wilson's Phalarope",
210
- "Red-necked Phalarope",
211
- "Red Phalarope"
212
- ]
213
-
214
- from pathlib import Path
215
-
216
- def remove_spaces(s):
217
- return s.replace(" ", "")
218
-
219
- for species in SPECIES:
220
- if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
221
- continue
222
- try:
223
- q = Query(
224
- name=species, q="A", length="10-30",
225
- )
226
-
227
- # retrieve metadata
228
- metafiles = q.retrieve_meta(verbose=True)
229
- # retrieve recordings
230
- q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
231
-
232
- except:
233
- print("Failed to download " + species)
234
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
setup.py CHANGED
@@ -28,13 +28,12 @@ setup(
28
  install_requires=[
29
  "torch",
30
  "argbind>=0.3.2",
31
- "numpy==1.23",
32
  "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
33
  "lac @ git+https://github.com/hugofloresgarcia/lac.git",
34
  "descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
35
  "gradio",
 
36
  "loralib",
37
- "torch_pitch_shift",
38
- "madmom",
39
  ],
40
  )
 
28
  install_requires=[
29
  "torch",
30
  "argbind>=0.3.2",
31
+ "numpy==1.22",
32
  "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
33
  "lac @ git+https://github.com/hugofloresgarcia/lac.git",
34
  "descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
35
  "gradio",
36
+ "tensorboardX",
37
  "loralib",
 
 
38
  ],
39
  )
vampnet/interface.py CHANGED
@@ -65,7 +65,7 @@ class Interface(torch.nn.Module):
65
  ):
66
  super().__init__()
67
  assert codec_ckpt is not None, "must provide a codec checkpoint"
68
- self.codec = DAC.load(Path(codec_ckpt))
69
  self.codec.eval()
70
  self.codec.to(device)
71
 
@@ -120,16 +120,17 @@ class Interface(torch.nn.Module):
120
  if coarse_ckpt is not None:
121
  self.coarse.to("cpu")
122
  state_dict = torch.load(coarse_ckpt, map_location="cpu")
123
- print(f"loading coarse from {coarse_ckpt}")
124
  self.coarse.load_state_dict(state_dict, strict=False)
125
  self.coarse.to(self.device)
126
  if c2f_ckpt is not None:
127
  self.c2f.to("cpu")
128
  state_dict = torch.load(c2f_ckpt, map_location="cpu")
129
- print(f"loading c2f from {c2f_ckpt}")
130
  self.c2f.load_state_dict(state_dict, strict=False)
131
  self.c2f.to(self.device)
132
 
 
133
  def s2t(self, seconds: float):
134
  """seconds to tokens"""
135
  if isinstance(seconds, np.ndarray):
@@ -193,8 +194,8 @@ class Interface(torch.nn.Module):
193
 
194
  def make_beat_mask(self,
195
  signal: AudioSignal,
196
- before_beat_s: float = 0.0,
197
- after_beat_s: float = 0.02,
198
  mask_downbeats: bool = True,
199
  mask_upbeats: bool = True,
200
  downbeat_downsample_factor: int = None,
 
65
  ):
66
  super().__init__()
67
  assert codec_ckpt is not None, "must provide a codec checkpoint"
68
+ self.codec = DAC.load(codec_ckpt)
69
  self.codec.eval()
70
  self.codec.to(device)
71
 
 
120
  if coarse_ckpt is not None:
121
  self.coarse.to("cpu")
122
  state_dict = torch.load(coarse_ckpt, map_location="cpu")
123
+
124
  self.coarse.load_state_dict(state_dict, strict=False)
125
  self.coarse.to(self.device)
126
  if c2f_ckpt is not None:
127
  self.c2f.to("cpu")
128
  state_dict = torch.load(c2f_ckpt, map_location="cpu")
129
+
130
  self.c2f.load_state_dict(state_dict, strict=False)
131
  self.c2f.to(self.device)
132
 
133
+
134
  def s2t(self, seconds: float):
135
  """seconds to tokens"""
136
  if isinstance(seconds, np.ndarray):
 
194
 
195
  def make_beat_mask(self,
196
  signal: AudioSignal,
197
+ before_beat_s: float = 0.1,
198
+ after_beat_s: float = 0.1,
199
  mask_downbeats: bool = True,
200
  mask_upbeats: bool = True,
201
  downbeat_downsample_factor: int = None,
vampnet/mask.py CHANGED
@@ -191,47 +191,29 @@ def onset_mask(
191
  width: int = 1
192
  ):
193
  import librosa
194
- import madmom
195
- from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor
196
- import tempfile
197
- import numpy as np
198
-
199
- with tempfile.NamedTemporaryFile(suffix='.wav') as f:
200
- sig = sig.clone()
201
- sig.write(f.name)
202
-
203
- proc = RNNOnsetProcessor(online=False)
204
- onsetproc = OnsetPeakPickingProcessor(threshold=0.3,
205
- fps=sig.sample_rate/interface.codec.hop_length)
206
-
207
- act = proc(f.name)
208
- onset_times = onsetproc(act)
209
-
210
- # convert to indices for z array
211
- onset_indices = librosa.time_to_frames(onset_times, sr=sig.sample_rate, hop_length=interface.codec.hop_length)
212
-
213
- if onset_indices.shape[0] == 0:
214
- mask = empty_mask(z)
215
- print(f"no onsets found, returning empty mask")
216
- else:
217
- torch.set_printoptions(threshold=1000)
218
- print("onset indices: ", onset_indices)
219
- print("onset times: ", onset_times)
220
-
221
- # create a mask, set onset
222
- mask = torch.ones_like(z)
223
- n_timesteps = z.shape[-1]
224
-
225
- for onset_index in onset_indices:
226
- onset_index = min(onset_index, n_timesteps - 1)
227
- onset_index = max(onset_index, 0)
228
- mask[:, :, onset_index - width:onset_index + width] = 0.0
229
-
230
- print(mask)
231
 
232
  return mask
233
 
234
 
235
 
236
  if __name__ == "__main__":
237
- pass
 
 
191
  width: int = 1
192
  ):
193
  import librosa
194
+
195
+ onset_indices = librosa.onset.onset_detect(
196
+ y=sig.clone().to_mono().samples.cpu().numpy()[0, 0],
197
+ sr=sig.sample_rate,
198
+ hop_length=interface.codec.hop_length,
199
+ backtrack=True,
200
+ )
201
+
202
+ # create a mask, set onset
203
+ mask = torch.ones_like(z)
204
+ n_timesteps = z.shape[-1]
205
+
206
+ for onset_index in onset_indices:
207
+ onset_index = min(onset_index, n_timesteps - 1)
208
+ onset_index = max(onset_index, 0)
209
+ mask[:, :, onset_index - width:onset_index + width] = 0.0
210
+
211
+ print(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  return mask
214
 
215
 
216
 
217
  if __name__ == "__main__":
218
+ torch.set_printoptions(threshold=10000)
219
+
vampnet/modules/transformer.py CHANGED
@@ -410,9 +410,7 @@ class TransformerStack(nn.Module):
410
  def subsequent_mask(self, size):
411
  return torch.ones(1, size, size).tril().bool()
412
 
413
- def forward(self, x, x_mask, cond=None, src=None, src_mask=None,
414
- return_activations: bool = False
415
- ):
416
  """Computes a full transformer stack
417
  Parameters
418
  ----------
@@ -439,8 +437,6 @@ class TransformerStack(nn.Module):
439
  encoder_decoder_position_bias = None
440
 
441
  # Compute transformer layers
442
- if return_activations:
443
- activations = []
444
  for layer in self.layers:
445
  x, position_bias, encoder_decoder_position_bias = layer(
446
  x=x,
@@ -451,15 +447,8 @@ class TransformerStack(nn.Module):
451
  position_bias=position_bias,
452
  encoder_decoder_position_bias=encoder_decoder_position_bias,
453
  )
454
- if return_activations:
455
- activations.append(x.detach())
456
 
457
-
458
- out = self.norm(x) if self.norm is not None else x
459
- if return_activations:
460
- return out, torch.stack(activations)
461
- else:
462
- return out
463
 
464
 
465
  class VampNet(at.ml.BaseModel):
@@ -467,7 +456,7 @@ class VampNet(at.ml.BaseModel):
467
  self,
468
  n_heads: int = 20,
469
  n_layers: int = 16,
470
- r_cond_dim: int = 0,
471
  n_codebooks: int = 9,
472
  n_conditioning_codebooks: int = 0,
473
  latent_dim: int = 8,
@@ -478,7 +467,6 @@ class VampNet(at.ml.BaseModel):
478
  dropout: float = 0.1
479
  ):
480
  super().__init__()
481
- assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
482
  self.n_heads = n_heads
483
  self.n_layers = n_layers
484
  self.r_cond_dim = r_cond_dim
@@ -525,25 +513,21 @@ class VampNet(at.ml.BaseModel):
525
  ),
526
  )
527
 
528
- def forward(self, x, return_activations: bool = False):
529
  x = self.embedding(x)
530
  x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
531
 
532
- x = rearrange(x, "b d n -> b n d")
533
- out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
534
- if return_activations:
535
- out, activations = out
536
 
 
 
537
  out = rearrange(out, "b n d -> b d n")
538
 
539
- out = self.classifier(out, None) # no cond here!
540
 
541
  out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
542
 
543
- if return_activations:
544
- return out, activations
545
- else:
546
- return out
547
 
548
  def r_embed(self, r, max_positions=10000):
549
  if self.r_cond_dim > 0:
@@ -594,23 +578,22 @@ class VampNet(at.ml.BaseModel):
594
  self,
595
  codec,
596
  time_steps: int = 300,
597
- sampling_steps: int = 36,
598
  start_tokens: Optional[torch.Tensor] = None,
599
- sampling_temperature: float = 1.0,
600
  mask: Optional[torch.Tensor] = None,
601
- mask_temperature: float = 10.5,
602
  typical_filtering=False,
603
  typical_mass=0.2,
604
  typical_min_tokens=1,
605
- top_p=None,
606
  return_signal=True,
607
- seed: int = None,
608
- sample_cutoff: float = 1.0,
609
  ):
610
- if seed is not None:
611
- at.util.seed(seed)
612
  logging.debug(f"beginning generation with {sampling_steps} steps")
613
 
 
 
 
 
 
614
 
615
 
616
  #####################
@@ -662,6 +645,9 @@ class VampNet(at.ml.BaseModel):
662
  for i in range(sampling_steps):
663
  logging.debug(f"step {i} of {sampling_steps}")
664
 
 
 
 
665
  # our current schedule step
666
  r = scalar_to_batch_tensor(
667
  (i + 1) / sampling_steps,
@@ -676,24 +662,41 @@ class VampNet(at.ml.BaseModel):
676
 
677
  # infer from latents
678
  # NOTE: this collapses the codebook dimension into the sequence dimension
679
- logits = self.forward(latents) # b, prob, seq
680
  logits = logits.permute(0, 2, 1) # b, seq, prob
681
- b = logits.shape[0]
 
 
 
 
 
682
 
683
  logging.debug(f"permuted logits with shape: {logits.shape}")
684
 
685
- sampled_z, selected_probs = sample_from_logits(
686
- logits, sample=(
687
- (i / sampling_steps) <= sample_cutoff
688
- ),
689
- temperature=sampling_temperature,
690
- typical_filtering=typical_filtering, typical_mass=typical_mass,
691
- typical_min_tokens=typical_min_tokens,
692
- top_k=None, top_p=top_p, return_probs=True,
693
- )
694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  logging.debug(f"sampled z with shape: {sampled_z.shape}")
696
 
 
 
 
 
 
 
 
 
697
  # flatten z_masked and mask, so we can deal with the sampling logic
698
  # we'll unflatten them at the end of the loop for the next forward pass
699
  # remove conditioning codebooks, we'll add them back at the end
@@ -730,7 +733,7 @@ class VampNet(at.ml.BaseModel):
730
 
731
  # get our new mask
732
  mask = mask_by_random_topk(
733
- num_to_mask, selected_probs, mask_temperature * (1-r)
734
  )
735
 
736
  # update the mask
@@ -763,97 +766,8 @@ class VampNet(at.ml.BaseModel):
763
  else:
764
  return sampled_z
765
 
766
- def sample_from_logits(
767
- logits,
768
- sample: bool = True,
769
- temperature: float = 1.0,
770
- top_k: int = None,
771
- top_p: float = None,
772
- typical_filtering: bool = False,
773
- typical_mass: float = 0.2,
774
- typical_min_tokens: int = 1,
775
- return_probs: bool = False
776
- ):
777
- """Convenience function to sample from a categorial distribution with input as
778
- unnormalized logits.
779
-
780
- Parameters
781
- ----------
782
- logits : Tensor[..., vocab_size]
783
- config: SamplingConfig
784
- The set of hyperparameters to be used for sampling
785
- sample : bool, optional
786
- Whether to perform multinomial sampling, by default True
787
- temperature : float, optional
788
- Scaling parameter when multinomial samping, by default 1.0
789
- top_k : int, optional
790
- Restricts sampling to only `top_k` values acc. to probability,
791
- by default None
792
- top_p : float, optional
793
- Restricts sampling to only those values with cumulative
794
- probability = `top_p`, by default None
795
-
796
- Returns
797
- -------
798
- Tensor[...]
799
- Sampled tokens
800
- """
801
- shp = logits.shape[:-1]
802
-
803
- if typical_filtering:
804
- typical_filter(logits,
805
- typical_mass=typical_mass,
806
- typical_min_tokens=typical_min_tokens
807
- )
808
-
809
- # Apply top_k sampling
810
- if top_k is not None:
811
- v, _ = logits.topk(top_k)
812
- logits[logits < v[..., [-1]]] = -float("inf")
813
-
814
- # Apply top_p (nucleus) sampling
815
- if top_p is not None and top_p < 1.0:
816
- v, sorted_indices = logits.sort(descending=True)
817
- cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
818
-
819
- sorted_indices_to_remove = cumulative_probs > top_p
820
- # Right shift indices_to_remove to keep 1st token over threshold
821
- sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
822
- ..., :-1
823
- ]
824
-
825
- # Compute indices_to_remove in unsorted array
826
- indices_to_remove = sorted_indices_to_remove.scatter(
827
- -1, sorted_indices, sorted_indices_to_remove
828
- )
829
-
830
- logits[indices_to_remove] = -float("inf")
831
-
832
- # Perform multinomial sampling after normalizing logits
833
- probs = (
834
- F.softmax(logits / temperature, dim=-1)
835
- if temperature > 0
836
- else logits.softmax(dim=-1)
837
- )
838
- token = (
839
- probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
840
- if sample
841
- else logits.argmax(-1)
842
- )
843
-
844
- if return_probs:
845
- token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
846
- return token, token_probs
847
- else:
848
- return token
849
-
850
-
851
 
852
- def mask_by_random_topk(
853
- num_to_mask: int,
854
- probs: torch.Tensor,
855
- temperature: float = 1.0,
856
- ):
857
  """
858
  Args:
859
  num_to_mask (int): number of tokens to mask
@@ -866,8 +780,7 @@ def mask_by_random_topk(
866
  logging.debug(f"temperature: {temperature}")
867
  logging.debug("")
868
 
869
- noise = gumbel_noise_like(probs)
870
- confidence = torch.log(probs) + temperature * noise
871
  logging.debug(f"confidence shape: {confidence.shape}")
872
 
873
  sorted_confidence, sorted_idx = confidence.sort(dim=-1)
@@ -937,7 +850,7 @@ if __name__ == "__main__":
937
  z_mask_latent = torch.rand(
938
  batch_size, model.latent_dim * model.n_codebooks, seq_len
939
  ).to(device)
940
- z_hat = model(z_mask_latent)
941
 
942
  pred = z_hat.argmax(dim=1)
943
  pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
 
410
  def subsequent_mask(self, size):
411
  return torch.ones(1, size, size).tril().bool()
412
 
413
+ def forward(self, x, x_mask, cond=None, src=None, src_mask=None):
 
 
414
  """Computes a full transformer stack
415
  Parameters
416
  ----------
 
437
  encoder_decoder_position_bias = None
438
 
439
  # Compute transformer layers
 
 
440
  for layer in self.layers:
441
  x, position_bias, encoder_decoder_position_bias = layer(
442
  x=x,
 
447
  position_bias=position_bias,
448
  encoder_decoder_position_bias=encoder_decoder_position_bias,
449
  )
 
 
450
 
451
+ return self.norm(x) if self.norm is not None else x
 
 
 
 
 
452
 
453
 
454
  class VampNet(at.ml.BaseModel):
 
456
  self,
457
  n_heads: int = 20,
458
  n_layers: int = 16,
459
+ r_cond_dim: int = 64,
460
  n_codebooks: int = 9,
461
  n_conditioning_codebooks: int = 0,
462
  latent_dim: int = 8,
 
467
  dropout: float = 0.1
468
  ):
469
  super().__init__()
 
470
  self.n_heads = n_heads
471
  self.n_layers = n_layers
472
  self.r_cond_dim = r_cond_dim
 
513
  ),
514
  )
515
 
516
+ def forward(self, x, cond):
517
  x = self.embedding(x)
518
  x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
519
 
520
+ cond = self.r_embed(cond)
 
 
 
521
 
522
+ x = rearrange(x, "b d n -> b n d")
523
+ out = self.transformer(x=x, x_mask=x_mask, cond=cond)
524
  out = rearrange(out, "b n d -> b d n")
525
 
526
+ out = self.classifier(out, cond)
527
 
528
  out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
529
 
530
+ return out
 
 
 
531
 
532
  def r_embed(self, r, max_positions=10000):
533
  if self.r_cond_dim > 0:
 
578
  self,
579
  codec,
580
  time_steps: int = 300,
581
+ sampling_steps: int = 24,
582
  start_tokens: Optional[torch.Tensor] = None,
 
583
  mask: Optional[torch.Tensor] = None,
584
+ temperature: Union[float, Tuple[float, float]] = 2.5,
585
  typical_filtering=False,
586
  typical_mass=0.2,
587
  typical_min_tokens=1,
 
588
  return_signal=True,
 
 
589
  ):
 
 
590
  logging.debug(f"beginning generation with {sampling_steps} steps")
591
 
592
+ #####################
593
+ # resolve temperature #
594
+ #####################
595
+ assert isinstance(temperature, float)
596
+ logging.debug(f"temperature: {temperature}")
597
 
598
 
599
  #####################
 
645
  for i in range(sampling_steps):
646
  logging.debug(f"step {i} of {sampling_steps}")
647
 
648
+ # our current temperature
649
+ logging.debug(f"temperature: {temperature}")
650
+
651
  # our current schedule step
652
  r = scalar_to_batch_tensor(
653
  (i + 1) / sampling_steps,
 
662
 
663
  # infer from latents
664
  # NOTE: this collapses the codebook dimension into the sequence dimension
665
+ logits = self.forward(latents, r) # b, prob, seq
666
  logits = logits.permute(0, 2, 1) # b, seq, prob
667
+ if typical_filtering:
668
+ typical_filter(logits,
669
+ typical_mass=typical_mass,
670
+ typical_min_tokens=typical_min_tokens
671
+ )
672
+
673
 
674
  logging.debug(f"permuted logits with shape: {logits.shape}")
675
 
 
 
 
 
 
 
 
 
 
676
 
677
+ # logits2probs
678
+ probs = torch.softmax(logits, dim=-1)
679
+ logging.debug(f"computed probs with shape: {probs.shape}")
680
+
681
+
682
+ # sample from logits with multinomial sampling
683
+ b = probs.shape[0]
684
+ probs = rearrange(probs, "b seq prob -> (b seq) prob")
685
+
686
+ sampled_z = torch.multinomial(probs, 1).squeeze(-1)
687
+
688
+ sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
689
+ probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
690
  logging.debug(f"sampled z with shape: {sampled_z.shape}")
691
 
692
+ # get the confidences: which tokens did we sample?
693
+ selected_probs = (
694
+ torch.take_along_dim(
695
+ probs, sampled_z.long().unsqueeze(-1),
696
+ dim=-1
697
+ ).squeeze(-1)
698
+ )
699
+
700
  # flatten z_masked and mask, so we can deal with the sampling logic
701
  # we'll unflatten them at the end of the loop for the next forward pass
702
  # remove conditioning codebooks, we'll add them back at the end
 
733
 
734
  # get our new mask
735
  mask = mask_by_random_topk(
736
+ num_to_mask, selected_probs, temperature * (1-r)
737
  )
738
 
739
  # update the mask
 
766
  else:
767
  return sampled_z
768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769
 
770
+ def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
 
 
 
 
771
  """
772
  Args:
773
  num_to_mask (int): number of tokens to mask
 
780
  logging.debug(f"temperature: {temperature}")
781
  logging.debug("")
782
 
783
+ confidence = torch.log(probs) + temperature * gumbel_noise_like(probs)
 
784
  logging.debug(f"confidence shape: {confidence.shape}")
785
 
786
  sorted_confidence, sorted_idx = confidence.sort(dim=-1)
 
850
  z_mask_latent = torch.rand(
851
  batch_size, model.latent_dim * model.n_codebooks, seq_len
852
  ).to(device)
853
+ z_hat = model(z_mask_latent, r)
854
 
855
  pred = z_hat.argmax(dim=1)
856
  pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)