fix audiotools version + sampling trick

#7
by hugggof - opened
Files changed (5) hide show
  1. .gitignore +2 -0
  2. app.py +53 -14
  3. requirements.txt +1 -1
  4. scripts/exp/train.py +7 -5
  5. vampnet/modules/transformer.py +109 -37
.gitignore CHANGED
@@ -175,6 +175,7 @@ lyrebird-audio-codec
175
  samples-*/**
176
 
177
  gradio-outputs/
 
178
  samples*/
179
  models-all/
180
  models.zip
@@ -183,3 +184,4 @@ descript-audio-codec/
183
  # *.pth
184
  .git-old
185
  conf/generated/*
 
 
175
  samples-*/**
176
 
177
  gradio-outputs/
178
+ models/
179
  samples*/
180
  models-all/
181
  models.zip
 
184
  # *.pth
185
  .git-old
186
  conf/generated/*
187
+ runs*/
app.py CHANGED
@@ -107,24 +107,36 @@ def _vamp(data, return_mask=False):
107
  mask = pmask.codebook_unmask(mask, ncc)
108
 
109
 
110
- print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[temp]}, use coarse2fine {data[use_coarse2fine]}")
 
111
  # save the mask as a txt file
112
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
113
 
 
114
  zv, mask_z = interface.coarse_vamp(
115
  z,
116
  mask=mask,
117
  sampling_steps=data[num_steps],
118
- temperature=float(data[temp]*10),
 
119
  return_mask=True,
120
  typical_filtering=data[typical_filtering],
121
  typical_mass=data[typical_mass],
122
  typical_min_tokens=data[typical_min_tokens],
 
123
  gen_fn=interface.coarse.generate,
 
124
  )
125
 
126
  if use_coarse2fine:
127
- zv = interface.coarse_to_fine(zv, temperature=data[temp], mask=mask)
 
 
 
 
 
 
 
128
 
129
  sig = interface.to_signal(zv).cpu()
130
  print("done")
@@ -157,7 +169,9 @@ def save_vamp(data):
157
  sig_out.write(out_dir / "output.wav")
158
 
159
  _data = {
160
- "temp": data[temp],
 
 
161
  "prefix_s": data[prefix_s],
162
  "suffix_s": data[suffix_s],
163
  "rand_mask_intensity": data[rand_mask_intensity],
@@ -168,6 +182,7 @@ def save_vamp(data):
168
  "n_conditioning_codebooks": data[n_conditioning_codebooks],
169
  "use_coarse2fine": data[use_coarse2fine],
170
  "stretch_factor": data[stretch_factor],
 
171
  }
172
 
173
  # save with yaml
@@ -183,13 +198,14 @@ def save_vamp(data):
183
  return f"saved! your save code is {out_dir.stem}", zip_path
184
 
185
 
 
186
  with gr.Blocks() as demo:
187
 
188
  with gr.Row():
189
  with gr.Column():
190
- gr.Markdown("# VampNet")
191
  gr.Markdown("""## Description:
192
- This is a demo of VampNet, a masked generative music model capable of doing music variations.
193
  You can control the extent and nature of variation with a set of manual controls and presets.
194
  Use this interface to experiment with different mask settings and explore the audio outputs.
195
  """)
@@ -197,8 +213,8 @@ with gr.Blocks() as demo:
197
  gr.Markdown("""
198
  ## Instructions:
199
  1. You can start by uploading some audio, or by loading the example audio.
200
- 2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings. Click the load preset button.
201
- 3. Click the "generate (vamp)!!!" button to generate audio. Listen to the output audio, and the masked audio to hear the mask hints.
202
  4. Optionally, you can add some notes and save the result.
203
  5. You can also use the output as the new input and continue experimenting!
204
  """)
@@ -377,16 +393,28 @@ with gr.Blocks() as demo:
377
  value=0.0
378
  )
379
 
380
- temp = gr.Slider(
381
- label="temperature",
382
  minimum=0.0,
383
  maximum=10.0,
384
- value=1.8
385
  )
386
-
 
 
 
 
 
 
387
 
388
 
389
  with gr.Accordion("sampling settings", open=False):
 
 
 
 
 
 
390
  typical_filtering = gr.Checkbox(
391
  label="typical filtering ",
392
  value=False
@@ -428,6 +456,14 @@ with gr.Blocks() as demo:
428
  )
