Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
•
9496f0e
1
Parent(s):
308d855
sampling tricks!
Browse files- app.py +49 -11
- vampnet/modules/transformer.py +109 -37
app.py
CHANGED
@@ -97,28 +97,35 @@ def _vamp(data, return_mask=False):
|
|
97 |
mask = pmask.codebook_unmask(mask, ncc)
|
98 |
|
99 |
|
100 |
-
print(
|
|
|
101 |
# save the mask as a txt file
|
102 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
103 |
|
|
|
104 |
zv, mask_z = interface.coarse_vamp(
|
105 |
z,
|
106 |
mask=mask,
|
107 |
sampling_steps=data[num_steps],
|
108 |
-
|
|
|
109 |
return_mask=True,
|
110 |
typical_filtering=data[typical_filtering],
|
111 |
typical_mass=data[typical_mass],
|
112 |
typical_min_tokens=data[typical_min_tokens],
|
|
|
113 |
gen_fn=interface.coarse.generate,
|
|
|
114 |
)
|
115 |
|
116 |
if use_coarse2fine:
|
117 |
zv = interface.coarse_to_fine(
|
118 |
zv,
|
119 |
-
|
|
|
120 |
mask=mask,
|
121 |
-
sampling_steps=data[num_steps]
|
|
|
122 |
)
|
123 |
|
124 |
sig = interface.to_signal(zv).cpu()
|
@@ -152,7 +159,9 @@ def save_vamp(data):
|
|
152 |
sig_out.write(out_dir / "output.wav")
|
153 |
|
154 |
_data = {
|
155 |
-
"
|
|
|
|
|
156 |
"prefix_s": data[prefix_s],
|
157 |
"suffix_s": data[suffix_s],
|
158 |
"rand_mask_intensity": data[rand_mask_intensity],
|
@@ -163,6 +172,7 @@ def save_vamp(data):
|
|
163 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
164 |
"use_coarse2fine": data[use_coarse2fine],
|
165 |
"stretch_factor": data[stretch_factor],
|
|
|
166 |
}
|
167 |
|
168 |
# save with yaml
|
@@ -385,16 +395,28 @@ with gr.Blocks() as demo:
|
|
385 |
value=0.0
|
386 |
)
|
387 |
|
388 |
-
|
389 |
-
label="temperature",
|
390 |
minimum=0.0,
|
391 |
maximum=10.0,
|
392 |
-
value=
|
393 |
)
|
394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
|
396 |
|
397 |
with gr.Accordion("sampling settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
typical_filtering = gr.Checkbox(
|
399 |
label="typical filtering ",
|
400 |
value=False
|
@@ -435,6 +457,18 @@ with gr.Blocks() as demo:
|
|
435 |
value=0.0
|
436 |
)
|
437 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
|
439 |
# mask settings
|
440 |
with gr.Column():
|
@@ -463,7 +497,9 @@ with gr.Blocks() as demo:
|
|
463 |
_inputs = {
|
464 |
input_audio,
|
465 |
num_steps,
|
466 |
-
|
|
|
|
|
467 |
prefix_s, suffix_s,
|
468 |
rand_mask_intensity,
|
469 |
periodic_p, periodic_w,
|
@@ -476,7 +512,9 @@ with gr.Blocks() as demo:
|
|
476 |
typical_mass,
|
477 |
typical_min_tokens,
|
478 |
beat_mask_width,
|
479 |
-
beat_mask_downbeats
|
|
|
|
|
480 |
}
|
481 |
|
482 |
# connect widgets
|
|
|
97 |
mask = pmask.codebook_unmask(mask, ncc)
|
98 |
|
99 |
|
100 |
+
print(data)
|
101 |
+
_top_p = data[top_p] if data[top_p] > 0 else None
|
102 |
# save the mask as a txt file
|
103 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
104 |
|
105 |
+
_seed = data[seed] if data[seed] > 0 else None
|
106 |
zv, mask_z = interface.coarse_vamp(
|
107 |
z,
|
108 |
mask=mask,
|
109 |
sampling_steps=data[num_steps],
|
110 |
+
mask_temperature=data[masktemp]*10,
|
111 |
+
sampling_temperature=data[sampletemp],
|
112 |
return_mask=True,
|
113 |
typical_filtering=data[typical_filtering],
|
114 |
typical_mass=data[typical_mass],
|
115 |
typical_min_tokens=data[typical_min_tokens],
|
116 |
+
top_p=_top_p,
|
117 |
gen_fn=interface.coarse.generate,
|
118 |
+
seed=_seed,
|
119 |
)
|
120 |
|
121 |
if use_coarse2fine:
|
122 |
zv = interface.coarse_to_fine(
|
123 |
zv,
|
124 |
+
mask_temperature=data[masktemp]*10,
|
125 |
+
sampling_temperature=data[sampletemp],
|
126 |
mask=mask,
|
127 |
+
sampling_steps=data[num_steps],
|
128 |
+
seed=_seed,
|
129 |
)
|
130 |
|
131 |
sig = interface.to_signal(zv).cpu()
|
|
|
159 |
sig_out.write(out_dir / "output.wav")
|
160 |
|
161 |
_data = {
|
162 |
+
"masktemp": data[masktemp],
|
163 |
+
"sampletemp": data[sampletemp],
|
164 |
+
"top_p": data[top_p],
|
165 |
"prefix_s": data[prefix_s],
|
166 |
"suffix_s": data[suffix_s],
|
167 |
"rand_mask_intensity": data[rand_mask_intensity],
|
|
|
172 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
173 |
"use_coarse2fine": data[use_coarse2fine],
|
174 |
"stretch_factor": data[stretch_factor],
|
175 |
+
"seed": data[seed],
|
176 |
}
|
177 |
|
178 |
# save with yaml
|
|
|
395 |
value=0.0
|
396 |
)
|
397 |
|
398 |
+
masktemp = gr.Slider(
|
399 |
+
label="mask temperature",
|
400 |
minimum=0.0,
|
401 |
maximum=10.0,
|
402 |
+
value=1.5
|
403 |
)
|
404 |
+
sampletemp = gr.Slider(
|
405 |
+
label="sample temperature",
|
406 |
+
minimum=0.1,
|
407 |
+
maximum=2.0,
|
408 |
+
value=1.0
|
409 |
+
)
|
410 |
+
|
411 |
|
412 |
|
413 |
with gr.Accordion("sampling settings", open=False):
|
414 |
+
top_p = gr.Slider(
|
415 |
+
label="top p (0.0 = off)",
|
416 |
+
minimum=0.0,
|
417 |
+
maximum=1.0,
|
418 |
+
value=0.0
|
419 |
+
)
|
420 |
typical_filtering = gr.Checkbox(
|
421 |
label="typical filtering ",
|
422 |
value=False
|
|
|
457 |
value=0.0
|
458 |
)
|
459 |
|
460 |
+
use_new_trick = gr.Checkbox(
|
461 |
+
label="new trick",
|
462 |
+
value=False
|
463 |
+
)
|
464 |
+
|
465 |
+
seed = gr.Number(
|
466 |
+
label="seed (0 for random)",
|
467 |
+
value=0,
|
468 |
+
precision=0,
|
469 |
+
)
|
470 |
+
|
471 |
+
|
472 |
|
473 |
# mask settings
|
474 |
with gr.Column():
|
|
|
497 |
_inputs = {
|
498 |
input_audio,
|
499 |
num_steps,
|
500 |
+
masktemp,
|
501 |
+
sampletemp,
|
502 |
+
top_p,
|
503 |
prefix_s, suffix_s,
|
504 |
rand_mask_intensity,
|
505 |
periodic_p, periodic_w,
|
|
|
512 |
typical_mass,
|
513 |
typical_min_tokens,
|
514 |
beat_mask_width,
|
515 |
+
beat_mask_downbeats,
|
516 |
+
seed,
|
517 |
+
seed
|
518 |
}
|
519 |
|
520 |
# connect widgets
|
vampnet/modules/transformer.py
CHANGED
@@ -367,6 +367,15 @@ class TransformerLayer(nn.Module):
|
|
367 |
|
368 |
return x, position_bias, encoder_decoder_position_bias
|
369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
class TransformerStack(nn.Module):
|
372 |
def __init__(
|
@@ -580,20 +589,20 @@ class VampNet(at.ml.BaseModel):
|
|
580 |
time_steps: int = 300,
|
581 |
sampling_steps: int = 24,
|
582 |
start_tokens: Optional[torch.Tensor] = None,
|
|
|
583 |
mask: Optional[torch.Tensor] = None,
|
584 |
-
|
585 |
typical_filtering=False,
|
586 |
typical_mass=0.2,
|
587 |
typical_min_tokens=1,
|
|
|
588 |
return_signal=True,
|
|
|
589 |
):
|
|
|
|
|
590 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
591 |
|
592 |
-
#####################
|
593 |
-
# resolve temperature #
|
594 |
-
#####################
|
595 |
-
|
596 |
-
logging.debug(f"temperature: {temperature}")
|
597 |
|
598 |
|
599 |
#####################
|
@@ -641,13 +650,11 @@ class VampNet(at.ml.BaseModel):
|
|
641 |
#################
|
642 |
# begin sampling #
|
643 |
#################
|
|
|
644 |
|
645 |
for i in range(sampling_steps):
|
646 |
logging.debug(f"step {i} of {sampling_steps}")
|
647 |
|
648 |
-
# our current temperature
|
649 |
-
logging.debug(f"temperature: {temperature}")
|
650 |
-
|
651 |
# our current schedule step
|
652 |
r = scalar_to_batch_tensor(
|
653 |
(i + 1) / sampling_steps,
|
@@ -664,39 +671,19 @@ class VampNet(at.ml.BaseModel):
|
|
664 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
665 |
logits = self.forward(latents, r) # b, prob, seq
|
666 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
667 |
-
|
668 |
-
typical_filter(logits,
|
669 |
-
typical_mass=typical_mass,
|
670 |
-
typical_min_tokens=typical_min_tokens
|
671 |
-
)
|
672 |
-
|
673 |
|
674 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
675 |
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
|
677 |
-
# logits2probs
|
678 |
-
probs = torch.softmax(logits, dim=-1)
|
679 |
-
logging.debug(f"computed probs with shape: {probs.shape}")
|
680 |
-
|
681 |
-
|
682 |
-
# sample from logits with multinomial sampling
|
683 |
-
b = probs.shape[0]
|
684 |
-
probs = rearrange(probs, "b seq prob -> (b seq) prob")
|
685 |
-
|
686 |
-
sampled_z = torch.multinomial(probs, 1).squeeze(-1)
|
687 |
-
|
688 |
-
sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
|
689 |
-
probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
|
690 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
691 |
|
692 |
-
# get the confidences: which tokens did we sample?
|
693 |
-
selected_probs = (
|
694 |
-
torch.take_along_dim(
|
695 |
-
probs, sampled_z.long().unsqueeze(-1),
|
696 |
-
dim=-1
|
697 |
-
).squeeze(-1)
|
698 |
-
)
|
699 |
-
|
700 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
701 |
# we'll unflatten them at the end of the loop for the next forward pass
|
702 |
# remove conditioning codebooks, we'll add them back at the end
|
@@ -733,7 +720,7 @@ class VampNet(at.ml.BaseModel):
|
|
733 |
|
734 |
# get our new mask
|
735 |
mask = mask_by_random_topk(
|
736 |
-
num_to_mask, selected_probs,
|
737 |
)
|
738 |
|
739 |
# update the mask
|
@@ -766,6 +753,91 @@ class VampNet(at.ml.BaseModel):
|
|
766 |
else:
|
767 |
return sampled_z
|
768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
769 |
|
770 |
def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
|
771 |
"""
|
|
|
367 |
|
368 |
return x, position_bias, encoder_decoder_position_bias
|
369 |
|
370 |
+
def t_schedule(n_steps, max_temp=1.0, min_temp=0.0, k=1.0):
|
371 |
+
x = np.linspace(0, 1, n_steps)
|
372 |
+
a = (0.5 - min_temp) / (max_temp - min_temp)
|
373 |
+
|
374 |
+
x = (x * 12) - 6
|
375 |
+
x0 = np.log((1 / a - 1) + 1e-5) / k
|
376 |
+
y = (1 / (1 + np.exp(- k *(x-x0))))[::-1]
|
377 |
+
|
378 |
+
return y
|
379 |
|
380 |
class TransformerStack(nn.Module):
|
381 |
def __init__(
|
|
|
589 |
time_steps: int = 300,
|
590 |
sampling_steps: int = 24,
|
591 |
start_tokens: Optional[torch.Tensor] = None,
|
592 |
+
sampling_temperature: float = 1.0,
|
593 |
mask: Optional[torch.Tensor] = None,
|
594 |
+
mask_temperature: float = 20.5,
|
595 |
typical_filtering=False,
|
596 |
typical_mass=0.2,
|
597 |
typical_min_tokens=1,
|
598 |
+
top_p=None,
|
599 |
return_signal=True,
|
600 |
+
seed: int = None
|
601 |
):
|
602 |
+
if seed is not None:
|
603 |
+
at.util.seed(seed)
|
604 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
605 |
|
|
|
|
|
|
|
|
|
|
|
606 |
|
607 |
|
608 |
#####################
|
|
|
650 |
#################
|
651 |
# begin sampling #
|
652 |
#################
|
653 |
+
t_sched = t_schedule(sampling_steps, max_temp=sampling_temperature)
|
654 |
|
655 |
for i in range(sampling_steps):
|
656 |
logging.debug(f"step {i} of {sampling_steps}")
|
657 |
|
|
|
|
|
|
|
658 |
# our current schedule step
|
659 |
r = scalar_to_batch_tensor(
|
660 |
(i + 1) / sampling_steps,
|
|
|
671 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
672 |
logits = self.forward(latents, r) # b, prob, seq
|
673 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
674 |
+
b = logits.shape[0]
|
|
|
|
|
|
|
|
|
|
|
675 |
|
676 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
677 |
|
678 |
+
sampled_z, selected_probs = sample_from_logits(
|
679 |
+
logits, sample=True, temperature=t_sched[i],
|
680 |
+
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
681 |
+
typical_min_tokens=typical_min_tokens,
|
682 |
+
top_k=None, top_p=top_p, return_probs=True
|
683 |
+
)
|
684 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
685 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
686 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
687 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
688 |
# we'll unflatten them at the end of the loop for the next forward pass
|
689 |
# remove conditioning codebooks, we'll add them back at the end
|
|
|
720 |
|
721 |
# get our new mask
|
722 |
mask = mask_by_random_topk(
|
723 |
+
num_to_mask, selected_probs, mask_temperature * (1-r)
|
724 |
)
|
725 |
|
726 |
# update the mask
|
|
|
753 |
else:
|
754 |
return sampled_z
|
755 |
|
756 |
+
def sample_from_logits(
|
757 |
+
logits,
|
758 |
+
sample: bool = True,
|
759 |
+
temperature: float = 1.0,
|
760 |
+
top_k: int = None,
|
761 |
+
top_p: float = None,
|
762 |
+
typical_filtering: bool = False,
|
763 |
+
typical_mass: float = 0.2,
|
764 |
+
typical_min_tokens: int = 1,
|
765 |
+
return_probs: bool = False
|
766 |
+
):
|
767 |
+
"""Convenience function to sample from a categorial distribution with input as
|
768 |
+
unnormalized logits.
|
769 |
+
|
770 |
+
Parameters
|
771 |
+
----------
|
772 |
+
logits : Tensor[..., vocab_size]
|
773 |
+
config: SamplingConfig
|
774 |
+
The set of hyperparameters to be used for sampling
|
775 |
+
sample : bool, optional
|
776 |
+
Whether to perform multinomial sampling, by default True
|
777 |
+
temperature : float, optional
|
778 |
+
Scaling parameter when multinomial samping, by default 1.0
|
779 |
+
top_k : int, optional
|
780 |
+
Restricts sampling to only `top_k` values acc. to probability,
|
781 |
+
by default None
|
782 |
+
top_p : float, optional
|
783 |
+
Restricts sampling to only those values with cumulative
|
784 |
+
probability = `top_p`, by default None
|
785 |
+
|
786 |
+
Returns
|
787 |
+
-------
|
788 |
+
Tensor[...]
|
789 |
+
Sampled tokens
|
790 |
+
"""
|
791 |
+
shp = logits.shape[:-1]
|
792 |
+
|
793 |
+
if typical_filtering:
|
794 |
+
typical_filter(logits,
|
795 |
+
typical_mass=typical_mass,
|
796 |
+
typical_min_tokens=typical_min_tokens
|
797 |
+
)
|
798 |
+
|
799 |
+
# Apply top_k sampling
|
800 |
+
if top_k is not None:
|
801 |
+
v, _ = logits.topk(top_k)
|
802 |
+
logits[logits < v[..., [-1]]] = -float("inf")
|
803 |
+
|
804 |
+
# Apply top_p (nucleus) sampling
|
805 |
+
if top_p is not None and top_p < 1.0:
|
806 |
+
v, sorted_indices = logits.sort(descending=True)
|
807 |
+
cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
|
808 |
+
|
809 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
810 |
+
# Right shift indices_to_remove to keep 1st token over threshold
|
811 |
+
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
|
812 |
+
..., :-1
|
813 |
+
]
|
814 |
+
|
815 |
+
# Compute indices_to_remove in unsorted array
|
816 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
817 |
+
-1, sorted_indices, sorted_indices_to_remove
|
818 |
+
)
|
819 |
+
|
820 |
+
logits[indices_to_remove] = -float("inf")
|
821 |
+
|
822 |
+
# Perform multinomial sampling after normalizing logits
|
823 |
+
probs = (
|
824 |
+
F.softmax(logits / temperature, dim=-1)
|
825 |
+
if temperature > 0
|
826 |
+
else logits.softmax(dim=-1)
|
827 |
+
)
|
828 |
+
token = (
|
829 |
+
probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
|
830 |
+
if sample
|
831 |
+
else logits.argmax(-1)
|
832 |
+
)
|
833 |
+
|
834 |
+
if return_probs:
|
835 |
+
token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
|
836 |
+
return token, token_probs
|
837 |
+
else:
|
838 |
+
return token
|
839 |
+
|
840 |
+
|
841 |
|
842 |
def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
|
843 |
"""
|