Hugo Flores Garcia commited on
Commit
9496f0e
1 Parent(s): 308d855

sampling tricks!

Browse files
Files changed (2) hide show
  1. app.py +49 -11
  2. vampnet/modules/transformer.py +109 -37
app.py CHANGED
@@ -97,28 +97,35 @@ def _vamp(data, return_mask=False):
97
  mask = pmask.codebook_unmask(mask, ncc)
98
 
99
 
100
- print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[temp]}, use coarse2fine {data[use_coarse2fine]}")
 
101
  # save the mask as a txt file
102
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
103
 
 
104
  zv, mask_z = interface.coarse_vamp(
105
  z,
106
  mask=mask,
107
  sampling_steps=data[num_steps],
108
- temperature=data[temp]*10,
 
109
  return_mask=True,
110
  typical_filtering=data[typical_filtering],
111
  typical_mass=data[typical_mass],
112
  typical_min_tokens=data[typical_min_tokens],
 
113
  gen_fn=interface.coarse.generate,
 
114
  )
115
 
116
  if use_coarse2fine:
117
  zv = interface.coarse_to_fine(
118
  zv,
119
- temperature=data[temp],
 
120
  mask=mask,
121
- sampling_steps=data[num_steps]
 
122
  )
123
 
124
  sig = interface.to_signal(zv).cpu()
@@ -152,7 +159,9 @@ def save_vamp(data):
152
  sig_out.write(out_dir / "output.wav")
153
 
154
  _data = {
155
- "temp": data[temp],
 
 
156
  "prefix_s": data[prefix_s],
157
  "suffix_s": data[suffix_s],
158
  "rand_mask_intensity": data[rand_mask_intensity],
@@ -163,6 +172,7 @@ def save_vamp(data):
163
  "n_conditioning_codebooks": data[n_conditioning_codebooks],
164
  "use_coarse2fine": data[use_coarse2fine],
165
  "stretch_factor": data[stretch_factor],
 
166
  }
167
 
168
  # save with yaml
@@ -385,16 +395,28 @@ with gr.Blocks() as demo:
385
  value=0.0
386
  )
387
 
388
- temp = gr.Slider(
389
- label="temperature",
390
  minimum=0.0,
391
  maximum=10.0,
392
- value=0.8
393
  )
394
-
 
 
 
 
 
 
395
 
396
 
397
  with gr.Accordion("sampling settings", open=False):
 
 
 
 
 
 
398
  typical_filtering = gr.Checkbox(
399
  label="typical filtering ",
400
  value=False
@@ -435,6 +457,18 @@ with gr.Blocks() as demo:
435
  value=0.0
436
  )
437
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
439
  # mask settings
440
  with gr.Column():
@@ -463,7 +497,9 @@ with gr.Blocks() as demo:
463
  _inputs = {
464
  input_audio,
465
  num_steps,
466
- temp,
 
 
467
  prefix_s, suffix_s,
468
  rand_mask_intensity,
469
  periodic_p, periodic_w,
@@ -476,7 +512,9 @@ with gr.Blocks() as demo:
476
  typical_mass,
477
  typical_min_tokens,
478
  beat_mask_width,
479
- beat_mask_downbeats
 
 
480
  }
481
 
482
  # connect widgets
 
97
  mask = pmask.codebook_unmask(mask, ncc)
98
 
99
 
100
+ print(data)
101
+ _top_p = data[top_p] if data[top_p] > 0 else None
102
  # save the mask as a txt file
103
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
104
 
105
+ _seed = data[seed] if data[seed] > 0 else None
106
  zv, mask_z = interface.coarse_vamp(
107
  z,
108
  mask=mask,
109
  sampling_steps=data[num_steps],
110
+ mask_temperature=data[masktemp]*10,
111
+ sampling_temperature=data[sampletemp],
112
  return_mask=True,
113
  typical_filtering=data[typical_filtering],
114
  typical_mass=data[typical_mass],
115
  typical_min_tokens=data[typical_min_tokens],
116
+ top_p=_top_p,
117
  gen_fn=interface.coarse.generate,
118
+ seed=_seed,
119
  )
