Hugo Flores Garcia commited on
Commit
c6f0e5a
1 Parent(s): a022742
app.py CHANGED
@@ -7,13 +7,15 @@ import audiotools as at
7
  import argbind
8
  import shutil
9
  import torch
 
10
 
11
  import gradio as gr
12
- from vampnet.interface import Interface
13
  from vampnet import mask as pmask
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
 
17
  interface = Interface(
18
  device=device,
19
  coarse_ckpt="models/vampnet/coarse.pth",
@@ -78,11 +80,20 @@ def shift_pitch(signal, interval: int):
78
  )
79
  return signal
80
 
81
- def _vamp(data, return_mask=False):
 
 
 
 
 
 
 
82
 
83
- out_dir = OUT_DIR / str(uuid.uuid4())
 
84
  out_dir.mkdir(parents=True)
85
  sig = at.AudioSignal(data[input_audio])
 
86
 
87
  # reload the model if necessary
88
  interface.reload(
@@ -96,18 +107,21 @@ def _vamp(data, return_mask=False):
96
  if data[pitch_shift_amt] != 0:
97
  sig = shift_pitch(sig, data[pitch_shift_amt])
98
 
 
 
99
  build_mask_kwargs = dict(
100
  rand_mask_intensity=data[rand_mask_intensity],
101
  prefix_s=data[prefix_s],
102
  suffix_s=data[suffix_s],
103
  periodic_prompt=data[periodic_p],
 
104
  periodic_prompt_width=data[periodic_w],
105
  onset_mask_width=data[onset_mask_width],
106
  _dropout=data[dropout],
107
- upper_codebook_mask=int(data[n_mask_codebooks])
 
108
  )
109
 
110
- _seed = data[seed] if data[seed] > 0 else None
111
  vamp_kwargs = dict(
112
  # _sampling_steps=[data[num_steps], 8, 8, 4, 4, 2, 2, 1, 1],
113
  mask_temperature=data[masktemp]*10,
@@ -121,31 +135,81 @@ def _vamp(data, return_mask=False):
121
  )
122
 
123
  # save the mask as a txt file
124
- sig, mask = interface.ez_vamp(
 
125
  sig,
126
- batch_size=1,
127
  feedback_steps=data[num_feedback_steps],
128
  time_stretch_factor=data[stretch_factor],
129
  build_mask_kwargs=build_mask_kwargs,
130
  vamp_kwargs=vamp_kwargs,
131
- return_mask=return_mask,
132
  )
133
 
134
- sig.write(out_dir / "output.wav")
 
135
 
136
- if return_mask:
137
- mask = interface.to_signal(mask.cuda()).cpu()
138
- mask.write(out_dir / "mask.wav")
139
- return sig.path_to_file, mask.path_to_file
140
- else:
141
  return sig.path_to_file
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  def vamp(data):
145
- return _vamp(data, return_mask=True)
146
 
147
  def api_vamp(data):
148
- return _vamp(data, return_mask=False)
149
 
150
  with gr.Blocks() as demo:
151
  with gr.Row():
@@ -193,7 +257,13 @@ with gr.Blocks() as demo:
193
  step=1,
194
  value=3,
195
  )
196
-
 
 
 
 
 
 
197
 
198
  onset_mask_width = gr.Slider(
199
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
@@ -204,9 +274,13 @@ with gr.Blocks() as demo:
204
  )
205
 
206
  n_mask_codebooks = gr.Number(
207
- label="first upper codebook level to mask",
208
  value=3,
209
  )
 
 
 
 
210
 
211
  with gr.Accordion("extras ", open=False):
212
  pitch_shift_amt = gr.Slider(
@@ -337,17 +411,45 @@ with gr.Blocks() as demo:
337
  value=1
338
  )
339
 
 
 
 
 
 
 
 
340
  vamp_button = gr.Button("generate (vamp)!!!")
341
- output_audio = gr.Audio(
342
- label="output audio",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  interactive=False,
344
  type="filepath"
345
  )
346
-
347
- use_as_input_button = gr.Button("use output as input")
348
 
349
  thank_you = gr.Markdown("")
350
 
 
 
 
351
 
352
  _inputs = {
353
  input_audio,
@@ -368,29 +470,31 @@ with gr.Blocks() as demo:
368
  n_mask_codebooks,
369
  pitch_shift_amt,
370
  sample_cutoff,
371
- num_feedback_steps
 
 
 
372
  }
373
 
374
  # connect widgets
375
  vamp_button.click(
376
  fn=vamp,
377
  inputs=_inputs,
378
- outputs=[output_audio, audio_mask],
379
  )
380
 
381
  api_vamp_button = gr.Button("api vamp", visible=False)
382
  api_vamp_button.click(
383
  fn=api_vamp,
384
  inputs=_inputs,
385
- outputs=[output_audio],
386
  api_name="vamp"
387
  )
388
 
389
- use_as_input_button.click(
390
- fn=lambda x: x,
391
- inputs=[output_audio],
392
- outputs=[input_audio]
393
- )
394
 
395
- demo.queue()
396
- demo.launch(share=True, debug=True)
 
 
 
 
 
7
  import argbind
8
  import shutil
9
  import torch
10
+ from datetime import datetime
11
 
12
  import gradio as gr
13
+ from vampnet.interface import Interface, signal_concat
14
  from vampnet import mask as pmask
