Hugo Flores commited on
Commit
a63cce0
1 Parent(s): 4a2dc41

interface improvements

Browse files
scripts/utils/process_folder-c2f.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from audiotools import AudioSignal
3
+ import torch
4
+ from pathlib import Path
5
+ import argbind
6
+ from tqdm import tqdm
7
+ import random
8
+
9
+ from collections import defaultdict
10
+
11
+ def coarse2fine_infer(
12
+ signal,
13
+ model,
14
+ vqvae,
15
+ device,
16
+ signal_window=3,
17
+ signal_hop=1.5,
18
+ max_excerpts=25,
19
+ ):
20
+ output = defaultdict(list)
21
+
22
+ # split into 3 seconds
23
+ windows = [s for s in signal.clone().windows(signal_window, signal_hop)]
24
+ random.shuffle(windows)
25
+ for w in windows[1:max_excerpts]: # skip the first window since it's mostly zero padded?
26
+ # batch the signal into chunks of 3
27
+ with torch.no_grad():
28
+ # get codes
29
+ w = w.to(device)
30
+ z = vqvae.encode(w.audio_data, w.sample_rate)["codes"]
31
+
32
+ model.to(device)
33
+ output["reconstructed"] = model.to_signal(z, vqvae).cpu()
34
+
35
+ # make a full mask
36
+ mask = torch.ones_like(z)
37
+ mask[:, :model.n_conditioning_codebooks, :] = 0
38
+
39
+ output["sampled"].append(model.sample(
40
+ codec=vqvae,
41
+ time_steps=z.shape[-1],
42
+ sampling_steps=12,
43
+ start_tokens=z,
44
+ mask=mask,
45
+ temperature=0.85,
46
+ top_k=None,
47
+ sample="gumbel",
48
+ typical_filtering=True,
49
+ return_signal=True
50
+ ).cpu())
51
+
52
+ output["argmax"].append(model.sample(
53
+ codec=vqvae,
54
+ time_steps=z.shape[-1],
55
+ sampling_steps=1,
56
+ start_tokens=z,
57
+ mask=mask,
58
+ temperature=1.0,
59
+ top_k=None,
60
+ sample="argmax",
61
+ typical_filtering=True,
62
+ return_signal=True
63
+ ).cpu())
64
+
65
+ return output
66
+
67
+
68
+ @argbind.bind(without_prefix=True)
69
+ def main(
70
+ sources=[
71
+ "/home/hugo/data/spotdl/audio/val", "/home/hugo/data/spotdl/audio/test"
72
+ ],
73
+ audio_ext="mp3",
74
+ exp_name="noise_mode",
75
+ model_paths=[
76
+ "ckpt/mask/best/vampnet/weights.pth",
77
+ "ckpt/random/best/vampnet/weights.pth",
78
+ ],
79
+ model_keys=[
80
+ "noise_mode=mask",
81
+ "noise_mode=random",
82
+ ],
83
+ vqvae_path="ckpt/wav2wav.pth",
84
+ device="cuda",
85
+ ):
86
+ from vampnet.modules.transformer import VampNet
87
+ from lac.model.lac import LAC
88
+ from audiotools.post import audio_zip
89
+
90
+ models = {
91
+ k: VampNet.load(p) for k, p in zip(model_keys, model_paths)
92
+ }
93
+ for model in models.values():
94
+ model.eval()
95
+ print(f"Loaded {len(models)} models.")
96
+
97
+ vqvae = LAC.load(vqvae_path)
98
+ vqvae.to(device)
99
+ vqvae.eval()
100
+ print("Loaded VQVAE.")
101
+
102
+ audio_dict = defaultdict(list)
103
+ for source in sources:
104
+ print(f"Processing {source}...")
105
+ for path in tqdm(list(Path(source).glob(f"**/*.{audio_ext}"))):
106
+ sig = AudioSignal(path)
107
+ sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0)
108
+
109
+ for model_key, model in models.items():
110
+ out = coarse2fine_infer(sig, model, vqvae, device)
111
+ for k in out:
112
+ audio_dict[f"{model_key}-{k}"].extend(out[k])
113
+
114
+ audio_zip(audio_dict, f"{exp_name}-results.zip")
115
+
116
+
117
+ if __name__ == "__main__":
118
+ args = argbind.parse_args()
119
+
120
+ with argbind.scope(args):
121
+ main()
122
+
vampnet/gradio.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+
4
+
vampnet/interface.py CHANGED
@@ -4,11 +4,21 @@ import math
4
 
5
  import torch
6
  from audiotools import AudioSignal
 
7
 
8
  from .modules.transformer import VampNet
9
  from lac.model.lac import LAC
10
 
11
 
 
 
 
 
 
 
 
 
 
12
  class Interface:
