Hugo Flores Garcia commited on
Commit
51f416f
1 Parent(s): a689560
Files changed (4) hide show
  1. app.py +3 -3
  2. vampnet/interface.py +271 -42
  3. vampnet/mask.py +11 -4
  4. vampnet/modules/transformer.py +262 -12
app.py CHANGED
@@ -494,7 +494,7 @@ with gr.Blocks() as demo:
494
  minimum=1,
495
  maximum=16,
496
  step=1,
497
- value=1
498
  )
499
 
500
  win_dur= gr.Slider(
@@ -580,8 +580,8 @@ with gr.Blocks() as demo:
580
  from pyharp import ModelCard, build_endpoint
581
 
582
  model_card = ModelCard(
583
- name="salad bowl",
584
- description="sounds",
585
  author="hugo flores garcía",
586
  tags=["generative","sound"],
587
  )
 
494
  minimum=1,
495
  maximum=16,
496
  step=1,
497
+ value=3
498
  )
499
 
500
  win_dur= gr.Slider(
 
580
  from pyharp import ModelCard, build_endpoint
581
 
582
  model_card = ModelCard(
583
+ name="nesquik",
584
+ description="the ultimate 8-bit crusher",
585
  author="hugo flores garcía",
586
  tags=["generative","sound"],
587
  )
vampnet/interface.py CHANGED
@@ -110,26 +110,26 @@ class Interface(torch.nn.Module):
110
  # check if we already loaded, if so, don't reload
111
  if self.coarse_path == Path(coarse_ckpt):
112
  print(f"already loaded {coarse_ckpt}")
113
- return
114
- self.coarse = _load_model(
115
- ckpt=coarse_ckpt,
116
- device=self.device,
117
- chunk_size_s=self.coarse.chunk_size_s,
118
- )
119
- self.coarse_path = Path(coarse_ckpt)
120
- print(f"loaded {coarse_ckpt}")
121
 
122
  if c2f_ckpt is not None:
123
  if self.c2f_path == Path(c2f_ckpt):
124
  print(f"already loaded {c2f_ckpt}")
125
- return
126
- self.c2f = _load_model(
127
- ckpt=c2f_ckpt,
128
- device=self.device,
129
- chunk_size_s=self.c2f.chunk_size_s,
130
- )
131
- self.c2f_path = Path(c2f_ckpt)
132
- print(f"loaded {c2f_ckpt}")
133
 
134
  def s2t(self, seconds: float):
135
  """seconds to tokens"""
@@ -273,11 +273,15 @@ class Interface(torch.nn.Module):
273
  else:
274
  mask = mask.repeat(1, self.coarse.n_codebooks, 1)
275
  return mask
 
 
 
276
 
277
  def coarse_to_fine(
278
  self,
279
  z: torch.Tensor,
280
  mask: torch.Tensor = None,
 
281
  **kwargs
282
  ):
283
  assert self.c2f is not None, "No coarse2fine model loaded"
@@ -289,7 +293,7 @@ class Interface(torch.nn.Module):
289
  if length % chunk_len != 0:
290
  pad_len = chunk_len - (length % chunk_len)
291
  z = torch.nn.functional.pad(z, (0, pad_len))
292
- mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None
293
 
294
  n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
295
  if n_codebooks_to_append > 0:
@@ -297,6 +301,7 @@ class Interface(torch.nn.Module):
297
  z,
298
  torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
299
  ], dim=1)
 
300
 
301
  # set the mask to 0 for all conditioning codebooks
302
  if mask is not None:
@@ -319,6 +324,9 @@ class Interface(torch.nn.Module):
319
  fine_z.append(chunk)
320
 
321
  fine_z = torch.cat(fine_z, dim=-1)
 
 
 
322
  return fine_z[:, :, :length].clone()
323
 
324
  def coarse_vamp(
@@ -331,22 +339,52 @@ class Interface(torch.nn.Module):
331
  ):
332
  # coarse z
333
  cz = z[:, : self.coarse.n_codebooks, :].clone()
334
- assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
335
-
336
  mask = mask[:, : self.coarse.n_codebooks, :]
 
337
 