15
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+
19
  interface = Interface(
20
  device=device,
21
  coarse_ckpt="models/vampnet/coarse.pth",
 
80
  )
81
  return signal
82
 
83
+ def _vamp(data, api: bool=False,# bypasses mask audio generation and other things to be faster
84
+ ):
85
+
86
+ _seed = data[seed] if data[seed] > 0 else None
87
+ if _seed is None:
88
+ # create a random seed
89
+ _seed = int(torch.randint(0, 2**32, (1,)).item())
90
+ at.util.seed(_seed)
91
 
92
+ datentime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
93
+ out_dir = OUT_DIR / f"{Path(data[input_audio]).stem}-{datentime}-seed-{_seed}-model-{data[model_choice]}"
94
  out_dir.mkdir(parents=True)
95
  sig = at.AudioSignal(data[input_audio])
96
+ sig.write(out_dir / "input.wav")
97
 
98
  # reload the model if necessary
99
  interface.reload(
 
107
  if data[pitch_shift_amt] != 0:
108
  sig = shift_pitch(sig, data[pitch_shift_amt])
109
 
110
+ _p2 = data[periodic_p] if data[p2] == 0 else data[p2]
111
+ _n_codebooks_2 = data[n_mask_codebooks] if data[n_mask_codebooks_2] == 0 else data[n_mask_codebooks_2]
112
  build_mask_kwargs = dict(
113
  rand_mask_intensity=data[rand_mask_intensity],
114
  prefix_s=data[prefix_s],
115
  suffix_s=data[suffix_s],
116
  periodic_prompt=data[periodic_p],
117
+ periodic_prompt2=_p2,
118
  periodic_prompt_width=data[periodic_w],
119
  onset_mask_width=data[onset_mask_width],
120
  _dropout=data[dropout],
121
+ upper_codebook_mask=int(data[n_mask_codebooks]),
122
+ upper_codebook_mask_2=int(_n_codebooks_2),
123
  )
124
 
 
125
  vamp_kwargs = dict(
126
  # _sampling_steps=[data[num_steps], 8, 8, 4, 4, 2, 2, 1, 1],
127
  mask_temperature=data[masktemp]*10,
 
135
  )
136
 
137
  # save the mask as a txt file
138
+ interface.set_chunk_size(data[win_dur])
139
+ sig, mask, codes = interface.ez_vamp(
140
  sig,
141
+ batch_size=4 if not api else 1,
142
  feedback_steps=data[num_feedback_steps],
143
  time_stretch_factor=data[stretch_factor],
144
  build_mask_kwargs=build_mask_kwargs,
145
  vamp_kwargs=vamp_kwargs,
146
+ return_mask=True,
147
  )
148
 
149
+ if api:
150
+ sig.write(out_dir / "out.wav")
151
 
 
 
 
 
 
152
  return sig.path_to_file
153
 
154
+ if not api:
155
+ # write codes to numpy file
156
+ np.save(out_dir / "codes.npy", codes.cpu().numpy())
157
+ metadata = {}
158
+ metadata["seed"] = _seed
159
+ metadata["model_choice"] = data[model_choice]
160
+ metadata["mask_kwargs"] = build_mask_kwargs
161
+ metadata["vamp_kwargs"] = vamp_kwargs
162
+ metadata["loudness"] = loudness
163
+ # save the metadata
164
+ with open(out_dir / "metadata.yml", "w") as f:
165
+ yaml.dump(metadata, f)
166
+
167
+ sig0 = sig[0].write(out_dir / "out1.wav")
168
+ sig1 = sig[1].write(out_dir / "out2.wav")
169
+ sig2 = sig[2].write(out_dir / "out3.wav")
170
+ sig3 = sig[3].write(out_dir / "out4.wav")
171
+
172
+ # write the mask to txt
173
+ with open(out_dir / "mask.txt", "w") as f:
174
+ m = mask[0].cpu().numpy()
175
+ # write to txt, each time step on a new line
176
+ for i in range(m.shape[-1]):
177
+ f.write(f"{m[:, i]}\n")
178
+
179
+
180
+ import matplotlib.pyplot as plt
181
+ plt.clf()
182
+ interface.visualize_codes(mask)
183
+ plt.savefig(out_dir / "mask.png")
184
+ plt.clf()
185
+ interface.visualize_codes(codes)
186
+ plt.savefig(out_dir / "codes.png")
187
+ plt.close()
188
+
189
+ # zip out dir, and return the path to the zip
190
+ shutil.make_archive(out_dir, 'zip', out_dir)
191
+
192
+ # chunk in groups of 1024 timesteps
193
+ _mask_sigs = []
194
+ for i in range(0, mask.shape[-1], 1024):
195
+ _mask_sigs.append(interface.to_signal(mask[:, :, i:i+1024].to(interface.device)).cpu())
196
+ mask = signal_concat(_mask_sigs)
197
+ mask.write(out_dir / "mask.wav")
198
+
199
+
200
+
201
+
202
+ return (
203
+ sig0.path_to_file, sig1.path_to_file,
204
+ sig2.path_to_file, sig3.path_to_file,
205
+ mask.path_to_file, str(out_dir.with_suffix(".zip")), out_dir / "mask.png"
206
+ )
207
 
208
  def vamp(data):
209
+ return _vamp(data, api=False)
210
 
211
  def api_vamp(data):
212
+ return _vamp(data, api=True)
213
 
214
  with gr.Blocks() as demo:
215
  with gr.Row():
 
257
  step=1,
258
  value=3,
259
  )
