Hugo Flores Garcia commited on
Commit
fa490b8
1 Parent(s): 4908bb4

interface for max

Browse files
conf/interface-spotdl.yml CHANGED
@@ -3,7 +3,7 @@ Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
3
  Interface.codec_ckpt: ./models/spotdl/codec.pth
4
  Interface.coarse_chunk_size_s: 10
5
  Interface.coarse2fine_chunk_size_s: 3
6
- Interface.wavebeat_ckpt: ./models/wavebeat.pth
7
 
8
 
9
  AudioLoader.sources:
 
3
  Interface.codec_ckpt: ./models/spotdl/codec.pth
4
  Interface.coarse_chunk_size_s: 10
5
  Interface.coarse2fine_chunk_size_s: 3
6
+ # Interface.wavebeat_ckpt: ./models/wavebeat.pth
7
 
8
 
9
  AudioLoader.sources:
demo.py CHANGED
@@ -59,6 +59,43 @@ def load_random_audio():
59
  return sig.path_to_file
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def vamp(
63
  input_audio, init_temp, final_temp,
64
  prefix_s, suffix_s, rand_mask_intensity,
@@ -68,7 +105,7 @@ def vamp(
68
  num_vamps, mode, use_beats, num_steps, snap_to_beats,
69
  beat_unmask_drop, mask_periodic_width,
70
  mask_periodic_dropout, mask_periodic_width_dropout,
71
- n_conditioning_codebooks
72
  ):
73
  # try:
74
  print(input_audio)
@@ -119,36 +156,19 @@ def vamp(
119
  return_mask=True
120
  )
121
 
122
- zv = interface.coarse_to_fine(zv)
123
- mask = interface.to_signal(mask_z).cpu()
 
124
 
125
  sig = interface.to_signal(zv).cpu()
126
  print("done")
127
- elif mode == "loop":
128
- print(f"running loop vampnet with {num_vamps} vamps")
129
- sig, mask = interface.loop(
130
- sig,
131
- temperature=(init_temp, final_temp),
132
- prefix_dur_s=prefix_s,
133
- suffix_dur_s=prefix_s, # suffix should be same length as prefix
134
- num_loops=num_vamps,
135
- downsample_factor=mask_periodic_amt,
136
- periodic_width=mask_periodic_width,
137
- intensity=rand_mask_intensity,
138
- ext_mask=beat_mask,
139
- verbose=True,
140
- return_mask=True
141
- )
142
- sig = sig.cpu()
143
- mask = mask.cpu()
144
- print("done")
145
-
146
 
147
  out_dir = OUT_DIR / str(uuid.uuid4())
148
  out_dir.mkdir()
149
  sig.write(out_dir / "output.wav")
150
- mask.write(out_dir / "mask.wav")
151
- return sig.path_to_file, mask.path_to_file
 
152
  # except Exception as e:
153
  # raise gr.Error(f"failed with error: {e}")
154
 
@@ -160,7 +180,7 @@ def save_vamp(
160
  mask_up_chk, up_factor,
161
  num_vamps, mode, output_audio, notes, use_beats, num_steps, snap_to_beats,
162
  beat_unmask_drop, mask_periodic_width, mask_periodic_dropout, mask_periodic_width_dropout,
163
- n_conditioning_codebooks
164
  ):
165
  out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
166
  out_dir.mkdir(parents=True, exist_ok=True)
@@ -193,7 +213,8 @@ def save_vamp(
193
  "mask_periodic_width": mask_periodic_width,
194
  "mask_periodic_dropout": mask_periodic_dropout,
195
  "mask_periodic_width_dropout": mask_periodic_width_dropout,
196
- "n_conditioning_codebooks": n_conditioning_codebooks
 
197
  }
198
 
199
  # save with yaml
@@ -260,6 +281,10 @@ with gr.Blocks() as demo:
260
  choices=["standard",],
261
  value="standard"
262
  )
 
 
 
 
263
  num_vamps = gr.Number(
264
  label="number of vamps. more vamps = longer generated audio",
265
  value=1,
@@ -476,9 +501,10 @@ with gr.Blocks() as demo:
476
  num_vamps, mode, use_beats, num_steps, snap_to_beats,
477
  beat_unmask_drop, mask_periodic_width,
478
  mask_periodic_dropout, mask_periodic_width_dropout,
479
- n_conditioning_codebooks
480
  ],
481
- outputs=[output_audio, audio_mask]
 
482
  )
483
 
484
  save_button.click(
@@ -494,9 +520,18 @@ with gr.Blocks() as demo:
494
  notes_text, use_beats, num_steps, snap_to_beats,
495
  beat_unmask_drop, mask_periodic_width,
496
  mask_periodic_dropout, mask_periodic_width_dropout,
497
- n_conditioning_codebooks
498
  ],
