Hugo Flores Garcia commited on
Commit
e3ca5f7
1 Parent(s): 7aa3063

refactor masking, interface, demo

Browse files
.dockerignore DELETED
@@ -1,2 +0,0 @@
1
- *.wav
2
- runs/
 
 
 
README.md CHANGED
@@ -27,7 +27,7 @@ git clone https://github.com/hugofloresgarcia/vampnet2.git
27
  pip install -e ./vampnet2
28
  ```
29
 
30
- ## A note on Argbind
31
  This repository relies on [argbind](https://github.com/pseeth/argbind) to manage CLIs and config files.
32
  Config files are stored in the `conf/` folder.
33
 
@@ -56,6 +56,10 @@ You just need to provide a list of audio files // folders to fine-tune on, then
56
  python scripts/exp/train.py --args.load conf/lora/birds.yml --save_path /path/to/checkpoints
57
  ```
58
 
 
 
 
 
59
  ## Launching the Gradio Interface
60
  ```bash
61
  python demo.py --args.load conf/interface/spotdl.yml --Interface.device cuda
 
27
  pip install -e ./vampnet2
28
  ```
29
 
30
+ ## A note on argbind
31
  This repository relies on [argbind](https://github.com/pseeth/argbind) to manage CLIs and config files.
32
  Config files are stored in the `conf/` folder.
33
 
 
56
  python scripts/exp/train.py --args.load conf/lora/birds.yml --save_path /path/to/checkpoints
57
  ```
58
 
59
+ ## Getting the Pretrained Models
60
+
61
+
62
+
63
  ## Launching the Gradio Interface
64
  ```bash
65
  python demo.py --args.load conf/interface/spotdl.yml --Interface.device cuda
demo.py CHANGED
@@ -3,6 +3,7 @@ from typing import Tuple
3
  import yaml
4
  import tempfile
5
  import uuid
 
6
 
7
  import numpy as np
8
  import audiotools as at
@@ -10,6 +11,7 @@ import argbind
10
 
11
  import gradio as gr
12
  from vampnet.interface import Interface
 
13
 
14
  Interface = argbind.bind(Interface)
15
  AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
@@ -60,132 +62,61 @@ def load_random_audio():
60
  return sig.path_to_file
61
 
62
 
63
- def ez_vamp(
64
- input_audio, init_temp, final_temp,
65
- mask_periodic_amt, mask_periodic_width, num_steps,
66
- stretch_factor,
67
- ):
68
- print(input_audio)
69
- sig = at.AudioSignal(input_audio)
70
-
71
- print(f"running standard vampnet with {num_vamps} vamps")
72
- zv = interface.coarse_vamp(
73
- sig,
74
- sampling_steps=num_steps,
75
- temperature=(init_temp, final_temp),
76
- prefix_dur_s=0.0,
77
- suffix_dur_s=0.0,
78
- num_vamps=1,
79
- downsample_factor=mask_periodic_amt,
80
- stretch_factor=stretch_factor,
81
- periodic_width=mask_periodic_width,
82
- periodic_dropout=0.0,
83
- periodic_width_dropout=0.0,
84
- n_conditioning_codebooks=None,
85
- intensity=1.0,
86
- ext_mask=None,
87
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- zv = interface.coarse_to_fine(zv)
90
 
91
  sig = interface.to_signal(zv).cpu()
92
  print("done")
93
 
94
  out_dir = OUT_DIR / str(uuid.uuid4())
95
  out_dir.mkdir()
96
- sig.write(out_dir / "output.wav")
97
- # mask.write(out_dir / "mask.wav")
98
- # return sig.path_to_file, mask.path_to_file
99
- return sig.path_to_file
100
-
101
 
102
- def vamp(
103
- input_audio, init_temp, final_temp,
104
- prefix_s, suffix_s, rand_mask_intensity,
105
- mask_periodic_amt, beat_unmask_dur,
106
- mask_dwn_chk, dwn_factor,
107
- mask_up_chk, up_factor,
108
- num_vamps, mode, use_beats, num_steps, snap_to_beats,
109
- beat_unmask_drop, mask_periodic_width,
110
- mask_periodic_dropout, mask_periodic_width_dropout,
111
- n_conditioning_codebooks, use_coarse2fine, stretch_factor,
112
- ):
113
- # try:
114
- print(input_audio)
115
-
116
- sig = at.AudioSignal(input_audio)
117
-
118
- if snap_to_beats:
119
- old_sig = sig.clone()
120
- sig = interface.snap_to_beats(sig)
121
- if sig.duration < (sig.duration / 4): # we cut off too much
122
- sig = old_sig
123
- print(f"new sig duration is {sig.duration} which is too short, reverting to old sig")
124
- print(f"new sig duration is {sig.duration}")
125
-
126
- if beat_unmask_dur > 0.0 and use_beats:
127
- beat_mask = interface.make_beat_mask(
128
- sig,
129
- before_beat_s=0.0,
130
- after_beat_s=beat_unmask_dur,
131
- mask_downbeats=mask_dwn_chk,
132
- mask_upbeats=mask_up_chk,
133
- downbeat_downsample_factor=dwn_factor if dwn_factor > 0 else None,
134
- beat_downsample_factor=up_factor if up_factor > 0 else None,
135
- dropout=beat_unmask_drop,
136
- invert=True
137
- )
138
- print(beat_mask)
139
- else:
140
- beat_mask = None
141
-
142
- if mode == "standard":
143
- print(f"running standard vampnet with {num_vamps} vamps")
144
- zv, mask_z = interface.coarse_vamp(
145
- sig,
146
- sampling_steps=num_steps,
147
- temperature=(init_temp, final_temp),
148
- prefix_dur_s=prefix_s,
149
- suffix_dur_s=suffix_s,
150
- num_vamps=num_vamps,
151
- downsample_factor=mask_periodic_amt,
152
- stretch_factor=stretch_factor,
153
- periodic_width=mask_periodic_width,
154
- periodic_dropout=mask_periodic_dropout,
155
- periodic_width_dropout=mask_periodic_width_dropout,
156
- n_conditioning_codebooks=n_conditioning_codebooks if n_conditioning_codebooks > 0 else None,
157
- intensity=rand_mask_intensity,
158
- ext_mask=beat_mask,
159
- verbose=True,
160
- return_mask=True
161
- )
162
-
163
- if use_coarse2fine:
164
- zv = interface.coarse_to_fine(zv)
165
- mask = interface.to_signal(mask_z).cpu()
166
-
167
- sig = interface.to_signal(zv).cpu()
168
- print("done")
169
-
170
- out_dir = OUT_DIR / str(uuid.uuid4())
171
- out_dir.mkdir()
172
- sig.write(out_dir / "output.wav")
173
- mask.write(out_dir / "mask.wav")
174
- return sig.path_to_file, mask.path_to_file
175
- # return sig.path_to_file, mask_z
176
- # except Exception as e:
177
- # raise gr.Error(f"failed with error: {e}")
178
 
179
- def save_vamp(
180
- input_audio, init_temp, final_temp,
181
- prefix_s, suffix_s, rand_mask_intensity,
182
- mask_periodic_amt, beat_unmask_dur,
183
- mask_dwn_chk, dwn_factor,
184
- mask_up_chk, up_factor,
185
- num_vamps, mode, output_audio, notes, use_beats, num_steps, snap_to_beats,
186
- beat_unmask_drop, mask_periodic_width, mask_periodic_dropout, mask_periodic_width_dropout,
187
- n_conditioning_codebooks, use_coarse2fine, stretch_factor
188
- ):
189
  out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
190
  out_dir.mkdir(parents=True, exist_ok=True)
191
 
@@ -196,30 +127,18 @@ def save_vamp(
196
  sig_out.write(out_dir / "output.wav")
197
 
198
  data = {
199
- "init_temp": init_temp,
200
- "final_temp": final_temp,
201
- "prefix_s": prefix_s,
202
- "suffix_s": suffix_s,
203
- "rand_mask_intensity": rand_mask_intensity,
204
- "mask_periodic_amt": mask_periodic_amt,
205
- "use_beats": use_beats,
206
- "beat_unmask_dur": beat_unmask_dur,
207
- "mask_dwn_chk": mask_dwn_chk,
208
- "dwn_factor": dwn_factor,
209
- "mask_up_chk": mask_up_chk,
210
- "up_factor": up_factor,
211
- "num_vamps": num_vamps,
212
- "num_steps": num_steps,
213
- "snap_to_beats": snap_to_beats,
214
- "mode": mode,
215
- "notes": notes,
216
- "beat_unmask_drop": beat_unmask_drop,
217
- "mask_periodic_width": mask_periodic_width,
218
- "mask_periodic_dropout": mask_periodic_dropout,
219
- "mask_periodic_width_dropout": mask_periodic_width_dropout,
220
- "n_conditioning_codebooks": n_conditioning_codebooks,
221
- "use_coarse2fine": use_coarse2fine,
222
- "stretch_factor": stretch_factor,
223
  }
224
 
225
  # save with yaml
@@ -234,68 +153,16 @@ def save_vamp(
234
 
235
  return f"saved! your save code is {out_dir.stem}", zip_path
236
 
237
- with gr.Blocks() as demo:
238
-
239
- with gr.Row():
240
- # input audio
241
- with gr.Column():
242
- gr.Markdown("""
243
- # Vampnet
244
- **Instructions**:
245
- 1. Upload some audio (or click the load random audio button)
246
- 2. Adjust the mask hints. The more hints, the more the generated music will follow the input music
247
- 3. Adjust the vampnet parameters. The more vamps, the longer the generated music will be
248
- 4. Click the "vamp" button
249
- 5. Listen to the generated audio
250
- 6. If you noticed something you liked, write some notes, click the "save vamp" button, and copy the save code
251
-
252
- """)
253
- gr.Markdown("## Input Audio")
254
 
255
- with gr.Column():
256
- gr.Markdown("""
257
- ### Tips
258
- - use the beat hint button so the output audio has the same beat structure as the input audio
259
- - if you want more beat structure:
260
- - enable beat hints
261
- - if you want a more "random" generation:
262
- - increase the periodic unmasking to 12 or more
263
- - increase the temperatures!
264
- - uncheck the beat hint button (or reduce the beat unmask duration)
265
- - if you want the generated audio to sound like the original, but with a different beat structure:
266
- - uncheck the beat hint button
267
- - decrease the periodic unmasking to anywhere from 2 to 20
268
- - slightly decrease the random intensity, to like .95
269
-
270
-
271
- """)
272
- with gr.Column():
273
- gr.Markdown("""
274
- ## Mask Hints
275
- - most of the original audio will be masked and replaced with audio generated by vampnet
276
- - mask hints are used to guide vampnet to generate audio that sounds like the original
277
- - the more hints you give, the more the generated audio will sound like the original
278
-
279
- """)
280
 
 
281
 
282
  with gr.Row():
283
  with gr.Column():
284
- mode = gr.Radio(
285
- label="**mode**. note that loop mode requires a prefix and suffix longer than 0",
286
- choices=["standard",],
287
- value="standard"
288
- )
289
  use_coarse2fine = gr.Checkbox(
290
  label="use coarse2fine",
291
  value=True
292
  )
293
- num_vamps = gr.Number(
294
- label="number of vamps. more vamps = longer generated audio",
295
- value=1,
296
- precision=0,
297
- visible=False
298
- )
299
 
300
  manual_audio_upload = gr.File(
301
  label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
@@ -328,57 +195,46 @@ with gr.Blocks() as demo:
328
  outputs=[ input_audio]
329
  )
330
 
331
-
332
  # mask settings
333
  with gr.Column():
334
 
335
- n_conditioning_codebooks = gr.Number(
336
- label="number of conditioning codebooks. probably 0",
337
- value=0,
338
- precision=0,
339
- )
340
-
341
- stretch_factor = gr.Slider(
342
- label="time stretch factor",
343
- minimum=0,
344
- maximum=64,
345
- step=1,
346
- value=1,
347
  )
348
 
349
- mask_periodic_amt = gr.Slider(
350
- label="periodic hint (0.0 means no hint, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
351
  minimum=0,
352
- maximum=64,
353
  step=1,
354
  value=9,
355
  )
356
- mask_periodic_width = gr.Slider(
357
- label="periodic hint width (steps, 1 step ~= 10milliseconds",
358
  minimum=1,
359
- maximum=100,
360
  step=1,
361
  value=1,
362
  )
363
- mask_periodic_dropout = gr.Slider(
364
- label="periodic hint dropout (0.0 means no dropout, 1.0 means all dropout)",
365
- minimum=0.0,
366
- maximum=1.0,
367
- value=0.0,
368
- )
369
- mask_periodic_width_dropout = gr.Slider(
370
- label="periodic hint width dropout (0.0 means no dropout, 1.0 means all dropout)",
371
- minimum=0.0,
372
- maximum=1.0,
373
- value=0.0,
374
- )
375
 
376
- rand_mask_intensity = gr.Slider(
377
- label="random mask intensity. (If this is less than 1, scatters tiny hints throughout the audio, should be between 0.9 and 1.0)",
378
- minimum=0.8,
379
- maximum=1.0,
380
- value=1.0
381
- )
 
 
 
 
 
 
 
 
 
382
 
383
  with gr.Accordion("prefix/suffix hints", open=False):
384
  prefix_s = gr.Slider(
@@ -408,15 +264,6 @@ with gr.Blocks() as demo:
408
  value=1.0
409
  )
410
 
411
- use_beats = gr.Checkbox(
412
- label="use beat hints (helps the output stick to the beat structure of the input)",
413
- value=False
414
- )
415
-
416
- snap_to_beats = gr.Checkbox(
417
- label="trim to beat markers (uncheck if the output audio is too short.)",
418
- value=True
419
- )
420
 
421
  num_steps = gr.Slider(
422
  label="number of steps (should normally be between 12 and 36)",
@@ -426,6 +273,14 @@ with gr.Blocks() as demo:
426
  value=36
427
  )
428
 
 
 
 
 
 
 
 
 
429
  vamp_button = gr.Button("vamp!!!")
430
 
431
  output_audio = gr.Audio(
@@ -434,59 +289,26 @@ with gr.Blocks() as demo:
434
  type="filepath"
435
  )
436
 
437
-
438
- # gr.Markdown("**NOTE**: for loop mode, both prefix and suffix must be greater than 0.")
439
- # compute_mask_button = gr.Button("compute mask")
440
- # mask_output = gr.Audio(
441
- # label="masked audio",
442
- # interactive=False,
443
- # visible=False
444
- # )
445
- # mask_output_viz = gr.Video(
446
- # label="masked audio",
447
- # interactive=False
448
- # )
449
 
450
- with gr.Column():
451
- with gr.Accordion(label="beat unmask (how much time around the beat should be hinted?)"):
 
 
 
 
 
 
 
 
 
452
 
453
- beat_unmask_dur = gr.Slider(
454
- label="duration",
455
- minimum=0.0,
456
- maximum=3.0,
457
- value=0.07
458
- )
459
- beat_unmask_drop = gr.Slider(
460
- label="dropout (within beat)",
461
- minimum=0.0,
462
- maximum=1.0,
463
- value=0.0
464
- )
465
 
466
- with gr.Accordion("downbeat settings", open=False):
467
- mask_dwn_chk = gr.Checkbox(
468
- label="hint downbeats",
469
- value=True
470
- )
471
- dwn_factor = gr.Slider(
472
- label="downbeat downsample factor (hint only every Nth downbeat)",
473
- value=0,
474
- minimum=0,
475
- maximum=16,
476
- step=1
477
- )
478
- with gr.Accordion("upbeat settings", open=False):
479
- mask_up_chk = gr.Checkbox(
480
- label="hint upbeats",
481
- value=True
482
- )
483
- up_factor = gr.Slider(
484
- label="upbeat downsample factor (hint only every Nth upbeat)",
485
- value=0,
486
- minimum=0,
487
- maximum=16,
488
- step=1
489
- )
490
 
491
  notes_text = gr.Textbox(
492
  label="type any notes about the generated audio here",
@@ -499,52 +321,43 @@ with gr.Blocks() as demo:
499
  interactive=False
500
  )
501
 
502
-
503
  thank_you = gr.Markdown("")
504
-
505
-
506
  # connect widgets
507
  vamp_button.click(
508
  fn=vamp,
509
- inputs=[input_audio, init_temp,final_temp,
510
- prefix_s, suffix_s, rand_mask_intensity,
511
- mask_periodic_amt, beat_unmask_dur,
512
- mask_dwn_chk, dwn_factor,
513
- mask_up_chk, up_factor,
514
- num_vamps, mode, use_beats, num_steps, snap_to_beats,
515
- beat_unmask_drop, mask_periodic_width,
516
- mask_periodic_dropout, mask_periodic_width_dropout,
517
- n_conditioning_codebooks, use_coarse2fine, stretch_factor
518
- ],
 
 
519
  outputs=[output_audio, audio_mask],
520
  api_name="vamp"
521
  )
522
 
523
  save_button.click(
524
  fn=save_vamp,
525
- inputs=[
526
- input_audio, init_temp, final_temp,
527
- prefix_s, suffix_s, rand_mask_intensity,
528
- mask_periodic_amt, beat_unmask_dur,
529
- mask_dwn_chk, dwn_factor,
530
- mask_up_chk, up_factor,
531
- num_vamps, mode,
532
- output_audio,
533
- notes_text, use_beats, num_steps, snap_to_beats,
534
- beat_unmask_drop, mask_periodic_width,
535
- mask_periodic_dropout, mask_periodic_width_dropout,
536
- n_conditioning_codebooks, use_coarse2fine, stretch_factor
537
- ],
538
  outputs=[thank_you, download_file]
539
  )
540
 
541
- ez_vamp_button = gr.Button("ez vamp")
542
- ez_vamp_button.click(
543
- fn=ez_vamp,
544
- inputs=[input_audio, init_temp, final_temp, mask_periodic_amt,
545
- mask_periodic_width, num_steps, stretch_factor ],
546
- outputs=[output_audio],
547
- api_name="ez_vamp"
548
- )
549
-
550
  demo.launch(share=True, enable_queue=False, debug=True)
 
3
  import yaml
4
  import tempfile
5
  import uuid
6
+ from dataclasses import dataclass, asdict
7
 
8
  import numpy as np
9
  import audiotools as at
 
11
 
12
  import gradio as gr
13
  from vampnet.interface import Interface
14
+ from vampnet import mask as pmask
15
 
16
  Interface = argbind.bind(Interface)
17
  AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
 
62
  return sig.path_to_file
63
 
64
 
65
+ def vamp(data):
66
+ print(data[input_audio])
67
+ sig = at.AudioSignal(data[input_audio])
68
+
69
+ z = interface.encode(sig)
70
+
71
+ ncc = data[n_conditioning_codebooks]
72
+
73
+ # build the mask
74
+ mask = pmask.linear_random(z, data[rand_mask_intensity])
75
+ mask = pmask.mask_and(
76
+ mask, pmask.inpaint(
77
+ z,
78
+ interface.s2t(data[prefix_s]),
79
+ interface.s2t(data[suffix_s])
80
+ )
 
 
 
 
 
 
 
 
81
  )
82
+ mask = pmask.mask_and(
83
+ mask, pmask.periodic_mask(
84
+ z,
85
+ data[periodic_p],
86
+ data[periodic_w],
87
+ random_roll=True
88
+ )
89
+ )
90
+ mask = pmask.dropout(mask, data[dropout])
91
+ mask = pmask.codebook_unmask(mask, ncc)
92
+
93
+ print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}")
94
+
95
+ zv, mask_z = interface.coarse_vamp(
96
+ z,
97
+ mask=mask,
98
+ sampling_steps=data[num_steps],
99
+ temperature=(data[init_temp], data[final_temp]),
100
+ return_mask=True
101
+ )
102
+
103
+ if use_coarse2fine:
104
+ zv = interface.coarse_to_fine(zv)
105
+
106
 
107
+ mask = interface.to_signal(mask_z).cpu()
108
 
109
  sig = interface.to_signal(zv).cpu()
110
  print("done")
111
 
112
  out_dir = OUT_DIR / str(uuid.uuid4())
113
  out_dir.mkdir()
 
 
 
 
 
114
 
115
+ sig.write(out_dir / "output.wav")
116
+ mask.write(out_dir / "mask.wav")
117
+ return sig.path_to_file, mask.path_to_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ def save_vamp(data):
 
 
 
 
 
 
 
 
 
120
  out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
121
  out_dir.mkdir(parents=True, exist_ok=True)
122
 
 
127
  sig_out.write(out_dir / "output.wav")
128
 
129
  data = {
130
+ "init_temp": data[init_temp],
131
+ "final_temp": data[final_temp],
132
+ "prefix_s": data[prefix_s],
133
+ "suffix_s": data[suffix_s],
134
+ "rand_mask_intensity": data[rand_mask_intensity],
135
+ "num_steps": data[num_steps],
136
+ "notes": data[notes_text],
137
+ "periodic_period": data[periodic_p],
138
+ "periodic_width": data[periodic_w],
139
+ "n_conditioning_codebooks": data[n_conditioning_codebooks],
140
+ "use_coarse2fine": data[use_coarse2fine],
141
+ "stretch_factor": data[stretch_factor],
 
 
 
 
 
 
 
 
 
 
 
 
142
  }
143
 
144
  # save with yaml
 
153
 
154
  return f"saved! your save code is {out_dir.stem}", zip_path
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ with gr.Blocks() as demo:
159
 
160
  with gr.Row():
161
  with gr.Column():
 
 
 
 
 
162
  use_coarse2fine = gr.Checkbox(
163
  label="use coarse2fine",
164
  value=True
165
  )
 
 
 
 
 
 
166
 
167
  manual_audio_upload = gr.File(
168
  label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
 
195
  outputs=[ input_audio]
196
  )
197
 
 
198
  # mask settings
199
  with gr.Column():
200
 
201
+ rand_mask_intensity = gr.Slider(
202
+ label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
203
+ minimum=0.0,
204
+ maximum=1.0,
205
+ value=1.0
 
 
 
 
 
 
 
206
  )
207
 
208
+ periodic_p = gr.Slider(
209
+ label="periodic prompt (0.0 means no hint, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
210
  minimum=0,
211
+ maximum=128,
212
  step=1,
213
  value=9,
214
  )
215
+ periodic_w = gr.Slider(
216
+ label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
217
  minimum=1,
218
+ maximum=20,
219
  step=1,
220
  value=1,
221
  )
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ with gr.Accordion("extras ", open=False):
224
+ n_conditioning_codebooks = gr.Number(
225
+ label="number of conditioning codebooks. probably 0",
226
+ value=0,
227
+ precision=0,
228
+ )
229
+
230
+ stretch_factor = gr.Slider(
231
+ label="time stretch factor",
232
+ minimum=0,
233
+ maximum=64,
234
+ step=1,
235
+ value=1,
236
+ )
237
+
238
 
239
  with gr.Accordion("prefix/suffix hints", open=False):
240
  prefix_s = gr.Slider(
 
264
  value=1.0
265
  )
266
 
 
 
 
 
 
 
 
 
 
267
 
268
  num_steps = gr.Slider(
269
  label="number of steps (should normally be between 12 and 36)",
 
273
  value=36
274
  )
275
 
276
+ dropout = gr.Slider(
277
+ label="mask dropout",
278
+ minimum=0.0,
279
+ maximum=1.0,
280
+ step=0.01,
281
+ value=0.0
282
+ )
283
+
284
  vamp_button = gr.Button("vamp!!!")
285
 
286
  output_audio = gr.Audio(
 
289
  type="filepath"
290
  )
291
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ # with gr.Column():
294
+ # with gr.Accordion(label="beat unmask (how much time around the beat should be hinted?)"):
295
+ # use_beats = gr.Checkbox(
296
+ # label="use beat hints (helps the output stick to the beat structure of the input)",
297
+ # value=False
298
+ # )
299
+
300
+ # snap_to_beats = gr.Checkbox(
301
+ # label="trim to beat markers (uncheck if the output audio is too short.)",
302
+ # value=True
303
+ # )
304
 
305
+ # beat_unmask_dur = gr.Slider(
306
+ # label="duration",
307
+ # minimum=0.0,
308
+ # maximum=3.0,
309
+ # value=0.07
310
+ # )
 
 
 
 
 
 
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  notes_text = gr.Textbox(
314
  label="type any notes about the generated audio here",
 
321
  interactive=False
322
  )
323
 
 
324
  thank_you = gr.Markdown("")
325
+
 
326
  # connect widgets
327
  vamp_button.click(
328
  fn=vamp,
329
+ inputs={
330
+ input_audio,
331
+ num_steps,
332
+ init_temp, final_temp,
333
+ prefix_s, suffix_s,
334
+ rand_mask_intensity,
335
+ periodic_p, periodic_w,
336
+ n_conditioning_codebooks,
337
+ dropout,
338
+ use_coarse2fine,
339
+ stretch_factor
340
+ },
341
  outputs=[output_audio, audio_mask],
342
  api_name="vamp"
343
  )
344
 
345
  save_button.click(
346
  fn=save_vamp,
347
+ inputs={
348
+ input_audio,
349
+ num_steps,
350
+ init_temp, final_temp,
351
+ prefix_s, suffix_s,
352
+ rand_mask_intensity,
353
+ periodic_p, periodic_w,
354
+ n_conditioning_codebooks,
355
+ dropout,
356
+ use_coarse2fine,
357
+ stretch_factor,
358
+ notes_text
359
+ },
360
  outputs=[thank_you, download_file]
361
  )
362
 
 
 
 
 
 
 
 
 
 
363
  demo.launch(share=True, enable_queue=False, debug=True)
env/alias.sh DELETED
@@ -1,3 +0,0 @@
1
- alias cleanup="pkill python && echo -en '\e[?25h'"
2
- alias stage="python ./scripts/utils/stage.py"
3
- alias fix_cursor="echo -en '\e[?25h'"
 
 
 
 
env/data.sh DELETED
@@ -1,36 +0,0 @@
1
- export PATH_TO_DATA=~/data
2
-
3
- if [[ $(hostname) == "oon17" ]]; then
4
- export PATH_TO_DATA=/data/
5
- fi
6
-
7
- if [[ $(hostname) == "oon19" ]]; then
8
- export PATH_TO_DATA=/home/prem/shared/data/
9
- fi
10
-
11
- if [[ $(hostname) == "lucas-ssound-trt-vm" ]]; then
12
- export PATH_TO_DATA=~/data
13
- fi
14
-
15
- if [[ $(hostname) == "a100-ssound" ]]; then
16
- export PATH_TO_DATA=~/data
17
- fi
18
-
19
- if [[ $(hostname) == "oon25" ]]; then
20
- export PATH_TO_DATA=/data
21
- fi
22
-
23
- if [[ $(hostname) == "macbook-pro-2.lan" ]]; then
24
- export PATH_TO_DATA=~/data
25
- fi
26
-
27
- if [[ $(hostname) == "oon11" ]]; then
28
- export PATH_TO_DATA=/data2/syncthing_lucas/data
29
- fi
30
-
31
- if [[ $(hostname) == "oon12" ]]; then
32
- export PATH_TO_DATA=/data
33
- fi
34
- if [[ $(hostname) == "oon26" ]]; then
35
- export PATH_TO_DATA=/data
36
- fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
env/entry_script.sh DELETED
@@ -1,41 +0,0 @@
1
- #!/bin/bash
2
- set -e
3
-
4
- if [ -z "${USER}" ]; then
5
- echo "We need USER to be set!"; exit 100
6
- fi
7
-
8
- # check if host uid and gid are set
9
- if [ -z "${HOST_USER_ID}" ]; then
10
- echo "Please set HOST_USER_ID env. variables to continue." ; exit 0
11
- fi
12
-
13
- if [ -z "${HOST_USER_GID}" ]; then
14
- echo "Please set HOST_USER_GID env. variables to continue." ; exit 0
15
- fi
16
-
17
- USER_ID=$HOST_USER_ID
18
- USER_GID=$HOST_USER_GID
19
- USER_HOME=/u/home
20
-
21
- # modify uid and gid to match host
22
- sed -i -e "s/^${USER}:\([^:]*\):[0-9]*:[0-9]*/${USER}:\1:${USER_ID}:${USER_GID}/" /etc/passwd
23
-
24
- # create a group for host gid
25
- groupadd -f --gid "${USER_GID}" "host_group"
26
-
27
- chown $USER_ID $USER_HOME
28
- chown $USER_ID /u/home/.zshrc
29
- chown $USER_ID /u/home/.oh-my-zsh
30
-
31
- mkdir -p /u/home/.cache
32
- chown -R $USER_ID:$USER_GID /u/home/.cache/
33
-
34
- _term() {
35
- echo "Caught SIGTERM signal!"
36
- kill -TERM "$child" 2>/dev/null
37
- }
38
-
39
- trap _term SIGTERM
40
-
41
- su -p "${USER}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
env/setup.py DELETED
@@ -1,94 +0,0 @@
1
- # This script guides the user through setting up their env.sh
2
- # if env.sh does not exist. Should have no dependencies other
3
- # than Python standard library.
4
- import shlex
5
- import socket
6
- import subprocess
7
- import textwrap
8
-
9
-
10
- def run(cmd):
11
- return subprocess.check_output(shlex.split(cmd)).decode("utf-8")
12
-
13
-
14
-
15
-
16
- print()
17
- print("4. Setting up paths.")
18
- print("--------------------")
19
-
20
- PATH_TO_RUNS = input("Where runs should go (default:./runs/): ") or "./runs/"
21
- TENSORBOARD_PATH = (
22
- input("Bucket/dir for tensorboard logs (default=PATH_TO_RUNS): ") or PATH_TO_RUNS
23
- )
24
-
25
- with open("env/data.sh") as f:
26
- data_script = f.read()
27
-
28
- write_to_data_sh = False
29
- if socket.gethostname() not in data_script:
30
- print("Looks like the data path for this machine is not setup.")
31
- PATH_TO_DATA = input(f"Path to data on {socket.gethostname()}: ") or "~/data"
32
-
33
- data_command = f"""
34
- if [[ $(hostname) == "{socket.gethostname()}" ]]; then
35
- export PATH_TO_DATA={PATH_TO_DATA}
36
- fi
37
- """
38
- write_to_data_sh = True
39
-
40
-
41
- print()
42
- print("5. Setting up Papaya")
43
- print("-----------------------------------------")
44
-
45
- PAPAYA_USER_TOKEN = input("Papaya user token: ") or "undefined"
46
-
47
- env_script = f"""
48
- source env/alias.sh
49
- source env/data.sh
50
- export GITHUB_TOKEN={GITHUB_TOKEN}
51
-
52
- export PAPAYA_USER_TOKEN={PAPAYA_USER_TOKEN}
53
-
54
- export HOST_USER_ID=$(id -u)
55
- export HOST_USER_GID=$(id -g)
56
-
57
- export JUPYTER_TOKEN={JUPYTER_TOKEN}
58
- export JUPYTER_PORT={JUPYTER_PORT}
59
- export TENSORBOARD_PORT={TENSORBOARD_PORT}
60
-
61
- export PATH_TO_RUNS={PATH_TO_RUNS}
62
- export TENSORBOARD_PATH={TENSORBOARD_PATH}
63
- """
64
-
65
- print()
66
- print("6. Potential file contents.")
67
- print("---------------------------")
68
-
69
- print("env/env.sh: \n")
70
- print("##################")
71
- print(env_script)
72
- print("##################")
73
-
74
- if write_to_data_sh:
75
- data_script += data_command
76
-
77
- print("env/data.sh:")
78
- print("##################")
79
- print(data_script)
80
- print("##################")
81
-
82
- print()
83
- write_to_files = input("Write to file [yn]? ") or "n"
84
- if write_to_files == "y":
85
- with open("env/env.sh", "w") as f:
86
- f.write(env_script.strip())
87
- with open("env/data.sh", "w") as f:
88
- f.write(data_script.strip())
89
-
90
- print()
91
- print("8. Finalize setup.")
92
- print("------------------")
93
- print("Run the following command to complete setup.")
94
- print("source env/env.sh")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/exp/train.py CHANGED
@@ -18,6 +18,8 @@ from tensorboardX import SummaryWriter
18
 
19
  import vampnet
20
  from vampnet.modules.transformer import VampNet
 
 
21
  from lac.model.lac import LAC
22
 
23
 
@@ -322,7 +324,10 @@ def train(
322
  n_batch = z.shape[0]
323
  r = rng.draw(n_batch)[:, 0].to(accel.device)
324
 
325
- z_mask, mask = vn.add_noise(z, r)
 
 
 
326
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
327
 
328
  dtype = torch.bfloat16 if accel.amp else None
@@ -331,14 +336,12 @@ def train(
331
  # for mask mode
332
  z_hat = vn.add_truth_to_logits(z, z_hat, mask)
333
 
334
- target = vn.embedding.flatten(
335
  z[:, vn.n_conditioning_codebooks :, :],
336
- n_codebooks=vn.n_predict_codebooks,
337
  )
338
 
339
- flat_mask = vn.embedding.flatten(
340
  mask[:, vn.n_conditioning_codebooks :, :],
341
- n_codebooks=vn.n_predict_codebooks,
342
  )
343
 
344
  if vn.noise_mode == "mask":
@@ -398,21 +401,22 @@ def train(
398
  n_batch = z.shape[0]
399
  r = rng.draw(n_batch)[:, 0].to(accel.device)
400
 
401
- z_mask, mask = vn.add_noise(z, r)
 
 
 
402
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
403
 
404
  z_hat = model(z_mask_latent, r)
405
  # for mask mode
406
  z_hat = vn.add_truth_to_logits(z, z_hat, mask)
407
 
408
- target = vn.embedding.flatten(
409
  z[:, vn.n_conditioning_codebooks :, :],
410
- n_codebooks=vn.n_predict_codebooks,
411
  )
412
 
413
- flat_mask = vn.embedding.flatten(
414
- mask[:, vn.n_conditioning_codebooks :, :],
415
- n_codebooks=vn.n_predict_codebooks,
416
  )
417
 
418
  output = {}
@@ -514,14 +518,12 @@ def train(
514
  def save_imputation(self, z: torch.Tensor):
515
  n_prefix = int(z.shape[-1] * 0.25)
516
  n_suffix = int(z.shape[-1] * 0.25)
517
- downsample_factor = None
518
 
519
  vn = accel.unwrap(model)
520
 
521
- z_mask, mask = vn.add_noise(
522
- z, r=0.0, n_prefix=n_prefix, n_suffix=n_suffix,
523
- downsample_factor=downsample_factor
524
- )
525
 
526
  imputed_noisy = vn.to_signal(z_mask, codec)
527
  imputed_true = vn.to_signal(z, codec)
@@ -574,9 +576,11 @@ def train(
574
 
575
  r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
576
 
577
- n_batch = z.shape[0]
578
 
579
- z_mask, mask = vn.add_noise(z, r)
 
 
 
580
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
581
 
582
  z_hat = model(z_mask_latent, r)
@@ -584,7 +588,7 @@ def train(
584
  z_hat = vn.add_truth_to_logits(z, z_hat, mask)
585
 
586
  z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
587
- z_pred = vn.embedding.unflatten(z_pred, n_codebooks=vn.n_predict_codebooks)
588
  z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
589
 
590
  generated = vn.to_signal(z_pred, codec)
 
18
 
19
  import vampnet
20
  from vampnet.modules.transformer import VampNet
21
+ from vampnet.util import codebook_unflatten, codebook_flatten
22
+ from vampnet import mask as pmask
23
  from lac.model.lac import LAC
24
 
25
 
 
324
  n_batch = z.shape[0]
325
  r = rng.draw(n_batch)[:, 0].to(accel.device)
326
 
327
+ mask = pmask.random(z, r)
328
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
329
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
330
+
331
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
332
 
333
  dtype = torch.bfloat16 if accel.amp else None
 
336
  # for mask mode
337
  z_hat = vn.add_truth_to_logits(z, z_hat, mask)
338
 
339
+ target = codebook_flatten(
340
  z[:, vn.n_conditioning_codebooks :, :],
 
341
  )
342
 
343
+ flat_mask = codebook_flatten(
344
  mask[:, vn.n_conditioning_codebooks :, :],
 
345
  )
346
 
347
  if vn.noise_mode == "mask":
 
401
  n_batch = z.shape[0]
402
  r = rng.draw(n_batch)[:, 0].to(accel.device)
403
 
404
+ mask = pmask.random(z, r)
405
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
406
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
407
+
408
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
409
 
410
  z_hat = model(z_mask_latent, r)
411
  # for mask mode
412
  z_hat = vn.add_truth_to_logits(z, z_hat, mask)
413
 
414
+ target = codebook_flatten(
415
  z[:, vn.n_conditioning_codebooks :, :],
 
416
  )
417
 
418
+ flat_mask = codebook_flatten(
419
+ mask[:, vn.n_conditioning_codebooks :, :]
 
420
  )
421
 
422
  output = {}
 
518
  def save_imputation(self, z: torch.Tensor):
519
  n_prefix = int(z.shape[-1] * 0.25)
520
  n_suffix = int(z.shape[-1] * 0.25)
 
521
 
522
  vn = accel.unwrap(model)
523
 
524
+ mask = pmask.inpaint(z, n_prefix, n_suffix)
525
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
526
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
 
527
 
528
  imputed_noisy = vn.to_signal(z_mask, codec)
529
  imputed_true = vn.to_signal(z, codec)
 
576
 
577
  r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
578
 
 
579
 
580
+ mask = pmask.random(z, r)
581
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
582
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
583
+
584
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
585
 
586
  z_hat = model(z_mask_latent, r)
 
588
  z_hat = vn.add_truth_to_logits(z, z_hat, mask)
589
 
590
  z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
591
+ z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
592
  z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
593
 
594
  generated = vn.to_signal(z_pred, codec)
scripts/utils/vamp_folder.py CHANGED
@@ -95,19 +95,6 @@ def opus(sig, interface, bitrate=128):
95
  )
96
  return sig
97
 
98
- def token_noise(ratio=1.0):
99
- def wrapper(sig, interface):
100
- z = interface.encode(sig)
101
- r = interface.coarse.invgamma(ratio).to(interface.device)
102
- print(f'adding noise with ratio {ratio}')
103
- z, mask = interface.coarse.add_noise(
104
- z,
105
- r,
106
- noise_mode="random"
107
- )
108
- return interface.to_signal(z)
109
- return wrapper
110
-
111
  def mask_ratio_1_step(ratio=1.0):
112
  def wrapper(sig, interface):
113
  r = interface.coarse.invgamma(ratio).to(interface.device)
 
95
  )
96
  return sig
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def mask_ratio_1_step(ratio=1.0):
99
  def wrapper(sig, interface):
100
  r = interface.coarse.invgamma(ratio).to(interface.device)
vampnet/interface.py CHANGED
@@ -9,6 +9,8 @@ import tqdm
9
 
10
  from .modules.transformer import VampNet
11
  from .beats import WaveBeat
 
 
12
  from lac.model.lac import LAC
13
 
14
 
@@ -20,14 +22,6 @@ def signal_concat(
20
  return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
21
 
22
 
23
- class SignalPrompt:
24
-
25
- def __init__(self, signal: AudioSignal):
26
- self.sig = signal
27
-
28
-
29
-
30
-
31
  class Interface(torch.nn.Module):
32
  def __init__(
33
  self,
@@ -100,10 +94,6 @@ class Interface(torch.nn.Module):
100
  def to_signal(self, z: torch.Tensor):
101
  return self.coarse.to_signal(z, self.codec)
102
 
103
- def autoencode(self, signal: AudioSignal):
104
- z = self.encode(signal)
105
- return self.to_signal(z)
106
-
107
  def preprocess(self, signal: AudioSignal):
108
  signal = (
109
  signal.clone()
@@ -249,182 +239,30 @@ class Interface(torch.nn.Module):
249
  fine_z = torch.cat(fine_z, dim=-1)
250
  return fine_z[:, :, :length].clone()
251
 
252
-
253
  def coarse_vamp(
254
  self,
255
- signal,
256
- prefix_dur_s: float = 0.0,
257
- suffix_dur_s: float = 0.0,
258
- num_vamps: int = 1,
259
- downsample_factor: int = None,
260
- stretch_factor: int = None,
261
- periodic_width: int = 1,
262
- periodic_dropout=0.0,
263
- periodic_width_dropout=0.0,
264
- intensity: float = 1.0,
265
- debug=False,
266
- swap_prefix_suffix=False,
267
- ext_mask=None,
268
- n_conditioning_codebooks=None,
269
- verbose=False,
270
  return_mask=False,
271
  **kwargs
272
  ):
273
- z = self.encode(signal)
274
-
275
  # coarse z
276
  cz = z[:, : self.coarse.n_codebooks, :].clone()
277
- c_seq_len = cz.shape[-1]
278
- n_prefix = self.s2t(prefix_dur_s)
279
- n_suffix = self.s2t(suffix_dur_s)
280
-
281
-
282
- # hmm, should be a better way to do this? think we just need a mask builder class
283
- add_random_periodic_offset = True
284
-
285
- if stretch_factor is not None and stretch_factor > 1:
286
- print(f"stretching by {stretch_factor}")
287
- assert stretch_factor >= 1, "stretch factor must be >= 1"
288
- cz = cz.repeat_interleave(stretch_factor, dim=-1)
289
-
290
- # the downsample factor is now relative to the stretched sequence
291
- assert downsample_factor is None or downsample_factor <= 2, "downsample_factor must be None when stretch_factor is not None"
292
-
293
- downsample_factor = stretch_factor
294
- add_random_periodic_offset = False
295
-
296
- assert n_prefix == 0 and n_suffix == 0, "prefix and suffix must be 0 when stretch_factor is not None"
297
- assert ext_mask is None, "ext_mask must be None when stretch_factor is not None"
298
-
299
- # trim cz to the original length
300
- cz = cz[:, :, :c_seq_len]
301
-
302
-
303
  assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
304
- assert n_prefix + n_suffix < c_seq_len, "prefix and suffix must be smaller than the chunk size"
305
-
306
- if swap_prefix_suffix:
307
- # swap the prefix and suffix
308
- assert n_prefix == n_suffix, "prefix and suffix must be the same size for now"
309
- cz[:, :, :n_prefix], cz[:, :, c_seq_len-n_suffix:] = cz[:, :, c_seq_len-n_suffix:], cz[:, :, :n_prefix].clone()
310
-
311
- # we'll keep the final codes sequence here
312
- c_vamp = {
313
- 'prefix': [cz[:, :, :n_prefix].clone()],
314
- 'suffix': [cz[:, :, c_seq_len-n_suffix:].clone()]
315
- }
316
-
317
- _cz = cz.clone()
318
- cz_mask = None
319
- range_fn = tqdm.trange if verbose else range
320
- for _ in range_fn(num_vamps):
321
- # add noise
322
- cz_masked, cz_mask = self.coarse.add_noise(
323
- _cz, r=1.0-intensity,
324
- n_prefix=n_prefix,
325
- n_suffix=n_suffix,
326
- downsample_factor=downsample_factor,
327
- periodic_width=periodic_width,
328
- periodic_dropout=periodic_dropout,
329
- add_random_periodic_offset=add_random_periodic_offset,
330
- periodic_width_dropout=periodic_width_dropout,
331
- mask=cz_mask,
332
- ext_mask=ext_mask,
333
- n_conditioning_codebooks=n_conditioning_codebooks
334
- )
335
- if debug:
336
- print("tokens to infer")
337
- self.to_signal(cz_masked).cpu().widget()
338
-
339
- # sample!
340
- if debug:
341
- print(f"mask: {cz_mask[:,0,:]}")
342
- print(f"z: {_cz[:,0,:]}")
343
- cz_sampled = self.coarse.sample(
344
- codec=self.codec,
345
- time_steps=_cz.shape[-1],
346
- start_tokens=_cz,
347
- mask=cz_mask,
348
- return_signal=False,
349
- **kwargs
350
- )
351
-
352
- if debug:
353
- print("tokens sampled")
354
- self.to_signal(cz_sampled).cpu().widget()
355
-
356
- # the z that was generated
357
- cz_generated = cz_sampled[:, :, n_prefix:c_seq_len-n_suffix].clone()
358
- n_generated = cz_generated.shape[-1]
359
-
360
- # create the new prefix and suffix
361
- # we'll make sure that the number of prefix and suffix
362
- # tokens is the same as the original
363
- # but we do want to advance the sequence as much as we can
364
- if n_prefix > 0 and n_suffix > 0:
365
- # we have both prefix and suffix, so we'll split the generated
366
- # codes in two halves
367
- prefix_start_idx = n_generated // 2
368
- prefix_stop_idx = prefix_start_idx + n_prefix
369
- assert prefix_start_idx >= 0, "internal error"
370
-
371
- suffix_start_idx = n_prefix + n_generated // 2
372
- suffix_stop_idx = suffix_start_idx + n_suffix
373
- assert suffix_stop_idx <= cz_sampled.shape[-1], "internal error"
374
-
375
- cz_new_prefix = cz_sampled[:, :, prefix_start_idx:prefix_stop_idx].clone()
376
- cz_new_suffix = cz_sampled[:, :, suffix_start_idx:suffix_stop_idx].clone()
377
-
378
- c_vamp['prefix'].append(cz_generated[:,:,:n_generated//2])
379
- c_vamp['suffix'].insert(0, cz_generated[:,:,n_generated//2:])
380
-
381
- elif n_prefix > 0:
382
- # we only have a prefix
383
- prefix_start_idx = n_generated
384
- prefix_stop_idx = prefix_start_idx + n_prefix
385
-
386
- cz_new_prefix = cz_sampled[:, :, prefix_start_idx:prefix_stop_idx].clone()
387
- cz_new_suffix = _cz[:, :, :0].clone()
388
-
389
-
390
- c_vamp['prefix'].append(cz_generated)
391
 
392
- elif n_suffix > 0:
393
- # we only have a suffix, so everything starting at 0 is generated
394
- suffix_stop_idx = max(n_generated, n_suffix)
395
- suffix_start_idx = suffix_stop_idx - n_suffix
396
 
397
- cz_new_prefix = _cz[:, :, :0].clone()
398
- cz_new_suffix = cz_sampled[:, :, suffix_start_idx:suffix_stop_idx].clone()
399
 
400
- c_vamp['suffix'].insert(0, cz_generated)
401
-
402
- else:
403
- # we have no prefix or suffix, so we'll just use the generated
404
- # codes as the new prefix and suffix
405
- cz_new_prefix = cz_generated.clone()
406
- cz_new_suffix = _cz[:, :, :0].clone()
407
-
408
- c_vamp['prefix'].append(cz_generated)
409
-
410
-
411
- n_to_insert = c_seq_len - (cz_new_prefix.shape[-1] + cz_new_suffix.shape[-1])
412
- to_insert = torch.zeros(cz_new_prefix.shape[0], cz_new_prefix.shape[1], n_to_insert).long().to(self.device)
413
- _cz = torch.cat([cz_new_prefix, to_insert, cz_new_suffix], dim=-1)
414
-
415
- to_insert_mask = torch.zeros_like(_cz).long().to(self.device)
416
- to_insert_mask[:, :, cz_new_prefix.shape[-1]:cz_new_prefix.shape[-1]+n_to_insert] = 1
417
- cz_mask = (cz_mask + to_insert_mask).bool().long()
418
-
419
-
420
- if debug:
421
- print("tokens to infer next round (area to insert in the middle)")
422
- self.to_signal(_cz).cpu().widget()
423
-
424
-
425
- prefix_codes = torch.cat(c_vamp['prefix'], dim=-1)
426
- suffix_codes = torch.cat(c_vamp['suffix'], dim=-1)
427
- c_vamp = torch.cat([prefix_codes, suffix_codes], dim=-1)
428
 
429
  # replace the mask token in cz_masked with random tokens
430
  # so that we can decode it
@@ -433,132 +271,61 @@ class Interface(torch.nn.Module):
433
 
434
  return c_vamp
435
 
436
- # create a variation of an audio signal
437
- def variation(
438
- self,
439
- signal: AudioSignal,
440
- verbose: bool = False,
441
- beat_mask: bool = False,
442
- beat_mask_kwargs: dict = {},
443
- **kwargs
444
- ):
445
- signal = signal.clone()
446
 
447
- # autoencode first, so the samples get rounded up to the nearest tokens
448
- signal = self.autoencode(signal).cpu()
449
 
450
- # pad the signal to the nearest chunk size
451
- req_len = (
452
- math.ceil(signal.duration / self.coarse.chunk_size_s)
453
- * self.coarse.chunk_size_s
454
- )
455
- # eventually we DO want overlap, but we want overlap-replace not
456
- # overlap-add
457
- overlap_hop_ratio = 1.0
458
- hop_duration = self.coarse.chunk_size_s * overlap_hop_ratio
459
- original_length = signal.length
460
-
461
- signal.zero_pad_to(req_len)
462
-
463
- # window the signal
464
- signal = signal.collect_windows(
465
- window_duration=self.coarse.chunk_size_s,
466
- hop_duration=hop_duration,
467
- )
468
 
469
- # output = []
470
- range_fn = range if not verbose else tqdm.trange
471
- for i in range_fn(signal.batch_size):
472
- sig = AudioSignal(
473
- signal.samples[i,...], signal.sample_rate
474
- )
475
- sig.to(self.device)
476
-
477
- if beat_mask:
478
- ext_mask = self.make_beat_mask(sig, **beat_mask_kwargs)
479
- else:
480
- ext_mask = None
481
-
482
- out_z = self.coarse_vamp(
483
- sig,
484
- num_vamps=1,
485
- swap_prefix_suffix=False,
486
- ext_mask=ext_mask,
487
- verbose=verbose,
488
- **kwargs
489
- )
490
- if self.c2f is not None:
491
- out_z = self.coarse_to_fine(out_z)
492
- out_sig = self.to_signal(out_z).cpu()
493
 
494
- signal.samples[i] = out_sig.samples
495
 
496
- output = signal.overlap_and_add(hop_duration)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
- output.truncate_samples(original_length)
499
- return output
 
 
 
 
 
500
 
501
- # create a loop of a single region with variations
502
- # TODO: this would work nicer if we could trim at the beat
503
- # otherwise the model has to awkwardly fill up space that won't match
504
- # the beat unless the signal is exactly the right length
505
- def loop(
506
- self,
507
- signal: AudioSignal,
508
- prefix_dur_s: float = 0.0,
509
- suffix_dur_s: float = 0.0,
510
- num_loops: int = 4,
511
- # overlap_hop_ratio: float = 1.0, # TODO: should this be fixed to 1.0? or should we overlap and replace instead of overlap add
512
- verbose: bool = False,
513
- return_mask: bool = False,
514
- **kwargs,
515
- ):
516
- assert prefix_dur_s >= 0.0, "prefix duration must be >= 0"
517
- assert suffix_dur_s >= 0.0, "suffix duration must be >= 0"
518
- signal = self.preprocess(signal)
519
-
520
- suffix_len_samples = int(suffix_dur_s * signal.sample_rate)
521
- prefix_len_tokens = self.s2t(prefix_dur_s)
522
- suffix_len_tokens = self.s2t(suffix_dur_s)
523
-
524
- loops = [
525
- # add everything but the suffix a the beggining
526
- self.encode(signal.clone().trim(before=0, after=suffix_len_samples))
527
- ]
528
- range_fn = range if not verbose else tqdm.trange
529
- for i in range_fn(num_loops):
530
- is_flipped = i % 2 == 0
531
- vamped = self.coarse_vamp(
532
- signal,
533
- prefix_dur_s=prefix_dur_s,
534
- suffix_dur_s=suffix_dur_s,
535
- swap_prefix_suffix=is_flipped,
536
- return_mask=return_mask,
537
- **kwargs
538
- )
539
- if return_mask:
540
- vamped, mask = vamped
541
-
542
- # if we're flipped, we trim the prefix off of the end
543
- # otherwise we trim the suffix off of the end
544
- trim_len = prefix_len_tokens if is_flipped else suffix_len_tokens
545
- vamped = vamped[:, :, :vamped.shape[-1]-trim_len]
546
-
547
- loops.append(vamped)
548
-
549
- if is_flipped:
550
- loops.append(
551
- # add everything but the prefix at the end
552
- self.encode(signal.clone())
553
- )
554
-
555
- if self.c2f is not None:
556
- loops = [self.coarse_to_fine(l) for l in loops]
557
 
558
- loops = [self.to_signal(l) for l in loops]
 
559
 
560
- if return_mask:
561
- return signal_concat(loops), self.to_signal(mask)
562
-
563
- return signal_concat(loops)
564
 
 
 
 
 
9
 
10
  from .modules.transformer import VampNet
11
  from .beats import WaveBeat
12
+ from .mask import *
13
+
14
  from lac.model.lac import LAC
15
 
16
 
 
22
  return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
23
 
24
 
 
 
 
 
 
 
 
 
25
  class Interface(torch.nn.Module):
26
  def __init__(
27
  self,
 
94
  def to_signal(self, z: torch.Tensor):
95
  return self.coarse.to_signal(z, self.codec)
96
 
 
 
 
 
97
  def preprocess(self, signal: AudioSignal):
98
  signal = (
99
  signal.clone()
 
239
  fine_z = torch.cat(fine_z, dim=-1)
240
  return fine_z[:, :, :length].clone()
241
 
 
242
  def coarse_vamp(
243
  self,
244
+ z,
245
+ mask,
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  return_mask=False,
247
  **kwargs
248
  ):
 
 
249
  # coarse z
250
  cz = z[:, : self.coarse.n_codebooks, :].clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
+ mask = mask[:, : self.coarse.n_codebooks, :]
 
 
 
254
 
255
+ cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token)
256
+ cz_masked = cz_masked[:, : self.coarse.n_codebooks, :]
257
 
258
+ c_vamp = self.coarse.sample(
259
+ codec=self.codec,
260
+ time_steps=cz.shape[-1],
261
+ start_tokens=cz,
262
+ mask=mask,
263
+ return_signal=False,
264
+ **kwargs
265
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  # replace the mask token in cz_masked with random tokens
268
  # so that we can decode it
 
271
 
272
  return c_vamp
273
 
 
 
 
 
 
 
 
 
 
 
274
 
275
+ if __name__ == "__main__":
276
+ import audiotools as at
277
 
278
+ interface = Interface(
279
+ coarse_ckpt="./models/spotdl/coarse.pth",
280
+ coarse2fine_ckpt="./models/spotdl/c2f.pth",
281
+ codec_ckpt="./models/spotdl/codec.pth",
282
+ device="cpu"
283
+ )
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
+ sig = at.AudioSignal('cali.mp3', duration=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
+ z = interface.encode(sig)
288
 
289
+ mask = linear_random(z, 0.8)
290
+ print(mask)
291
+ mask = mask_and(
292
+ mask, inpaint(
293
+ z,
294
+ interface.s2t(3),
295
+ interface.s2t(3)
296
+ )
297
+ )
298
+ print(mask)
299
+ mask = mask_and(
300
+ mask, periodic_mask(
301
+ z,
302
+ 7,
303
+ 1,
304
+ random_roll=True
305
+ )
306
+ )
307
+ mask = dropout(mask, 0.0)
308
+ mask = codebook_unmask(mask, 0)
309
+
310
 
311
+ zv, mask_z = interface.coarse_vamp(
312
+ z,
313
+ mask=mask,
314
+ sampling_steps=1,
315
+ temperature=(0.8,1),
316
+ return_mask=True
317
+ )
318
 
319
+ use_coarse2fine = False
320
+ if use_coarse2fine:
321
+ zv = interface.coarse_to_fine(zv)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
+ print(mask_z)
324
+ mask = interface.to_signal(mask_z).cpu()
325
 
326
+ sig = interface.to_signal(zv).cpu()
327
+ print("done")
 
 
328
 
329
+ sig.write("output.wav")
330
+ mask.write("mask.wav")
331
+
vampnet/mask.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from .util import scalar_to_batch_tensor
6
+
7
+ def _gamma(r):
8
+ return (r * torch.pi / 2).cos()
9
+
10
+ def _invgamma(y):
11
+ if not torch.is_tensor(y):
12
+ y = torch.tensor(y)[None]
13
+ return 2 * y.acos() / torch.pi
14
+
15
+ def full_mask(x: torch.Tensor):
16
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
17
+ return torch.ones_like(x).long()
18
+
19
+ def empty_mask(x: torch.Tensor):
20
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
21
+ return torch.zeros_like(x).long()
22
+
23
+ def apply_mask(
24
+ x: torch.Tensor,
25
+ mask: torch.Tensor,
26
+ mask_token: int
27
+ ):
28
+ assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
29
+ assert mask.shape == x.shape, "mask must be same shape as x"
30
+ assert mask.dtype == torch.long, "mask must be long dtype"
31
+ assert ~torch.any(mask > 1), "mask must be binary"
32
+ assert ~torch.any(mask < 0), "mask must be binary"
33
+
34
+ fill_x = torch.full_like(x, mask_token)
35
+ x = x * (1 - mask) + fill_x * mask
36
+
37
+ return x, mask
38
+
39
+ def random(
40
+ x: torch.Tensor,
41
+ r: torch.Tensor
42
+ ):
43
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
44
+ if not isinstance(r, torch.Tensor):
45
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
46
+
47
+ r = _gamma(r)[:, None, None]
48
+ probs = torch.ones_like(x) * r
49
+
50
+ mask = torch.bernoulli(probs)
51
+ mask = mask.round().long()
52
+
53
+ return mask
54
+
55
+ def linear_random(
56
+ x: torch.Tensor,
57
+ r: torch.Tensor,
58
+ ):
59
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
60
+ if not isinstance(r, torch.Tensor):
61
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
62
+
63
+ probs = torch.ones_like(x).to(x.device).float()
64
+ # expand to batch and codebook dims
65
+ probs = probs.expand(x.shape[0], x.shape[1], -1)
66
+ probs = probs * r
67
+
68
+ mask = torch.bernoulli(probs)
69
+ mask = mask.round().long()
70
+
71
+ return mask
72
+
73
+ def inpaint(x: torch.Tensor,
74
+ n_prefix,
75
+ n_suffix,
76
+ ):
77
+ assert n_prefix is not None
78
+ assert n_suffix is not None
79
+
80
+ mask = full_mask(x)
81
+
82
+ # if we have a prefix or suffix, set their mask prob to 0
83
+ if n_prefix > 0:
84
+ if not isinstance(n_prefix, torch.Tensor):
85
+ n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
86
+ for i, n in enumerate(n_prefix):
87
+ if n > 0:
88
+ mask[i, :, :n] = 0.0
89
+ if n_suffix > 0:
90
+ if not isinstance(n_suffix, torch.Tensor):
91
+ n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
92
+ for i, n in enumerate(n_suffix):
93
+ if n > 0:
94
+ mask[i, :, -n:] = 0.0
95
+
96
+
97
+ return mask
98
+
99
+ def periodic_mask(x: torch.Tensor,
100
+ period: int, width: int = 1,
101
+ random_roll=False,
102
+ ):
103
+ mask = full_mask(x)
104
+ if period == 0:
105
+ return mask
106
+
107
+ if not isinstance(period, torch.Tensor):
108
+ period = scalar_to_batch_tensor(period, x.shape[0])
109
+ for i, factor in enumerate(period):
110
+ if factor == 0:
111
+ continue
112
+ for j in range(mask.shape[-1]):
113
+ if j % factor == 0:
114
+ # figure out how wide the mask should be
115
+ j_start = max(0, j - width // 2 )
116
+ j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1
117
+ # flip a coin for each position in the mask
118
+ j_mask = torch.bernoulli(torch.ones(j_end - j_start))
119
+ assert torch.all(j_mask == 1)
120
+ j_fill = torch.ones_like(j_mask) * (1 - j_mask)
121
+ assert torch.all(j_fill == 0)
122
+ # fill
123
+ mask[i, :, j_start:j_end] = j_fill
124
+ if random_roll:
125
+ # add a random offset to the mask
126
+ offset = torch.randint(0, period[0], (1,))
127
+ mask = torch.roll(mask, offset.item(), dims=-1)
128
+
129
+ return mask
130
+
131
+ def codebook_unmask(
132
+ mask: torch.Tensor,
133
+ n_conditioning_codebooks: int
134
+ ):
135
+ if n_conditioning_codebooks == None:
136
+ return mask
137
+ # if we have any conditioning codebooks, set their mask to 0
138
+ mask = mask.clone()
139
+ mask[:, :n_conditioning_codebooks, :] = 0
140
+ return mask
141
+
142
+ def mask_and(
143
+ mask1: torch.Tensor,
144
+ mask2: torch.Tensor
145
+ ):
146
+ assert mask1.shape == mask2.shape, "masks must be same shape"
147
+ return torch.min(mask1, mask2)
148
+
149
+ def dropout(
150
+ mask: torch.Tensor,
151
+ p: float,
152
+ ):
153
+ return torch.bernoulli((torch.ones_like(mask) * (1-p)).float()).long() * mask
154
+
155
+ def mask_or(
156
+ mask1: torch.Tensor,
157
+ mask2: torch.Tensor
158
+ ):
159
+ assert mask1.shape == mask2.shape, "masks must be same shape"
160
+ assert mask1.max() <= 1, "mask1 must be binary"
161
+ assert mask2.max() <= 1, "mask2 must be binary"
162
+ assert mask1.min() >= 0, "mask1 must be binary"
163
+ assert mask2.min() >= 0, "mask2 must be binary"
164
+ return (mask1 + mask2).clamp(0, 1)
165
+
166
+ def time_stretch_mask(
167
+ x: torch.Tensor,
168
+ stretch_factor: int,
169
+ mask_token: int
170
+ ):
171
+ assert stretch_factor >= 1, "stretch factor must be >= 1"
172
+ c_seq_len = x.shape[-1]
173
+ x = x.repeat_interleave(stretch_factor, dim=-1)
174
+
175
+ # trim cz to the original length
176
+ x = x[:, :, :c_seq_len]
177
+
178
+ mask = periodic_mask(x, stretch_factor, width=1)
179
+ return apply_mask(x, mask, mask_token)
180
+
181
+
182
+ if __name__ == "__main__":
183
+ torch.set_printoptions(threshold=10000)
184
+
vampnet/modules/base.py DELETED
@@ -1,412 +0,0 @@
1
- import math
2
- from typing import Optional
3
- from typing import Tuple
4
- from typing import Union
5
-
6
- import audiotools as at
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- from einops import rearrange
11
- from tqdm import tqdm
12
-
13
- from ..util import scalar_to_batch_tensor
14
-
15
-
16
- def log(t, eps=1e-20):
17
- return torch.log(t + eps)
18
-
19
-
20
- def gumbel_noise(t):
21
- noise = torch.zeros_like(t).uniform_(0, 1)
22
- return -log(-log(noise))
23
-
24
-
25
- def gumbel_sample(t, temperature=1.0, dim=-1):
26
- return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
27
-
28
-
29
- class VampBase(at.ml.BaseModel):
30
- def forward(self, x: torch.Tensor, r: torch.Tensor):
31
- raise NotImplementedError
32
-
33
- def add_noise(
34
- self,
35
- x: torch.Tensor,
36
- r: torch.Tensor,
37
- random_x: Optional[torch.Tensor] = None,
38
- mask: Optional[torch.Tensor] = None,
39
- ext_mask: Optional[torch.Tensor] = None,
40
- n_prefix: Optional[torch.Tensor] = None,
41
- n_suffix: Optional[torch.Tensor] = None,
42
- downsample_factor: Optional[int] = None,
43
- periodic_width: int = 1,
44
- periodic_width_dropout: float = 0.0,
45
- periodic_dropout: float = 0.0,
46
- add_random_periodic_offset: bool = False, # TODO: should be always false lol this is hacky
47
- n_conditioning_codebooks: Optional[int] = None,
48
- noise_mode: str = None,
49
- ) -> Tuple[torch.Tensor, torch.Tensor]:
50
- assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
51
-
52
- if mask is None:
53
- if not isinstance(r, torch.Tensor):
54
- r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
55
- r = self.gamma(r)[:, None, None]
56
- probs = torch.ones_like(x) * r
57
-
58
- # if we have a prefix or suffix, set their mask prob to 0
59
- if n_prefix is not None:
60
- if not isinstance(n_prefix, torch.Tensor):
61
- n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
62
- for i, n in enumerate(n_prefix):
63
- if n > 0:
64
- probs[i, :, :n] = 0.0
65
- if n_suffix is not None:
66
- if not isinstance(n_suffix, torch.Tensor):
67
- n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
68
- for i, n in enumerate(n_suffix):
69
- if n > 0:
70
- probs[i, :, -n:] = 0.0
71
-
72
- # if we have a downsample factor, set the mask prob to 0
73
- if downsample_factor is not None and downsample_factor > 0:
74
- if not isinstance(downsample_factor, torch.Tensor):
75
- downsample_factor = scalar_to_batch_tensor(downsample_factor, x.shape[0])
76
- for i, factor in enumerate(downsample_factor):
77
- if factor == 0:
78
- continue
79
- for j in range(probs.shape[-1]):
80
- if j % factor == 0:
81
- # if we have periodic dropout
82
- if periodic_dropout > 0:
83
- # flip a coin
84
- if torch.bernoulli(torch.tensor(periodic_dropout)).item() == 1:
85
- # if we win, skip
86
- continue
87
-
88
- # figure out how wide the mask should be
89
- j_start = max(0, j - periodic_width // 2)
90
- j_end = min(probs.shape[-1] - 1, j + periodic_width // 2) + 1
91
- # flip a coin for each position in the mask
92
- j_mask = torch.bernoulli(torch.ones(j_end - j_start) * periodic_width_dropout)
93
- j_fill = torch.ones_like(j_mask) * (1 - j_mask)
94
- # fill
95
- probs[i, :, j_start:j_end] = 1 - j_fill
96
- if add_random_periodic_offset:
97
- # add a random offset to the mask
98
- offset = torch.randint(0, downsample_factor[0], (1,))
99
- probs = torch.roll(probs, offset.item(), dims=-1)
100
-
101
- mask = torch.bernoulli(probs)
102
- mask = mask.round().long()
103
-
104
- # if we have any conditioning codebooks, set their mask to 0
105
- n_conditioning_codebooks = n_conditioning_codebooks or self.n_conditioning_codebooks
106
- mask[:, :n_conditioning_codebooks, :] = 0
107
- else:
108
- assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
109
- assert mask.shape == x.shape, "mask must be same shape as x"
110
-
111
- if random_x is None:
112
- random_x = torch.randint_like(x, 0, self.vocab_size)
113
-
114
- noise_mode = noise_mode if noise_mode is not None else self.noise_mode
115
- if noise_mode == "mask":
116
- random_x = torch.full_like(x, self.mask_token)
117
- elif noise_mode == "random":
118
- if random_x is None:
119
- random_x = torch.randint_like(x, 0, self.vocab_size)
120
- else:
121
- raise ValueError(f"invalid noise mode {noise_mode}")
122
-
123
- # add the external mask if we were given one
124
- if ext_mask is not None:
125
- assert ext_mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
126
- mask = (mask * ext_mask).bool().long()
127
-
128
- x = x * (1 - mask) + random_x * mask
129
- return x, mask
130
-
131
- def add_truth_to_logits(
132
- self,
133
- z_true,
134
- z_hat,
135
- mask,
136
- ):
137
- if self.noise_mode == "mask":
138
- z_true = z_true[:, self.n_conditioning_codebooks :, :]
139
- mask = mask[:, self.n_conditioning_codebooks :, :]
140
-
141
- truth = F.one_hot(z_true, self.vocab_size)
142
- mask = mask[:, :, :, None].expand(-1, -1, -1, self.vocab_size)
143
- z_hat = rearrange(
144
- z_hat,
145
- "b p (t c) -> b c t p",
146
- c=self.n_codebooks - self.n_conditioning_codebooks,
147
- )
148
-
149
- z_hat = z_hat * mask + truth * (1 - mask)
150
-
151
- z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
152
- else:
153
- raise ValueError(f"invalid noise mode for adding truth to logits {self.noise_mode}")
154
-
155
- return z_hat
156
-
157
- def gamma(self, r):
158
- return (r * torch.pi / 2).cos()
159
-
160
- def invgamma(self, y):
161
- if not torch.is_tensor(y):
162
- y = torch.tensor(y)[None]
163
- return 2 * y.acos() / torch.pi
164
-
165
- def r_embed(self, r, max_positions=10000):
166
- """ """
167
- assert hasattr(self, "r_cond_dim"), "must set r_cond_dim before calling r_embed"
168
-
169
- if self.r_cond_dim > 0:
170
- dtype = r.dtype
171
-
172
- r = self.gamma(r) * max_positions
173
- half_dim = self.r_cond_dim // 2
174
-
175
- emb = math.log(max_positions) / (half_dim - 1)
176
- emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
177
-
178
- emb = r[:, None] * emb[None, :]
179
- emb = torch.cat([emb.sin(), emb.cos()], dim=1)
180
-
181
- if self.r_cond_dim % 2 == 1: # zero pad
182
- emb = nn.functional.pad(emb, (0, 1), mode="constant")
183
-
184
- return emb.to(dtype)
185
- else:
186
- return r
187
-
188
- @torch.no_grad()
189
- def to_signal(self, z, codec):
190
- """
191
- convert a sequence of latents to a signal.
192
- """
193
- if z.ndim == 2:
194
- z = self.embedding.unflatten(z)
195
- assert z.ndim == 3
196
-
197
- signal = at.AudioSignal(
198
- codec.decode(
199
- codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
200
- )["audio"],
201
- codec.sample_rate,
202
- )
203
-
204
- # find where the mask token is and replace it with silence in the audio
205
- for tstep in range(z.shape[-1]):
206
- if torch.any(z[:, :, tstep] == self.mask_token):
207
- sample_idx_0 = tstep * codec.hop_length
208
- sample_idx_1 = sample_idx_0 + codec.hop_length
209
- signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
210
-
211
- return signal
212
-
213
- @torch.no_grad()
214
- def sample(
215
- self,
216
- codec,
217
- time_steps: int = 300,
218
- sampling_steps: int = 36,
219
- start_tokens: Optional[torch.Tensor] = None,
220
- mask: Optional[torch.Tensor] = None,
221
- temperature: Union[float, Tuple[float, float]] = 0.8,
222
- top_k: int = None,
223
- sample: str = "gumbel",
224
- typical_filtering=True,
225
- typical_mass=0.2,
226
- typical_min_tokens=1,
227
- return_signal=True,
228
- ):
229
- if isinstance(temperature, float):
230
- temperature = torch.tensor(temperature).repeat(sampling_steps)
231
- elif isinstance(temperature, tuple):
232
- assert len(temperature) == 2
233
- l, h = temperature
234
- temperature = torch.linspace(l, h, sampling_steps)
235
- else:
236
- raise TypeError(f"invalid type for temperature")
237
-
238
- def flatten(codes):
239
- return rearrange(codes, "b c t -> b (t c)")
240
-
241
- def unflatten(codes, c):
242
- return rearrange(codes, "b (t c) -> b c t", c=c)
243
-
244
- z = start_tokens
245
-
246
- if z is None:
247
- z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
248
- self.device
249
- )
250
-
251
- if mask is None:
252
- mask = torch.ones_like(z).to(self.device).int()
253
- mask[:, : self.n_conditioning_codebooks, :] = 0.0
254
- if mask.ndim == 2:
255
- mask = mask[:, None, :].repeat(1, z.shape[1], 1)
256
-
257
- # figure out which timesteps we're keeping
258
- keep_mask = 1 - mask
259
-
260
- # any conditioning codebook levels need to be in the keep mask
261
- # if self.n_conditioning_codebooks > 0:
262
- # cond_mask = torch.ones(z.shape[0], self.n_conditioning_codebooks, z.shape[-1]).to(z.device)
263
- # keep_mask = torch.cat([cond_mask, keep_mask], dim=1)
264
-
265
- # flatten
266
- keep_mask = flatten(keep_mask)
267
-
268
- # our r steps
269
- r_steps = torch.linspace(0, 1, sampling_steps + 1)[1:].to(self.device)
270
-
271
- # how many tokens did we keep on init?
272
- num_kept_on_init = keep_mask.sum()
273
-
274
- # how many codebooks are we inferring vs conditioning on?
275
- n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
276
-
277
- for i in range(sampling_steps):
278
- # our current temperature
279
- tmpt = temperature[i]
280
-
281
- # our current schedule step
282
- r = r_steps[i : i + 1]
283
-
284
- with torch.inference_mode():
285
- # mask our z
286
- keep_mask_unflat = unflatten(keep_mask, c=self.n_codebooks)
287
- z_masked = z.masked_fill(~keep_mask_unflat.bool(), self.mask_token)
288
-
289
- # get latents
290
- latents = self.embedding.from_codes(z_masked, codec)
291
-
292
- # infer from latents
293
- logits = self.forward(latents, r)
294
- logits = logits.permute(0, 2, 1) # b, seq, prob
295
-
296
- # the schedule determines how many samples to keep
297
- num_tokens_to_infer = (z.shape[-1] * z.shape[-2]) - num_kept_on_init
298
- num_to_keep = num_kept_on_init + int(
299
- num_tokens_to_infer * (self.gamma(1 - r))
300
- )
301
-
302
- # figure out which logits we wanna keep
303
- if num_to_keep > 0:
304
- probs = logits.softmax(dim=-1)
305
-
306
- # do mod self.vocab_size to make sure we don't sample from the mask token
307
- # in case the mask token was in the og z
308
- keep_probs = F.one_hot(z%self.vocab_size, self.vocab_size)[:, :, :]
309
-
310
- probs = rearrange(
311
- probs, "b (t c) p -> b c t p", c=n_infer_codebooks
312
- )
313
- probs = torch.cat(
314
- [keep_probs[:, : self.n_conditioning_codebooks, ...], probs],
315
- dim=1,
316
- )
317
-
318
- keep_probs = rearrange(
319
- keep_probs, "b c t p -> b (t c) p", c=self.n_codebooks
320
- )
321
- probs = rearrange(probs, "b c t p -> b (t c) p", c=self.n_codebooks)
322
-
323
- keep_prob_mask = keep_mask.unsqueeze(-1).repeat(
324
- 1, 1, self.vocab_size
325
- )
326
- probs = (keep_prob_mask.long() * keep_probs) + (
327
- 1 - keep_prob_mask.long()
328
- ) * probs
329
-
330
- highest_probs = probs.max(dim=-1, keepdim=False)[0]
331
- v, _ = highest_probs.topk(num_to_keep, dim=-1)
332
-
333
- keep_mask = torch.ones_like(keep_mask).bool().clone()
334
- keep_mask[highest_probs < v[..., [-1]]] = 0
335
-
336
- logits = torch.log(probs)
337
-
338
- z_inferred = self.sample_from_logits(
339
- logits=logits,
340
- top_k=top_k,
341
- temperature=tmpt,
342
- sample=sample,
343
- typical_filtering=typical_filtering,
344
- typical_mass=typical_mass,
345
- typical_min_tokens=typical_min_tokens,
346
- )
347
-
348
- z = rearrange(z_inferred, "b (t c) -> b c t", c=self.n_codebooks)
349
-
350
- # add conditioning codebooks back
351
- # z = torch.cat([z[:, :self.n_conditioning_codebooks, :], z_inferred], dim=1)
352
-
353
- if return_signal:
354
- return self.to_signal(z, codec)
355
- else:
356
- return z
357
-
358
- def sample_from_logits(
359
- self,
360
- logits,
361
- top_k: int = None,
362
- temperature: float = 1.0,
363
- sample: str = "multinomial",
364
- typical_filtering=False,
365
- typical_mass=0.2,
366
- typical_min_tokens=1,
367
- ):
368
- # add temperature
369
- logits = logits / temperature
370
-
371
- # add topk
372
- if top_k is not None:
373
- v, topk_idx = logits.topk(top_k)
374
- logits[logits < v[..., [-1]]] = -float("inf")
375
-
376
- if typical_filtering:
377
- assert top_k is None
378
- nb, nt, _ = logits.shape
379
- x_flat = rearrange(logits, "b t l -> (b t ) l")
380
- x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
381
- x_flat_norm_p = torch.exp(x_flat_norm)
382
- entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
383
-
384
- c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
385
- c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
386
- x_flat_cumsum = (
387
- x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
388
- )
389
-
390
- last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
391
- sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
392
- 1, last_ind.view(-1, 1)
393
- )
394
- if typical_min_tokens > 1:
395
- sorted_indices_to_remove[..., :typical_min_tokens] = 0
396
- indices_to_remove = sorted_indices_to_remove.scatter(
397
- 1, x_flat_indices, sorted_indices_to_remove
398
- )
399
- x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
400
- logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
401
-
402
- if sample == "multinomial":
403
- probs = torch.softmax(logits, dim=-1)
404
- inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
405
- elif sample == "argmax":
406
- inferred = torch.softmax(logits, dim=-1).argmax(dim=-1)
407
- elif sample == "gumbel":
408
- inferred = gumbel_sample(logits, dim=-1)
409
- else:
410
- raise ValueError(f"invalid sampling method: {sample}")
411
-
412
- return inferred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vampnet/modules/layers.py CHANGED
@@ -162,20 +162,3 @@ class CodebookEmbedding(nn.Module):
162
  x = self.out_proj(latents)
163
  return x
164
 
165
- def flatten(self, tokens: torch.Tensor, n_codebooks: int = None):
166
- """
167
- flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
168
- """
169
- n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
170
- return rearrange(tokens, "b c t -> b (t c)", c=n_c)
171
-
172
- def unflatten(self, flat_tokens: torch.Tensor, n_codebooks: int = None):
173
- """
174
- unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
175
- """
176
- nb, nt = flat_tokens.shape
177
-
178
- n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
179
- tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c)
180
-
181
- return tokens
 
162
  x = self.out_proj(latents)
163
  return x
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vampnet/modules/transformer.py CHANGED
@@ -1,4 +1,5 @@
1
  import math
 
2
 
3
  import numpy as np
4
  import torch
@@ -6,16 +7,30 @@ import torch.nn as nn
6
  import torch.nn.functional as F
7
  from einops import rearrange
8
  import loralib as lora
 
9
 
10
- from .base import VampBase
11
  from .activations import get_activation
12
  from .layers import CodebookEmbedding
13
  from .layers import FiLM
14
  from .layers import SequentialWithFiLM
15
  from .layers import WNConv1d
 
 
16
 
17
  LORA_R = 8
18
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  class RMSNorm(nn.Module):
21
  def __init__(self, hidden_size: int, eps=1e-6):
@@ -435,7 +450,7 @@ class TransformerStack(nn.Module):
435
  return self.norm(x) if self.norm is not None else x
436
 
437
 
438
- class VampNet(VampBase):
439
  def __init__(
440
  self,
441
  n_heads: int = 20,
@@ -519,6 +534,270 @@ class VampNet(VampBase):
519
  out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
520
 
521
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
 
524
  if __name__ == "__main__":
@@ -538,8 +817,7 @@ if __name__ == "__main__":
538
  ).to(device)
539
 
540
  r = torch.zeros(batch_size).to(device)
541
- z_mask, mask = model.add_noise(z, r)
542
-
543
  z_mask_latent = torch.rand(
544
  batch_size, model.latent_dim * model.n_codebooks, seq_len
545
  ).to(device)
 
1
  import math
2
+ from typing import Optional, Tuple, Union
3
 
4
  import numpy as np
5
  import torch
 
7
  import torch.nn.functional as F
8
  from einops import rearrange
9
  import loralib as lora
10
+ import audiotools as at
11
 
 
12
  from .activations import get_activation
13
  from .layers import CodebookEmbedding
14
  from .layers import FiLM
15
  from .layers import SequentialWithFiLM
16
  from .layers import WNConv1d
17
+ from ..util import scalar_to_batch_tensor, codebook_flatten, codebook_unflatten
18
+ from ..mask import _gamma
19
 
20
  LORA_R = 8
21
 
22
+ def log(t, eps=1e-20):
23
+ return torch.log(t + eps)
24
+
25
+
26
+ def gumbel_noise(t):
27
+ noise = torch.zeros_like(t).uniform_(0, 1)
28
+ return -log(-log(noise))
29
+
30
+
31
+ def gumbel_sample(t, temperature=1.0, dim=-1):
32
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
33
+
34
 
35
  class RMSNorm(nn.Module):
36
  def __init__(self, hidden_size: int, eps=1e-6):
 
450
  return self.norm(x) if self.norm is not None else x
451
 
452
 
453
+ class VampNet(at.ml.BaseModel):
454
  def __init__(
455
  self,
456
  n_heads: int = 20,
 
534
  out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
535
 
536
  return out
537
+
538
+ def r_embed(self, r, max_positions=10000):
539
+ if self.r_cond_dim > 0:
540
+ dtype = r.dtype
541
+
542
+ r = _gamma(r) * max_positions
543
+ half_dim = self.r_cond_dim // 2
544
+
545
+ emb = math.log(max_positions) / (half_dim - 1)
546
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
547
+
548
+ emb = r[:, None] * emb[None, :]
549
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
550
+
551
+ if self.r_cond_dim % 2 == 1: # zero pad
552
+ emb = nn.functional.pad(emb, (0, 1), mode="constant")
553
+
554
+ return emb.to(dtype)
555
+ else:
556
+ return r
557
+
558
+ @torch.no_grad()
559
+ def to_signal(self, z, codec):
560
+ """
561
+ convert a sequence of latents to a signal.
562
+ """
563
+ assert z.ndim == 3
564
+
565
+ signal = at.AudioSignal(
566
+ codec.decode(
567
+ codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
568
+ )["audio"],
569
+ codec.sample_rate,
570
+ )
571
+
572
+ # find where the mask token is and replace it with silence in the audio
573
+ for tstep in range(z.shape[-1]):
574
+ if torch.any(z[:, :, tstep] == self.mask_token):
575
+ sample_idx_0 = tstep * codec.hop_length
576
+ sample_idx_1 = sample_idx_0 + codec.hop_length
577
+ signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
578
+
579
+ return signal
580
+
581
+ def add_truth_to_logits(
582
+ self,
583
+ z_true,
584
+ z_hat,
585
+ mask,
586
+ ):
587
+ if self.noise_mode == "mask":
588
+ z_true = z_true[:, self.n_conditioning_codebooks :, :]
589
+ mask = mask[:, self.n_conditioning_codebooks :, :]
590
+
591
+ truth = F.one_hot(z_true, self.vocab_size)
592
+ mask = mask[:, :, :, None].expand(-1, -1, -1, self.vocab_size)
593
+ z_hat = rearrange(
594
+ z_hat,
595
+ "b p (t c) -> b c t p",
596
+ c=self.n_codebooks - self.n_conditioning_codebooks,
597
+ )
598
+
599
+ z_hat = z_hat * mask + truth * (1 - mask)
600
+
601
+ z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
602
+ else:
603
+ raise ValueError(f"invalid noise mode for adding truth to logits {self.noise_mode}")
604
+
605
+ return z_hat
606
+
607
+
608
+ @torch.no_grad()
609
+ def sample(
610
+ self,
611
+ codec,
612
+ time_steps: int = 300,
613
+ sampling_steps: int = 36,
614
+ start_tokens: Optional[torch.Tensor] = None,
615
+ mask: Optional[torch.Tensor] = None,
616
+ temperature: Union[float, Tuple[float, float]] = 0.8,
617
+ top_k: int = None,
618
+ sample: str = "gumbel",
619
+ typical_filtering=True,
620
+ typical_mass=0.2,
621
+ typical_min_tokens=1,
622
+ return_signal=True,
623
+ ):
624
+ if isinstance(temperature, float):
625
+ temperature = torch.tensor(temperature).repeat(sampling_steps)
626
+ elif isinstance(temperature, tuple):
627
+ assert len(temperature) == 2
628
+ l, h = temperature
629
+ temperature = torch.linspace(l, h, sampling_steps)
630
+ else:
631
+ raise TypeError(f"invalid type for temperature")
632
+
633
+ z = start_tokens
634
+
635
+ if z is None:
636
+ z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
637
+ self.device
638
+ )
639
+
640
+ if mask is None:
641
+ mask = torch.ones_like(z).to(self.device).int()
642
+ mask[:, : self.n_conditioning_codebooks, :] = 0.0
643
+ if mask.ndim == 2:
644
+ mask = mask[:, None, :].repeat(1, z.shape[1], 1)
645
+
646
+ # figure out which timesteps we're keeping
647
+ keep_mask = 1 - mask
648
+
649
+ # any conditioning codebook levels need to be in the keep mask
650
+ # if self.n_conditioning_codebooks > 0:
651
+ # cond_mask = torch.ones(z.shape[0], self.n_conditioning_codebooks, z.shape[-1]).to(z.device)
652
+ # keep_mask = torch.cat([cond_mask, keep_mask], dim=1)
653
+
654
+ # flatten
655
+ keep_mask = codebook_flatten(keep_mask)
656
+
657
+ # our r steps
658
+ r_steps = torch.linspace(0, 1, sampling_steps + 1)[1:].to(self.device)
659
+
660
+ # how many tokens did we keep on init?
661
+ num_kept_on_init = keep_mask.sum()
662
+
663
+ # how many codebooks are we inferring vs conditioning on?
664
+ n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
665
+
666
+ for i in range(sampling_steps):
667
+ # our current temperature
668
+ tmpt = temperature[i]
669
+
670
+ # our current schedule step
671
+ r = r_steps[i : i + 1]
672
+
673
+ with torch.inference_mode():
674
+ # mask our z
675
+ keep_mask_unflat = codebook_unflatten(keep_mask, n_c=self.n_codebooks)
676
+ z_masked = z.masked_fill(~keep_mask_unflat.bool(), self.mask_token)
677
+
678
+ # get latents
679
+ latents = self.embedding.from_codes(z_masked, codec)
680
+
681
+ # infer from latents
682
+ logits = self.forward(latents, r)
683
+ logits = logits.permute(0, 2, 1) # b, seq, prob
684
+
685
+ # the schedule determines how many samples to keep
686
+ num_tokens_to_infer = (z.shape[-1] * z.shape[-2]) - num_kept_on_init
687
+ num_to_keep = num_kept_on_init + int(
688
+ num_tokens_to_infer * (_gamma(1 - r))
689
+ )
690
+
691
+ # figure out which logits we wanna keep
692
+ if num_to_keep > 0:
693
+ probs = logits.softmax(dim=-1)
694
+
695
+ # do mod self.vocab_size to make sure we don't sample from the mask token
696
+ # in case the mask token was in the og z
697
+ keep_probs = F.one_hot(z%self.vocab_size, self.vocab_size)[:, :, :]
698
+
699
+ probs = rearrange(
700
+ probs, "b (t c) p -> b c t p", c=n_infer_codebooks
701
+ )
702
+ probs = torch.cat(
703
+ [keep_probs[:, : self.n_conditioning_codebooks, ...], probs],
704
+ dim=1,
705
+ )
706
+
707
+ keep_probs = rearrange(
708
+ keep_probs, "b c t p -> b (t c) p", c=self.n_codebooks
709
+ )
710
+ probs = rearrange(probs, "b c t p -> b (t c) p", c=self.n_codebooks)
711
+
712
+ keep_prob_mask = keep_mask.unsqueeze(-1).repeat(
713
+ 1, 1, self.vocab_size
714
+ )
715
+ probs = (keep_prob_mask.long() * keep_probs) + (
716
+ 1 - keep_prob_mask.long()
717
+ ) * probs
718
+
719
+ highest_probs = probs.max(dim=-1, keepdim=False)[0]
720
+ v, _ = highest_probs.topk(num_to_keep, dim=-1)
721
+
722
+ keep_mask = torch.ones_like(keep_mask).bool().clone()
723
+ keep_mask[highest_probs < v[..., [-1]]] = 0
724
+
725
+ logits = torch.log(probs)
726
+
727
+ z_inferred = self.sample_from_logits(
728
+ logits=logits,
729
+ top_k=top_k,
730
+ temperature=tmpt,
731
+ sample=sample,
732
+ typical_filtering=typical_filtering,
733
+ typical_mass=typical_mass,
734
+ typical_min_tokens=typical_min_tokens,
735
+ )
736
+
737
+ z = codebook_unflatten(z_inferred, n_c=self.n_codebooks)
738
+
739
+
740
+ if return_signal:
741
+ return self.to_signal(z, codec)
742
+ else:
743
+ return z
744
+
745
+ def sample_from_logits(
746
+ self,
747
+ logits,
748
+ top_k: int = None,
749
+ temperature: float = 1.0,
750
+ sample: str = "multinomial",
751
+ typical_filtering=False,
752
+ typical_mass=0.2,
753
+ typical_min_tokens=1,
754
+ ):
755
+ # add temperature
756
+ logits = logits / temperature
757
+
758
+ # add topk
759
+ if top_k is not None:
760
+ v, topk_idx = logits.topk(top_k)
761
+ logits[logits < v[..., [-1]]] = -float("inf")
762
+
763
+ if typical_filtering:
764
+ assert top_k is None
765
+ nb, nt, _ = logits.shape
766
+ x_flat = rearrange(logits, "b t l -> (b t ) l")
767
+ x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
768
+ x_flat_norm_p = torch.exp(x_flat_norm)
769
+ entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
770
+
771
+ c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
772
+ c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
773
+ x_flat_cumsum = (
774
+ x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
775
+ )
776
+
777
+ last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
778
+ sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
779
+ 1, last_ind.view(-1, 1)
780
+ )
781
+ if typical_min_tokens > 1:
782
+ sorted_indices_to_remove[..., :typical_min_tokens] = 0
783
+ indices_to_remove = sorted_indices_to_remove.scatter(
784
+ 1, x_flat_indices, sorted_indices_to_remove
785
+ )
786
+ x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
787
+ logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
788
+
789
+ if sample == "multinomial":
790
+ probs = torch.softmax(logits, dim=-1)
791
+ inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
792
+ elif sample == "argmax":
793
+ inferred = torch.softmax(logits, dim=-1).argmax(dim=-1)
794
+ elif sample == "gumbel":
795
+ inferred = gumbel_sample(logits, dim=-1)
796
+ else:
797
+ raise ValueError(f"invalid sampling method: {sample}")
798
+
799
+ return inferred
800
+
801
 
802
 
803
  if __name__ == "__main__":
 
817
  ).to(device)
818
 
819
  r = torch.zeros(batch_size).to(device)
820
+
 
821
  z_mask_latent = torch.rand(
822
  batch_size, model.latent_dim * model.n_codebooks, seq_len
823
  ).to(device)
vampnet/signal.py DELETED
@@ -1,5 +0,0 @@
1
- import torch
2
- from typing import Optional, Tuple
3
-
4
- from .util import scalar_to_batch_tensor
5
-
 
 
 
 
 
 
vampnet/util.py CHANGED
@@ -1,6 +1,7 @@
1
  import tqdm
2
 
3
  import torch
 
4
 
5
  def scalar_to_batch_tensor(x, batch_size):
6
  return torch.tensor(x).repeat(batch_size)
@@ -29,4 +30,17 @@ def parallelize(
29
  elif parallel == "single":
30
  return [fn(x) for x in tqdm.tqdm(*iterables)]
31
  else:
32
- raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import tqdm
2
 
3
  import torch
4
+ from einops import rearrange
5
 
6
  def scalar_to_batch_tensor(x, batch_size):
7
  return torch.tensor(x).repeat(batch_size)
 
30
  elif parallel == "single":
31
  return [fn(x) for x in tqdm.tqdm(*iterables)]
32
  else:
33
+ raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")
34
+
35
+ def codebook_flatten(tokens: torch.Tensor):
36
+ """
37
+ flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
38
+ """
39
+ return rearrange(tokens, "b c t -> b (t c)")
40
+
41
+ def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
42
+ """
43
+ unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
44
+ """
45
+ tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c)
46
+ return tokens