260
+ p2 = gr.Slider(
261
+ label="periodic prompt 2 (0 - same as p1, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
262
+ minimum=0,
263
+ maximum=128,
264
+ step=1,
265
+ value=0,
266
+ )
267
 
268
  onset_mask_width = gr.Slider(
269
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
 
274
  )
275
 
276
  n_mask_codebooks = gr.Number(
277
+ label="compression prompt (masks entire upper codebook levels above the specified level)",
278
  value=3,
279
  )
280
+ n_mask_codebooks_2 = gr.Number(
281
+ label="compression prompt 2 via linear interpolation (0 == constant)",
282
+ value=0,
283
+ )
284
 
285
  with gr.Accordion("extras ", open=False):
286
  pitch_shift_amt = gr.Slider(
 
411
  value=1
412
  )
413
 
414
+ win_dur= gr.Slider(
415
+ label="window duration (seconds)",
416
+ minimum=2,
417
+ maximum=10,
418
+ value=6)
419
+
420
+
421
  vamp_button = gr.Button("generate (vamp)!!!")
422
+ maskimg = gr.Image(
423
+ label="mask image",
424
+ interactive=False,
425
+ type="filepath"
426
+ )
427
+ out1 = gr.Audio(
428
+ label="output audio 1",
429
+ interactive=False,
430
+ type="filepath"
431
+ )
432
+ out2 = gr.Audio(
433
+ label="output audio 2",
434
+ interactive=False,
435
+ type="filepath"
436
+ )
437
+ out3 = gr.Audio(
438
+ label="output audio 3",
439
+ interactive=False,
440
+ type="filepath"
441
+ )
442
+ out4 = gr.Audio(
443
+ label="output audio 4",
444
  interactive=False,
445
  type="filepath"
446
  )
 
 
447
 
448
  thank_you = gr.Markdown("")
449
 
450
+ # download all the outputs
451
+ download = gr.File(type="filepath", label="download outputs")
452
+
453
 
454
  _inputs = {
455
  input_audio,
 
470
  n_mask_codebooks,
471
  pitch_shift_amt,
472
  sample_cutoff,
473
+ num_feedback_steps,
474
+ p2,
475
+ n_mask_codebooks_2,
476
+ win_dur
477
  }
478
 
479
  # connect widgets
480
  vamp_button.click(
481
  fn=vamp,
482
  inputs=_inputs,
483
+ outputs=[out1, out2, out3, out4, audio_mask, download, maskimg],
484
  )
485
 
486
  api_vamp_button = gr.Button("api vamp", visible=False)
487
  api_vamp_button.click(
488
  fn=api_vamp,
489
  inputs=_inputs,
490
+ outputs=[out1],
491
  api_name="vamp"
492
  )
493
 
 
 
 
 
 
494
 
495
+ try:
496
+ demo.queue()
497
+ demo.launch(share=True, debug=True)
498
+ except KeyboardInterrupt:
499
+ shutil.rmtree("gradio-outputs", ignore_errors=True)
500
+ raise
conf/generated/church-bells/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/vampnet/c2f.pth
12
+ save_path: ./runs/church-bells/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - data/church-bells
15
+ val/AudioLoader.sources: *id001
conf/generated/church-bells/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/vampnet/coarse.pth
5
+ save_path: ./runs/church-bells/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - data/church-bells
8
+ val/AudioLoader.sources: *id001
conf/generated/church-bells/interface.yml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - data/church-bells
3
+ Interface.coarse2fine_ckpt: ./runs/church-bells/c2f/latest/vampnet/weights.pth
4
+ Interface.coarse_ckpt: ./runs/church-bells/coarse/latest/vampnet/weights.pth
5
+ Interface.codec_ckpt: ./models/vampnet/codec.pth
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
conf/generated/copepod/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/vampnet/c2f.pth
12
+ save_path: ./runs/copepod/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - data/copepod
15
+ val/AudioLoader.sources: *id001
conf/generated/copepod/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/vampnet/coarse.pth
5
+ save_path: ./runs/copepod/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - data/copepod
8
+ val/AudioLoader.sources: *id001
conf/generated/copepod/interface.yml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - data/copepod
3
+ Interface.coarse2fine_ckpt: ./runs/copepod/c2f/latest/vampnet/weights.pth
4
+ Interface.coarse_ckpt: ./runs/copepod/coarse/latest/vampnet/weights.pth
5
+ Interface.codec_ckpt: ./models/vampnet/codec.pth
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
conf/generated/growl/c2f.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/vampnet/c2f.pth
12
+ save_path: ./runs/growl/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - data/growly
15
+ - animals/
16
+ val/AudioLoader.sources: *id001
conf/generated/growl/coarse.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/vampnet/coarse.pth
5
+ save_path: ./runs/growl/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - data/growly
8
+ - animals/
9
+ val/AudioLoader.sources: *id001
conf/generated/growl/interface.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - data/growly
3
+ - animals/
4
+ Interface.coarse2fine_ckpt: ./runs/growl/c2f/latest/vampnet/weights.pth
5
+ Interface.coarse_ckpt: ./runs/growl/coarse/latest/vampnet/weights.pth
6
+ Interface.codec_ckpt: ./models/vampnet/codec.pth
7
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
conf/generated/sample-instrument/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/vampnet/c2f.pth
12
+ save_path: ./runs/sample-instrument/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - data/sample-instrument/
15
+ val/AudioLoader.sources: *id001
conf/generated/sample-instrument/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/vampnet/coarse.pth
5
+ save_path: ./runs/sample-instrument/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - data/sample-instrument/
8
+ val/AudioLoader.sources: *id001
conf/generated/sample-instrument/interface.yml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - data/sample-instrument/
3
+ Interface.coarse2fine_ckpt: ./runs/sample-instrument/c2f/latest/vampnet/weights.pth
4
+ Interface.coarse_ckpt: ./runs/sample-instrument/coarse/latest/vampnet/weights.pth
5
+ Interface.codec_ckpt: ./models/vampnet/codec.pth
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
models/models/vampnet/c2f.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b10ea2d45459d34edb773cbacd71f40f7baa1f4e75ac8bcd93b022ac69f8fa63
3
+ size 1101898865
models/models/vampnet/coarse.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78e4ad4f8398e8ec3651bc5e5c6ea2995e1080b6226be186723ccf4320c9756c
3
+ size 1332182321
models/models/vampnet/codec.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3db3fa43ab5d160439ddb81fc540b5573ad5ae962230de3fc5b47d218845b855
3
+ size 600996465
models/models/wavebeat.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ff1066a4470cb98b20edf1e489f6995b19e0435b9cfd5a190bf90a954d0cadb
3
+ size 33248861
scripts/utils/split.py CHANGED
@@ -16,12 +16,11 @@ def train_test_split(
16
  audio_folder: str = ".",
17
  test_size: float = 0.2,
18
  seed: int = 42,
19
- pattern: str = "**/*.mp3",
20
  ):