499
  outputs=[thank_you, download_file]
500
  )
501
 
502
- demo.launch(share=True, enable_queue=True)
 
 
 
 
 
 
 
 
 
 
59
  return sig.path_to_file
60
 
61
 
62
+ def ez_vamp(
63
+ input_audio, init_temp, final_temp,
64
+ mask_periodic_amt, mask_periodic_width, num_steps,
65
+ ):
66
+ print(input_audio)
67
+ sig = at.AudioSignal(input_audio)
68
+
69
+ print(f"running standard vampnet with {num_vamps} vamps")
70
+ zv = interface.coarse_vamp_v2(
71
+ sig,
72
+ sampling_steps=num_steps,
73
+ temperature=(init_temp, final_temp),
74
+ prefix_dur_s=0.0,
75
+ suffix_dur_s=0.0,
76
+ num_vamps=1,
77
+ downsample_factor=mask_periodic_amt,
78
+ periodic_width=mask_periodic_width,
79
+ periodic_dropout=0.0,
80
+ periodic_width_dropout=0.0,
81
+ n_conditioning_codebooks=None,
82
+ intensity=1.0,
83
+ ext_mask=None,
84
+ )
85
+
86
+ zv = interface.coarse_to_fine(zv)
87
+
88
+ sig = interface.to_signal(zv).cpu()
89
+ print("done")
90
+
91
+ out_dir = OUT_DIR / str(uuid.uuid4())
92
+ out_dir.mkdir()
93
+ sig.write(out_dir / "output.wav")
94
+ # mask.write(out_dir / "mask.wav")
95
+ # return sig.path_to_file, mask.path_to_file
96
+ return sig.path_to_file
97
+
98
+
99
  def vamp(
100
  input_audio, init_temp, final_temp,
101
  prefix_s, suffix_s, rand_mask_intensity,
 
105
  num_vamps, mode, use_beats, num_steps, snap_to_beats,
106
  beat_unmask_drop, mask_periodic_width,
107
  mask_periodic_dropout, mask_periodic_width_dropout,
108
+ n_conditioning_codebooks, use_coarse2fine
109
  ):
110
  # try:
111
  print(input_audio)
 
156
  return_mask=True
157
  )
158
 
159
+ if use_coarse2fine:
160
+ zv = interface.coarse_to_fine(zv)
161
+ # mask = interface.to_signal(mask_z).cpu()
162
 
163
  sig = interface.to_signal(zv).cpu()
164
  print("done")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  out_dir = OUT_DIR / str(uuid.uuid4())
167
  out_dir.mkdir()
168
  sig.write(out_dir / "output.wav")
169
+ # mask.write(out_dir / "mask.wav")
170
+ # return sig.path_to_file, mask.path_to_file
171
+ return sig.path_to_file, None
172
  # except Exception as e:
173
  # raise gr.Error(f"failed with error: {e}")
174
 
 
180
  mask_up_chk, up_factor,
181
  num_vamps, mode, output_audio, notes, use_beats, num_steps, snap_to_beats,
182
  beat_unmask_drop, mask_periodic_width, mask_periodic_dropout, mask_periodic_width_dropout,
183
+ n_conditioning_codebooks, use_coarse2fine
184
  ):
185
  out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
186
  out_dir.mkdir(parents=True, exist_ok=True)
 
213
  "mask_periodic_width": mask_periodic_width,
214
  "mask_periodic_dropout": mask_periodic_dropout,
215
  "mask_periodic_width_dropout": mask_periodic_width_dropout,
216
+ "n_conditioning_codebooks": n_conditioning_codebooks,
217
+ "use_coarse2fine": use_coarse2fine,
218
  }
219
 
220
  # save with yaml
 
281
  choices=["standard",],
282
  value="standard"
283
  )
284
+ use_coarse2fine = gr.Checkbox(
285
+ label="use coarse2fine",
286
+ value=True
287
+ )
288
  num_vamps = gr.Number(
289
  label="number of vamps. more vamps = longer generated audio",
290
  value=1,
 
501
  num_vamps, mode, use_beats, num_steps, snap_to_beats,
502
  beat_unmask_drop, mask_periodic_width,
503
  mask_periodic_dropout, mask_periodic_width_dropout,
504
+ n_conditioning_codebooks, use_coarse2fine
505
  ],
506
+ outputs=[output_audio, audio_mask],
507
+ api_name="vamp"
508
  )
509
 
510
  save_button.click(
 
520
  notes_text, use_beats, num_steps, snap_to_beats,
521
  beat_unmask_drop, mask_periodic_width,
522
  mask_periodic_dropout, mask_periodic_width_dropout,
523
+ n_conditioning_codebooks, use_coarse2fine
524
  ],