338
- cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token)
339
- cz_masked = cz_masked[:, : self.coarse.n_codebooks, :]
340
-
341
- gen_fn = gen_fn or self.coarse.generate
342
- c_vamp = gen_fn(
343
- codec=self.codec,
344
- time_steps=cz.shape[-1],
345
- start_tokens=cz,
346
- mask=mask,
347
- return_signal=False,
348
- **kwargs
349
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  # add the fine codes back in
352
  c_vamp = torch.cat(
@@ -358,16 +396,169 @@ class Interface(torch.nn.Module):
358
  return c_vamp, cz_masked
359
 
360
  return c_vamp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
- # def chunked_coarse_vamp(
363
- # self,
364
- # z,
365
- # mask,
366
- # return_mask=False,
367
- # gen_fn=None,
368
- # **kwargs
369
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
  if __name__ == "__main__":
373
  import audiotools as at
@@ -389,8 +580,6 @@ if __name__ == "__main__":
389
  sig = at.AudioSignal('assets/example.wav')
390
 
391
  z = interface.encode(sig)
392
- breakpoint()
393
-
394
  # mask = linear_random(z, 1.0)
395
  # mask = mask_and(
396
  # mask, periodic_mask(
@@ -429,4 +618,44 @@ if __name__ == "__main__":
429
  sig = interface.to_signal(zv).cpu()
430
  print("done")
431
 
432
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  # check if we already loaded, if so, don't reload
111
  if self.coarse_path == Path(coarse_ckpt):
112
  print(f"already loaded {coarse_ckpt}")
113
+ else:
114
+ self.coarse = _load_model(
115
+ ckpt=coarse_ckpt,
116
+ device=self.device,
117
+ chunk_size_s=self.coarse.chunk_size_s,
118
+ )
119
+ self.coarse_path = Path(coarse_ckpt)
120
+ print(f"loaded {coarse_ckpt}")
121
 
122
  if c2f_ckpt is not None:
123
  if self.c2f_path == Path(c2f_ckpt):
124
  print(f"already loaded {c2f_ckpt}")
125
+ else:
126
+ self.c2f = _load_model(
127
+ ckpt=c2f_ckpt,
128
+ device=self.device,
129
+ chunk_size_s=self.c2f.chunk_size_s,
130
+ )
131
+ self.c2f_path = Path(c2f_ckpt)
132
+ print(f"loaded {c2f_ckpt}")
133
 
134
  def s2t(self, seconds: float):
135
  """seconds to tokens"""
 
273
  else:
274
  mask = mask.repeat(1, self.coarse.n_codebooks, 1)
275
  return mask
276
+
277
+ def set_chunk_size(self, chunk_size_s: float):
278
+ self.coarse.chunk_size_s = chunk_size_s
279
 
280
  def coarse_to_fine(
281
  self,
282
  z: torch.Tensor,
283
  mask: torch.Tensor = None,
284
+ return_mask: bool = False,
285
  **kwargs
286
  ):
287
  assert self.c2f is not None, "No coarse2fine model loaded"
 
293
  if length % chunk_len != 0:
294
  pad_len = chunk_len - (length % chunk_len)
295
  z = torch.nn.functional.pad(z, (0, pad_len))
296
+ mask = torch.nn.functional.pad(mask, (0, pad_len), value=1) if mask is not None else None
297
 
298
  n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
299
  if n_codebooks_to_append > 0:
 
301
  z,
302
  torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
303
  ], dim=1)
304
+ print(f"appended {n_codebooks_to_append} codebooks to z")
305
 
306
  # set the mask to 0 for all conditioning codebooks
307
  if mask is not None:
 
324
  fine_z.append(chunk)
325
 
326
  fine_z = torch.cat(fine_z, dim=-1)
327
+ if return_mask:
328
+ return fine_z[:, :, :length].clone(), apply_mask(fine_z, mask, self.c2f.mask_token)[0][:, :, :length].clone()
329
+
330
  return fine_z[:, :, :length].clone()
331
 
332
  def coarse_vamp(
 
339
  ):
340
  # coarse z
341
  cz = z[:, : self.coarse.n_codebooks, :].clone()
 
 
342
  mask = mask[:, : self.coarse.n_codebooks, :]
343
+ # assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
344
 