429
 
430
 
 
 
 
 
 
 
 
 
431
  # mask settings
432
  with gr.Column():
433
  vamp_button = gr.Button("generate (vamp)!!!")
@@ -455,7 +491,9 @@ with gr.Blocks() as demo:
455
  _inputs = {
456
  input_audio,
457
  num_steps,
458
- temp,
 
 
459
  prefix_s, suffix_s,
460
  rand_mask_intensity,
461
  periodic_p, periodic_w,
@@ -468,6 +506,7 @@ with gr.Blocks() as demo:
468
  typical_mass,
469
  typical_min_tokens,
470
  beat_mask_width,
 
471
  beat_mask_downbeats
472
  }
473
 
@@ -498,4 +537,4 @@ with gr.Blocks() as demo:
498
  outputs=[thank_you, download_file]
499
  )
500
 
501
- demo.queue().launch()
 
107
  mask = pmask.codebook_unmask(mask, ncc)
108
 
109
 
110
+ print(data)
111
+ _top_p = data[top_p] if data[top_p] > 0 else None
112
  # save the mask as a txt file
113
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
114
 
115
+ _seed = data[seed] if data[seed] > 0 else None
116
  zv, mask_z = interface.coarse_vamp(
117
  z,
118
  mask=mask,
119
  sampling_steps=data[num_steps],
120
+ mask_temperature=data[masktemp]*10,
121
+ sampling_temperature=data[sampletemp],
122
  return_mask=True,
123
  typical_filtering=data[typical_filtering],
124
  typical_mass=data[typical_mass],
125
  typical_min_tokens=data[typical_min_tokens],
126
+ top_p=_top_p,
127
  gen_fn=interface.coarse.generate,
128
+ seed=_seed,
129
  )
130
 
131
  if use_coarse2fine:
132
+ zv = interface.coarse_to_fine(
133
+ zv,
134
+ mask_temperature=data[masktemp]*10,
135
+ sampling_temperature=data[sampletemp],
136
+ mask=mask,
137
+ sampling_steps=data[num_steps],
138
+ seed=_seed,
139
+ )
140
 
141
  sig = interface.to_signal(zv).cpu()
142
  print("done")
 
169
  sig_out.write(out_dir / "output.wav")
170
 
171
  _data = {
172
+ "masktemp": data[masktemp],
173
+ "sampletemp": data[sampletemp],
174
+ "top_p": data[top_p],
175
  "prefix_s": data[prefix_s],
176
  "suffix_s": data[suffix_s],
177
  "rand_mask_intensity": data[rand_mask_intensity],
 
182
  "n_conditioning_codebooks": data[n_conditioning_codebooks],
183
  "use_coarse2fine": data[use_coarse2fine],
184
  "stretch_factor": data[stretch_factor],
185
+ "seed": data[seed],
186
  }
187
 
188
  # save with yaml
 
198
  return f"saved! your save code is {out_dir.stem}", zip_path
199
 
200
 
201
+
202
  with gr.Blocks() as demo:
203
 
204
  with gr.Row():
205
  with gr.Column():
206
+ gr.Markdown("# VampNet Audio Vamping")
207
  gr.Markdown("""## Description:
208
+ This is a demo of the VampNet, a generative audio model that transforms the input audio based on the chosen settings.
209
  You can control the extent and nature of variation with a set of manual controls and presets.
210
  Use this interface to experiment with different mask settings and explore the audio outputs.
211
  """)
 
213
  gr.Markdown("""
214
  ## Instructions:
215
  1. You can start by uploading some audio, or by loading the example audio.
216
+ 2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
217
+ 3. Click the "generate (vamp)!!!" button to apply the vamp operation. Listen to the output audio.
218
  4. Optionally, you can add some notes and save the result.
219
  5. You can also use the output as the new input and continue experimenting!
220
  """)
 
393
  value=0.0
394
  )
395
 
396
+ masktemp = gr.Slider(
397
+ label="mask temperature",
398
  minimum=0.0,
399
  maximum=10.0,
400
+ value=1.5
401
  )
402
+ sampletemp = gr.Slider(
403
+ label="sample temperature",
404
+ minimum=0.1,
405
+ maximum=2.0,
406
+ value=1.0
407
+ )
408
+
409
 
410
 
411
  with gr.Accordion("sampling settings", open=False):