21
  print(f"finding audio")
22
 
23
  audio_folder = Path(audio_folder)
24
- audio_files = list(tqdm(audio_folder.glob(pattern)))
25
  print(f"found {len(audio_files)} audio files")
26
 
27
  # split according to test_size
@@ -49,7 +48,10 @@ def train_test_split(
49
  for file in tqdm(files):
50
  out_file = audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name
51
  out_file.parent.mkdir(exist_ok=True, parents=True)
52
- os.symlink(file, out_file)
 
 
 
53
 
54
  # save split as json
55
  with open(Path(audio_folder) / f"{split}.json", "w") as f:
 
16
  audio_folder: str = ".",
17
  test_size: float = 0.2,
18
  seed: int = 42,
 
19
  ):
20
  print(f"finding audio")
21
 
22
  audio_folder = Path(audio_folder)
23
+ audio_files = util.find_audio(audio_folder)
24
  print(f"found {len(audio_files)} audio files")
25
 
26
  # split according to test_size
 
48
  for file in tqdm(files):
49
  out_file = audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name
50
  out_file.parent.mkdir(exist_ok=True, parents=True)
51
+ try:
52
+ os.symlink(file, out_file)
53
+ except FileExistsError:
54
+ print(f"File {out_file} already exists, skipping")
55
 
56
  # save split as json
57
  with open(Path(audio_folder) / f"{split}.json", "w") as f:
vampnet/interface.py CHANGED
@@ -273,11 +273,15 @@ class Interface(torch.nn.Module):
273
  else:
274
  mask = mask.repeat(1, self.coarse.n_codebooks, 1)
275
  return mask
 
 
 
276
 
277
  def coarse_to_fine(
278
  self,
279
  z: torch.Tensor,
280
  mask: torch.Tensor = None,
 
281
  **kwargs
282
  ):
283
  assert self.c2f is not None, "No coarse2fine model loaded"
@@ -320,6 +324,9 @@ class Interface(torch.nn.Module):
320
  fine_z.append(chunk)
321
 
322
  fine_z = torch.cat(fine_z, dim=-1)
 
 
 
323
  return fine_z[:, :, :length].clone()
324
 
325
  def coarse_vamp(
@@ -397,10 +404,12 @@ class Interface(torch.nn.Module):
397
  prefix_s: float = 0.0,
398
  suffix_s: float = 0.0,
399
  periodic_prompt: int = 7,
 
400
  periodic_prompt_width: int = 1,
401
  onset_mask_width: int = 0,
402
  _dropout: float = 0.0,
403
  upper_codebook_mask: int = 3,
 
404
  ncc: int = 0,
405
  ):
406
 
@@ -409,10 +418,17 @@ class Interface(torch.nn.Module):
409
  mask,
410
  inpaint(z, self.s2t(prefix_s), self.s2t(suffix_s)),
411
  )
412
- mask = mask_and(
413
- mask,
414
- periodic_mask(z, periodic_prompt, periodic_prompt_width, random_roll=True),
415
- )
 
 
 
 
 
 
 
416
  if onset_mask_width > 0:
417
  assert sig is not None, f"must provide a signal to use onset mask"