120
 
121
  if use_coarse2fine:
122
  zv = interface.coarse_to_fine(
123
  zv,
124
+ mask_temperature=data[masktemp]*10,
125
+ sampling_temperature=data[sampletemp],
126
  mask=mask,
127
+ sampling_steps=data[num_steps],
128
+ seed=_seed,
129
  )
130
 
131
  sig = interface.to_signal(zv).cpu()
 
159
  sig_out.write(out_dir / "output.wav")
160
 
161
  _data = {
162
+ "masktemp": data[masktemp],
163
+ "sampletemp": data[sampletemp],
164
+ "top_p": data[top_p],
165
  "prefix_s": data[prefix_s],
166
  "suffix_s": data[suffix_s],
167
  "rand_mask_intensity": data[rand_mask_intensity],
 
172
  "n_conditioning_codebooks": data[n_conditioning_codebooks],
173
  "use_coarse2fine": data[use_coarse2fine],
174
  "stretch_factor": data[stretch_factor],
175
+ "seed": data[seed],
176
  }
177
 
178
  # save with yaml
 
395
  value=0.0
396
  )
397
 
398
+ masktemp = gr.Slider(
399
+ label="mask temperature",
400
  minimum=0.0,
401
  maximum=10.0,
402
+ value=1.5
403
  )
404
+ sampletemp = gr.Slider(
405
+ label="sample temperature",
406
+ minimum=0.1,
407
+ maximum=2.0,
408
+ value=1.0
409
+ )
410
+
411
 
412
 
413
  with gr.Accordion("sampling settings", open=False):
414
+ top_p = gr.Slider(
415
+ label="top p (0.0 = off)",
416
+ minimum=0.0,
417
+ maximum=1.0,
418
+ value=0.0
419
+ )
420
  typical_filtering = gr.Checkbox(
421
  label="typical filtering ",
422
  value=False
 
457
  value=0.0
458
  )
459
 
460
+ use_new_trick = gr.Checkbox(
461
+ label="new trick",
462
+ value=False
463
+ )
464
+
465
+ seed = gr.Number(
466
+ label="seed (0 for random)",
467
+ value=0,
468
+ precision=0,
469
+ )
470
+
471
+
472
 
473
  # mask settings
474
  with gr.Column():
 
497
  _inputs = {
498
  input_audio,
499
  num_steps,
500
+ masktemp,
501
+ sampletemp,
502
+ top_p,
503
  prefix_s, suffix_s,
504
  rand_mask_intensity,
505
  periodic_p, periodic_w,
 
512
  typical_mass,
513
  typical_min_tokens,
514
  beat_mask_width,
515
+ beat_mask_downbeats,
516
+ seed,
517
+ seed
518
  }
519
 
520
  # connect widgets
vampnet/modules/transformer.py CHANGED
@@ -367,6 +367,15 @@ class TransformerLayer(nn.Module):
367
 
368
  return x, position_bias, encoder_decoder_position_bias
369
 
 
 
 
 
 
 
 
 
 
370
 
371
  class TransformerStack(nn.Module):
372
  def __init__(
@@ -580,20 +589,20 @@ class VampNet(at.ml.BaseModel):
580
  time_steps: int = 300,
581
  sampling_steps: int = 24,
582
  start_tokens: Optional[torch.Tensor] = None,
 
583
  mask: Optional[torch.Tensor] = None,
584
- temperature: float = 2.5,
585
  typical_filtering=False,
586
  typical_mass=0.2,
587
  typical_min_tokens=1,
 
588
  return_signal=True,
 
589
  ):
 
 
590
  logging.debug(f"beginning generation with {sampling_steps} steps")
591
 
592
- #####################
593
- # resolve temperature #
594
- #####################
595
-
596
- logging.debug(f"temperature: {temperature}")
597
 
598
 
599
  #####################
@@ -641,13 +650,11 @@ class VampNet(at.ml.BaseModel):
641
  #################
642
  # begin sampling #
643
  #################
 
644
 
645
  for i in range(sampling_steps):
646
  logging.debug(f"step {i} of {sampling_steps}")
647
 