412
+ top_p = gr.Slider(
413
+ label="top p (0.0 = off)",
414
+ minimum=0.0,
415
+ maximum=1.0,
416
+ value=0.0
417
+ )
418
  typical_filtering = gr.Checkbox(
419
  label="typical filtering ",
420
  value=False
 
456
  )
457
 
458
 
459
+ seed = gr.Number(
460
+ label="seed (0 for random)",
461
+ value=0,
462
+ precision=0,
463
+ )
464
+
465
+
466
+
467
  # mask settings
468
  with gr.Column():
469
  vamp_button = gr.Button("generate (vamp)!!!")
 
491
  _inputs = {
492
  input_audio,
493
  num_steps,
494
+ masktemp,
495
+ sampletemp,
496
+ top_p,
497
  prefix_s, suffix_s,
498
  rand_mask_intensity,
499
  periodic_p, periodic_w,
 
506
  typical_mass,
507
  typical_min_tokens,
508
  beat_mask_width,
509
+ seed,
510
  beat_mask_downbeats
511
  }
512
 
 
537
  outputs=[thank_you, download_file]
538
  )
539
 
540
+ demo.launch()
requirements.txt CHANGED
@@ -5,4 +5,4 @@ gradio
5
  loralib
6
  wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
7
  lac @ git+https://github.com/hugofloresgarcia/lac.git
8
- audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
 
5
  loralib
6
  wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
7
  lac @ git+https://github.com/hugofloresgarcia/lac.git
8
+ descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2
scripts/exp/train.py CHANGED
@@ -485,7 +485,6 @@ def load(
485
  save_path: str,
486
  resume: bool = False,
487
  tag: str = "latest",
488
- load_weights: bool = False,
489
  fine_tune_checkpoint: Optional[str] = None,
490
  grad_clip_val: float = 5.0,
491
  ) -> State:
