Hugo Flores Garcia commited on
Commit
0321745
1 Parent(s): 2ddbddc

taking prompt c2f tokens into account

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. vampnet/interface.py +40 -20
app.py CHANGED
@@ -114,7 +114,7 @@ def _vamp(data, return_mask=False):
114
  )
115
 
116
  if use_coarse2fine:
117
- zv = interface.coarse_to_fine(zv, temperature=data[temp])
118
 
119
  sig = interface.to_signal(zv).cpu()
120
  print("done")
@@ -410,7 +410,8 @@ with gr.Blocks() as demo:
410
 
411
  use_coarse2fine = gr.Checkbox(
412
  label="use coarse2fine",
413
- value=True
 
414
  )
415
 
416
  num_steps = gr.Slider(
 
114
  )
115
 
116
  if use_coarse2fine:
117
+ zv = interface.coarse_to_fine(zv, temperature=data[temp], mask=mask)
118
 
119
  sig = interface.to_signal(zv).cpu()
120
  print("done")
 
410
 
411
  use_coarse2fine = gr.Checkbox(
412
  label="use coarse2fine",
413
+ value=True,
414
+ visible=False
415
  )
416
 
417
  num_steps = gr.Slider(
vampnet/interface.py CHANGED
@@ -22,6 +22,7 @@ def signal_concat(
22
 
23
  return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
24
 
 
25
  def _load_model(
26
  ckpt: str,
27
  lora_ckpt: str = None,
@@ -275,36 +276,47 @@ class Interface(torch.nn.Module):
275
 
276
  def coarse_to_fine(
277
  self,
278
- coarse_z: torch.Tensor,
 
279
  **kwargs
280
  ):
281
  assert self.c2f is not None, "No coarse2fine model loaded"
282
- length = coarse_z.shape[-1]
283
  chunk_len = self.s2t(self.c2f.chunk_size_s)
284
- n_chunks = math.ceil(coarse_z.shape[-1] / chunk_len)
285
 
286
  # zero pad to chunk_len
287
  if length % chunk_len != 0:
288
  pad_len = chunk_len - (length % chunk_len)
289
- coarse_z = torch.nn.functional.pad(coarse_z, (0, pad_len))
 
290
 
291
- n_codebooks_to_append = self.c2f.n_codebooks - coarse_z.shape[1]
292
  if n_codebooks_to_append > 0:
293
- coarse_z = torch.cat([
294
- coarse_z,
295
- torch.zeros(coarse_z.shape[0], n_codebooks_to_append, coarse_z.shape[-1]).long().to(self.device)
296
  ], dim=1)
297
 
 
 
 
 
 
298
  fine_z = []
299
  for i in range(n_chunks):
300
- chunk = coarse_z[:, :, i * chunk_len : (i + 1) * chunk_len]
 
 
301
  chunk = self.c2f.generate(
302
  codec=self.codec,
303
  time_steps=chunk_len,
304
  start_tokens=chunk,
305
  return_signal=False,
 
306
  **kwargs
307
  )
 
308
  fine_z.append(chunk)
309
 
310
  fine_z = torch.cat(fine_z, dim=-1)
@@ -337,6 +349,12 @@ class Interface(torch.nn.Module):
337
  **kwargs
338
  )
339
 
 
 
 
 
 
 
340
  if return_mask:
341
  return c_vamp, cz_masked
342
 
@@ -352,17 +370,18 @@ if __name__ == "__main__":
352
  at.util.seed(42)
353
 
354
  interface = Interface(
355
- coarse_ckpt="./models/spotdl/coarse.pth",
356
- coarse2fine_ckpt="./models/spotdl/c2f.pth",
357
- codec_ckpt="./models/spotdl/codec.pth",
358
  device="cuda",
359
  wavebeat_ckpt="./models/wavebeat.pth"
360
  )
361
 
362
 
363
- sig = at.AudioSignal.zeros(duration=10, sample_rate=44100)
364
 
365
  z = interface.encode(sig)
 
366
 
367
  # mask = linear_random(z, 1.0)
368
  # mask = mask_and(
@@ -374,13 +393,14 @@ if __name__ == "__main__":
374
  # )
375
  # )
376
 
377
- mask = interface.make_beat_mask(
378
- sig, 0.0, 0.075
379
- )
380
  # mask = dropout(mask, 0.0)
381
  # mask = codebook_unmask(mask, 0)
 
 
382
 
383
- breakpoint()
384
  zv, mask_z = interface.coarse_vamp(
385
  z,
386
  mask=mask,
@@ -389,16 +409,16 @@ if __name__ == "__main__":
389
  return_mask=True,
390
  gen_fn=interface.coarse.generate
391
  )
 
392
 
393
  use_coarse2fine = True
394
  if use_coarse2fine:
395
- zv = interface.coarse_to_fine(zv, temperature=0.8)
 
396
 
397
  mask = interface.to_signal(mask_z).cpu()
398
 
399
  sig = interface.to_signal(zv).cpu()
400
  print("done")
401
 
402
- sig.write("output3.wav")
403
- mask.write("mask.wav")
404
 
 
22
 
23
  return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
24
 
25
+
26
  def _load_model(
27
  ckpt: str,
28
  lora_ckpt: str = None,
 
276
 
277
  def coarse_to_fine(
278
  self,
279
+ z: torch.Tensor,
280
+ mask: torch.Tensor = None,
281
  **kwargs
282
  ):
283
  assert self.c2f is not None, "No coarse2fine model loaded"
284
+ length = z.shape[-1]
285
  chunk_len = self.s2t(self.c2f.chunk_size_s)
286
+ n_chunks = math.ceil(z.shape[-1] / chunk_len)
287
 
288
  # zero pad to chunk_len
289
  if length % chunk_len != 0:
290
  pad_len = chunk_len - (length % chunk_len)
291
+ z = torch.nn.functional.pad(z, (0, pad_len))
292
+ mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None
293
 
294
+ n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
295
  if n_codebooks_to_append > 0:
296
+ z = torch.cat([
297
+ z,
298
+ torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
299
  ], dim=1)
300
 
301
+ # set the mask to 0 for all conditioning codebooks
302
+ if mask is not None:
303
+ mask = mask.clone()
304
+ mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
305
+
306
  fine_z = []
307
  for i in range(n_chunks):
308
+ chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
309
+ mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None
310
+
311
  chunk = self.c2f.generate(
312
  codec=self.codec,
313
  time_steps=chunk_len,
314
  start_tokens=chunk,
315
  return_signal=False,
316
+ mask=mask_chunk,
317
  **kwargs
318
  )
319
+ breakpoint()
320
  fine_z.append(chunk)
321
 
322
  fine_z = torch.cat(fine_z, dim=-1)
 
349
  **kwargs
350
  )
351
 
352
+ # add the fine codes back in
353
+ c_vamp = torch.cat(
354
+ [c_vamp, z[:, self.coarse.n_codebooks :, :]],
355
+ dim=1
356
+ )
357
+
358
  if return_mask:
359
  return c_vamp, cz_masked
360
 
 
370
  at.util.seed(42)
371
 
372
  interface = Interface(
373
+ coarse_ckpt="./models/vampnet/coarse.pth",
374
+ coarse2fine_ckpt="./models/vampnet/c2f.pth",
375
+ codec_ckpt="./models/vampnet/codec.pth",
376
  device="cuda",
377
  wavebeat_ckpt="./models/wavebeat.pth"
378
  )
379
 
380
 
381
+ sig = at.AudioSignal('assets/example.wav')
382
 
383
  z = interface.encode(sig)
384
+ breakpoint()
385
 
386
  # mask = linear_random(z, 1.0)
387
  # mask = mask_and(
 
393
  # )
394
  # )
395
 
396
+ # mask = interface.make_beat_mask(
397
+ # sig, 0.0, 0.075
398
+ # )
399
  # mask = dropout(mask, 0.0)
400
  # mask = codebook_unmask(mask, 0)
401
+
402
+ mask = inpaint(z, n_prefix=100, n_suffix=100)
403
 
 
404
  zv, mask_z = interface.coarse_vamp(
405
  z,
406
  mask=mask,
 
409
  return_mask=True,
410
  gen_fn=interface.coarse.generate
411
  )
412
+
413
 
414
  use_coarse2fine = True
415
  if use_coarse2fine:
416
+ zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask)
417
+ breakpoint()
418
 
419
  mask = interface.to_signal(mask_z).cpu()
420
 
421
  sig = interface.to_signal(zv).cpu()
422
  print("done")
423
 
 
 
424