13
  def __init__(
14
  self,
@@ -26,20 +36,29 @@ class Interface:
26
  self.coarse = VampNet.load(location=Path(coarse_ckpt), map_location="cpu")
27
  self.coarse.to(device)
28
  self.coarse.eval()
29
- self.coarse.chunk_size_s = coarse_chunk_size_s
30
 
31
- self.c2f = VampNet.load(
32
- location=Path(coarse2fine_ckpt), map_location="cpu"
33
- )
34
- self.c2f.to(device)
35
- self.c2f.eval()
36
- self.c2f.chunk_size_s = coarse2fine_chunk_size_s
 
37
 
38
  self.device = device
39
 
40
  def s2t(self, seconds: float):
41
  """seconds to tokens"""
42
- return int(seconds * self.codec.sample_rate / self.codec.hop_length)
 
 
 
 
 
 
 
 
43
 
44
  def to(self, device):
45
  self.device = device
@@ -51,15 +70,22 @@ class Interface:
51
  def to_signal(self, z: torch.Tensor):
52
  return self.coarse.to_signal(z, self.codec)
53
 
54
- @torch.inference_mode()
55
- def encode(self, signal: AudioSignal):
 
 
 
56
  signal = (
57
- signal.clone().to(self.device)
58
  .resample(self.codec.sample_rate)
59
  .to_mono()
60
  .normalize(-24)
61
  .ensure_max_of_audio(1.0)
62
  )
 
 
 
 
63
  z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
64
  return z
65
 
@@ -68,6 +94,7 @@ class Interface:
68
  coarse_z: torch.Tensor,
69
  **kwargs
70
  ):
 
71
  length = coarse_z.shape[-1]
72
  chunk_len = self.s2t(self.c2f.chunk_size_s)
73
  n_chunks = math.ceil(coarse_z.shape[-1] / chunk_len)
@@ -198,24 +225,30 @@ class Interface:
198
  def coarse_vamp_v2(
199
  self,
200
  signal,
201
- prefix_dur_s: float = 1.25,
202
- suffix_dur_s: float = 1.25,
203
- num_loops: int = 3,
204
  downsample_factor: int = None,
 
205
  debug=False,
 
206
  **kwargs
207
  ):
208
  z = self.encode(signal)
209
 
210
- assert signal.duration == self.coarse.chunk_size_s, "signal duration must match coarse chunk size for now"
211
-
212
  # coarse z
213
  cz = z[:, : self.coarse.n_codebooks, :].clone()
214
  c_seq_len = cz.shape[-1]
215
  n_prefix = self.s2t(prefix_dur_s)
216
  n_suffix = self.s2t(suffix_dur_s)
217
 
 
218
  assert n_prefix + n_suffix < c_seq_len, "prefix and suffix must be smaller than the chunk size"
 
 
 
 
 
219
 
220
  # we'll keep the final codes sequence here