345
+ # cut into chunks, keep the last chunk separate if it's too small
346
+ chunk_len = self.s2t(self.coarse.chunk_size_s)
347
+ n_chunks = math.ceil(cz.shape[-1] / chunk_len)
348
+ last_chunk_len = cz.shape[-1] % chunk_len
349
+
350
+ cz_chunks = []
351
+ mask_chunks = []
352
+ for i in range(n_chunks):
353
+ chunk = cz[:, :, i * chunk_len : (i + 1) * chunk_len]
354
+ mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len]
355
+
356
+ # make sure that the very first and last timestep of each chunk is 0 so that we don't get a weird
357
+ # discontinuity when we stitch the chunks back together
358
+ # only if there's already a 0 somewhere in the chunk
359
+ if torch.any(mask_chunk == 0):
360
+ mask_chunk[:, :, 0] = 0
361
+ mask_chunk[:, :, -1] = 0
362
+
363
+ cz_chunks.append(chunk)
364
+ mask_chunks.append(mask_chunk)
365
+
366
+ # now vamp each chunk
367
+ cz_masked_chunks = []
368
+ cz_vamped_chunks = []
369
+ for chunk, mask_chunk in zip(cz_chunks, mask_chunks):
370
+ cz_masked_chunk, mask_chunk = apply_mask(chunk, mask_chunk, self.coarse.mask_token)
371
+ cz_masked_chunk = cz_masked_chunk[:, : self.coarse.n_codebooks, :]
372
+ cz_masked_chunks.append(cz_masked_chunk)
373
+
374
+ gen_fn = gen_fn or self.coarse.generate
375
+ c_vamp_chunk = gen_fn(
376
+ codec=self.codec,
377
+ time_steps=chunk_len,
378
+ start_tokens=cz_masked_chunk,
379
+ return_signal=False,
380
+ mask=mask_chunk,
381
+ **kwargs
382
+ )
383
+ cz_vamped_chunks.append(c_vamp_chunk)
384
+
385
+ # stitch the chunks back together
386
+ cz_masked = torch.cat(cz_masked_chunks, dim=-1)
387
+ c_vamp = torch.cat(cz_vamped_chunks, dim=-1)
388
 
389
  # add the fine codes back in
