Hugo Flores Garcia commited on
Commit
d98455c
1 Parent(s): c068a29
scripts/exp/train.py CHANGED
@@ -342,8 +342,6 @@ def train(
342
  dtype = torch.bfloat16 if accel.amp else None
343
  with accel.autocast(dtype=dtype):
344
  z_hat = model(z_mask_latent, r)
345
- # for mask mode
346
- z_hat = vn.add_truth_to_logits(z, z_hat, mask)
347
 
348
  target = codebook_flatten(
349
  z[:, vn.n_conditioning_codebooks :, :],
@@ -414,8 +412,6 @@ def train(
414
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
415
 
416
  z_hat = model(z_mask_latent, r)
417
- # for mask mode
418
- z_hat = vn.add_truth_to_logits(z, z_hat, mask)
419
 
420
  target = codebook_flatten(
421
  z[:, vn.n_conditioning_codebooks :, :],
@@ -573,8 +569,6 @@ def train(
573
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
574
 
575
  z_hat = model(z_mask_latent, r)
576
- # for mask mode
577
- z_hat = vn.add_truth_to_logits(z, z_hat, mask)
578
 
579
  z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
580
  z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
 
342
  dtype = torch.bfloat16 if accel.amp else None
343
  with accel.autocast(dtype=dtype):
344
  z_hat = model(z_mask_latent, r)
 
 
345
 
346
  target = codebook_flatten(
347
  z[:, vn.n_conditioning_codebooks :, :],
 
412
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
413
 
414
  z_hat = model(z_mask_latent, r)
 
 
415
 
416
  target = codebook_flatten(
417
  z[:, vn.n_conditioning_codebooks :, :],
 
569
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
570
 
571
  z_hat = model(z_mask_latent, r)
 
 
572
 
573
  z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
574
  z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
scripts/utils/vamp_folder.py CHANGED
@@ -9,10 +9,13 @@ from tqdm import tqdm
9
  import torch
10
 
11
  from vampnet.interface import Interface
 
12
  import audiotools as at
13
 
14
  Interface: Interface = argbind.bind(Interface)
15
 
 
 
16
  def calculate_bitrate(
17
  interface, num_codebooks,
18
  downsample_factor
@@ -38,29 +41,19 @@ def coarse2fine(sig, interface):
38
  z = interface.coarse_to_fine(z)
39
  return interface.to_signal(z)
40
 
41
- def coarse2fine_argmax(sig, interface):
42
- z = interface.encode(sig)
43
- z = z[:, :interface.c2f.n_conditioning_codebooks, :]
44
-
45
- z = interface.coarse_to_fine(z,
46
- sample="argmax", sampling_steps=1,
47
- temperature=1.0
48
- )
49
- return interface.to_signal(z)
50
-
51
  class CoarseCond:
52
 
53
- def __init__(self, num_codebooks, downsample_factor):
54
- self.num_codebooks = num_codebooks
55
  self.downsample_factor = downsample_factor
56
 
57
  def __call__(self, sig, interface):
58
- n_conditioning_codebooks = interface.coarse.n_codebooks - self.num_codebooks
59
- zv = interface.coarse_vamp(sig,
60
- n_conditioning_codebooks=n_conditioning_codebooks,
61
- downsample_factor=self.downsample_factor,
62
- )
63
 
 
64
  zv = interface.coarse_to_fine(zv)
65
  return interface.to_signal(zv)
66
 
@@ -97,24 +90,24 @@ def opus(sig, interface, bitrate=128):
97
 
98
  def mask_ratio_1_step(ratio=1.0):
99
  def wrapper(sig, interface):
100
- r = interface.coarse.invgamma(ratio).to(interface.device)
101
- intensity = 1-r
102
-
103
  zv = interface.coarse_vamp(
104
- sig,
105
- sample='argmax',
106
  sampling_steps=1,
107
- intensity=intensity
108
  )
109
 
110
  return interface.to_signal(zv)
111
  return wrapper
112
 
113
  def num_sampling_steps(num_steps=1):
114
- def wrapper(sig, interface):
 
 
115
  zv = interface.coarse_vamp(
116
- sig,
117
- downsample_factor=16,
118
  sampling_steps=num_steps,
119
  )
120
 
@@ -130,9 +123,9 @@ def beat_mask(ctx_time):
130
  after_beat_s=ctx_time,
131
  invert=True
132
  )
 
133
  zv = interface.coarse_vamp(
134
- sig,
135
- ext_mask=beat_mask,
136
  )
137
 
138
  zv = interface.coarse_to_fine(zv)
@@ -140,17 +133,28 @@ def beat_mask(ctx_time):
140
  return wrapper
141
 
142
  def inpaint(ctx_time):
143
- def wrapper(sig, interface):
144
- zv = interface.coarse_vamp(
145
- sig,
146
- prefix_dur_s=ctx_time,
147
- suffix_dur_s=ctx_time,
148
- )
149
 
 
150
  zv = interface.coarse_to_fine(zv)
 
151
  return interface.to_signal(zv)
152
  return wrapper
153
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  EXP_REGISTRY = {}
155
 
156
  EXP_REGISTRY["gen-compression"] = {
@@ -158,62 +162,27 @@ EXP_REGISTRY["gen-compression"] = {
158
  "reconstructed": reconstructed,
159
  "coarse2fine": coarse2fine,
160
  **{
161
- f"{n}_codebooks_downsampled_{x}x": CoarseCond(num_codebooks=n, downsample_factor=x)
162
  for (n, x) in (
163
- (4, 2), # 4 codebooks, downsampled 2x,
164
- (2, 2), # 2 codebooks, downsampled 2x
165
- (1, None), # 1 codebook, no downsampling
166
  (4, 4), # 4 codebooks, downsampled 4x
167
- (1, 2), # 1 codebook, downsampled 2x,
168
- (4, 6), # 4 codebooks, downsampled 6x
169
- (4, 8), # 4 codebooks, downsampled 8x
170
  (4, 16), # 4 codebooks, downsampled 16x
171
  (4, 32), # 4 codebooks, downsampled 16x
172
  )
173
  },
 
 
 
 
174
 
175
  }
176
 
177
- EXP_REGISTRY["opus-jazzpop"] = {
178
- f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
179
- for bitrate in [5620, 1875, 1250, 625]
180
- }
181
-
182
- EXP_REGISTRY["opus-spotdl"] = {
183
- f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
184
- for bitrate in [8036, 2296, 1148, 574]
185
- }
186
-
187
- EXP_REGISTRY["opus-baseline"] = {
188
- f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
189
- for bitrate in [8000, 12000, 16000]
190
- }
191
-
192
- EXP_REGISTRY["c2f"] = {
193
- "baseline": baseline,
194
- "reconstructed": reconstructed,
195
- "coarse2fine": coarse2fine,
196
- "coarse2fine_argmax": coarse2fine_argmax,
197
- }
198
-
199
- EXP_REGISTRY["token-noise"] = {
200
- f"token_noise_{r}": token_noise(r) for r in [0.25, 0.5, 0.75, 1.0]
201
- }
202
-
203
- EXP_REGISTRY["mask-ratio"] = {
204
- "codec": reconstructed,
205
- **{f"mask_ratio_{r}": mask_ratio_1_step(r) for r in [0.25, 0.5, 0.75, 0.9]}
206
- }
207
 
208
  EXP_REGISTRY["sampling-steps"] = {
209
- "codec": reconstructed,
210
- **{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 24, 36, 64, 72]},
211
  }
212
 
213
- EXP_REGISTRY["baseline"] = {
214
- "baseline": baseline,
215
- "codec": reconstructed,
216
- }
217
 
218
  EXP_REGISTRY["musical-sampling"] = {
219
  "baseline": baseline,
@@ -226,12 +195,13 @@ EXP_REGISTRY["musical-sampling"] = {
226
  @argbind.bind(without_prefix=True)
227
  def main(
228
  sources=[
229
- "/data/spotdl/audio/val", "/data/spotdl/audio/test"
230
  ],
231
  output_dir: str = "./samples",
232
- max_excerpts: int = 5000,
233
- exp_type: str = "coarse",
234
  seed: int = 0,
 
235
  ):
236
  at.util.seed(seed)
237
  interface = Interface()
@@ -241,7 +211,7 @@ def main(
241
 
242
  from audiotools.data.datasets import AudioLoader, AudioDataset
243
 
244
- loader = AudioLoader(sources=sources, shuffle_state=seed)
245
  dataset = AudioDataset(loader,
246
  sample_rate=interface.codec.sample_rate,
247
  duration=interface.coarse.chunk_size_s,
 
9
  import torch
10
 
11
  from vampnet.interface import Interface
12
+ from vampnet import mask as pmask
13
  import audiotools as at
14
 
15
  Interface: Interface = argbind.bind(Interface)
16
 
17
+
18
+
19
  def calculate_bitrate(
20
  interface, num_codebooks,
21
  downsample_factor
 
41
  z = interface.coarse_to_fine(z)
42
  return interface.to_signal(z)
43
 
 
 
 
 
 
 
 
 
 
 
44
  class CoarseCond:
45
 
46
+ def __init__(self, num_conditioning_codebooks, downsample_factor):
47
+ self.num_conditioning_codebooks = num_conditioning_codebooks
48
  self.downsample_factor = downsample_factor
49
 
50
  def __call__(self, sig, interface):
51
+ z = interface.encode(sig)
52
+ mask = pmask.full_mask(z)
53
+ mask = pmask.codebook_unmask(mask, self.num_conditioning_codebooks)
54
+ mask = pmask.periodic_mask(mask, self.downsample_factor)
 
55
 
56
+ zv = interface.coarse_vamp(z, mask)
57
  zv = interface.coarse_to_fine(zv)
58
  return interface.to_signal(zv)
59
 
 
90
 
91
  def mask_ratio_1_step(ratio=1.0):
92
  def wrapper(sig, interface):
93
+ z = interface.encode(sig)
94
+ mask = pmask.linear_random(z, ratio)
 
95
  zv = interface.coarse_vamp(
96
+ z,
97
+ mask,
98
  sampling_steps=1,
 
99
  )
100
 
101
  return interface.to_signal(zv)
102
  return wrapper
103
 
104
  def num_sampling_steps(num_steps=1):
105
+ def wrapper(sig, interface: Interface):
106
+ z = interface.encode(sig)
107
+ mask = pmask.periodic_mask(z, 16)
108
  zv = interface.coarse_vamp(
109
+ z,
110
+ mask,
111
  sampling_steps=num_steps,
112
  )
113
 
 
123
  after_beat_s=ctx_time,
124
  invert=True
125
  )
126
+ z = interface.encode(sig)
127
  zv = interface.coarse_vamp(
128
+ z, beat_mask,
 
129
  )
130
 
131
  zv = interface.coarse_to_fine(zv)
 
133
  return wrapper
134
 
135
  def inpaint(ctx_time):
136
+ def wrapper(sig, interface: Interface):
137
+ z = interface.encode(sig)
138
+ mask = pmask.inpaint(z, interface.s2t(ctx_time), interface.s2t(ctx_time))
 
 
 
139
 
140
+ zv = interface.coarse_vamp(z, mask)
141
  zv = interface.coarse_to_fine(zv)
142
+
143
  return interface.to_signal(zv)
144
  return wrapper
145
 
146
+ def token_noise(noise_amt):
147
+ def wrapper(sig, interface: Interface):
148
+ z = interface.encode(sig)
149
+ mask = pmask.random(z, noise_amt)
150
+ z = torch.where(
151
+ mask,
152
+ torch.randint_like(z, 0, interface.coarse.vocab_size),
153
+ z
154
+ )
155
+ return interface.to_signal(z)
156
+ return wrapper
157
+
158
  EXP_REGISTRY = {}
159
 
160
  EXP_REGISTRY["gen-compression"] = {
 
162
  "reconstructed": reconstructed,
163
  "coarse2fine": coarse2fine,
164
  **{
165
+ f"{n}_codebooks_downsampled_{x}x": CoarseCond(num_conditioning_codebooks=n, downsample_factor=x)
166
  for (n, x) in (
167
+ (1, 1), # 1 codebook, no downsampling
 
 
168
  (4, 4), # 4 codebooks, downsampled 4x
 
 
 
169
  (4, 16), # 4 codebooks, downsampled 16x
170
  (4, 32), # 4 codebooks, downsampled 16x
171
  )
172
  },
173
+ **{
174
+ f"token_noise_{x}": mask_ratio_1_step(ratio=x)
175
+ for x in [0.25, 0.5, 0.75]
176
+ },
177
 
178
  }
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  EXP_REGISTRY["sampling-steps"] = {
182
+ # "codec": reconstructed,
183
+ **{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 36, 64, 72]},
184
  }
185
 
 
 
 
 
186
 
187
  EXP_REGISTRY["musical-sampling"] = {
188
  "baseline": baseline,
 
195
  @argbind.bind(without_prefix=True)
196
  def main(
197
  sources=[
198
+ "/media/CHONK/hugo/spotdl/audio-test",
199
  ],
200
  output_dir: str = "./samples",
201
+ max_excerpts: int = 2000,
202
+ exp_type: str = "gen-compression",
203
  seed: int = 0,
204
+ ext: str = [".mp3"],
205
  ):
206
  at.util.seed(seed)
207
  interface = Interface()
 
211
 
212
  from audiotools.data.datasets import AudioLoader, AudioDataset
213
 
214
+ loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext)
215
  dataset = AudioDataset(loader,
216
  sample_rate=interface.codec.sample_rate,
217
  duration=interface.coarse.chunk_size_s,
vampnet/interface.py CHANGED
@@ -321,7 +321,7 @@ class Interface(torch.nn.Module):
321
  cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token)
322
  cz_masked = cz_masked[:, : self.coarse.n_codebooks, :]
323
 
324
- gen_fn = gen_fn or self.coarse.sample
325
  c_vamp = gen_fn(
326
  codec=self.codec,
327
  time_steps=cz.shape[-1],
 
321
  cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token)
322
  cz_masked = cz_masked[:, : self.coarse.n_codebooks, :]
323
 
324
+ gen_fn = gen_fn or self.coarse.generate
325
  c_vamp = gen_fn(
326
  codec=self.codec,
327
  time_steps=cz.shape[-1],
vampnet/modules/transformer.py CHANGED
@@ -572,173 +572,13 @@ class VampNet(at.ml.BaseModel):
572
 
573
  return signal
574
 
575
- def add_truth_to_logits(
576
- self,
577
- z_true,
578
- z_hat,
579
- mask,
580
- ):
581
- z_true = z_true[:, self.n_conditioning_codebooks :, :]
582
- mask = mask[:, self.n_conditioning_codebooks :, :]
583
-
584
- truth = F.one_hot(z_true, self.vocab_size)
585
- mask = mask[:, :, :, None].expand(-1, -1, -1, self.vocab_size)
586
- z_hat = rearrange(
587
- z_hat,
588
- "b p (t c) -> b c t p",
589
- c=self.n_codebooks - self.n_conditioning_codebooks,
590
- )
591
-
592
- z_hat = z_hat * mask + truth * (1 - mask)
593
-
594
- z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
595
-
596
- return z_hat
597
-
598
-
599
- @torch.no_grad()
600
- def sample(
601
- self,
602
- codec,
603
- time_steps: int = 300,
604
- sampling_steps: int = 36,
605
- start_tokens: Optional[torch.Tensor] = None,
606
- mask: Optional[torch.Tensor] = None,
607
- temperature: Union[float, Tuple[float, float]] = 0.8,
608
- top_k: int = None,
609
- sample: str = "gumbel",
610
- typical_filtering=True,
611
- typical_mass=0.2,
612
- typical_min_tokens=1,
613
- return_signal=True,
614
- ):
615
- if isinstance(temperature, float):
616
- temperature = torch.tensor(temperature).repeat(sampling_steps)
617
- elif isinstance(temperature, tuple):
618
- assert len(temperature) == 2
619
- l, h = temperature
620
- temperature = torch.linspace(l, h, sampling_steps)
621
- else:
622
- raise TypeError(f"invalid type for temperature")
623
-
624
- z = start_tokens
625
-
626
- if z is None:
627
- z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
628
- self.device
629
- )
630
-
631
- if mask is None:
632
- mask = torch.ones_like(z).to(self.device).int()
633
- mask[:, : self.n_conditioning_codebooks, :] = 0.0
634
- if mask.ndim == 2:
635
- mask = mask[:, None, :].repeat(1, z.shape[1], 1)
636
-
637
- # figure out which timesteps we're keeping
638
- keep_mask = 1 - mask
639
-
640
- # any conditioning codebook levels need to be in the keep mask
641
- # if self.n_conditioning_codebooks > 0:
642
- # cond_mask = torch.ones(z.shape[0], self.n_conditioning_codebooks, z.shape[-1]).to(z.device)
643
- # keep_mask = torch.cat([cond_mask, keep_mask], dim=1)
644
-
645
- # flatten
646
- keep_mask = codebook_flatten(keep_mask)
647
-
648
- # our r steps
649
- r_steps = torch.linspace(0, 1, sampling_steps + 1)[1:].to(self.device)
650
-
651
- # how many tokens did we keep on init?
652
- num_kept_on_init = keep_mask.sum()
653
-
654
- # how many codebooks are we inferring vs conditioning on?
655
- n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
656
-
657
- for i in range(sampling_steps):
658
- # our current temperature
659
- tmpt = temperature[i]
660
-
661
- # our current schedule step
662
- r = r_steps[i : i + 1]
663
-
664
- with torch.inference_mode():
665
- # mask our z
666
- keep_mask_unflat = codebook_unflatten(keep_mask, n_c=self.n_codebooks)
667
- z_masked = z.masked_fill(~keep_mask_unflat.bool(), self.mask_token)
668
-
669
- # get latents
670
- latents = self.embedding.from_codes(z_masked, codec)
671
-
672
- # infer from latents
673
- logits = self.forward(latents, r)
674
- logits = logits.permute(0, 2, 1) # b, seq, prob
675
-
676
- # the schedule determines how many samples to keep
677
- num_tokens_to_infer = (z.shape[-1] * z.shape[-2]) - num_kept_on_init
678
- num_to_keep = num_kept_on_init + int(
679
- num_tokens_to_infer * (_gamma(1 - r))
680
- )
681
-
682
- # figure out which logits we wanna keep
683
- if num_to_keep > 0:
684
- probs = logits.softmax(dim=-1)
685
-
686
- # do mod self.vocab_size to make sure we don't sample from the mask token
687
- # in case the mask token was in the og z
688
- keep_probs = F.one_hot(z%self.vocab_size, self.vocab_size)[:, :, :]
689
-
690
- probs = rearrange(
691
- probs, "b (t c) p -> b c t p", c=n_infer_codebooks
692
- )
693
- probs = torch.cat(
694
- [keep_probs[:, : self.n_conditioning_codebooks, ...], probs],
695
- dim=1,
696
- )
697
-
698
- keep_probs = rearrange(
699
- keep_probs, "b c t p -> b (t c) p", c=self.n_codebooks
700
- )
701
- probs = rearrange(probs, "b c t p -> b (t c) p", c=self.n_codebooks)
702
-
703
- keep_prob_mask = keep_mask.unsqueeze(-1).repeat(
704
- 1, 1, self.vocab_size
705
- )
706
- probs = (keep_prob_mask.long() * keep_probs) + (
707
- 1 - keep_prob_mask.long()
708
- ) * probs
709
-
710
- highest_probs = probs.max(dim=-1, keepdim=False)[0]
711
- v, _ = highest_probs.topk(num_to_keep, dim=-1)
712
-
713
- keep_mask = torch.ones_like(keep_mask).bool().clone()
714
- keep_mask[highest_probs < v[..., [-1]]] = 0
715
-
716
- logits = torch.log(probs)
717
-
718
- z_inferred = sample_from_logits(
719
- logits=logits,
720
- top_k=top_k,
721
- temperature=tmpt,
722
- sample=sample,
723
- typical_filtering=typical_filtering,
724
- typical_mass=typical_mass,
725
- typical_min_tokens=typical_min_tokens,
726
- )
727
-
728
- z = codebook_unflatten(z_inferred, n_c=self.n_codebooks)
729
-
730
-
731
- if return_signal:
732
- return self.to_signal(z, codec)
733
- else:
734
- return z
735
 
736
  @torch.no_grad()
737
  def generate(
738
  self,
739
  codec,
740
  time_steps: int = 300,
741
- sampling_steps: int = 36,
742
  start_tokens: Optional[torch.Tensor] = None,
743
  mask: Optional[torch.Tensor] = None,
744
  temperature: Union[float, Tuple[float, float]] = 8.0,
@@ -747,7 +587,7 @@ class VampNet(at.ml.BaseModel):
747
  typical_min_tokens=1,
748
  return_signal=True,
749
  ):
750
- logging.info(f"beginning generation with {sampling_steps} steps")
751
 
752
  #####################
753
  # resolve temperature #
@@ -761,7 +601,7 @@ class VampNet(at.ml.BaseModel):
761
  else:
762
  raise TypeError(f"invalid type for temperature")
763
 
764
- logging.info(f"temperature: {temperature}")
765
 
766
 
767
  #####################
@@ -774,7 +614,7 @@ class VampNet(at.ml.BaseModel):
774
  self.device
775
  )
776
 
777
- logging.info(f"created z with shape {z.shape}")
778
 
779
 
780
  #################
@@ -788,7 +628,7 @@ class VampNet(at.ml.BaseModel):
788
  mask = mask[:, None, :].repeat(1, z.shape[1], 1)
789
  # init_mask = mask.clone()
790
 
791
- logging.info(f"created mask with shape {mask.shape}")
792
 
793
 
794
  ###########
@@ -796,38 +636,38 @@ class VampNet(at.ml.BaseModel):
796
  ##########
797
  # apply the mask to z
798
  z_masked = z.masked_fill(mask.bool(), self.mask_token)
799
- # logging.info(f"z_masked: {z_masked}")
800
 
801
  # how many mask tokens to begin with?
802
  num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
803
- logging.info(f"num mask tokens at start: {num_mask_tokens_at_start}")
804
 
805
  # our r steps
806
  r_steps = torch.linspace(1e-10, 1, sampling_steps+1)[1:].to(self.device)
807
- logging.info(f"r steps: {r_steps}")
808
 
809
  # how many codebooks are we inferring vs conditioning on?
810
  n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
811
- logging.info(f"n infer codebooks: {n_infer_codebooks}")
812
 
813
  #################
814
  # begin sampling #
815
  #################
816
 
817
  for i in range(sampling_steps):
818
- logging.info(f"step {i} of {sampling_steps}")
819
 
820
  # our current temperature
821
  tmpt = temperature[i]
822
- logging.info(f"temperature: {tmpt}")
823
 
824
  # our current schedule step
825
  r = r_steps[i : i + 1]
826
- logging.info(f"r: {r}")
827
 
828
  # get latents
829
  latents = self.embedding.from_codes(z_masked, codec)
830
- logging.info(f"computed latents with shape: {latents.shape}")
831
 
832
 
833
  # infer from latents
@@ -841,12 +681,12 @@ class VampNet(at.ml.BaseModel):
841
  )
842
 
843
 
844
- logging.info(f"permuted logits with shape: {logits.shape}")
845
 
846
 
847
  # logits2probs
848
  probs = torch.softmax(logits, dim=-1)
849
- logging.info(f"computed probs with shape: {probs.shape}")
850
 
851
 
852
  # sample from logits with multinomial sampling
@@ -857,7 +697,7 @@ class VampNet(at.ml.BaseModel):
857
 
858
  sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
859
  probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
860
- logging.info(f"sampled z with shape: {sampled_z.shape}")
861
 
862
 
863
  # flatten z_masked and mask, so we can deal with the sampling logic
@@ -868,12 +708,12 @@ class VampNet(at.ml.BaseModel):
868
  mask = (z_masked == self.mask_token).int()
869
 
870
  # update the mask, remove conditioning codebooks from the mask
871
- logging.info(f"updated mask with shape: {mask.shape}")
872
  # add z back into sampled z where the mask was false
873
  sampled_z = torch.where(
874
  mask.bool(), sampled_z, z_masked
875
  )
876
- logging.info(f"added z back into sampled z with shape: {sampled_z.shape}")
877
 
878
 
879
  # get the confidences: which tokens did we sample?
@@ -891,7 +731,7 @@ class VampNet(at.ml.BaseModel):
891
 
892
  # get the num tokens to mask, according to the schedule
893
  num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
894
- logging.info(f"num to mask: {num_to_mask}")
895
 
896
  num_to_mask = torch.maximum(
897
  torch.tensor(1),
@@ -911,17 +751,17 @@ class VampNet(at.ml.BaseModel):
911
  z_masked = torch.where(
912
  mask.bool(), self.mask_token, sampled_z
913
  )
914
- logging.info(f"updated z_masked with shape: {z_masked.shape}")
915
 
916
  z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
917
  mask = codebook_unflatten(mask, n_infer_codebooks)
918
- logging.info(f"unflattened z_masked with shape: {z_masked.shape}")
919
 
920
  # add conditioning codebooks back to z_masked
921
  z_masked = torch.cat(
922
  (z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
923
  )
924
- logging.info(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}")
925
 
926
 
927
  # add conditioning codebooks back to sampled_z
@@ -930,7 +770,7 @@ class VampNet(at.ml.BaseModel):
930
  (z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
931
  )
932
 
933
- logging.info(f"finished sampling")
934
 
935
  if return_signal:
936
  return self.to_signal(sampled_z, codec)
@@ -945,28 +785,28 @@ def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: floa
945
  probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
946
  temperature (float, optional): temperature. Defaults to 1.0.
947
  """
948
- logging.info(f"masking by random topk")
949
- logging.info(f"num to mask: {num_to_mask}")
950
- logging.info(f"probs shape: {probs.shape}")
951
- logging.info(f"temperature: {temperature}")
952
- logging.info("")
953
 
954
  confidence = torch.log(probs) + temperature * gumbel_noise_like(probs)
955
- logging.info(f"confidence shape: {confidence.shape}")
956
 
957
  sorted_confidence, sorted_idx = confidence.sort(dim=-1)
958
- logging.info(f"sorted confidence shape: {sorted_confidence.shape}")
959
- logging.info(f"sorted idx shape: {sorted_idx.shape}")
960
 
961
  # get the cut off threshold, given the mask length
962
  cut_off = torch.take_along_dim(
963
  sorted_confidence, num_to_mask, axis=-1
964
  )
965
- logging.info(f"cut off shape: {cut_off.shape}")
966
 
967
  # mask out the tokens
968
  mask = confidence < cut_off
969
- logging.info(f"mask shape: {mask.shape}")
970
 
971
  return mask
972
 
@@ -999,61 +839,6 @@ def typical_filter(
999
  logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
1000
  return logits
1001
 
1002
- def sample_from_logits(
1003
- logits,
1004
- top_k: int = None,
1005
- temperature: float = 1.0,
1006
- sample: str = "multinomial",
1007
- typical_filtering=False,
1008
- typical_mass=0.2,
1009
- typical_min_tokens=1,
1010
- ):
1011
- # add temperature
1012
- logits = logits / temperature
1013
-
1014
- # add topk
1015
- if top_k is not None and typical_filtering == False:
1016
- v, topk_idx = logits.topk(top_k)
1017
- logits[logits < v[..., [-1]]] = -float("inf")
1018
-
1019
- if typical_filtering:
1020
- assert top_k is None
1021
- nb, nt, _ = logits.shape
1022
- x_flat = rearrange(logits, "b t l -> (b t ) l")
1023
- x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
1024
- x_flat_norm_p = torch.exp(x_flat_norm)
1025
- entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
1026
-
1027
- c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
1028
- c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
1029
- x_flat_cumsum = (
1030
- x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
1031
- )
1032
-
1033
- last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
1034
- sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
1035
- 1, last_ind.view(-1, 1)
1036
- )
1037
- if typical_min_tokens > 1:
1038
- sorted_indices_to_remove[..., :typical_min_tokens] = 0
1039
- indices_to_remove = sorted_indices_to_remove.scatter(
1040
- 1, x_flat_indices, sorted_indices_to_remove
1041
- )
1042
- x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
1043
- logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
1044
-
1045
- if sample == "multinomial":
1046
- probs = torch.softmax(logits, dim=-1)
1047
- inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
1048
- elif sample == "argmax":
1049
- inferred = torch.softmax(logits, dim=-1).argmax(dim=-1)
1050
- elif sample == "gumbel":
1051
- inferred = gumbel_sample(logits, dim=-1)
1052
- else:
1053
- raise ValueError(f"invalid sampling method: {sample}")
1054
-
1055
- return inferred
1056
-
1057
 
1058
  if __name__ == "__main__":
1059
  # import argbind
 
572
 
573
  return signal
574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
 
576
  @torch.no_grad()
577
  def generate(
578
  self,
579
  codec,
580
  time_steps: int = 300,
581
+ sampling_steps: int = 24,
582
  start_tokens: Optional[torch.Tensor] = None,
583
  mask: Optional[torch.Tensor] = None,
584
  temperature: Union[float, Tuple[float, float]] = 8.0,
 
587
  typical_min_tokens=1,
588
  return_signal=True,
589
  ):
590
+ logging.debug(f"beginning generation with {sampling_steps} steps")
591
 
592
  #####################
593
  # resolve temperature #
 
601
  else:
602
  raise TypeError(f"invalid type for temperature")
603
 
604
+ logging.debug(f"temperature: {temperature}")
605
 
606
 
607
  #####################
 
614
  self.device
615
  )
616
 
617
+ logging.debug(f"created z with shape {z.shape}")
618
 
619
 
620
  #################
 
628
  mask = mask[:, None, :].repeat(1, z.shape[1], 1)
629
  # init_mask = mask.clone()
630
 
631
+ logging.debug(f"created mask with shape {mask.shape}")
632
 
633
 
634
  ###########
 
636
  ##########
637
  # apply the mask to z
638
  z_masked = z.masked_fill(mask.bool(), self.mask_token)
639
+ # logging.debug(f"z_masked: {z_masked}")
640
 
641
  # how many mask tokens to begin with?
642
  num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
643
+ logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
644
 
645
  # our r steps
646
  r_steps = torch.linspace(1e-10, 1, sampling_steps+1)[1:].to(self.device)
647
+ logging.debug(f"r steps: {r_steps}")
648
 
649
  # how many codebooks are we inferring vs conditioning on?
650
  n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
651
+ logging.debug(f"n infer codebooks: {n_infer_codebooks}")
652
 
653
  #################
654
  # begin sampling #
655
  #################
656
 
657
  for i in range(sampling_steps):
658
+ logging.debug(f"step {i} of {sampling_steps}")
659
 
660
  # our current temperature
661
  tmpt = temperature[i]
662
+ logging.debug(f"temperature: {tmpt}")
663
 
664
  # our current schedule step
665
  r = r_steps[i : i + 1]
666
+ logging.debug(f"r: {r}")
667
 
668
  # get latents
669
  latents = self.embedding.from_codes(z_masked, codec)
670
+ logging.debug(f"computed latents with shape: {latents.shape}")
671
 
672
 
673
  # infer from latents
 
681
  )
682
 
683
 
684
+ logging.debug(f"permuted logits with shape: {logits.shape}")
685
 
686
 
687
  # logits2probs
688
  probs = torch.softmax(logits, dim=-1)
689
+ logging.debug(f"computed probs with shape: {probs.shape}")
690
 
691
 
692
  # sample from logits with multinomial sampling
 
697
 
698
  sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
699
  probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
700
+ logging.debug(f"sampled z with shape: {sampled_z.shape}")
701
 
702
 
703
  # flatten z_masked and mask, so we can deal with the sampling logic
 
708
  mask = (z_masked == self.mask_token).int()
709
 
710
  # update the mask, remove conditioning codebooks from the mask
711
+ logging.debug(f"updated mask with shape: {mask.shape}")
712
  # add z back into sampled z where the mask was false
713
  sampled_z = torch.where(
714
  mask.bool(), sampled_z, z_masked
715
  )
716
+ logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
717
 
718
 
719
  # get the confidences: which tokens did we sample?
 
731
 
732
  # get the num tokens to mask, according to the schedule
733
  num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
734
+ logging.debug(f"num to mask: {num_to_mask}")
735
 
736
  num_to_mask = torch.maximum(
737
  torch.tensor(1),
 
751
  z_masked = torch.where(
752
  mask.bool(), self.mask_token, sampled_z
753
  )
754
+ logging.debug(f"updated z_masked with shape: {z_masked.shape}")
755
 
756
  z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
757
  mask = codebook_unflatten(mask, n_infer_codebooks)
758
+ logging.debug(f"unflattened z_masked with shape: {z_masked.shape}")
759
 
760
  # add conditioning codebooks back to z_masked
761
  z_masked = torch.cat(
762
  (z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
763
  )
764
+ logging.debug(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}")
765
 
766
 
767
  # add conditioning codebooks back to sampled_z
 
770
  (z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
771
  )
772
 
773
+ logging.debug(f"finished sampling")
774
 
775
  if return_signal:
776
  return self.to_signal(sampled_z, codec)
 
785
  probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
786
  temperature (float, optional): temperature. Defaults to 1.0.
787
  """
788
+ logging.debug(f"masking by random topk")
789
+ logging.debug(f"num to mask: {num_to_mask}")
790
+ logging.debug(f"probs shape: {probs.shape}")
791
+ logging.debug(f"temperature: {temperature}")
792
+ logging.debug("")
793
 
794
  confidence = torch.log(probs) + temperature * gumbel_noise_like(probs)
795
+ logging.debug(f"confidence shape: {confidence.shape}")
796
 
797
  sorted_confidence, sorted_idx = confidence.sort(dim=-1)
798
+ logging.debug(f"sorted confidence shape: {sorted_confidence.shape}")
799
+ logging.debug(f"sorted idx shape: {sorted_idx.shape}")
800
 
801
  # get the cut off threshold, given the mask length
802
  cut_off = torch.take_along_dim(
803
  sorted_confidence, num_to_mask, axis=-1
804
  )
805
+ logging.debug(f"cut off shape: {cut_off.shape}")
806
 
807
  # mask out the tokens
808
  mask = confidence < cut_off
809
+ logging.debug(f"mask shape: {mask.shape}")
810
 
811
  return mask
812
 
 
839
  logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
840
  return logits
841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
 
843
  if __name__ == "__main__":
844
  # import argbind