221
  c_vamp = {
@@ -225,10 +258,10 @@ class Interface:
225
 
226
  _cz = cz.clone()
227
  cz_mask = None
228
- for _ in range(num_loops):
229
  # add noise
230
  cz_masked, cz_mask = self.coarse.add_noise(
231
- _cz, r=0.0,
232
  n_prefix=n_prefix,
233
  n_suffix=n_suffix,
234
  downsample_factor=downsample_factor,
@@ -244,7 +277,7 @@ class Interface:
244
  print(f"z: {_cz[:,0,:]}")
245
  cz_sampled = self.coarse.sample(
246
  codec=self.codec,
247
- time_steps=self.s2t(self.coarse.chunk_size_s),
248
  start_tokens=_cz,
249
  mask=cz_mask,
250
  return_signal=False,
@@ -329,17 +362,113 @@ class Interface:
329
  c_vamp = torch.cat([prefix_codes, suffix_codes], dim=-1)
330
  return c_vamp
331
 
 
 
 
 
 
 
 
 
 
332
 
 
 
333
 
 
 
 
 
 
 
 
334
 
 
335
 
 
 
 
 
 
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
 
338
 
 
339
 
 
 
340
 
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
-
 
344
 
 
 
 
345
 
 
4
 
5
  import torch
6
  from audiotools import AudioSignal
7
+ import tqdm
8
 
9
  from .modules.transformer import VampNet
10
  from lac.model.lac import LAC
11
 
12
 
13
+
14
+ def signal_concat(
15
+ audio_signals: list,
16
+ ):
17
+ audio_data = torch.cat([x.audio_data for x in audio_signals], dim=-1)
18
+
19
+ return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
20
+
21
+
22
  class Interface:
23
  def __init__(
24
  self,
 
36
  self.coarse = VampNet.load(location=Path(coarse_ckpt), map_location="cpu")
37
  self.coarse.to(device)
38
  self.coarse.eval()
39
+ self.coarse.chunk_size_s = self.s2t2s(coarse_chunk_size_s)
40
 
41
+ if coarse2fine_ckpt is not None:
42
+ self.c2f = VampNet.load(
43
+ location=Path(coarse2fine_ckpt), map_location="cpu"
44
+ )
45
+ self.c2f.to(device)
46
+ self.c2f.eval()
47
+ self.c2f.chunk_size_s = self.s2t2s(coarse2fine_chunk_size_s)
48
 
49
  self.device = device
50
 
51
  def s2t(self, seconds: float):
52
  """seconds to tokens"""
53
+ return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
54
+
55
+ def s2t2s(self, seconds: float):
56
+ """seconds to tokens to seconds"""
57
+ return self.t2s(self.s2t(seconds))
58
+
59
+ def t2s(self, tokens: int):
60
+ """tokens to seconds"""
61
+ return tokens * self.codec.hop_length / self.codec.sample_rate
62
 
63
  def to(self, device):
64
  self.device = device
 
70
  def to_signal(self, z: torch.Tensor):
71
  return self.coarse.to_signal(z, self.codec)
72
 
73
+ def autoencode(self, signal: AudioSignal):
74
+ z = self.encode(signal)
75
+ return self.to_signal(z)
76
+
77
+ def preprocess(self, signal: AudioSignal):
78
  signal = (
79
+ signal.clone()
80
  .resample(self.codec.sample_rate)
81
  .to_mono()
82
  .normalize(-24)
83
  .ensure_max_of_audio(1.0)
84
  )
85
+ return signal
86
+ @torch.inference_mode()
87
+ def encode(self, signal: AudioSignal):
88
+ signal = self.preprocess(signal).to(self.device)
89
  z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
90
  return z
91
 
 
94
  coarse_z: torch.Tensor,
95
  **kwargs
96
  ):
97
+ assert self.c2f is not None, "No coarse2fine model loaded"
98
  length = coarse_z.shape[-1]
99
  chunk_len = self.s2t(self.c2f.chunk_size_s)
100
  n_chunks = math.ceil(coarse_z.shape[-1] / chunk_len)
 
225
  def coarse_vamp_v2(
226
  self,
227
  signal,
228
+ prefix_dur_s: float = 0.0,
229
+ suffix_dur_s: float = 0.0,
230
+ num_vamps: int = 1,
231
  downsample_factor: int = None,
232
+ intensity: float = 1.0,
233
  debug=False,
234
+ swap_prefix_suffix=False,
235
  **kwargs
236
  ):
237
  z = self.encode(signal)
238
 
 
 
239
  # coarse z
240
  cz = z[:, : self.coarse.n_codebooks, :].clone()
241
  c_seq_len = cz.shape[-1]
242
  n_prefix = self.s2t(prefix_dur_s)
243
  n_suffix = self.s2t(suffix_dur_s)
244
 
245
+ assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
246
  assert n_prefix + n_suffix < c_seq_len, "prefix and suffix must be smaller than the chunk size"
247
+
248
+ if swap_prefix_suffix:
249
+ # swap the prefix and suffix regions in c_z
250
+ assert n_prefix == n_suffix, "prefix and suffix must be the same size for now"
251
+ cz[:, :, :n_prefix], cz[:, :, c_seq_len-n_suffix:] = cz[:, :, c_seq_len-n_suffix:], cz[:, :, :n_prefix].clone()
252
 
253
  # we'll keep the final codes sequence here
254
  c_vamp = {
 
258
 
259
  _cz = cz.clone()
260
  cz_mask = None
261
+ for _ in range(num_vamps):
262
  # add noise
263
  cz_masked, cz_mask = self.coarse.add_noise(
264
+ _cz, r=1.0-intensity,
265
  n_prefix=n_prefix,
266
  n_suffix=n_suffix,
267
  downsample_factor=downsample_factor,
 
277
  print(f"z: {_cz[:,0,:]}")
278
  cz_sampled = self.coarse.sample(
279
  codec=self.codec,
280
+ time_steps=_cz.shape[-1],
281
  start_tokens=_cz,
282
  mask=cz_mask,
283
  return_signal=False,
 
362
  c_vamp = torch.cat([prefix_codes, suffix_codes], dim=-1)
363
  return c_vamp
364
 
365
+ # create a variation of an audio signal
366
+ def variation(
367
+ self,
368
+ signal: AudioSignal,
369
+ overlap_hop_ratio: float = 1.0, # TODO: should this be fixed to 1.0? or should we overlap and replace instead of overlap add
370
+ verbose: bool = False,
371
+ **kwargs
372
+ ):
373
+ signal = signal.clone()
374
 
375
+ # autoencode first, so the samples get rounded up to the nearest tokens
376
+ signal = self.autoencode(signal).cpu()
377
 
378
+ # pad the signal to the nearest chunk size
379
+ req_len = (
380
+ math.ceil(signal.duration / self.coarse.chunk_size_s)
381
+ * self.coarse.chunk_size_s
382
+ )
383
+ hop_duration = self.coarse.chunk_size_s * overlap_hop_ratio
384
+ original_length = signal.length
385
 
386
+ signal.zero_pad_to(req_len)
387
 
388
+ # window the signal
389
+ signal = signal.collect_windows(
390
+ window_duration=self.coarse.chunk_size_s,
391
+ hop_duration=hop_duration,
392
+ )
393
 
394
+ # output = []
395
+ range_fn = range if not verbose else tqdm.trange
396
+ for i in range_fn(signal.batch_size):
397
+ sig = AudioSignal(
398
+ signal.samples[i,...], signal.sample_rate
399
+ )
400
+ sig.to(self.device)
401
+ out_z = self.coarse_vamp_v2(
402
+ sig,
403
+ num_vamps=1,
404
+ swap_prefix_suffix=False,
405
+ **kwargs
406
+ )
407
+ if self.c2f is not None:
408
+ out_z = self.coarse_to_fine(out_z)
409
+ out_sig = self.to_signal(out_z).cpu()
410
 
411
+ signal.samples[i] = out_sig.samples
412
 
413
+ output = signal.overlap_and_add(hop_duration)
414
 
415
+ output.truncate_samples(original_length)
416
+ return output
417
 
418
 
419
+ # create a loop of a single region with variations
420
+ # TODO: this would work nicer if we could trim at the beat
421
+ # otherwise the model has to awkwardly fill up space that won't match
422
+ # the beat unless the signal is exactly the right length
423
+ def loop(
424
+ self,
425
+ signal: AudioSignal,
426
+ prefix_dur_s: float = 0.0,
427
+ suffix_dur_s: float = 0.0,
428
+ num_loops: int = 4,
429
+ # overlap_hop_ratio: float = 1.0, # TODO: should this be fixed to 1.0? or should we overlap and replace instead of overlap add
430
+ verbose: bool = False,
431
+ **kwargs,
432
+ ):
433
+ assert prefix_dur_s >= 0.0, "prefix duration must be >= 0"
434
+ assert suffix_dur_s >= 0.0, "suffix duration must be >= 0"
435
+ signal = self.preprocess(signal)
436
+
437
+ suffix_len_samples = int(suffix_dur_s * signal.sample_rate)
438
+ prefix_len_tokens = self.s2t(prefix_dur_s)
439
+ suffix_len_tokens = self.s2t(suffix_dur_s)
440
+
441
+ loops = [
442
+ # add everything but the suffix a the beggining
443
+ self.encode(signal.clone().trim(before=0, after=suffix_len_samples))
444
+ ]
445
+ range_fn = range if not verbose else tqdm.trange
446
+ for i in range_fn(num_loops):
447
+ is_flipped = i % 2 == 0
448
+ vamped = self.coarse_vamp_v2(
449
+ signal,
450
+ prefix_dur_s=prefix_dur_s,
451
+ suffix_dur_s=suffix_dur_s,
452
+ swap_prefix_suffix=is_flipped,
453
+ **kwargs
454
+ )
455
+ # if we're flipped, we trim the prefix off of the end
456
+ # otherwise we trim the suffix off of the end
457
+ trim_len = prefix_len_tokens if is_flipped else suffix_len_tokens
458
+ vamped = vamped[:, :, :vamped.shape[-1]-trim_len]
459
+
460
+ loops.append(vamped)
461
+
462
+ if is_flipped:
463
+ loops.append(
464
+ # add everything but the prefix at the end
465
+ self.encode(signal.clone())
466
+ )
467
 
468
+ if self.c2f is not None:
469
+ loops = [self.coarse_to_fine(l) for l in loops]
470
 
471
+ loops = [self.to_signal(l) for l in loops]
472
+
473
+ return signal_concat(loops)
474
 
vampnet/modules/base.py CHANGED
@@ -31,6 +31,11 @@ class VampBase(at.ml.BaseModel):
31
  def forward(self, x: torch.Tensor, r: torch.Tensor):
32
  raise NotImplementedError
33
 
 
 
 
 
 
34
  def add_noise(
35
  self,
36
  x: torch.Tensor,
 
31
  def forward(self, x: torch.Tensor, r: torch.Tensor):
32
  raise NotImplementedError
33
 
34
+ # TODO: add a beat tracking method
35
+ # that uses a beat tracking model to find beat positions
36
+ # and then unmask the codes in those poisitions (with some width)
37
+ # and drop them out with some randomness
38
+ # and have the option to DONT drop out downbeats for
39
  def add_noise(
40
  self,
41
  x: torch.Tensor,