Hugo Flores Garcia
commited on
Commit
•
51f416f
1
Parent(s):
a689560
fix
Browse files- app.py +3 -3
- vampnet/interface.py +271 -42
- vampnet/mask.py +11 -4
- vampnet/modules/transformer.py +262 -12
app.py
CHANGED
@@ -494,7 +494,7 @@ with gr.Blocks() as demo:
|
|
494 |
minimum=1,
|
495 |
maximum=16,
|
496 |
step=1,
|
497 |
-
value=
|
498 |
)
|
499 |
|
500 |
win_dur= gr.Slider(
|
@@ -580,8 +580,8 @@ with gr.Blocks() as demo:
|
|
580 |
from pyharp import ModelCard, build_endpoint
|
581 |
|
582 |
model_card = ModelCard(
|
583 |
-
name="
|
584 |
-
description="
|
585 |
author="hugo flores garcía",
|
586 |
tags=["generative","sound"],
|
587 |
)
|
|
|
494 |
minimum=1,
|
495 |
maximum=16,
|
496 |
step=1,
|
497 |
+
value=3
|
498 |
)
|
499 |
|
500 |
win_dur= gr.Slider(
|
|
|
580 |
from pyharp import ModelCard, build_endpoint
|
581 |
|
582 |
model_card = ModelCard(
|
583 |
+
name="nesquik",
|
584 |
+
description="the ultimate 8-bit crusher",
|
585 |
author="hugo flores garcía",
|
586 |
tags=["generative","sound"],
|
587 |
)
|
vampnet/interface.py
CHANGED
@@ -110,26 +110,26 @@ class Interface(torch.nn.Module):
|
|
110 |
# check if we already loaded, if so, don't reload
|
111 |
if self.coarse_path == Path(coarse_ckpt):
|
112 |
print(f"already loaded {coarse_ckpt}")
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
|
122 |
if c2f_ckpt is not None:
|
123 |
if self.c2f_path == Path(c2f_ckpt):
|
124 |
print(f"already loaded {c2f_ckpt}")
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
|
134 |
def s2t(self, seconds: float):
|
135 |
"""seconds to tokens"""
|
@@ -273,11 +273,15 @@ class Interface(torch.nn.Module):
|
|
273 |
else:
|
274 |
mask = mask.repeat(1, self.coarse.n_codebooks, 1)
|
275 |
return mask
|
|
|
|
|
|
|
276 |
|
277 |
def coarse_to_fine(
|
278 |
self,
|
279 |
z: torch.Tensor,
|
280 |
mask: torch.Tensor = None,
|
|
|
281 |
**kwargs
|
282 |
):
|
283 |
assert self.c2f is not None, "No coarse2fine model loaded"
|
@@ -289,7 +293,7 @@ class Interface(torch.nn.Module):
|
|
289 |
if length % chunk_len != 0:
|
290 |
pad_len = chunk_len - (length % chunk_len)
|
291 |
z = torch.nn.functional.pad(z, (0, pad_len))
|
292 |
-
mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None
|
293 |
|
294 |
n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
|
295 |
if n_codebooks_to_append > 0:
|
@@ -297,6 +301,7 @@ class Interface(torch.nn.Module):
|
|
297 |
z,
|
298 |
torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
|
299 |
], dim=1)
|
|
|
300 |
|
301 |
# set the mask to 0 for all conditioning codebooks
|
302 |
if mask is not None:
|
@@ -319,6 +324,9 @@ class Interface(torch.nn.Module):
|
|
319 |
fine_z.append(chunk)
|
320 |
|
321 |
fine_z = torch.cat(fine_z, dim=-1)
|
|
|
|
|
|
|
322 |
return fine_z[:, :, :length].clone()
|
323 |
|
324 |
def coarse_vamp(
|
@@ -331,22 +339,52 @@ class Interface(torch.nn.Module):
|
|
331 |
):
|
332 |
# coarse z
|
333 |
cz = z[:, : self.coarse.n_codebooks, :].clone()
|
334 |
-
assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
|
335 |
-
|
336 |
mask = mask[:, : self.coarse.n_codebooks, :]
|
|
|
337 |
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
# add the fine codes back in
|
352 |
c_vamp = torch.cat(
|
@@ -358,16 +396,169 @@ class Interface(torch.nn.Module):
|
|
358 |
return c_vamp, cz_masked
|
359 |
|
360 |
return c_vamp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
if __name__ == "__main__":
|
373 |
import audiotools as at
|
@@ -389,8 +580,6 @@ if __name__ == "__main__":
|
|
389 |
sig = at.AudioSignal('assets/example.wav')
|
390 |
|
391 |
z = interface.encode(sig)
|
392 |
-
breakpoint()
|
393 |
-
|
394 |
# mask = linear_random(z, 1.0)
|
395 |
# mask = mask_and(
|
396 |
# mask, periodic_mask(
|
@@ -429,4 +618,44 @@ if __name__ == "__main__":
|
|
429 |
sig = interface.to_signal(zv).cpu()
|
430 |
print("done")
|
431 |
|
432 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
# check if we already loaded, if so, don't reload
|
111 |
if self.coarse_path == Path(coarse_ckpt):
|
112 |
print(f"already loaded {coarse_ckpt}")
|
113 |
+
else:
|
114 |
+
self.coarse = _load_model(
|
115 |
+
ckpt=coarse_ckpt,
|
116 |
+
device=self.device,
|
117 |
+
chunk_size_s=self.coarse.chunk_size_s,
|
118 |
+
)
|
119 |
+
self.coarse_path = Path(coarse_ckpt)
|
120 |
+
print(f"loaded {coarse_ckpt}")
|
121 |
|
122 |
if c2f_ckpt is not None:
|
123 |
if self.c2f_path == Path(c2f_ckpt):
|
124 |
print(f"already loaded {c2f_ckpt}")
|
125 |
+
else:
|
126 |
+
self.c2f = _load_model(
|
127 |
+
ckpt=c2f_ckpt,
|
128 |
+
device=self.device,
|
129 |
+
chunk_size_s=self.c2f.chunk_size_s,
|
130 |
+
)
|
131 |
+
self.c2f_path = Path(c2f_ckpt)
|
132 |
+
print(f"loaded {c2f_ckpt}")
|
133 |
|
134 |
def s2t(self, seconds: float):
|
135 |
"""seconds to tokens"""
|
|
|
273 |
else:
|
274 |
mask = mask.repeat(1, self.coarse.n_codebooks, 1)
|
275 |
return mask
|
276 |
+
|
277 |
+
def set_chunk_size(self, chunk_size_s: float):
|
278 |
+
self.coarse.chunk_size_s = chunk_size_s
|
279 |
|
280 |
def coarse_to_fine(
|
281 |
self,
|
282 |
z: torch.Tensor,
|
283 |
mask: torch.Tensor = None,
|
284 |
+
return_mask: bool = False,
|
285 |
**kwargs
|
286 |
):
|
287 |
assert self.c2f is not None, "No coarse2fine model loaded"
|
|
|
293 |
if length % chunk_len != 0:
|
294 |
pad_len = chunk_len - (length % chunk_len)
|
295 |
z = torch.nn.functional.pad(z, (0, pad_len))
|
296 |
+
mask = torch.nn.functional.pad(mask, (0, pad_len), value=1) if mask is not None else None
|
297 |
|
298 |
n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
|
299 |
if n_codebooks_to_append > 0:
|
|
|
301 |
z,
|
302 |
torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
|
303 |
], dim=1)
|
304 |
+
print(f"appended {n_codebooks_to_append} codebooks to z")
|
305 |
|
306 |
# set the mask to 0 for all conditioning codebooks
|
307 |
if mask is not None:
|
|
|
324 |
fine_z.append(chunk)
|
325 |
|
326 |
fine_z = torch.cat(fine_z, dim=-1)
|
327 |
+
if return_mask:
|
328 |
+
return fine_z[:, :, :length].clone(), apply_mask(fine_z, mask, self.c2f.mask_token)[0][:, :, :length].clone()
|
329 |
+
|
330 |
return fine_z[:, :, :length].clone()
|
331 |
|
332 |
def coarse_vamp(
|
|
|
339 |
):
|
340 |
# coarse z
|
341 |
cz = z[:, : self.coarse.n_codebooks, :].clone()
|
|
|
|
|
342 |
mask = mask[:, : self.coarse.n_codebooks, :]
|
343 |
+
# assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
|
344 |
|
345 |
+
# cut into chunks, keep the last chunk separate if it's too small
|
346 |
+
chunk_len = self.s2t(self.coarse.chunk_size_s)
|
347 |
+
n_chunks = math.ceil(cz.shape[-1] / chunk_len)
|
348 |
+
last_chunk_len = cz.shape[-1] % chunk_len
|
349 |
+
|
350 |
+
cz_chunks = []
|
351 |
+
mask_chunks = []
|
352 |
+
for i in range(n_chunks):
|
353 |
+
chunk = cz[:, :, i * chunk_len : (i + 1) * chunk_len]
|
354 |
+
mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len]
|
355 |
+
|
356 |
+
# make sure that the very first and last timestep of each chunk is 0 so that we don't get a weird
|
357 |
+
# discontinuity when we stitch the chunks back together
|
358 |
+
# only if there's already a 0 somewhere in the chunk
|
359 |
+
if torch.any(mask_chunk == 0):
|
360 |
+
mask_chunk[:, :, 0] = 0
|
361 |
+
mask_chunk[:, :, -1] = 0
|
362 |
+
|
363 |
+
cz_chunks.append(chunk)
|
364 |
+
mask_chunks.append(mask_chunk)
|
365 |
+
|
366 |
+
# now vamp each chunk
|
367 |
+
cz_masked_chunks = []
|
368 |
+
cz_vamped_chunks = []
|
369 |
+
for chunk, mask_chunk in zip(cz_chunks, mask_chunks):
|
370 |
+
cz_masked_chunk, mask_chunk = apply_mask(chunk, mask_chunk, self.coarse.mask_token)
|
371 |
+
cz_masked_chunk = cz_masked_chunk[:, : self.coarse.n_codebooks, :]
|
372 |
+
cz_masked_chunks.append(cz_masked_chunk)
|
373 |
+
|
374 |
+
gen_fn = gen_fn or self.coarse.generate
|
375 |
+
c_vamp_chunk = gen_fn(
|
376 |
+
codec=self.codec,
|
377 |
+
time_steps=chunk_len,
|
378 |
+
start_tokens=cz_masked_chunk,
|
379 |
+
return_signal=False,
|
380 |
+
mask=mask_chunk,
|
381 |
+
**kwargs
|
382 |
+
)
|
383 |
+
cz_vamped_chunks.append(c_vamp_chunk)
|
384 |
+
|
385 |
+
# stitch the chunks back together
|
386 |
+
cz_masked = torch.cat(cz_masked_chunks, dim=-1)
|
387 |
+
c_vamp = torch.cat(cz_vamped_chunks, dim=-1)
|
388 |
|
389 |
# add the fine codes back in
|
390 |
c_vamp = torch.cat(
|
|
|
396 |
return c_vamp, cz_masked
|
397 |
|
398 |
return c_vamp
|
399 |
+
|
400 |
+
def build_mask(self,
|
401 |
+
z: torch.Tensor,
|
402 |
+
sig: AudioSignal = None,
|
403 |
+
rand_mask_intensity: float = 1.0,
|
404 |
+
prefix_s: float = 0.0,
|
405 |
+
suffix_s: float = 0.0,
|
406 |
+
periodic_prompt: int = 7,
|
407 |
+
periodic_prompt2: int = 7,
|
408 |
+
periodic_prompt_width: int = 1,
|
409 |
+
onset_mask_width: int = 0,
|
410 |
+
_dropout: float = 0.0,
|
411 |
+
upper_codebook_mask: int = 3,
|
412 |
+
upper_codebook_mask_2: int = None,
|
413 |
+
ncc: int = 0,
|
414 |
+
):
|
415 |
+
|
416 |
+
mask = linear_random(z, rand_mask_intensity)
|
417 |
+
mask = mask_and(
|
418 |
+
mask,
|
419 |
+
inpaint(z, self.s2t(prefix_s), self.s2t(suffix_s)),
|
420 |
+
)
|
421 |
|
422 |
+
pmask1 = periodic_mask(z, periodic_prompt, periodic_prompt_width, random_roll=True)
|
423 |
+
pmask2 = periodic_mask(z, periodic_prompt2, periodic_prompt_width, random_roll=True)
|
424 |
+
# interpolate the two masks
|
425 |
+
pmask = torch.round(
|
426 |
+
pmask1 * torch.linspace(1, 0, pmask1.shape[-1], device=pmask1.device) +
|
427 |
+
pmask2 * torch.linspace(0, 1, pmask2.shape[-1], device=pmask2.device)
|
428 |
+
).long()
|
429 |
+
|
430 |
+
mask = mask_and(mask, pmask)
|
431 |
+
|
432 |
+
if onset_mask_width > 0:
|
433 |
+
assert sig is not None, f"must provide a signal to use onset mask"
|
434 |
+
mask = mask_or(
|
435 |
+
mask, onset_mask(
|
436 |
+
sig, z, interface,
|
437 |
+
width=onset_mask_width
|
438 |
+
)
|
439 |
+
)
|
440 |
+
|
441 |
+
mask = dropout(mask, _dropout)
|
442 |
+
mask = codebook_unmask(mask, ncc)
|
443 |
+
|
444 |
+
mask = codebook_mask(mask, int(upper_codebook_mask), upper_codebook_mask_2)
|
445 |
+
return mask
|
446 |
+
|
447 |
+
def ez_vamp(
|
448 |
+
self,
|
449 |
+
sig: AudioSignal,
|
450 |
+
batch_size: int = 4,
|
451 |
+
feedback_steps: int = 1,
|
452 |
+
time_stretch_factor: int = 1,
|
453 |
+
return_mask: bool = False,
|
454 |
+
build_mask_kwargs: dict = None,
|
455 |
+
vamp_kwargs: dict = None,
|
456 |
+
):
|
457 |
+
feedback_steps = int(feedback_steps)
|
458 |
+
build_mask_kwargs = build_mask_kwargs or {}
|
459 |
+
vamp_kwargs = vamp_kwargs or {}
|
460 |
+
|
461 |
+
loudness = sig.loudness()
|
462 |
+
sig = self.preprocess(sig)
|
463 |
+
|
464 |
+
z = self.encode(sig)
|
465 |
+
|
466 |
+
# expand z to batch size
|
467 |
+
z = z.expand(batch_size, -1, -1)
|
468 |
+
mask = self.build_mask(
|
469 |
+
z=z,
|
470 |
+
**build_mask_kwargs
|
471 |
+
)
|
472 |
+
mask = mask.expand(batch_size, -1, -1)
|
473 |
+
|
474 |
+
# stretch mask and z to match the time stretch factor
|
475 |
+
# we'll add (stretch_factor - 1) mask tokens in between each timestep of z
|
476 |
+
# and we'll make the mask 1 in all the new slots we added
|
477 |
+
if time_stretch_factor > 1:
|
478 |
+
z = z.repeat_interleave(time_stretch_factor, dim=-1)
|
479 |
+
mask = mask.repeat_interleave(time_stretch_factor, dim=-1)
|
480 |
+
added_mask = torch.ones_like(mask)
|
481 |
+
added_mask[:, :, ::time_stretch_factor] = 0
|
482 |
+
mask = mask.bool() | added_mask.bool()
|
483 |
+
mask = mask.long()
|
484 |
+
|
485 |
+
prev_zvs = []
|
486 |
+
for i in tqdm.tqdm(range(feedback_steps), desc="feedback steps"):
|
487 |
+
print(z.shape)
|
488 |
+
|
489 |
+
vamp_kwargs.pop("mask", None)
|
490 |
+
vamp_kwargs.pop('return_mask', None)
|
491 |
+
print("coarse!")
|
492 |
+
zv, mask_z = self.coarse_vamp(
|
493 |
+
z,
|
494 |
+
mask=mask,
|
495 |
+
return_mask=True,
|
496 |
+
**vamp_kwargs
|
497 |
+
)
|
498 |
+
|
499 |
+
# add the top codebooks back in
|
500 |
+
if zv.shape[1] < z.shape[1]:
|
501 |
+
print(f"adding {z.shape[1] - zv.shape[1]} codebooks back in")
|
502 |
+
zv = torch.cat(
|
503 |
+
[zv, z[:, self.coarse.n_codebooks :, :]],
|
504 |
+
dim=1
|
505 |
+
)
|
506 |
+
|
507 |
+
# now, coarse2fine
|
508 |
+
print(f"coarse2fine!")
|
509 |
+
zv, fine_zv_mask = self.coarse_to_fine(
|
510 |
+
zv,
|
511 |
+
mask=mask,
|
512 |
+
**vamp_kwargs,
|
513 |
+
_sampling_steps=[2, 2, 1, 1],
|
514 |
+
return_mask=True
|
515 |
+
)
|
516 |
+
mask_z = torch.cat(
|
517 |
+
[mask_z[:, :self.coarse.n_codebooks, :], fine_zv_mask[:, self.coarse.n_codebooks:, :]],
|
518 |
+
dim=1
|
519 |
+
)
|
520 |
+
|
521 |
+
prev_zvs.append(zv)
|
522 |
+
z = zv
|
523 |
+
|
524 |
+
# perform to_signal batch item by batch
|
525 |
+
sigs = []
|
526 |
+
for zv in prev_zvs:
|
527 |
+
# do it in timestep chunks of 1024
|
528 |
+
_sigs = []
|
529 |
+
for i in range(0, zv.shape[-1], 1024):
|
530 |
+
_sigs.append(self.to_signal(zv[:, :, i:i+1024]).cpu())
|
531 |
+
sigs.append(signal_concat(_sigs))
|
532 |
+
print("done")
|
533 |
+
|
534 |
+
sig = AudioSignal.batch(sigs)
|
535 |
|
536 |
+
# sig = self.to_signal(zv).cpu()
|
537 |
+
# print("done")
|
538 |
+
|
539 |
+
sig = sig.normalize(loudness)
|
540 |
+
|
541 |
+
if return_mask:
|
542 |
+
return sig, mask_z.cpu(), zv.cpu()
|
543 |
+
else:
|
544 |
+
return sig
|
545 |
+
|
546 |
+
def visualize_codes(self, z: torch.Tensor):
|
547 |
+
import matplotlib.pyplot as plt
|
548 |
+
# make sure the figsize is square when imshow is called
|
549 |
+
fig = plt.figure(figsize=(10, 7))
|
550 |
+
# in subplots, plot z[0] and the mask
|
551 |
+
# set title to "codes" and "mask"
|
552 |
+
fig.add_subplot(2, 1, 1)
|
553 |
+
plt.imshow(z[0].cpu().numpy(), aspect='auto', origin='lower', cmap="tab20")
|
554 |
+
plt.title("codes")
|
555 |
+
plt.ylabel("codebook index")
|
556 |
+
# set the xticks to seconds
|
557 |
+
plt.xticks(
|
558 |
+
np.arange(0, z.shape[-1], self.s2t(1)),
|
559 |
+
np.arange(0, self.t2s(z.shape[-1]), 1)
|
560 |
+
)
|
561 |
+
plt.xlabel("time (s)")
|
562 |
|
563 |
if __name__ == "__main__":
|
564 |
import audiotools as at
|
|
|
580 |
sig = at.AudioSignal('assets/example.wav')
|
581 |
|
582 |
z = interface.encode(sig)
|
|
|
|
|
583 |
# mask = linear_random(z, 1.0)
|
584 |
# mask = mask_and(
|
585 |
# mask, periodic_mask(
|
|
|
618 |
sig = interface.to_signal(zv).cpu()
|
619 |
print("done")
|
620 |
|
621 |
+
|
622 |
+
|
623 |
+
|
624 |
+
# example plotting code
|
625 |
+
# import matplotlib.pyplot as plt
|
626 |
+
# from pathlib import Path
|
627 |
+
# Path(".vampnet").mkdir(exist_ok=True)
|
628 |
+
# plt.clf()
|
629 |
+
# # close all figs
|
630 |
+
# plt.close('all')
|
631 |
+
# # set the fig size
|
632 |
+
# plt.subplot(4, 1, 1)
|
633 |
+
# # sig = self.to_signal(sampled_z, codec)
|
634 |
+
# # sig.cpu().specshow()
|
635 |
+
|
636 |
+
# plt.subplot(4, 1, 2)
|
637 |
+
# # since z_masked is a codebook, we want to plot the colormap
|
638 |
+
# # with distinct colors for each codebook index
|
639 |
+
# # plt.imshow(_debug_z_masked_before_forward[0].cpu().numpy(), aspect='auto', origin='lower', cmap="tab20")
|
640 |
+
# # make it so that anywhere where the mask is 1, we make that pixel black
|
641 |
+
# plt.imshow(_debug_z_masked_before_forward[0].cpu().numpy(), aspect='auto', origin='lower', cmap='gray_r',)
|
642 |
+
|
643 |
+
|
644 |
+
# plt.subplot(4, 1, 3)
|
645 |
+
# # plot the mask (which is a matrix)
|
646 |
+
# plt.imshow(mask[0].cpu().numpy(), aspect='auto', origin='lower', cmap='gray_r')
|
647 |
+
# plt.subplot(4, 1, 4)
|
648 |
+
# # replace any inf or -inf with 0
|
649 |
+
# _selected_probs = torch.where(
|
650 |
+
# selected_probs == torch.inf, torch.zeros_like(selected_probs), selected_probs
|
651 |
+
# )
|
652 |
+
# _selected_probs = torch.where(
|
653 |
+
# selected_probs == -torch.inf, torch.zeros_like(selected_probs), selected_probs
|
654 |
+
# )
|
655 |
+
# # fig = plt.gcf()
|
656 |
+
# # fig.set_figheight(15)
|
657 |
+
# # fig.set_figwidth(15)
|
658 |
+
# plt.imshow(codebook_unflatten(_selected_probs, n_infer_codebooks)[0].cpu().numpy(), aspect='auto', origin='lower', cmap="viridis" )
|
659 |
+
# # plt.show()
|
660 |
+
# plt.savefig(f".vampnet/c={codebook_level}_{i}.png")
|
661 |
+
# plt.close('all')
|
vampnet/mask.py
CHANGED
@@ -60,6 +60,7 @@ def linear_random(
|
|
60 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
61 |
if not isinstance(r, torch.Tensor):
|
62 |
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
|
|
|
63 |
|
64 |
probs = torch.ones_like(x).to(x.device).float()
|
65 |
# expand to batch and codebook dims
|
@@ -98,7 +99,7 @@ def inpaint(x: torch.Tensor,
|
|
98 |
return mask
|
99 |
|
100 |
def periodic_mask(x: torch.Tensor,
|
101 |
-
period: int,
|
102 |
random_roll=False,
|
103 |
):
|
104 |
mask = full_mask(x)
|
@@ -140,9 +141,15 @@ def codebook_unmask(
|
|
140 |
mask[:, :n_conditioning_codebooks, :] = 0
|
141 |
return mask
|
142 |
|
143 |
-
def codebook_mask(mask: torch.Tensor,
|
144 |
mask = mask.clone()
|
145 |
-
mask[:,
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
return mask
|
147 |
|
148 |
def mask_and(
|
@@ -239,4 +246,4 @@ def onset_mask(
|
|
239 |
|
240 |
|
241 |
if __name__ == "__main__":
|
242 |
-
pass
|
|
|
60 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
61 |
if not isinstance(r, torch.Tensor):
|
62 |
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
|
63 |
+
r = r[:, None, None]
|
64 |
|
65 |
probs = torch.ones_like(x).to(x.device).float()
|
66 |
# expand to batch and codebook dims
|
|
|
99 |
return mask
|
100 |
|
101 |
def periodic_mask(x: torch.Tensor,
|
102 |
+
period: int,width: int = 1,
|
103 |
random_roll=False,
|
104 |
):
|
105 |
mask = full_mask(x)
|
|
|
141 |
mask[:, :n_conditioning_codebooks, :] = 0
|
142 |
return mask
|
143 |
|
144 |
+
def codebook_mask(mask: torch.Tensor, val1: int, val2: int = None):
|
145 |
mask = mask.clone()
|
146 |
+
mask[:, val1:, :] = 1
|
147 |
+
# val2 = val2 or val1
|
148 |
+
# vs = torch.linspace(val1, val2, mask.shape[1])
|
149 |
+
# for t, v in enumerate(vs):
|
150 |
+
# v = int(v)
|
151 |
+
# mask[:, v:, t] = 1
|
152 |
+
|
153 |
return mask
|
154 |
|
155 |
def mask_and(
|
|
|
246 |
|
247 |
|
248 |
if __name__ == "__main__":
|
249 |
+
pass
|
vampnet/modules/transformer.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import math
|
2 |
import logging
|
3 |
-
from typing import Optional, Tuple, Union
|
4 |
|
5 |
import numpy as np
|
6 |
import torch
|
@@ -572,6 +572,8 @@ class VampNet(at.ml.BaseModel):
|
|
572 |
"""
|
573 |
assert z.ndim == 3
|
574 |
|
|
|
|
|
575 |
signal = at.AudioSignal(
|
576 |
codec.decode(
|
577 |
codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
|
@@ -581,34 +583,279 @@ class VampNet(at.ml.BaseModel):
|
|
581 |
|
582 |
# find where the mask token is and replace it with silence in the audio
|
583 |
for tstep in range(z.shape[-1]):
|
584 |
-
if torch.
|
585 |
sample_idx_0 = tstep * codec.hop_length
|
586 |
sample_idx_1 = sample_idx_0 + codec.hop_length
|
587 |
signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
|
588 |
|
589 |
return signal
|
590 |
-
|
591 |
-
|
592 |
@torch.no_grad()
|
593 |
def generate(
|
594 |
self,
|
595 |
codec,
|
596 |
time_steps: int = 300,
|
597 |
-
|
598 |
start_tokens: Optional[torch.Tensor] = None,
|
599 |
sampling_temperature: float = 1.0,
|
600 |
mask: Optional[torch.Tensor] = None,
|
601 |
mask_temperature: float = 10.5,
|
602 |
-
typical_filtering=
|
603 |
typical_mass=0.2,
|
604 |
typical_min_tokens=1,
|
605 |
-
top_p=
|
606 |
-
return_signal=True,
|
607 |
seed: int = None,
|
608 |
-
sample_cutoff: float =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
609 |
):
|
610 |
if seed is not None:
|
611 |
at.util.seed(seed)
|
|
|
612 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
613 |
|
614 |
|
@@ -763,6 +1010,10 @@ class VampNet(at.ml.BaseModel):
|
|
763 |
else:
|
764 |
return sampled_z
|
765 |
|
|
|
|
|
|
|
|
|
766 |
def sample_from_logits(
|
767 |
logits,
|
768 |
sample: bool = True,
|
@@ -942,12 +1193,11 @@ if __name__ == "__main__":
|
|
942 |
pred = z_hat.argmax(dim=1)
|
943 |
pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
|
944 |
|
945 |
-
|
946 |
-
|
947 |
breakpoint()
|
948 |
|
949 |
args = argbind.parse_args()
|
950 |
with argbind.scope(args):
|
951 |
try_model()
|
952 |
|
953 |
-
|
|
|
1 |
import math
|
2 |
import logging
|
3 |
+
from typing import Optional, Tuple, Union, List
|
4 |
|
5 |
import numpy as np
|
6 |
import torch
|
|
|
572 |
"""
|
573 |
assert z.ndim == 3
|
574 |
|
575 |
+
# remove mask token
|
576 |
+
z = z.masked_fill(z == self.mask_token, 0)
|
577 |
signal = at.AudioSignal(
|
578 |
codec.decode(
|
579 |
codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
|
|
|
583 |
|
584 |
# find where the mask token is and replace it with silence in the audio
|
585 |
for tstep in range(z.shape[-1]):
|
586 |
+
if torch.all(z[:, :, tstep] == self.mask_token):
|
587 |
sample_idx_0 = tstep * codec.hop_length
|
588 |
sample_idx_1 = sample_idx_0 + codec.hop_length
|
589 |
signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
|
590 |
|
591 |
return signal
|
592 |
+
|
|
|
593 |
@torch.no_grad()
|
594 |
def generate(
|
595 |
self,
|
596 |
codec,
|
597 |
time_steps: int = 300,
|
598 |
+
_sampling_steps: List[int] = [16, 8, 8, 2, 2, 2, 2, 1, 1],
|
599 |
start_tokens: Optional[torch.Tensor] = None,
|
600 |
sampling_temperature: float = 1.0,
|
601 |
mask: Optional[torch.Tensor] = None,
|
602 |
mask_temperature: float = 10.5,
|
603 |
+
typical_filtering=True,
|
604 |
typical_mass=0.2,
|
605 |
typical_min_tokens=1,
|
606 |
+
top_p=0.9,
|
|
|
607 |
seed: int = None,
|
608 |
+
sample_cutoff: float = 0.9,
|
609 |
+
return_signal=True,
|
610 |
+
debug=False,
|
611 |
+
causal_weight: float = 0.0,
|
612 |
+
use_og_method: bool = False,
|
613 |
+
):
|
614 |
+
if use_og_method:
|
615 |
+
return self.og_method(
|
616 |
+
codec,
|
617 |
+
time_steps,
|
618 |
+
_sampling_steps,
|
619 |
+
start_tokens,
|
620 |
+
sampling_temperature,
|
621 |
+
mask,
|
622 |
+
mask_temperature,
|
623 |
+
typical_filtering,
|
624 |
+
typical_mass,
|
625 |
+
typical_min_tokens,
|
626 |
+
top_p,
|
627 |
+
seed,
|
628 |
+
sample_cutoff,
|
629 |
+
return_signal,
|
630 |
+
debug,
|
631 |
+
causal_weight,
|
632 |
+
)
|
633 |
+
|
634 |
+
if seed is not None:
|
635 |
+
at.util.seed(seed)
|
636 |
+
|
637 |
+
#####################
|
638 |
+
# resolve initial z #
|
639 |
+
#####################
|
640 |
+
z = start_tokens
|
641 |
+
|
642 |
+
if z is None:
|
643 |
+
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
|
644 |
+
self.device
|
645 |
+
)
|
646 |
+
|
647 |
+
logging.debug(f"created z with shape {z.shape}")
|
648 |
+
|
649 |
+
#################
|
650 |
+
# resolve mask #
|
651 |
+
#################
|
652 |
+
|
653 |
+
if mask is None:
|
654 |
+
mask = torch.ones_like(z).to(self.device).int()
|
655 |
+
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
656 |
+
if mask.ndim == 2:
|
657 |
+
mask = mask[:, None, :].repeat(1, z.shape[1], 1)
|
658 |
+
orig_mask = mask
|
659 |
+
logging.debug(f"created mask with shape {mask.shape}")
|
660 |
+
|
661 |
+
###########
|
662 |
+
# set up #
|
663 |
+
##########
|
664 |
+
# apply the mask to z
|
665 |
+
z_masked = z.masked_fill(mask.bool(), self.mask_token)
|
666 |
+
|
667 |
+
# how many codebooks are we inferring vs conditioning on?
|
668 |
+
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
669 |
+
logging.debug(f"n infer codebooks: {n_infer_codebooks}")
|
670 |
+
|
671 |
+
#################
|
672 |
+
# begin sampling #
|
673 |
+
#################
|
674 |
+
# add one sampling step for each codebook level
|
675 |
+
logging.debug(f"initial mask: {mask}")
|
676 |
+
logging.debug(f"adding {n_infer_codebooks} sampling steps")
|
677 |
+
steps = _sampling_steps + [1 for _ in range(n_infer_codebooks - len(_sampling_steps))]
|
678 |
+
# truncate if we have too many
|
679 |
+
steps = steps[:n_infer_codebooks]
|
680 |
+
for codebook_level, nsteps in enumerate(steps):
|
681 |
+
|
682 |
+
# apply the orig mask to z_masked, only in the current codebook level
|
683 |
+
# this is crucial due to the stemgen random masking we did during training
|
684 |
+
# which ensures all upper codebooks are masked while inferring the bottom ones.
|
685 |
+
z_masked[:, codebook_level, :] = torch.where(
|
686 |
+
orig_mask[:, codebook_level, :].bool(),
|
687 |
+
self.mask_token,
|
688 |
+
z_masked[:, codebook_level, :]
|
689 |
+
)
|
690 |
+
|
691 |
+
# how many mask tokens to begin with?
|
692 |
+
num_mask_tokens_at_start = (z_masked[:, codebook_level, :] == self.mask_token).sum(dim=-1)
|
693 |
+
logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
|
694 |
+
|
695 |
+
for i in range(nsteps):
|
696 |
+
logging.debug(f"processing cb level {codebook_level} of {len(steps)}")
|
697 |
+
logging.debug(f"step {i} of {nsteps}")
|
698 |
+
|
699 |
+
# our current schedule step
|
700 |
+
r = scalar_to_batch_tensor(
|
701 |
+
(i + 1) / nsteps,
|
702 |
+
z.shape[0]
|
703 |
+
).to(z.device)
|
704 |
+
logging.debug(f"r: {r}")
|
705 |
+
|
706 |
+
# get latents
|
707 |
+
latents = self.embedding.from_codes(z_masked, codec)
|
708 |
+
logging.debug(f"computed latents with shape: {latents.shape}")
|
709 |
+
|
710 |
+
# infer from latents
|
711 |
+
# NOTE: this collapses the codebook dimension into the sequence dimension
|
712 |
+
logits = self.forward(
|
713 |
+
latents,
|
714 |
+
) # b, prob, seq
|
715 |
+
logits = logits.permute(0, 2, 1) # b, seq, prob
|
716 |
+
logging.debug(f"permuted logits with shape: {logits.shape}")
|
717 |
+
|
718 |
+
sampled_z, selected_probs = sample_from_logits(
|
719 |
+
logits, sample=(
|
720 |
+
(i / nsteps) <= sample_cutoff
|
721 |
+
),
|
722 |
+
temperature=sampling_temperature,
|
723 |
+
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
724 |
+
typical_min_tokens=typical_min_tokens,
|
725 |
+
top_k=None, top_p=top_p, return_probs=True,
|
726 |
+
)
|
727 |
+
|
728 |
+
# fill selected probs with -inf if we're not in the codebook level we are sampling from
|
729 |
+
# find out which codebook we are sampling from
|
730 |
+
selected_probs = codebook_unflatten(selected_probs, n_infer_codebooks)
|
731 |
+
selected_probs[:, codebook_level+1:, :,] = -float("inf") # all the ones above
|
732 |
+
# selected_probs[:, :codebook_level, :,] = -float("inf")
|
733 |
+
logging.debug(f"masking all but codebook {codebook_level}")
|
734 |
+
logging.debug(f"selected probs: {selected_probs}")
|
735 |
+
logging.debug(mask)
|
736 |
+
selected_probs = codebook_flatten(selected_probs)
|
737 |
+
|
738 |
+
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
739 |
+
|
740 |
+
# flatten z_masked and mask, so we can deal with the sampling logic
|
741 |
+
# we'll unflatten them at the end of the loop for the next forward pass
|
742 |
+
# remove conditioning codebooks, we'll add them back at the end
|
743 |
+
z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
|
744 |
+
|
745 |
+
mask = (z_masked == self.mask_token).int()
|
746 |
+
logging.debug(f"mask now: {mask}")
|
747 |
+
|
748 |
+
# update the mask, remove conditioning codebooks from the mask
|
749 |
+
logging.debug(f"updated mask with shape: {mask.shape}")
|
750 |
+
|
751 |
+
# add z back into sampled z where the mask was false
|
752 |
+
sampled_z = torch.where(
|
753 |
+
mask.bool(), sampled_z, z_masked
|
754 |
+
)
|
755 |
+
logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
|
756 |
+
|
757 |
+
# get the num tokens to mask, according to the schedule
|
758 |
+
num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
|
759 |
+
# num_to_mask = torch.floor(r * num_mask_tokens_at_start).unsqueeze(1).long() # doesn't work at all this way
|
760 |
+
logging.debug(f"num to mask: {num_to_mask}")
|
761 |
+
logging.debug(f"masking {num_to_mask.sum()} tokens")
|
762 |
+
|
763 |
+
if i != (nsteps - 1):
|
764 |
+
mask = codebook_unflatten(mask, n_infer_codebooks)
|
765 |
+
num_to_mask = torch.maximum(
|
766 |
+
torch.tensor(1),
|
767 |
+
torch.minimum(
|
768 |
+
mask[:, codebook_level, :].sum(dim=-1, keepdim=True) - 1,
|
769 |
+
num_to_mask
|
770 |
+
)
|
771 |
+
)
|
772 |
+
logging.debug(f"will mask {num_to_mask.sum()} tokens")
|
773 |
+
mask = codebook_flatten(mask)
|
774 |
+
|
775 |
+
# ignore any tokens that weren't masked
|
776 |
+
selected_probs = torch.where(
|
777 |
+
mask.bool(), selected_probs, torch.inf
|
778 |
+
)
|
779 |
+
|
780 |
+
# add a causal weight to the selected probs
|
781 |
+
# NOTE: some experiments i did showed that this didn't help.
|
782 |
+
# set it to 0 until further eval
|
783 |
+
causal_probs = torch.linspace(1, 0, z_masked.shape[-1], device=z_masked.device)
|
784 |
+
causal_probs = causal_probs.repeat(z_masked.shape[0], 1)
|
785 |
+
selected_probs = selected_probs + causal_probs * causal_weight
|
786 |
+
|
787 |
+
# # get our new mask
|
788 |
+
############
|
789 |
+
mask = codebook_unflatten(mask, n_infer_codebooks)
|
790 |
+
selected_probs = codebook_unflatten(selected_probs, n_infer_codebooks)
|
791 |
+
|
792 |
+
# only consider probs at current level
|
793 |
+
selected_probs_cur_level = selected_probs[:, codebook_level, :]
|
794 |
+
mask_cur_level = mask_by_random_topk(
|
795 |
+
num_to_mask, selected_probs_cur_level, mask_temperature * (1-r.unsqueeze(1))
|
796 |
+
)
|
797 |
+
mask[:, codebook_level, :] = mask_cur_level
|
798 |
+
|
799 |
+
mask = codebook_flatten(mask)
|
800 |
+
selected_probs = codebook_flatten(selected_probs)
|
801 |
+
###############
|
802 |
+
|
803 |
+
|
804 |
+
# update the mask
|
805 |
+
z_masked = torch.where(
|
806 |
+
mask.bool(), self.mask_token, sampled_z
|
807 |
+
)
|
808 |
+
logging.debug(f"updated z_masked with shape: {z_masked.shape}")
|
809 |
+
|
810 |
+
z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
|
811 |
+
mask = codebook_unflatten(mask, n_infer_codebooks)
|
812 |
+
logging.debug(f"unflattened z_masked with shape: {z_masked.shape}")
|
813 |
+
|
814 |
+
# add conditioning codebooks back to z_masked
|
815 |
+
z_masked = torch.cat(
|
816 |
+
(z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
|
817 |
+
)
|
818 |
+
logging.debug(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}")
|
819 |
+
|
820 |
+
|
821 |
+
# add conditioning codebooks back to sampled_z
|
822 |
+
sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
|
823 |
+
sampled_z = torch.cat(
|
824 |
+
(z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
|
825 |
+
)
|
826 |
+
|
827 |
+
logging.debug(f"finished sampling")
|
828 |
+
|
829 |
+
|
830 |
+
if return_signal:
|
831 |
+
return self.to_signal(sampled_z, codec)
|
832 |
+
else:
|
833 |
+
return sampled_z
|
834 |
+
|
835 |
+
|
836 |
+
|
837 |
+
def og_method(
|
838 |
+
self,
|
839 |
+
codec,
|
840 |
+
time_steps: int = 300,
|
841 |
+
_sampling_steps: List[int] = [16, 8, 8, 2, 2, 2, 2, 1, 1],
|
842 |
+
start_tokens: Optional[torch.Tensor] = None,
|
843 |
+
sampling_temperature: float = 1.0,
|
844 |
+
mask: Optional[torch.Tensor] = None,
|
845 |
+
mask_temperature: float = 10.5,
|
846 |
+
typical_filtering=True,
|
847 |
+
typical_mass=0.2,
|
848 |
+
typical_min_tokens=1,
|
849 |
+
top_p=0.9,
|
850 |
+
seed: int = None,
|
851 |
+
sample_cutoff: float = 0.75,
|
852 |
+
return_signal=True,
|
853 |
+
debug=False,
|
854 |
+
causal_weight: float = 0.0,
|
855 |
):
|
856 |
if seed is not None:
|
857 |
at.util.seed(seed)
|
858 |
+
sampling_steps = sum(_sampling_steps)
|
859 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
860 |
|
861 |
|
|
|
1010 |
else:
|
1011 |
return sampled_z
|
1012 |
|
1013 |
+
|
1014 |
+
|
1015 |
+
|
1016 |
+
|
1017 |
def sample_from_logits(
|
1018 |
logits,
|
1019 |
sample: bool = True,
|
|
|
1193 |
pred = z_hat.argmax(dim=1)
|
1194 |
pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
|
1195 |
|
1196 |
+
logging.debug(f"model has {num_params(model)/1e6:<.3f}M parameters")
|
1197 |
+
logging.debug(f"prediction has shape {pred.shape}")
|
1198 |
breakpoint()
|
1199 |
|
1200 |
args = argbind.parse_args()
|
1201 |
with argbind.scope(args):
|
1202 |
try_model()
|
1203 |
|
|