@@ -498,7 +497,7 @@ def load(
498
  kwargs = {
499
  "folder": f"{save_path}/{tag}",
500
  "map_location": "cpu",
501
- "package": not load_weights,
502
  }
503
  tracker.print(f"Loading checkpoint from {kwargs['folder']}")
504
  if (Path(kwargs["folder"]) / "vampnet").exists():
@@ -511,11 +510,14 @@ def load(
511
 
512
  if args["fine_tune"]:
513
  assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
514
- model = VampNet.load(location=Path(fine_tune_checkpoint), map_location="cpu")
515
-
 
 
 
516
 
517
- model = VampNet() if model is None else model
518
 
 
519
  model = accel.prepare_model(model)
520
 
521
  # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
 
485
  save_path: str,
486
  resume: bool = False,
487
  tag: str = "latest",
 
488
  fine_tune_checkpoint: Optional[str] = None,
489
  grad_clip_val: float = 5.0,
490
  ) -> State:
 
497
  kwargs = {
498
  "folder": f"{save_path}/{tag}",
499
  "map_location": "cpu",
500
+ "package": False,
501
  }
502
  tracker.print(f"Loading checkpoint from {kwargs['folder']}")
503
  if (Path(kwargs["folder"]) / "vampnet").exists():
 
510
 
511
  if args["fine_tune"]:
512
  assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
513
+ model = torch.compile(
514
+ VampNet.load(location=Path(fine_tune_checkpoint),
515
+ map_location="cpu",
516
+ )
517
+ )
518
 
 
519
 
520
+ model = torch.compile(VampNet()) if model is None else model
521
  model = accel.prepare_model(model)
522
 
523
  # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
vampnet/modules/transformer.py CHANGED
@@ -367,6 +367,15 @@ class TransformerLayer(nn.Module):
367
 
368
  return x, position_bias, encoder_decoder_position_bias
369
 
 
 
 
 
 
 
 
 
 
370
 
371
  class TransformerStack(nn.Module):
372
  def __init__(
@@ -580,20 +589,20 @@ class VampNet(at.ml.BaseModel):
580
  time_steps: int = 300,
581
  sampling_steps: int = 24,
582
  start_tokens: Optional[torch.Tensor] = None,
 
583
  mask: Optional[torch.Tensor] = None,
584
- temperature: float = 2.5,
585
  typical_filtering=False,
586
  typical_mass=0.2,
587
  typical_min_tokens=1,
 
588
  return_signal=True,
 
589
  ):
 
 
590
  logging.debug(f"beginning generation with {sampling_steps} steps")
591
 
592
- #####################
593
- # resolve temperature #
594
- #####################
595
-
596
- logging.debug(f"temperature: {temperature}")
597
 
598
 
599
  #####################
@@ -641,13 +650,11 @@ class VampNet(at.ml.BaseModel):
641
  #################
642
  # begin sampling #
643
  #################
 
644
 
645
  for i in range(sampling_steps):
646
  logging.debug(f"step {i} of {sampling_steps}")
647
 
648
- # our current temperature
649
- logging.debug(f"temperature: {temperature}")
650
-
651
  # our current schedule step
652
  r = scalar_to_batch_tensor(
653
  (i + 1) / sampling_steps,
@@ -664,39 +671,19 @@ class VampNet(at.ml.BaseModel):
664
  # NOTE: this collapses the codebook dimension into the sequence dimension
665
  logits = self.forward(latents, r) # b, prob, seq
666
  logits = logits.permute(0, 2, 1) # b, seq, prob
667
- if typical_filtering:
668
- typical_filter(logits,
669
- typical_mass=typical_mass,
670
- typical_min_tokens=typical_min_tokens
671
- )
672
-
673
 
674
  logging.debug(f"permuted logits with shape: {logits.shape}")
675
 
 
 
 
 
 
 
676
 
677
- # logits2probs
678
- probs = torch.softmax(logits, dim=-1)
679
- logging.debug(f"computed probs with shape: {probs.shape}")
680
-
681
-
682
- # sample from logits with multinomial sampling
683
- b = probs.shape[0]
684
- probs = rearrange(probs, "b seq prob -> (b seq) prob")
685
-
686
- sampled_z = torch.multinomial(probs, 1).squeeze(-1)
687
-
688
- sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
689
- probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
690
  logging.debug(f"sampled z with shape: {sampled_z.shape}")
691
 
692
- # get the confidences: which tokens did we sample?
693
- selected_probs = (
694
- torch.take_along_dim(
695
- probs, sampled_z.long().unsqueeze(-1),
696
- dim=-1
697
- ).squeeze(-1)
698
- )
699
-
700
  # flatten z_masked and mask, so we can deal with the sampling logic
701
  # we'll unflatten them at the end of the loop for the next forward pass
702
  # remove conditioning codebooks, we'll add them back at the end
@@ -733,7 +720,7 @@ class VampNet(at.ml.BaseModel):
733
 
734
  # get our new mask
735
  mask = mask_by_random_topk(
736
- num_to_mask, selected_probs, temperature * (1-r)
737
  )
738
 
739
  # update the mask
@@ -766,6 +753,91 @@ class VampNet(at.ml.BaseModel):
766
  else:
767
  return sampled_z
768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769
 
770
  def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
771
  """
 
367
 
368
  return x, position_bias, encoder_decoder_position_bias
369
 
370
+ def t_schedule(n_steps, max_temp=1.0, min_temp=0.0, k=1.0):
371
+ x = np.linspace(0, 1, n_steps)
372
+ a = (0.5 - min_temp) / (max_temp - min_temp)
373
+
374
+ x = (x * 12) - 6
375
+ x0 = np.log((1 / a - 1) + 1e-5) / k
376
+ y = (1 / (1 + np.exp(- k *(x-x0))))[::-1]
377
+
378
+ return y
379
 
380
  class TransformerStack(nn.Module):
381
  def __init__(
 
589
  time_steps: int = 300,
590
  sampling_steps: int = 24,
591
  start_tokens: Optional[torch.Tensor] = None,
592
+ sampling_temperature: float = 1.0,
593
  mask: Optional[torch.Tensor] = None,
594
+ mask_temperature: float = 20.5,
595
  typical_filtering=False,
596
  typical_mass=0.2,
597
  typical_min_tokens=1,
598
+ top_p=None,
599
  return_signal=True,
600
+ seed: int = None
601
  ):
602
+ if seed is not None:
603
+ at.util.seed(seed)
604
  logging.debug(f"beginning generation with {sampling_steps} steps")
605
 
 
 
 
 
 
606
 
607
 
608
  #####################
 
650
  #################
651
  # begin sampling #
652
  #################
653
+ t_sched = t_schedule(sampling_steps, max_temp=sampling_temperature)
654
 
655
  for i in range(sampling_steps):
656
  logging.debug(f"step {i} of {sampling_steps}")
657
 
 
 
 
658
  # our current schedule step
659
  r = scalar_to_batch_tensor(
660
  (i + 1) / sampling_steps,
 
671
  # NOTE: this collapses the codebook dimension into the sequence dimension
672
  logits = self.forward(latents, r) # b, prob, seq
673
  logits = logits.permute(0, 2, 1) # b, seq, prob
674
+ b = logits.shape[0]
 
 
 
 
 
675
 
676
  logging.debug(f"permuted logits with shape: {logits.shape}")
677
 
678
+ sampled_z, selected_probs = sample_from_logits(
679
+ logits, sample=True, temperature=t_sched[i],
680
+ typical_filtering=typical_filtering, typical_mass=typical_mass,
681
+ typical_min_tokens=typical_min_tokens,
682
+ top_k=None, top_p=top_p, return_probs=True
683
+ )
684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
  logging.debug(f"sampled z with shape: {sampled_z.shape}")
686
 
 
 
 
 
 
 
 
 
687
  # flatten z_masked and mask, so we can deal with the sampling logic
688
  # we'll unflatten them at the end of the loop for the next forward pass
689
  # remove conditioning codebooks, we'll add them back at the end
 
720
 
721
  # get our new mask
722
  mask = mask_by_random_topk(
723
+ num_to_mask, selected_probs, mask_temperature * (1-r)
724
  )
725
 
726
  # update the mask
 
753
  else:
754
  return sampled_z
755
 
756
+ def sample_from_logits(
757
+ logits,
758
+ sample: bool = True,
759
+ temperature: float = 1.0,
760
+ top_k: int = None,
761
+ top_p: float = None,
762
+ typical_filtering: bool = False,
763
+ typical_mass: float = 0.2,
764
+ typical_min_tokens: int = 1,
765
+ return_probs: bool = False
766
+ ):
767
+ """Convenience function to sample from a categorial distribution with input as
768
+ unnormalized logits.
769
+
770
+ Parameters
771
+ ----------
772
+ logits : Tensor[..., vocab_size]
773
+ config: SamplingConfig
774
+ The set of hyperparameters to be used for sampling
775
+ sample : bool, optional
776
+ Whether to perform multinomial sampling, by default True
777
+ temperature : float, optional
778
+ Scaling parameter when multinomial samping, by default 1.0
779
+ top_k : int, optional
780
+ Restricts sampling to only `top_k` values acc. to probability,
781
+ by default None
782
+ top_p : float, optional
783
+ Restricts sampling to only those values with cumulative
784
+ probability = `top_p`, by default None
785
+
786
+ Returns
787
+ -------
788
+ Tensor[...]
789
+ Sampled tokens
790
+ """
791
+ shp = logits.shape[:-1]
792
+
793
+ if typical_filtering:
794
+ typical_filter(logits,
795
+ typical_mass=typical_mass,
796
+ typical_min_tokens=typical_min_tokens
797
+ )
798
+
799
+ # Apply top_k sampling
800
+ if top_k is not None:
801
+ v, _ = logits.topk(top_k)
802
+ logits[logits < v[..., [-1]]] = -float("inf")
803
+
804
+ # Apply top_p (nucleus) sampling
805
+ if top_p is not None and top_p < 1.0:
806
+ v, sorted_indices = logits.sort(descending=True)
807
+ cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
808
+
809
+ sorted_indices_to_remove = cumulative_probs > top_p
810
+ # Right shift indices_to_remove to keep 1st token over threshold
811
+ sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
812
+ ..., :-1
813
+ ]
814
+
815
+ # Compute indices_to_remove in unsorted array
816
+ indices_to_remove = sorted_indices_to_remove.scatter(
817
+ -1, sorted_indices, sorted_indices_to_remove
818
+ )
819
+
820
+ logits[indices_to_remove] = -float("inf")
821
+
822
+ # Perform multinomial sampling after normalizing logits
823
+ probs = (
824
+ F.softmax(logits / temperature, dim=-1)
825
+ if temperature > 0
826
+ else logits.softmax(dim=-1)
827
+ )
828
+ token = (
829
+ probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
830
+ if sample
831
+ else logits.argmax(-1)
832
+ )
833
+
834
+ if return_probs:
835
+ token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
836
+ return token, token_probs
837
+ else:
838
+ return token
839
+
840
+
841
 
842
  def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
843
  """