648
- # our current temperature
649
- logging.debug(f"temperature: {temperature}")
650
-
651
  # our current schedule step
652
  r = scalar_to_batch_tensor(
653
  (i + 1) / sampling_steps,
@@ -664,39 +671,19 @@ class VampNet(at.ml.BaseModel):
664
  # NOTE: this collapses the codebook dimension into the sequence dimension
665
  logits = self.forward(latents, r) # b, prob, seq
666
  logits = logits.permute(0, 2, 1) # b, seq, prob
667
- if typical_filtering:
668
- typical_filter(logits,
669
- typical_mass=typical_mass,
670
- typical_min_tokens=typical_min_tokens
671
- )
672
-
673
 
674
  logging.debug(f"permuted logits with shape: {logits.shape}")
675
 
 
 
 
 
 
 
676
 
677
- # logits2probs
678
- probs = torch.softmax(logits, dim=-1)
679
- logging.debug(f"computed probs with shape: {probs.shape}")
680
-
681
-
682
- # sample from logits with multinomial sampling
683
- b = probs.shape[0]
684
- probs = rearrange(probs, "b seq prob -> (b seq) prob")
685
-
686
- sampled_z = torch.multinomial(probs, 1).squeeze(-1)
687
-
688
- sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
689
- probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
690
  logging.debug(f"sampled z with shape: {sampled_z.shape}")
691
 
692
- # get the confidences: which tokens did we sample?
693
- selected_probs = (
694
- torch.take_along_dim(
695
- probs, sampled_z.long().unsqueeze(-1),
696
- dim=-1
697
- ).squeeze(-1)
698
- )
699
-
700
  # flatten z_masked and mask, so we can deal with the sampling logic
701
  # we'll unflatten them at the end of the loop for the next forward pass
702
  # remove conditioning codebooks, we'll add them back at the end
@@ -733,7 +720,7 @@ class VampNet(at.ml.BaseModel):
733
 
734
  # get our new mask
735
  mask = mask_by_random_topk(
736
- num_to_mask, selected_probs, temperature * (1-r)
737
  )
738
 
739
  # update the mask
@@ -766,6 +753,91 @@ class VampNet(at.ml.BaseModel):
766
  else:
767
  return sampled_z
768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769
 
770
  def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
771
  """
 
367
 
368
  return x, position_bias, encoder_decoder_position_bias
369
 
370
+ def t_schedule(n_steps, max_temp=1.0, min_temp=0.0, k=1.0):
371
+ x = np.linspace(0, 1, n_steps)
372
+ a = (0.5 - min_temp) / (max_temp - min_temp)
373
+
374
+ x = (x * 12) - 6
375
+ x0 = np.log((1 / a - 1) + 1e-5) / k
376
+ y = (1 / (1 + np.exp(- k *(x-x0))))[::-1]
377
+
378
+ return y
379
 
380
  class TransformerStack(nn.Module):
381
  def __init__(
 
589
  time_steps: int = 300,
590
  sampling_steps: int = 24,
591
  start_tokens: Optional[torch.Tensor] = None,
592
+ sampling_temperature: float = 1.0,
593
  mask: Optional[torch.Tensor] = None,
594
+ mask_temperature: float = 20.5,
595
  typical_filtering=False,
596
  typical_mass=0.2,
597
  typical_min_tokens=1,
598
+ top_p=None,
599
  return_signal=True,
600
+ seed: int = None
601
  ):
602
+ if seed is not None:
603
+ at.util.seed(seed)
604
  logging.debug(f"beginning generation with {sampling_steps} steps")
605
 
 
 
 
 
 
606
 
607
 
608
  #####################
 
650
  #################
651
  # begin sampling #
652
  #################
653
+ t_sched = t_schedule(sampling_steps, max_temp=sampling_temperature)
654
 
655
  for i in range(sampling_steps):
656
  logging.debug(f"step {i} of {sampling_steps}")
657
 
 
 
 
658
  # our current schedule step
659
  r = scalar_to_batch_tensor(
660
  (i + 1) / sampling_steps,
 
671
  # NOTE: this collapses the codebook dimension into the sequence dimension
672
  logits = self.forward(latents, r) # b, prob, seq
673
  logits = logits.permute(0, 2, 1) # b, seq, prob
674
+ b = logits.shape[0]
 
 
 
 
 
675
 
676
  logging.debug(f"permuted logits with shape: {logits.shape}")
677
 
678
+ sampled_z, selected_probs = sample_from_logits(
679
+ logits, sample=True, temperature=t_sched[i],
680
+ typical_filtering=typical_filtering, typical_mass=typical_mass,
681
+ typical_min_tokens=typical_min_tokens,
682
+ top_k=None, top_p=top_p, return_probs=True
683
+ )
684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
  logging.debug(f"sampled z with shape: {sampled_z.shape}")
686
 
 
 
 
 
 
 
 
 
687
  # flatten z_masked and mask, so we can deal with the sampling logic
688
  # we'll unflatten them at the end of the loop for the next forward pass
689
  # remove conditioning codebooks, we'll add them back at the end
 
720
 
721
  # get our new mask
722
  mask = mask_by_random_topk(
723
+ num_to_mask, selected_probs, mask_temperature * (1-r)
724
  )
725
 
726
  # update the mask
 
753
  else:
754
  return sampled_z
755
 
756
+ def sample_from_logits(
757
+ logits,
758
+ sample: bool = True,
759
+ temperature: float = 1.0,
760
+ top_k: int = None,
761
+ top_p: float = None,
762
+ typical_filtering: bool = False,
763
+ typical_mass: float = 0.2,
764
+ typical_min_tokens: int = 1,
765
+ return_probs: bool = False
766
+ ):
767
+ """Convenience function to sample from a categorial distribution with input as
768
+ unnormalized logits.
769
+
770
+ Parameters
771
+ ----------
772
+ logits : Tensor[..., vocab_size]
773
+ config: SamplingConfig
774
+ The set of hyperparameters to be used for sampling
775
+ sample : bool, optional
776
+ Whether to perform multinomial sampling, by default True
777
+ temperature : float, optional
778
+ Scaling parameter when multinomial samping, by default 1.0
779
+ top_k : int, optional
780
+ Restricts sampling to only `top_k` values acc. to probability,
781
+ by default None
782
+ top_p : float, optional
783
+ Restricts sampling to only those values with cumulative
784
+ probability = `top_p`, by default None
785
+
786
+ Returns
787
+ -------
788
+ Tensor[...]
789
+ Sampled tokens
790
+ """
791
+ shp = logits.shape[:-1]
792
+
793
+ if typical_filtering:
794
+ typical_filter(logits,
795
+ typical_mass=typical_mass,
796
+ typical_min_tokens=typical_min_tokens
797
+ )
798
+
799
+ # Apply top_k sampling
800
+ if top_k is not None:
801
+ v, _ = logits.topk(top_k)
802
+ logits[logits < v[..., [-1]]] = -float("inf")
803
+
804
+ # Apply top_p (nucleus) sampling
805
+ if top_p is not None and top_p < 1.0:
806
+ v, sorted_indices = logits.sort(descending=True)
807
+ cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
808
+
809
+ sorted_indices_to_remove = cumulative_probs > top_p
810
+ # Right shift indices_to_remove to keep 1st token over threshold
811
+ sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
812
+ ..., :-1
813
+ ]
814
+
815
+ # Compute indices_to_remove in unsorted array
816
+ indices_to_remove = sorted_indices_to_remove.scatter(
817
+ -1, sorted_indices, sorted_indices_to_remove
818
+ )
819
+
820
+ logits[indices_to_remove] = -float("inf")
821
+
822
+ # Perform multinomial sampling after normalizing logits
823
+ probs = (
824
+ F.softmax(logits / temperature, dim=-1)
825
+ if temperature > 0
826
+ else logits.softmax(dim=-1)
827
+ )
828
+ token = (
829
+ probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
830
+ if sample
831
+ else logits.argmax(-1)
832
+ )
833
+
834
+ if return_probs:
835
+ token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
836
+ return token, token_probs
837
+ else:
838
+ return token
839
+
840
+
841
 
842
  def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
843
  """