390
  c_vamp = torch.cat(
 
396
  return c_vamp, cz_masked
397
 
398
  return c_vamp
399
+
400
+ def build_mask(self,
401
+ z: torch.Tensor,
402
+ sig: AudioSignal = None,
403
+ rand_mask_intensity: float = 1.0,
404
+ prefix_s: float = 0.0,
405
+ suffix_s: float = 0.0,
406
+ periodic_prompt: int = 7,
407
+ periodic_prompt2: int = 7,
408
+ periodic_prompt_width: int = 1,
409
+ onset_mask_width: int = 0,
410
+ _dropout: float = 0.0,
411
+ upper_codebook_mask: int = 3,
412
+ upper_codebook_mask_2: int = None,
413
+ ncc: int = 0,
414
+ ):
415
+
416
+ mask = linear_random(z, rand_mask_intensity)
417
+ mask = mask_and(
418
+ mask,
419
+ inpaint(z, self.s2t(prefix_s), self.s2t(suffix_s)),
420
+ )
421
 
422
+ pmask1 = periodic_mask(z, periodic_prompt, periodic_prompt_width, random_roll=True)
423
+ pmask2 = periodic_mask(z, periodic_prompt2, periodic_prompt_width, random_roll=True)
424
+ # interpolate the two masks
425
+ pmask = torch.round(
426
+ pmask1 * torch.linspace(1, 0, pmask1.shape[-1], device=pmask1.device) +
427
+ pmask2 * torch.linspace(0, 1, pmask2.shape[-1], device=pmask2.device)
428
+ ).long()
429
+
430
+ mask = mask_and(mask, pmask)
431
+
432
+ if onset_mask_width > 0:
433
+ assert sig is not None, f"must provide a signal to use onset mask"
434
+ mask = mask_or(
435
+ mask, onset_mask(
436
+ sig, z, interface,
437
+ width=onset_mask_width
438
+ )
439
+ )
440
+
441
+ mask = dropout(mask, _dropout)
442
+ mask = codebook_unmask(mask, ncc)
443
+
444
+ mask = codebook_mask(mask, int(upper_codebook_mask), upper_codebook_mask_2)
445
+ return mask
446
+
447
+ def ez_vamp(
448
+ self,
449
+ sig: AudioSignal,
450
+ batch_size: int = 4,
451
+ feedback_steps: int = 1,
452
+ time_stretch_factor: int = 1,
453
+ return_mask: bool = False,
454
+ build_mask_kwargs: dict = None,
455
+ vamp_kwargs: dict = None,
456
+ ):
457
+ feedback_steps = int(feedback_steps)
458
+ build_mask_kwargs = build_mask_kwargs or {}
459
+ vamp_kwargs = vamp_kwargs or {}
460
+
461
+ loudness = sig.loudness()
462
+ sig = self.preprocess(sig)
463
+
464
+ z = self.encode(sig)
465
+
466
+ # expand z to batch size
467
+ z = z.expand(batch_size, -1, -1)
468
+ mask = self.build_mask(
469
+ z=z,
470
+ **build_mask_kwargs
471
+ )
472
+ mask = mask.expand(batch_size, -1, -1)
473
+
474
+ # stretch mask and z to match the time stretch factor
475
+ # we'll add (stretch_factor - 1) mask tokens in between each timestep of z
476
+ # and we'll make the mask 1 in all the new slots we added
477
+ if time_stretch_factor > 1:
478
+ z = z.repeat_interleave(time_stretch_factor, dim=-1)
479
+ mask = mask.repeat_interleave(time_stretch_factor, dim=-1)
480
+ added_mask = torch.ones_like(mask)
481
+ added_mask[:, :, ::time_stretch_factor] = 0
482
+ mask = mask.bool() | added_mask.bool()
483
+ mask = mask.long()
484
+
485
+ prev_zvs = []
486
+ for i in tqdm.tqdm(range(feedback_steps), desc="feedback steps"):
487
+ print(z.shape)
488
+
489
+ vamp_kwargs.pop("mask", None)
490
+ vamp_kwargs.pop('return_mask', None)
491
+ print("coarse!")
492
+ zv, mask_z = self.coarse_vamp(
493
+ z,
494
+ mask=mask,
495
+ return_mask=True,
496
+ **vamp_kwargs
497
+ )
498
+
499
+ # add the top codebooks back in
500
+ if zv.shape[1] < z.shape[1]:
501
+ print(f"adding {z.shape[1] - zv.shape[1]} codebooks back in")
502
+ zv = torch.cat(
503
+ [zv, z[:, self.coarse.n_codebooks :, :]],
504
+ dim=1
505
+ )
506
+
507
+ # now, coarse2fine
508
+ print(f"coarse2fine!")
509
+ zv, fine_zv_mask = self.coarse_to_fine(
510
+ zv,
511
+ mask=mask,
512
+ **vamp_kwargs,
513
+ _sampling_steps=[2, 2, 1, 1],
514
+ return_mask=True
515
+ )
516
+ mask_z = torch.cat(
517
+ [mask_z[:, :self.coarse.n_codebooks, :], fine_zv_mask[:, self.coarse.n_codebooks:, :]],
518
+ dim=1
519
+ )
520
+
521
+ prev_zvs.append(zv)
522
+ z = zv
523
+
524
+ # perform to_signal batch item by batch
525
+ sigs = []
526
+ for zv in prev_zvs:
527
+ # do it in timestep chunks of 1024
528
+ _sigs = []
529
+ for i in range(0, zv.shape[-1], 1024):
530
+ _sigs.append(self.to_signal(zv[:, :, i:i+1024]).cpu())
531
+ sigs.append(signal_concat(_sigs))
532
+ print("done")
533
+
534
+ sig = AudioSignal.batch(sigs)
535
 
536
+ # sig = self.to_signal(zv).cpu()
537
+ # print("done")
538
+
539
+ sig = sig.normalize(loudness)
540
+
541
+ if return_mask:
542
+ return sig, mask_z.cpu(), zv.cpu()
543
+ else:
544
+ return sig
545
+
546
+ def visualize_codes(self, z: torch.Tensor):
547
+ import matplotlib.pyplot as plt
548
+ # make sure the figsize is square when imshow is called
549
+ fig = plt.figure(figsize=(10, 7))
550
+ # in subplots, plot z[0] and the mask
551
+ # set title to "codes" and "mask"
552
+ fig.add_subplot(2, 1, 1)
553
+ plt.imshow(z[0].cpu().numpy(), aspect='auto', origin='lower', cmap="tab20")
554
+ plt.title("codes")
555
+ plt.ylabel("codebook index")
556
+ # set the xticks to seconds
557
+ plt.xticks(
558
+ np.arange(0, z.shape[-1], self.s2t(1)),
559
+ np.arange(0, self.t2s(z.shape[-1]), 1)
560
+ )
561
+ plt.xlabel("time (s)")
562
 
563
  if __name__ == "__main__":
564
  import audiotools as at
 
580
  sig = at.AudioSignal('assets/example.wav')
581
 
582
  z = interface.encode(sig)
 
 
583
  # mask = linear_random(z, 1.0)
584
  # mask = mask_and(
585
  # mask, periodic_mask(
 
618
  sig = interface.to_signal(zv).cpu()
619
  print("done")
620
 
621
+
622
+
623
+
624
+ # example plotting code
625
+ # import matplotlib.pyplot as plt
626
+ # from pathlib import Path
627
+ # Path(".vampnet").mkdir(exist_ok=True)
628
+ # plt.clf()
629
+ # # close all figs
630
+ # plt.close('all')
631
+ # # set the fig size
632
+ # plt.subplot(4, 1, 1)
633
+ # # sig = self.to_signal(sampled_z, codec)
634
+ # # sig.cpu().specshow()
635
+
636
+ # plt.subplot(4, 1, 2)
637
+ # # since z_masked is a codebook, we want to plot the colormap
638
+ # # with distinct colors for each codebook index
639
+ # # plt.imshow(_debug_z_masked_before_forward[0].cpu().numpy(), aspect='auto', origin='lower', cmap="tab20")
640
+ # # make it so that anywhere where the mask is 1, we make that pixel black
641
+ # plt.imshow(_debug_z_masked_before_forward[0].cpu().numpy(), aspect='auto', origin='lower', cmap='gray_r',)
642
+
643
+
644
+ # plt.subplot(4, 1, 3)
645
+ # # plot the mask (which is a matrix)
646
+ # plt.imshow(mask[0].cpu().numpy(), aspect='auto', origin='lower', cmap='gray_r')
647
+ # plt.subplot(4, 1, 4)
648
+ # # replace any inf or -inf with 0
649
+ # _selected_probs = torch.where(
650
+ # selected_probs == torch.inf, torch.zeros_like(selected_probs), selected_probs
651
+ # )
652
+ # _selected_probs = torch.where(
653
+ # selected_probs == -torch.inf, torch.zeros_like(selected_probs), selected_probs
654
+ # )
655
+ # # fig = plt.gcf()
656
+ # # fig.set_figheight(15)
657
+ # # fig.set_figwidth(15)
658
+ # plt.imshow(codebook_unflatten(_selected_probs, n_infer_codebooks)[0].cpu().numpy(), aspect='auto', origin='lower', cmap="viridis" )
659
+ # # plt.show()
660
+ # plt.savefig(f".vampnet/c={codebook_level}_{i}.png")
661
+ # plt.close('all')
vampnet/mask.py CHANGED
@@ -60,6 +60,7 @@ def linear_random(
60
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
61
  if not isinstance(r, torch.Tensor):
62
  r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
 
63
 
64
  probs = torch.ones_like(x).to(x.device).float()
65
  # expand to batch and codebook dims
@@ -98,7 +99,7 @@ def inpaint(x: torch.Tensor,
98
  return mask
99
 
100
  def periodic_mask(x: torch.Tensor,
101
- period: int, width: int = 1,
102
  random_roll=False,
103
  ):
104
  mask = full_mask(x)
@@ -140,9 +141,15 @@ def codebook_unmask(
140
  mask[:, :n_conditioning_codebooks, :] = 0
141
  return mask
142
 
143
- def codebook_mask(mask: torch.Tensor, start: int):
144
  mask = mask.clone()
145
- mask[:, start:, :] = 1
 
 
 
 
 
 
146
  return mask
147
 
148
  def mask_and(
@@ -239,4 +246,4 @@ def onset_mask(
239
 
240
 
241
  if __name__ == "__main__":
242
- pass
 
60
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
61
  if not isinstance(r, torch.Tensor):
62
  r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
63
+ r = r[:, None, None]
64
 
65
  probs = torch.ones_like(x).to(x.device).float()
66
  # expand to batch and codebook dims
 
99
  return mask
100
 
101
  def periodic_mask(x: torch.Tensor,
102
+ period: int,width: int = 1,
103
  random_roll=False,
104
  ):
105
  mask = full_mask(x)
 
141
  mask[:, :n_conditioning_codebooks, :] = 0
142
  return mask
143
 
144
+ def codebook_mask(mask: torch.Tensor, val1: int, val2: int = None):
145
  mask = mask.clone()
146
+ mask[:, val1:, :] = 1
147
+ # val2 = val2 or val1
148
+ # vs = torch.linspace(val1, val2, mask.shape[1])
149
+ # for t, v in enumerate(vs):
150
+ # v = int(v)
151
+ # mask[:, v:, t] = 1
152
+
153
  return mask
154
 
155
  def mask_and(
 
246
 
247
 
248
  if __name__ == "__main__":
249
+ pass
vampnet/modules/transformer.py CHANGED
@@ -1,6 +1,6 @@
1
  import math
2
  import logging
3
- from typing import Optional, Tuple, Union
4
 
5
  import numpy as np
6
  import torch
@@ -572,6 +572,8 @@ class VampNet(at.ml.BaseModel):
572
  """
573
  assert z.ndim == 3
574
 
 
 
575
  signal = at.AudioSignal(
576
  codec.decode(
577
  codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
@@ -581,34 +583,279 @@ class VampNet(at.ml.BaseModel):
581
 
582
  # find where the mask token is and replace it with silence in the audio
583
  for tstep in range(z.shape[-1]):
584
- if torch.any(z[:, :, tstep] == self.mask_token):
585
  sample_idx_0 = tstep * codec.hop_length
586
  sample_idx_1 = sample_idx_0 + codec.hop_length
587
  signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
588
 
589
  return signal
590
-
591
-
592
  @torch.no_grad()
593
  def generate(
594
  self,
595
  codec,
596
  time_steps: int = 300,
597
- sampling_steps: int = 36,
598
  start_tokens: Optional[torch.Tensor] = None,
599
  sampling_temperature: float = 1.0,
600
  mask: Optional[torch.Tensor] = None,
601
  mask_temperature: float = 10.5,
602
- typical_filtering=False,
603
  typical_mass=0.2,
604
  typical_min_tokens=1,
605
- top_p=None,
606
- return_signal=True,
607
  seed: int = None,
608
- sample_cutoff: float = 1.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  ):
610
  if seed is not None:
611
  at.util.seed(seed)
 
612
  logging.debug(f"beginning generation with {sampling_steps} steps")
613
 
614
 
@@ -763,6 +1010,10 @@ class VampNet(at.ml.BaseModel):
763
  else:
764
  return sampled_z
765
 
 
 
 
 
766
  def sample_from_logits(
767
  logits,
768
  sample: bool = True,
@@ -942,12 +1193,11 @@ if __name__ == "__main__":
942
  pred = z_hat.argmax(dim=1)
943
  pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
944
 
945
- print(f"model has {num_params(model)/1e6:<.3f}M parameters")
946
- print(f"prediction has shape {pred.shape}")
947
  breakpoint()
948
 
949
  args = argbind.parse_args()
950
  with argbind.scope(args):
951
  try_model()
952
 
953
-
 
1
  import math
2
  import logging
3
+ from typing import Optional, Tuple, Union, List
4
 
5
  import numpy as np
6
  import torch
 
572
  """
573
  assert z.ndim == 3
574
 
575
+ # remove mask token
576
+ z = z.masked_fill(z == self.mask_token, 0)
577
  signal = at.AudioSignal(
578
  codec.decode(
579
  codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
 
583
 
584
  # find where the mask token is and replace it with silence in the audio
585
  for tstep in range(z.shape[-1]):
586
+ if torch.all(z[:, :, tstep] == self.mask_token):
587
  sample_idx_0 = tstep * codec.hop_length
588
  sample_idx_1 = sample_idx_0 + codec.hop_length
589
  signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
590
 
591
  return signal
592
+
 
593
  @torch.no_grad()
594
  def generate(
595
  self,
596
  codec,
597
  time_steps: int = 300,
598
+ _sampling_steps: List[int] = [16, 8, 8, 2, 2, 2, 2, 1, 1],
599
  start_tokens: Optional[torch.Tensor] = None,
600
  sampling_temperature: float = 1.0,
601
  mask: Optional[torch.Tensor] = None,
602
  mask_temperature: float = 10.5,
603
+ typical_filtering=True,
604
  typical_mass=0.2,
605
  typical_min_tokens=1,
606
+ top_p=0.9,
 
607
  seed: int = None,
608
+ sample_cutoff: float = 0.9,
609
+ return_signal=True,
610
+ debug=False,
611
+ causal_weight: float = 0.0,
612
+ use_og_method: bool = False,
613
+ ):
614
+ if use_og_method:
615
+ return self.og_method(
616
+ codec,
617
+ time_steps,
618
+ _sampling_steps,
619
+ start_tokens,
620
+ sampling_temperature,
621
+ mask,
622
+ mask_temperature,
623
+ typical_filtering,
624
+ typical_mass,
625
+ typical_min_tokens,
626
+ top_p,
627
+ seed,
628
+ sample_cutoff,
629
+ return_signal,
630
+ debug,
631
+ causal_weight,
632
+ )
633
+
634
+ if seed is not None:
635
+ at.util.seed(seed)
636
+
637
+ #####################
638
+ # resolve initial z #
639
+ #####################
640
+ z = start_tokens
641
+
642
+ if z is None:
643
+ z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
644
+ self.device
645
+ )
646
+
647
+ logging.debug(f"created z with shape {z.shape}")
648
+
649
+ #################
650
+ # resolve mask #
651
+ #################
652
+
653
+ if mask is None:
654
+ mask = torch.ones_like(z).to(self.device).int()
655
+ mask[:, : self.n_conditioning_codebooks, :] = 0.0
656
+ if mask.ndim == 2:
657
+ mask = mask[:, None, :].repeat(1, z.shape[1], 1)
658
+ orig_mask = mask
659
+ logging.debug(f"created mask with shape {mask.shape}")
660
+
661
+ ###########
662
+ # set up #
663
+ ##########
664
+ # apply the mask to z
665
+ z_masked = z.masked_fill(mask.bool(), self.mask_token)
666
+
667
+ # how many codebooks are we inferring vs conditioning on?
668
+ n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
669
+ logging.debug(f"n infer codebooks: {n_infer_codebooks}")
670
+
671
+ #################
672
+ # begin sampling #
673
+ #################
674
+ # add one sampling step for each codebook level
675
+ logging.debug(f"initial mask: {mask}")
676
+ logging.debug(f"adding {n_infer_codebooks} sampling steps")
677
+ steps = _sampling_steps + [1 for _ in range(n_infer_codebooks - len(_sampling_steps))]
678
+ # truncate if we have too many
679
+ steps = steps[:n_infer_codebooks]
680
+ for codebook_level, nsteps in enumerate(steps):
681
+
682
+ # apply the orig mask to z_masked, only in the current codebook level
683
+ # this is crucial due to the stemgen random masking we did during training
684
+ # which ensures all upper codebooks are masked while inferring the bottom ones.
685
+ z_masked[:, codebook_level, :] = torch.where(
686
+ orig_mask[:, codebook_level, :].bool(),
687
+ self.mask_token,
688
+ z_masked[:, codebook_level, :]
689
+ )
690
+
691
+ # how many mask tokens to begin with?
692
+ num_mask_tokens_at_start = (z_masked[:, codebook_level, :] == self.mask_token).sum(dim=-1)
693
+ logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
694
+
695
+ for i in range(nsteps):
696
+ logging.debug(f"processing cb level {codebook_level} of {len(steps)}")
697
+ logging.debug(f"step {i} of {nsteps}")
698
+
699
+ # our current schedule step
700
+ r = scalar_to_batch_tensor(
701
+ (i + 1) / nsteps,
702
+ z.shape[0]
703
+ ).to(z.device)
704
+ logging.debug(f"r: {r}")
705
+
706
+ # get latents
707
+ latents = self.embedding.from_codes(z_masked, codec)
708
+ logging.debug(f"computed latents with shape: {latents.shape}")
709
+
710
+ # infer from latents
711
+ # NOTE: this collapses the codebook dimension into the sequence dimension
712
+ logits = self.forward(
713
+ latents,
714
+ ) # b, prob, seq
715
+ logits = logits.permute(0, 2, 1) # b, seq, prob
716
+ logging.debug(f"permuted logits with shape: {logits.shape}")
717
+
718
+ sampled_z, selected_probs = sample_from_logits(
719
+ logits, sample=(
720
+ (i / nsteps) <= sample_cutoff
721
+ ),
722
+ temperature=sampling_temperature,
723
+ typical_filtering=typical_filtering, typical_mass=typical_mass,
724
+ typical_min_tokens=typical_min_tokens,
725
+ top_k=None, top_p=top_p, return_probs=True,
726
+ )
727
+
728
+ # fill selected probs with -inf if we're not in the codebook level we are sampling from
729
+ # find out which codebook we are sampling from
730
+ selected_probs = codebook_unflatten(selected_probs, n_infer_codebooks)
731
+ selected_probs[:, codebook_level+1:, :,] = -float("inf") # all the ones above
732
+ # selected_probs[:, :codebook_level, :,] = -float("inf")
733
+ logging.debug(f"masking all but codebook {codebook_level}")
734
+ logging.debug(f"selected probs: {selected_probs}")
735
+ logging.debug(mask)
736
+ selected_probs = codebook_flatten(selected_probs)
737
+
738
+ logging.debug(f"sampled z with shape: {sampled_z.shape}")
739
+
740
+ # flatten z_masked and mask, so we can deal with the sampling logic
741
+ # we'll unflatten them at the end of the loop for the next forward pass
742
+ # remove conditioning codebooks, we'll add them back at the end
743
+ z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
744
+
745
+ mask = (z_masked == self.mask_token).int()
746
+ logging.debug(f"mask now: {mask}")
747
+
748
+ # update the mask, remove conditioning codebooks from the mask
749
+ logging.debug(f"updated mask with shape: {mask.shape}")
750
+
751
+ # add z back into sampled z where the mask was false
752
+ sampled_z = torch.where(
753
+ mask.bool(), sampled_z, z_masked
754
+ )
755
+ logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
756
+
757
+ # get the num tokens to mask, according to the schedule
758
+ num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
759
+ # num_to_mask = torch.floor(r * num_mask_tokens_at_start).unsqueeze(1).long() # doesn't work at all this way
760
+ logging.debug(f"num to mask: {num_to_mask}")
761
+ logging.debug(f"masking {num_to_mask.sum()} tokens")
762
+
763
+ if i != (nsteps - 1):
764
+ mask = codebook_unflatten(mask, n_infer_codebooks)
765
+ num_to_mask = torch.maximum(
766
+ torch.tensor(1),
767
+ torch.minimum(
768
+ mask[:, codebook_level, :].sum(dim=-1, keepdim=True) - 1,
769
+ num_to_mask
770
+ )
771
+ )
772
+ logging.debug(f"will mask {num_to_mask.sum()} tokens")
773
+ mask = codebook_flatten(mask)
774
+
775
+ # ignore any tokens that weren't masked
776
+ selected_probs = torch.where(
777
+ mask.bool(), selected_probs, torch.inf
778
+ )
779
+
780
+ # add a causal weight to the selected probs
781
+ # NOTE: some experiments i did showed that this didn't help.
782
+ # set it to 0 until further eval
783
+ causal_probs = torch.linspace(1, 0, z_masked.shape[-1], device=z_masked.device)
784
+ causal_probs = causal_probs.repeat(z_masked.shape[0], 1)
785
+ selected_probs = selected_probs + causal_probs * causal_weight
786
+
787
+ # # get our new mask
788
+ ############
789
+ mask = codebook_unflatten(mask, n_infer_codebooks)
790
+ selected_probs = codebook_unflatten(selected_probs, n_infer_codebooks)
791
+
792
+ # only consider probs at current level
793
+ selected_probs_cur_level = selected_probs[:, codebook_level, :]
794
+ mask_cur_level = mask_by_random_topk(
795
+ num_to_mask, selected_probs_cur_level, mask_temperature * (1-r.unsqueeze(1))
796
+ )
797
+ mask[:, codebook_level, :] = mask_cur_level
798
+
799
+ mask = codebook_flatten(mask)
800
+ selected_probs = codebook_flatten(selected_probs)
801
+ ###############
802
+
803
+
804
+ # update the mask
805
+ z_masked = torch.where(
806
+ mask.bool(), self.mask_token, sampled_z
807
+ )
808
+ logging.debug(f"updated z_masked with shape: {z_masked.shape}")
809
+
810
+ z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
811
+ mask = codebook_unflatten(mask, n_infer_codebooks)
812
+ logging.debug(f"unflattened z_masked with shape: {z_masked.shape}")
813
+
814
+ # add conditioning codebooks back to z_masked
815
+ z_masked = torch.cat(
816
+ (z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
817
+ )
818
+ logging.debug(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}")
819
+
820
+
821
+ # add conditioning codebooks back to sampled_z
822
+ sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
823
+ sampled_z = torch.cat(
824
+ (z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
825
+ )
826
+
827
+ logging.debug(f"finished sampling")
828
+
829
+
830
+ if return_signal:
831
+ return self.to_signal(sampled_z, codec)
832
+ else:
833
+ return sampled_z
834
+
835
+
836
+
837
+ def og_method(
838
+ self,
839
+ codec,
840
+ time_steps: int = 300,
841
+ _sampling_steps: List[int] = [16, 8, 8, 2, 2, 2, 2, 1, 1],
842
+ start_tokens: Optional[torch.Tensor] = None,
843
+ sampling_temperature: float = 1.0,
844
+ mask: Optional[torch.Tensor] = None,
845
+ mask_temperature: float = 10.5,
846
+ typical_filtering=True,
847
+ typical_mass=0.2,
848
+ typical_min_tokens=1,
849
+ top_p=0.9,
850
+ seed: int = None,
851
+ sample_cutoff: float = 0.75,
852
+ return_signal=True,
853
+ debug=False,
854
+ causal_weight: float = 0.0,
855
  ):
856
  if seed is not None:
857
  at.util.seed(seed)
858
+ sampling_steps = sum(_sampling_steps)
859
  logging.debug(f"beginning generation with {sampling_steps} steps")
860
 
861
 
 
1010
  else:
1011
  return sampled_z
1012
 
1013
+
1014
+
1015
+
1016
+
1017
  def sample_from_logits(
1018
  logits,
1019
  sample: bool = True,
 
1193
  pred = z_hat.argmax(dim=1)
1194
  pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
1195
 
1196
+ logging.debug(f"model has {num_params(model)/1e6:<.3f}M parameters")
1197
+ logging.debug(f"prediction has shape {pred.shape}")
1198
  breakpoint()
1199
 
1200
  args = argbind.parse_args()
1201
  with argbind.scope(args):
1202
  try_model()
1203