525
  outputs=[thank_you, download_file]
526
  )
527
 
528
+ ez_vamp_button = gr.Button("ez vamp")
529
+ ez_vamp_button.click(
530
+ fn=ez_vamp,
531
+ inputs=[input_audio, init_temp, final_temp, mask_periodic_amt,
532
+ mask_periodic_width, num_steps ],
533
+ outputs=[output_audio],
534
+ api_name="ez_vamp"
535
+ )
536
+
537
+ demo.launch(share=True, enable_queue=False, debug=True)
vampnet/interface.py CHANGED
@@ -20,7 +20,7 @@ def signal_concat(
20
  return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
21
 
22
 
23
- class Interface:
24
  def __init__(
25
  self,
26
  coarse_ckpt: str = None,
@@ -31,6 +31,7 @@ class Interface:
31
  coarse_chunk_size_s: int = 5,
32
  coarse2fine_chunk_size_s: int = 3,
33
  ):
 
34
  assert codec_ckpt is not None, "must provide a codec checkpoint"
35
  self.codec = LAC.load(Path(codec_ckpt))
36
  self.codec.eval()
@@ -240,103 +241,7 @@ class Interface:
240
  fine_z = torch.cat(fine_z, dim=-1)
241
  return fine_z[:, :, :length].clone()
242
 
243
- def coarse_vamp(
244
- self,
245
- signal,
246
- prefix_dur_s: float = 1.25,
247
- suffix_dur_s: float = 1.25,
248
- num_loops: int = 3,
249
- mode="impute",
250
- downsample_factor: int = None,
251
- debug=False,
252
- **kwargs
253
- ):
254
- z = self.encode(signal)
255
-
256
- assert signal.duration == self.coarse.chunk_size_s, "signal duration must match coarse chunk size for now"
257
-
258
- # coarse z
259
- cz = z[:, : self.coarse.n_codebooks, :].clone()
260
- c_seq_len = cz.shape[-1]
261
- n_prefix = self.s2t(prefix_dur_s)
262
- n_suffix = self.s2t(suffix_dur_s)
263
-
264
- # we'll keep the final codes sequence here
265
- c_vamp = {
266
- 'prefix': [cz[:, :, :n_prefix].clone()],
267
- 'suffix': [cz[:, :, c_seq_len-n_suffix:].clone()]
268
- }
269
-
270
- _cz = cz.clone()
271
- for _ in range(num_loops):
272
- # add noise
273
- cz_masked, cz_mask = self.coarse.add_noise(
274
- _cz, r=0.0,
275
- n_prefix=n_prefix,
276
- n_suffix=n_suffix,
277
- downsample_factor=downsample_factor
278
- )
279
- if debug:
280
- print("tokens to infer")
281
- self.to_signal(cz_masked).cpu().widget()
282
-
283
- # sample!
284
- cz_sampled = self.coarse.sample(
285
- codec=self.codec,
286
- time_steps=self.s2t(self.coarse.chunk_size_s),
287
- start_tokens=_cz,
288
- mask=cz_mask,
289
- return_signal=False,
290
- **kwargs
291
- )
292
-
293
- if debug:
294
- print("tokens sampled")
295
- self.to_signal(cz_sampled).cpu().widget()
296
-
297
- cz_imputed = cz_sampled[:, :, n_prefix:c_seq_len-n_suffix].clone()
298
-
299
- if mode == "impute":
300
- # split the imputed codes into two halves
301
- cz_imputed_a = cz_imputed[:, :, : cz_imputed.shape[-1] // 2].clone()
302
- cz_imputed_b = cz_imputed[:, :, cz_imputed.shape[-1] // 2 :].clone()
303
- elif mode == "continue":
304
- cz_imputed_a = cz_imputed[:, :, : cz_imputed.shape[-1]].clone()
305
- cz_imputed_b = _cz[:, :, :0].clone() # empty
306
- elif mode == "reverse-continue":
307
- cz_imputed_a = _cz[:, :, :0].clone() # empty
308
- cz_imputed_b = cz_imputed[:, :, : cz_imputed.shape[-1]].clone()
309
- else:
310
- raise ValueError(f"mode {mode} not supported")
311
-
312
- if debug:
313
- # add to our c_vamp
314
- if cz_imputed_a.shape[-1] > 0:
315
- print("new_prefix added")
316
- self.to_signal(cz_imputed_a).cpu().widget()
317
- if cz_imputed_b.shape[-1] > 0:
318
- print("new_suffix added")
319
- self.to_signal(cz_imputed_b).cpu().widget()
320
-
321
- c_vamp['prefix'].append(cz_imputed_a.clone())
322
- c_vamp['suffix'].insert(0, cz_imputed_b.clone())
323
-
324
- n_to_insert = c_seq_len - (cz_imputed_a.shape[-1] + cz_imputed_b.shape[-1])
325
- to_insert = torch.zeros(cz_imputed_a.shape[0], cz_imputed_a.shape[1], n_to_insert).long().to(self.device)
326
- _cz = torch.cat([cz_imputed_a, to_insert, cz_imputed_b], dim=-1)
327
-
328
- if debug:
329
- print("tokens to infer next round (area to insert in the middle)")
330
- self.to_signal(_cz).cpu().widget()
331
-
332
-
333
-
334
-
335
- prefix_codes = torch.cat(c_vamp['prefix'], dim=-1)
336
- suffix_codes = torch.cat(c_vamp['suffix'], dim=-1)
337
- c_vamp = torch.cat([prefix_codes, suffix_codes], dim=-1)
338
- return c_vamp
339
-
340
  def coarse_vamp_v2(
341
  self,
342
  signal,
@@ -390,6 +295,7 @@ class Interface:
390
  downsample_factor=downsample_factor,
391
  periodic_width=periodic_width,
392
  periodic_dropout=periodic_dropout,
 
393
  periodic_width_dropout=periodic_width_dropout,
394
  mask=cz_mask,
395
  ext_mask=ext_mask,
 
20
  return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
21
 
22
 
23
+ class Interface(torch.nn.Module):
24
  def __init__(
25
  self,
26
  coarse_ckpt: str = None,
 
31
  coarse_chunk_size_s: int = 5,
32
  coarse2fine_chunk_size_s: int = 3,
33
  ):
34
+ super().__init__()
35
  assert codec_ckpt is not None, "must provide a codec checkpoint"
36
  self.codec = LAC.load(Path(codec_ckpt))
37
  self.codec.eval()
 
241
  fine_z = torch.cat(fine_z, dim=-1)
242
  return fine_z[:, :, :length].clone()
243
 
244
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  def coarse_vamp_v2(
246
  self,
247
  signal,
 
295
  downsample_factor=downsample_factor,
296
  periodic_width=periodic_width,
297
  periodic_dropout=periodic_dropout,
298
+ add_random_periodic_offset=True,
299
  periodic_width_dropout=periodic_width_dropout,
300
  mask=cz_mask,
301
  ext_mask=ext_mask,
vampnet/modules/base.py CHANGED
@@ -44,6 +44,7 @@ class VampBase(at.ml.BaseModel):
44
  periodic_width: int = 1,
45
  periodic_width_dropout: float = 0.0,
46
  periodic_dropout: float = 0.0,
 
47
  n_conditioning_codebooks: Optional[int] = None,
48
  noise_mode: str = None,
49
  ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -84,6 +85,7 @@ class VampBase(at.ml.BaseModel):
84
  if torch.bernoulli(torch.tensor(periodic_dropout)).item() == 1:
85
  # if we win, skip
86
  continue
 
87
  # figure out how wide the mask should be
88
  j_start = max(0, j - periodic_width // 2)
89
  j_end = min(probs.shape[-1] - 1, j + periodic_width // 2) + 1
@@ -92,6 +94,10 @@ class VampBase(at.ml.BaseModel):
92
  j_fill = torch.ones_like(j_mask) * (1 - j_mask)
93
  # fill
94
  probs[i, :, j_start:j_end] = 1 - j_fill
 
 
 
 
95
 
96
  mask = torch.bernoulli(probs)
97
  mask = mask.round().long()
 
44
  periodic_width: int = 1,
45
  periodic_width_dropout: float = 0.0,
46
  periodic_dropout: float = 0.0,
47
+ add_random_periodic_offset: bool = False, # TODO: should be always false lol this is hacky
48
  n_conditioning_codebooks: Optional[int] = None,
49
  noise_mode: str = None,
50
  ) -> Tuple[torch.Tensor, torch.Tensor]:
 
85
  if torch.bernoulli(torch.tensor(periodic_dropout)).item() == 1:
86
  # if we win, skip
87
  continue
88
+
89
  # figure out how wide the mask should be
90
  j_start = max(0, j - periodic_width // 2)
91
  j_end = min(probs.shape[-1] - 1, j + periodic_width // 2) + 1
 
94
  j_fill = torch.ones_like(j_mask) * (1 - j_mask)
95
  # fill
96
  probs[i, :, j_start:j_end] = 1 - j_fill
97
+ if add_random_periodic_offset:
98
+ # add a random offset to the mask
99
+ offset = torch.randint(0, downsample_factor[0], (1,))
100
+ probs = torch.roll(probs, offset.item(), dims=-1)
101
 
102
  mask = torch.bernoulli(probs)
103
  mask = mask.round().long()