418
  mask = mask_or(
@@ -424,7 +440,8 @@ class Interface(torch.nn.Module):
424
 
425
  mask = dropout(mask, _dropout)
426
  mask = codebook_unmask(mask, ncc)
427
- mask = codebook_mask(mask, int(upper_codebook_mask))
 
428
  return mask
429
 
430
  def ez_vamp(
@@ -440,8 +457,8 @@ class Interface(torch.nn.Module):
440
  build_mask_kwargs = build_mask_kwargs or {}
441
  vamp_kwargs = vamp_kwargs or {}
442
 
443
- sig = self.preprocess(sig)
444
  loudness = sig.loudness()
 
445
 
446
  z = self.encode(sig)
447
 
@@ -488,26 +505,60 @@ class Interface(torch.nn.Module):
488
 
489
  # now, coarse2fine
490
  print(f"coarse2fine!")
491
- zv = self.coarse_to_fine(
492
  zv,
493
  mask=mask,
494
  **vamp_kwargs,
495
- _sampling_steps=[4, 2, 2, 2]
 
 
 
 
 
496
  )
497
 
498
  prev_zvs.append(zv)
499
  z = zv
500
 
501
- sig = self.to_signal(zv).cpu()
 
 
 
 
 
 
 
502
  print("done")
503
 
 
 
 
 
 
504
  sig = sig.normalize(loudness)
505
 
506
  if return_mask:
507
- return sig, mask.cpu()
508
  else:
509
  return sig
510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
  if __name__ == "__main__":
512
  import audiotools as at
513
  import logging
@@ -528,8 +579,6 @@ if __name__ == "__main__":
528
  sig = at.AudioSignal('assets/example.wav')
529
 
530
  z = interface.encode(sig)
531
- breakpoint()
532
-
533
  # mask = linear_random(z, 1.0)
534
  # mask = mask_and(
535
  # mask, periodic_mask(
@@ -569,3 +618,43 @@ if __name__ == "__main__":
569
  print("done")
570
 
571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  else:
274
  mask = mask.repeat(1, self.coarse.n_codebooks, 1)
275
  return mask
276
+
277
+ def set_chunk_size(self, chunk_size_s: float):
278
+ self.coarse.chunk_size_s = chunk_size_s
279
 
280
  def coarse_to_fine(
281
  self,
282
  z: torch.Tensor,
283
  mask: torch.Tensor = None,
284
+ return_mask: bool = False,
285
  **kwargs
286
  ):
287
  assert self.c2f is not None, "No coarse2fine model loaded"
 
324
  fine_z.append(chunk)
325
 
326
  fine_z = torch.cat(fine_z, dim=-1)
327
+ if return_mask:
328
+ return fine_z[:, :, :length].clone(), apply_mask(fine_z, mask, self.c2f.mask_token)[0][:, :, :length].clone()
329
+
330
  return fine_z[:, :, :length].clone()
331
 
332
  def coarse_vamp(
 
404
  prefix_s: float = 0.0,
405
  suffix_s: float = 0.0,
406
  periodic_prompt: int = 7,
407
+ periodic_prompt2: int = 7,
408
  periodic_prompt_width: int = 1,
409
  onset_mask_width: int = 0,
410
  _dropout: float = 0.0,
411
  upper_codebook_mask: int = 3,
412
+ upper_codebook_mask_2: int = None,
413
  ncc: int = 0,
414
  ):
415
 
 
418
  mask,
419
  inpaint(z, self.s2t(prefix_s), self.s2t(suffix_s)),
420
  )
421
+
422
+ pmask1 = periodic_mask(z, periodic_prompt, periodic_prompt_width, random_roll=True)
423
+ pmask2 = periodic_mask(z, periodic_prompt2, periodic_prompt_width, random_roll=True)
424
+ # interpolate the two masks
425
+ pmask = torch.round(
426
+ pmask1 * torch.linspace(1, 0, pmask1.shape[-1], device=pmask1.device) +
427
+ pmask2 * torch.linspace(0, 1, pmask2.shape[-1], device=pmask2.device)
428
+ ).long()
429
+
430
+ mask = mask_and(mask, pmask)
431
+
432
  if onset_mask_width > 0:
433
  assert sig is not None, f"must provide a signal to use onset mask"
434
  mask = mask_or(
 
440
 
441
  mask = dropout(mask, _dropout)
442
  mask = codebook_unmask(mask, ncc)
443
+
444
+ mask = codebook_mask(mask, int(upper_codebook_mask), upper_codebook_mask_2)
445
  return mask
446
 
447
  def ez_vamp(
 
457
  build_mask_kwargs = build_mask_kwargs or {}
458
  vamp_kwargs = vamp_kwargs or {}
459
 
 
460
  loudness = sig.loudness()
461
+ sig = self.preprocess(sig)
462
 
463
  z = self.encode(sig)
464
 
 
505
 
506
  # now, coarse2fine
507
  print(f"coarse2fine!")
508
+ zv, fine_zv_mask = self.coarse_to_fine(
509
  zv,
510
  mask=mask,
511
  **vamp_kwargs,
512
+ _sampling_steps=[2, 2, 1, 1],
513
+ return_mask=True
514
+ )
515
+ mask_z = torch.cat(
516
+ [mask_z[:, :self.coarse.n_codebooks, :], fine_zv_mask[:, self.coarse.n_codebooks:, :]],
517
+ dim=1
518
  )
519
 
520
  prev_zvs.append(zv)
521
  z = zv
522
 
523
+ # perform to_signal batch item by batch
524
+ sigs = []
525
+ for zv in prev_zvs:
526
+ # do it in timestep chunks of 1024
527
+ _sigs = []
528
+ for i in range(0, zv.shape[-1], 1024):
529
+ _sigs.append(self.to_signal(zv[:, :, i:i+1024]).cpu())
530
+ sigs.append(signal_concat(_sigs))
531
  print("done")
532
 
533
+ sig = AudioSignal.batch(sigs)
534
+
535
+ # sig = self.to_signal(zv).cpu()
536
+ # print("done")
537
+
538
  sig = sig.normalize(loudness)
539
 
540
  if return_mask:
541
+ return sig, mask_z.cpu(), zv.cpu()
542
  else:
543
  return sig
544
 
545
+ def visualize_codes(self, z: torch.Tensor):
546
+ import matplotlib.pyplot as plt
547
+ # make sure the figsize is square when imshow is called
548
+ fig = plt.figure(figsize=(10, 7))
549
+ # in subplots, plot z[0] and the mask
550
+ # set title to "codes" and "mask"
551
+ fig.add_subplot(2, 1, 1)
552
+ plt.imshow(z[0].cpu().numpy(), aspect='auto', origin='lower', cmap="tab20")
553
+ plt.title("codes")
554
+ plt.ylabel("codebook index")
555
+ # set the xticks to seconds
556
+ plt.xticks(
557
+ np.arange(0, z.shape[-1], self.s2t(1)),
558
+ np.arange(0, self.t2s(z.shape[-1]), 1)
559
+ )
560
+ plt.xlabel("time (s)")
561
+
562
  if __name__ == "__main__":
563
  import audiotools as at
564
  import logging
 
579
  sig = at.AudioSignal('assets/example.wav')
580
 
581
  z = interface.encode(sig)
 
 
582
  # mask = linear_random(z, 1.0)
583
  # mask = mask_and(
584
  # mask, periodic_mask(
 
618
  print("done")
619
 
620
 
621
+
622
+
623
+ # example plotting code
624
+ # import matplotlib.pyplot as plt
625
+ # from pathlib import Path
626
+ # Path(".vampnet").mkdir(exist_ok=True)
627
+ # plt.clf()
628
+ # # close all figs
629
+ # plt.close('all')
630
+ # # set the fig size
631
+ # plt.subplot(4, 1, 1)
632
+ # # sig = self.to_signal(sampled_z, codec)
633
+ # # sig.cpu().specshow()
634
+
635
+ # plt.subplot(4, 1, 2)
636
+ # # since z_masked is a codebook, we want to plot the colormap
637
+ # # with distinct colors for each codebook index
638
+ # # plt.imshow(_debug_z_masked_before_forward[0].cpu().numpy(), aspect='auto', origin='lower', cmap="tab20")
639
+ # # make it so that anywhere where the mask is 1, we make that pixel black
640
+ # plt.imshow(_debug_z_masked_before_forward[0].cpu().numpy(), aspect='auto', origin='lower', cmap='gray_r',)
641
+
642
+
643
+ # plt.subplot(4, 1, 3)
644
+ # # plot the mask (which is a matrix)
645
+ # plt.imshow(mask[0].cpu().numpy(), aspect='auto', origin='lower', cmap='gray_r')
646
+ # plt.subplot(4, 1, 4)
647
+ # # replace any inf or -inf with 0
648
+ # _selected_probs = torch.where(
649
+ # selected_probs == torch.inf, torch.zeros_like(selected_probs), selected_probs
650
+ # )
651
+ # _selected_probs = torch.where(
652
+ # selected_probs == -torch.inf, torch.zeros_like(selected_probs), selected_probs
653
+ # )
654
+ # # fig = plt.gcf()
655
+ # # fig.set_figheight(15)
656
+ # # fig.set_figwidth(15)
657
+ # plt.imshow(codebook_unflatten(_selected_probs, n_infer_codebooks)[0].cpu().numpy(), aspect='auto', origin='lower', cmap="viridis" )
658
+ # # plt.show()
659
+ # plt.savefig(f".vampnet/c={codebook_level}_{i}.png")
660
+ # plt.close('all')
vampnet/mask.py CHANGED
@@ -60,6 +60,7 @@ def linear_random(
60
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
61
  if not isinstance(r, torch.Tensor):
62
  r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
 
63
 
64
  probs = torch.ones_like(x).to(x.device).float()
65
  # expand to batch and codebook dims
@@ -98,7 +99,7 @@ def inpaint(x: torch.Tensor,
98
  return mask
99
 
100
  def periodic_mask(x: torch.Tensor,
101
- period: int, width: int = 1,
102
  random_roll=False,
103
  ):
104
  mask = full_mask(x)
@@ -140,9 +141,15 @@ def codebook_unmask(
140
  mask[:, :n_conditioning_codebooks, :] = 0
141
  return mask
142
 
143
- def codebook_mask(mask: torch.Tensor, start: int):
144
  mask = mask.clone()
145
- mask[:, start:, :] = 1
 
 
 
 
 
 
146
  return mask
147
 
148
  def mask_and(
 
60
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
61
  if not isinstance(r, torch.Tensor):
62
  r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
63
+ r = r[:, None, None]
64
 
65
  probs = torch.ones_like(x).to(x.device).float()
66
  # expand to batch and codebook dims
 
99
  return mask
100
 
101
  def periodic_mask(x: torch.Tensor,
102
+ period: int,width: int = 1,
103
  random_roll=False,
104
  ):
105
  mask = full_mask(x)
 
141
  mask[:, :n_conditioning_codebooks, :] = 0
142
  return mask
143
 
144
+ def codebook_mask(mask: torch.Tensor, val1: int, val2: int = None):
145
  mask = mask.clone()
146
+ mask[:, val1:, :] = 1
147
+ # val2 = val2 or val1
148
+ # vs = torch.linspace(val1, val2, mask.shape[1])
149
+ # for t, v in enumerate(vs):
150
+ # v = int(v)
151
+ # mask[:, v:, t] = 1
152
+
153
  return mask
154
 
155
  def mask_and(
vampnet/modules/transformer.py CHANGED
@@ -572,6 +572,8 @@ class VampNet(at.ml.BaseModel):
572
  """
573
  assert z.ndim == 3
574
 
 
 
575
  signal = at.AudioSignal(
576
  codec.decode(
577
  codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
@@ -581,15 +583,13 @@ class VampNet(at.ml.BaseModel):
581
 
582
  # find where the mask token is and replace it with silence in the audio
583
  for tstep in range(z.shape[-1]):
584
- if torch.any(z[:, :, tstep] == self.mask_token):
585
  sample_idx_0 = tstep * codec.hop_length
586
  sample_idx_1 = sample_idx_0 + codec.hop_length
587
  signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
588
 
589
  return signal
590
-
591
-
592
-
593
  @torch.no_grad()
594
  def generate(
595
  self,
@@ -600,17 +600,37 @@ class VampNet(at.ml.BaseModel):
600
  sampling_temperature: float = 1.0,
601
  mask: Optional[torch.Tensor] = None,
602
  mask_temperature: float = 10.5,
603
- typical_filtering=False,
604
  typical_mass=0.2,
605
  typical_min_tokens=1,
606
- top_p=None,
607
  seed: int = None,
608
- sample_cutoff: float = 1.0,
609
  return_signal=True,
610
  debug=False,
611
  causal_weight: float = 0.0,
 
612
  ):
613
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  if seed is not None:
615
  at.util.seed(seed)
616
 
@@ -749,7 +769,7 @@ class VampNet(at.ml.BaseModel):
749
  num_to_mask
750
  )
751
  )
752
- print(f"will mask {num_to_mask.sum()} tokens")
753
  mask = codebook_flatten(mask)
754
 
755
  # ignore any tokens that weren't masked
@@ -812,6 +832,188 @@ class VampNet(at.ml.BaseModel):
812
  else:
813
  return sampled_z
814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
815
  def sample_from_logits(
816
  logits,
817
  sample: bool = True,
 
572
  """
573
  assert z.ndim == 3
574
 
575
+ # remove mask token
576
+ z = z.masked_fill(z == self.mask_token, 0)
577
  signal = at.AudioSignal(
578
  codec.decode(
579
  codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
 
583
 
584
  # find where the mask token is and replace it with silence in the audio
585
  for tstep in range(z.shape[-1]):
586
+ if torch.all(z[:, :, tstep] == self.mask_token):
587
  sample_idx_0 = tstep * codec.hop_length
588
  sample_idx_1 = sample_idx_0 + codec.hop_length
589
  signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
590
 
591
  return signal
592
+
 
 
593
  @torch.no_grad()
594
  def generate(
595
  self,
 
600
  sampling_temperature: float = 1.0,
601
  mask: Optional[torch.Tensor] = None,
602
  mask_temperature: float = 10.5,
603
+ typical_filtering=True,
604
  typical_mass=0.2,
605
  typical_min_tokens=1,
606
+ top_p=0.9,
607
  seed: int = None,
608
+ sample_cutoff: float = 0.9,
609
  return_signal=True,
610
  debug=False,
611
  causal_weight: float = 0.0,
612
+ use_og_method: bool = False,
613
  ):
614
+ if use_og_method:
615
+ return self.og_method(
616
+ codec,
617
+ time_steps,
618
+ _sampling_steps,
619
+ start_tokens,
620
+ sampling_temperature,
621
+ mask,
622
+ mask_temperature,
623
+ typical_filtering,
624
+ typical_mass,
625
+ typical_min_tokens,
626
+ top_p,
627
+ seed,
628
+ sample_cutoff,
629
+ return_signal,
630
+ debug,
631
+ causal_weight,
632
+ )
633
+
634
  if seed is not None:
635
  at.util.seed(seed)
636
 
 
769
  num_to_mask
770
  )
771
  )
772
+ logging.debug(f"will mask {num_to_mask.sum()} tokens")
773
  mask = codebook_flatten(mask)
774
 
775
  # ignore any tokens that weren't masked
 
832
  else:
833
  return sampled_z
834
 
835
+
836
+
837
+ def og_method(
838
+ self,
839
+ codec,
840
+ time_steps: int = 300,
841
+ _sampling_steps: List[int] = [16, 8, 8, 2, 2, 2, 2, 1, 1],
842
+ start_tokens: Optional[torch.Tensor] = None,
843
+ sampling_temperature: float = 1.0,
844
+ mask: Optional[torch.Tensor] = None,
845
+ mask_temperature: float = 10.5,
846
+ typical_filtering=True,
847
+ typical_mass=0.2,
848
+ typical_min_tokens=1,
849
+ top_p=0.9,
850
+ seed: int = None,
851
+ sample_cutoff: float = 0.75,
852
+ return_signal=True,
853
+ debug=False,
854
+ causal_weight: float = 0.0,
855
+ ):
856
+ if seed is not None:
857
+ at.util.seed(seed)
858
+ sampling_steps = sum(_sampling_steps)
859
+ logging.debug(f"beginning generation with {sampling_steps} steps")
860
+
861
+
862
+
863
+ #####################
864
+ # resolve initial z #
865
+ #####################
866
+ z = start_tokens
867
+
868
+ if z is None:
869
+ z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
870
+ self.device
871
+ )
872
+
873
+ logging.debug(f"created z with shape {z.shape}")
874
+
875
+
876
+ #################
877
+ # resolve mask #
878
+ #################
879
+
880
+ if mask is None:
881
+ mask = torch.ones_like(z).to(self.device).int()
882
+ mask[:, : self.n_conditioning_codebooks, :] = 0.0
883
+ if mask.ndim == 2:
884
+ mask = mask[:, None, :].repeat(1, z.shape[1], 1)
885
+ # init_mask = mask.clone()
886
+
887
+ logging.debug(f"created mask with shape {mask.shape}")
888
+
889
+
890
+ ###########
891
+ # set up #
892
+ ##########
893
+ # apply the mask to z
894
+ z_masked = z.masked_fill(mask.bool(), self.mask_token)
895
+ # logging.debug(f"z_masked: {z_masked}")
896
+
897
+ # how many mask tokens to begin with?
898
+ num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
899
+ logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
900
+
901
+ # how many codebooks are we inferring vs conditioning on?
902
+ n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
903
+ logging.debug(f"n infer codebooks: {n_infer_codebooks}")
904
+
905
+ #################
906
+ # begin sampling #
907
+ #################
908
+
909
+ for i in range(sampling_steps):
910
+ logging.debug(f"step {i} of {sampling_steps}")
911
+
912
+ # our current schedule step
913
+ r = scalar_to_batch_tensor(
914
+ (i + 1) / sampling_steps,
915
+ z.shape[0]
916
+ ).to(z.device)
917
+ logging.debug(f"r: {r}")
918
+
919
+ # get latents
920
+ latents = self.embedding.from_codes(z_masked, codec)
921
+ logging.debug(f"computed latents with shape: {latents.shape}")
922
+
923
+
924
+ # infer from latents
925
+ # NOTE: this collapses the codebook dimension into the sequence dimension
926
+ logits = self.forward(latents) # b, prob, seq
927
+ logits = logits.permute(0, 2, 1) # b, seq, prob
928
+ b = logits.shape[0]
929
+
930
+ logging.debug(f"permuted logits with shape: {logits.shape}")
931
+
932
+ sampled_z, selected_probs = sample_from_logits(
933
+ logits, sample=(
934
+ (i / sampling_steps) <= sample_cutoff
935
+ ),
936
+ temperature=sampling_temperature,
937
+ typical_filtering=typical_filtering, typical_mass=typical_mass,
938
+ typical_min_tokens=typical_min_tokens,
939
+ top_k=None, top_p=top_p, return_probs=True,
940
+ )
941
+
942
+ logging.debug(f"sampled z with shape: {sampled_z.shape}")
943
+
944
+ # flatten z_masked and mask, so we can deal with the sampling logic
945
+ # we'll unflatten them at the end of the loop for the next forward pass
946
+ # remove conditioning codebooks, we'll add them back at the end
947
+ z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
948
+
949
+ mask = (z_masked == self.mask_token).int()
950
+
951
+ # update the mask, remove conditioning codebooks from the mask
952
+ logging.debug(f"updated mask with shape: {mask.shape}")
953
+ # add z back into sampled z where the mask was false
954
+ sampled_z = torch.where(
955
+ mask.bool(), sampled_z, z_masked
956
+ )
957
+ logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
958
+
959
+ # ignore any tokens that weren't masked
960
+ selected_probs = torch.where(
961
+ mask.bool(), selected_probs, torch.inf
962
+ )
963
+
964
+ # get the num tokens to mask, according to the schedule
965
+ num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
966
+ logging.debug(f"num to mask: {num_to_mask}")
967
+
968
+ if i != (sampling_steps - 1):
969
+ num_to_mask = torch.maximum(
970
+ torch.tensor(1),
971
+ torch.minimum(
972
+ mask.sum(dim=-1, keepdim=True) - 1,
973
+ num_to_mask
974
+ )
975
+ )
976
+
977
+
978
+ # get our new mask
979
+ mask = mask_by_random_topk(
980
+ num_to_mask, selected_probs, mask_temperature * (1-r)
981
+ )
982
+
983
+ # update the mask
984
+ z_masked = torch.where(
985
+ mask.bool(), self.mask_token, sampled_z
986
+ )
987
+ logging.debug(f"updated z_masked with shape: {z_masked.shape}")
988
+
989
+ z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
990
+ mask = codebook_unflatten(mask, n_infer_codebooks)
991
+ logging.debug(f"unflattened z_masked with shape: {z_masked.shape}")
992
+
993
+ # add conditioning codebooks back to z_masked
994
+ z_masked = torch.cat(
995
+ (z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
996
+ )
997
+ logging.debug(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}")
998
+
999
+
1000
+ # add conditioning codebooks back to sampled_z
1001
+ sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
1002
+ sampled_z = torch.cat(
1003
+ (z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
1004
+ )
1005
+
1006
+ logging.debug(f"finished sampling")
1007
+
1008
+ if return_signal:
1009
+ return self.to_signal(sampled_z, codec)
1010
+ else:
1011
+ return sampled_z
1012
+
1013
+
1014
+
1015
+
1016
+
1017
  def sample_from_logits(
1018
  logits,
1019
  sample: bool = True,