Spaces:
Runtime error
Runtime error
Hugo Flores Garcia
commited on
Commit
·
1062aec
1
Parent(s):
1fedcf3
sampling tricks, fix audiotools pin
Browse files- .gitignore +2 -0
- app.py +53 -14
- requirements.txt +1 -1
- scripts/exp/train.py +7 -5
- vampnet/modules/transformer.py +109 -37
.gitignore
CHANGED
@@ -175,6 +175,7 @@ lyrebird-audio-codec
|
|
175 |
samples-*/**
|
176 |
|
177 |
gradio-outputs/
|
|
|
178 |
samples*/
|
179 |
models-all/
|
180 |
models.zip
|
@@ -183,3 +184,4 @@ descript-audio-codec/
|
|
183 |
# *.pth
|
184 |
.git-old
|
185 |
conf/generated/*
|
|
|
|
175 |
samples-*/**
|
176 |
|
177 |
gradio-outputs/
|
178 |
+
models/
|
179 |
samples*/
|
180 |
models-all/
|
181 |
models.zip
|
|
|
184 |
# *.pth
|
185 |
.git-old
|
186 |
conf/generated/*
|
187 |
+
runs*/
|
app.py
CHANGED
@@ -107,24 +107,36 @@ def _vamp(data, return_mask=False):
|
|
107 |
mask = pmask.codebook_unmask(mask, ncc)
|
108 |
|
109 |
|
110 |
-
print(
|
|
|
111 |
# save the mask as a txt file
|
112 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
113 |
|
|
|
114 |
zv, mask_z = interface.coarse_vamp(
|
115 |
z,
|
116 |
mask=mask,
|
117 |
sampling_steps=data[num_steps],
|
118 |
-
|
|
|
119 |
return_mask=True,
|
120 |
typical_filtering=data[typical_filtering],
|
121 |
typical_mass=data[typical_mass],
|
122 |
typical_min_tokens=data[typical_min_tokens],
|
|
|
123 |
gen_fn=interface.coarse.generate,
|
|
|
124 |
)
|
125 |
|
126 |
if use_coarse2fine:
|
127 |
-
zv = interface.coarse_to_fine(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
sig = interface.to_signal(zv).cpu()
|
130 |
print("done")
|
@@ -157,7 +169,9 @@ def save_vamp(data):
|
|
157 |
sig_out.write(out_dir / "output.wav")
|
158 |
|
159 |
_data = {
|
160 |
-
"
|
|
|
|
|
161 |
"prefix_s": data[prefix_s],
|
162 |
"suffix_s": data[suffix_s],
|
163 |
"rand_mask_intensity": data[rand_mask_intensity],
|
@@ -168,6 +182,7 @@ def save_vamp(data):
|
|
168 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
169 |
"use_coarse2fine": data[use_coarse2fine],
|
170 |
"stretch_factor": data[stretch_factor],
|
|
|
171 |
}
|
172 |
|
173 |
# save with yaml
|
@@ -183,13 +198,14 @@ def save_vamp(data):
|
|
183 |
return f"saved! your save code is {out_dir.stem}", zip_path
|
184 |
|
185 |
|
|
|
186 |
with gr.Blocks() as demo:
|
187 |
|
188 |
with gr.Row():
|
189 |
with gr.Column():
|
190 |
-
gr.Markdown("# VampNet")
|
191 |
gr.Markdown("""## Description:
|
192 |
-
This is a demo of VampNet, a
|
193 |
You can control the extent and nature of variation with a set of manual controls and presets.
|
194 |
Use this interface to experiment with different mask settings and explore the audio outputs.
|
195 |
""")
|
@@ -197,8 +213,8 @@ with gr.Blocks() as demo:
|
|
197 |
gr.Markdown("""
|
198 |
## Instructions:
|
199 |
1. You can start by uploading some audio, or by loading the example audio.
|
200 |
-
2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
|
201 |
-
3. Click the "generate (vamp)!!!" button to
|
202 |
4. Optionally, you can add some notes and save the result.
|
203 |
5. You can also use the output as the new input and continue experimenting!
|
204 |
""")
|
@@ -377,16 +393,28 @@ with gr.Blocks() as demo:
|
|
377 |
value=0.0
|
378 |
)
|
379 |
|
380 |
-
|
381 |
-
label="temperature",
|
382 |
minimum=0.0,
|
383 |
maximum=10.0,
|
384 |
-
value=1.
|
385 |
)
|
386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
|
388 |
|
389 |
with gr.Accordion("sampling settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
typical_filtering = gr.Checkbox(
|
391 |
label="typical filtering ",
|
392 |
value=False
|
@@ -428,6 +456,14 @@ with gr.Blocks() as demo:
|
|
428 |
)
|
429 |
|
430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
# mask settings
|
432 |
with gr.Column():
|
433 |
vamp_button = gr.Button("generate (vamp)!!!")
|
@@ -455,7 +491,9 @@ with gr.Blocks() as demo:
|
|
455 |
_inputs = {
|
456 |
input_audio,
|
457 |
num_steps,
|
458 |
-
|
|
|
|
|
459 |
prefix_s, suffix_s,
|
460 |
rand_mask_intensity,
|
461 |
periodic_p, periodic_w,
|
@@ -468,6 +506,7 @@ with gr.Blocks() as demo:
|
|
468 |
typical_mass,
|
469 |
typical_min_tokens,
|
470 |
beat_mask_width,
|
|
|
471 |
beat_mask_downbeats
|
472 |
}
|
473 |
|
@@ -498,4 +537,4 @@ with gr.Blocks() as demo:
|
|
498 |
outputs=[thank_you, download_file]
|
499 |
)
|
500 |
|
501 |
-
demo.
|
|
|
107 |
mask = pmask.codebook_unmask(mask, ncc)
|
108 |
|
109 |
|
110 |
+
print(data)
|
111 |
+
_top_p = data[top_p] if data[top_p] > 0 else None
|
112 |
# save the mask as a txt file
|
113 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
114 |
|
115 |
+
_seed = data[seed] if data[seed] > 0 else None
|
116 |
zv, mask_z = interface.coarse_vamp(
|
117 |
z,
|
118 |
mask=mask,
|
119 |
sampling_steps=data[num_steps],
|
120 |
+
mask_temperature=data[masktemp]*10,
|
121 |
+
sampling_temperature=data[sampletemp],
|
122 |
return_mask=True,
|
123 |
typical_filtering=data[typical_filtering],
|
124 |
typical_mass=data[typical_mass],
|
125 |
typical_min_tokens=data[typical_min_tokens],
|
126 |
+
top_p=_top_p,
|
127 |
gen_fn=interface.coarse.generate,
|
128 |
+
seed=_seed,
|
129 |
)
|
130 |
|
131 |
if use_coarse2fine:
|
132 |
+
zv = interface.coarse_to_fine(
|
133 |
+
zv,
|
134 |
+
mask_temperature=data[masktemp]*10,
|
135 |
+
sampling_temperature=data[sampletemp],
|
136 |
+
mask=mask,
|
137 |
+
sampling_steps=data[num_steps],
|
138 |
+
seed=_seed,
|
139 |
+
)
|
140 |
|
141 |
sig = interface.to_signal(zv).cpu()
|
142 |
print("done")
|
|
|
169 |
sig_out.write(out_dir / "output.wav")
|
170 |
|
171 |
_data = {
|
172 |
+
"masktemp": data[masktemp],
|
173 |
+
"sampletemp": data[sampletemp],
|
174 |
+
"top_p": data[top_p],
|
175 |
"prefix_s": data[prefix_s],
|
176 |
"suffix_s": data[suffix_s],
|
177 |
"rand_mask_intensity": data[rand_mask_intensity],
|
|
|
182 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
183 |
"use_coarse2fine": data[use_coarse2fine],
|
184 |
"stretch_factor": data[stretch_factor],
|
185 |
+
"seed": data[seed],
|
186 |
}
|
187 |
|
188 |
# save with yaml
|
|
|
198 |
return f"saved! your save code is {out_dir.stem}", zip_path
|
199 |
|
200 |
|
201 |
+
|
202 |
with gr.Blocks() as demo:
|
203 |
|
204 |
with gr.Row():
|
205 |
with gr.Column():
|
206 |
+
gr.Markdown("# VampNet Audio Vamping")
|
207 |
gr.Markdown("""## Description:
|
208 |
+
This is a demo of the VampNet, a generative audio model that transforms the input audio based on the chosen settings.
|
209 |
You can control the extent and nature of variation with a set of manual controls and presets.
|
210 |
Use this interface to experiment with different mask settings and explore the audio outputs.
|
211 |
""")
|
|
|
213 |
gr.Markdown("""
|
214 |
## Instructions:
|
215 |
1. You can start by uploading some audio, or by loading the example audio.
|
216 |
+
2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
|
217 |
+
3. Click the "generate (vamp)!!!" button to apply the vamp operation. Listen to the output audio.
|
218 |
4. Optionally, you can add some notes and save the result.
|
219 |
5. You can also use the output as the new input and continue experimenting!
|
220 |
""")
|
|
|
393 |
value=0.0
|
394 |
)
|
395 |
|
396 |
+
masktemp = gr.Slider(
|
397 |
+
label="mask temperature",
|
398 |
minimum=0.0,
|
399 |
maximum=10.0,
|
400 |
+
value=1.5
|
401 |
)
|
402 |
+
sampletemp = gr.Slider(
|
403 |
+
label="sample temperature",
|
404 |
+
minimum=0.1,
|
405 |
+
maximum=2.0,
|
406 |
+
value=1.0
|
407 |
+
)
|
408 |
+
|
409 |
|
410 |
|
411 |
with gr.Accordion("sampling settings", open=False):
|
412 |
+
top_p = gr.Slider(
|
413 |
+
label="top p (0.0 = off)",
|
414 |
+
minimum=0.0,
|
415 |
+
maximum=1.0,
|
416 |
+
value=0.0
|
417 |
+
)
|
418 |
typical_filtering = gr.Checkbox(
|
419 |
label="typical filtering ",
|
420 |
value=False
|
|
|
456 |
)
|
457 |
|
458 |
|
459 |
+
seed = gr.Number(
|
460 |
+
label="seed (0 for random)",
|
461 |
+
value=0,
|
462 |
+
precision=0,
|
463 |
+
)
|
464 |
+
|
465 |
+
|
466 |
+
|
467 |
# mask settings
|
468 |
with gr.Column():
|
469 |
vamp_button = gr.Button("generate (vamp)!!!")
|
|
|
491 |
_inputs = {
|
492 |
input_audio,
|
493 |
num_steps,
|
494 |
+
masktemp,
|
495 |
+
sampletemp,
|
496 |
+
top_p,
|
497 |
prefix_s, suffix_s,
|
498 |
rand_mask_intensity,
|
499 |
periodic_p, periodic_w,
|
|
|
506 |
typical_mass,
|
507 |
typical_min_tokens,
|
508 |
beat_mask_width,
|
509 |
+
seed,
|
510 |
beat_mask_downbeats
|
511 |
}
|
512 |
|
|
|
537 |
outputs=[thank_you, download_file]
|
538 |
)
|
539 |
|
540 |
+
demo.launch(share=True, enable_queue=False, debug=True)
|
requirements.txt
CHANGED
@@ -5,4 +5,4 @@ gradio
|
|
5 |
loralib
|
6 |
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
7 |
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
8 |
-
audiotools @ git+https://github.com/
|
|
|
5 |
loralib
|
6 |
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
7 |
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
8 |
+
descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2
|
scripts/exp/train.py
CHANGED
@@ -485,7 +485,6 @@ def load(
|
|
485 |
save_path: str,
|
486 |
resume: bool = False,
|
487 |
tag: str = "latest",
|
488 |
-
load_weights: bool = False,
|
489 |
fine_tune_checkpoint: Optional[str] = None,
|
490 |
grad_clip_val: float = 5.0,
|
491 |
) -> State:
|
@@ -498,7 +497,7 @@ def load(
|
|
498 |
kwargs = {
|
499 |
"folder": f"{save_path}/{tag}",
|
500 |
"map_location": "cpu",
|
501 |
-
"package":
|
502 |
}
|
503 |
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
504 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
@@ -511,11 +510,14 @@ def load(
|
|
511 |
|
512 |
if args["fine_tune"]:
|
513 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
514 |
-
model =
|
515 |
-
|
|
|
|
|
|
|
516 |
|
517 |
-
model = VampNet() if model is None else model
|
518 |
|
|
|
519 |
model = accel.prepare_model(model)
|
520 |
|
521 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
|
|
485 |
save_path: str,
|
486 |
resume: bool = False,
|
487 |
tag: str = "latest",
|
|
|
488 |
fine_tune_checkpoint: Optional[str] = None,
|
489 |
grad_clip_val: float = 5.0,
|
490 |
) -> State:
|
|
|
497 |
kwargs = {
|
498 |
"folder": f"{save_path}/{tag}",
|
499 |
"map_location": "cpu",
|
500 |
+
"package": False,
|
501 |
}
|
502 |
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
503 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
|
|
510 |
|
511 |
if args["fine_tune"]:
|
512 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
513 |
+
model = torch.compile(
|
514 |
+
VampNet.load(location=Path(fine_tune_checkpoint),
|
515 |
+
map_location="cpu",
|
516 |
+
)
|
517 |
+
)
|
518 |
|
|
|
519 |
|
520 |
+
model = torch.compile(VampNet()) if model is None else model
|
521 |
model = accel.prepare_model(model)
|
522 |
|
523 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
vampnet/modules/transformer.py
CHANGED
@@ -367,6 +367,15 @@ class TransformerLayer(nn.Module):
|
|
367 |
|
368 |
return x, position_bias, encoder_decoder_position_bias
|
369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
class TransformerStack(nn.Module):
|
372 |
def __init__(
|
@@ -580,20 +589,20 @@ class VampNet(at.ml.BaseModel):
|
|
580 |
time_steps: int = 300,
|
581 |
sampling_steps: int = 24,
|
582 |
start_tokens: Optional[torch.Tensor] = None,
|
|
|
583 |
mask: Optional[torch.Tensor] = None,
|
584 |
-
|
585 |
typical_filtering=False,
|
586 |
typical_mass=0.2,
|
587 |
typical_min_tokens=1,
|
|
|
588 |
return_signal=True,
|
|
|
589 |
):
|
|
|
|
|
590 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
591 |
|
592 |
-
#####################
|
593 |
-
# resolve temperature #
|
594 |
-
#####################
|
595 |
-
|
596 |
-
logging.debug(f"temperature: {temperature}")
|
597 |
|
598 |
|
599 |
#####################
|
@@ -641,13 +650,11 @@ class VampNet(at.ml.BaseModel):
|
|
641 |
#################
|
642 |
# begin sampling #
|
643 |
#################
|
|
|
644 |
|
645 |
for i in range(sampling_steps):
|
646 |
logging.debug(f"step {i} of {sampling_steps}")
|
647 |
|
648 |
-
# our current temperature
|
649 |
-
logging.debug(f"temperature: {temperature}")
|
650 |
-
|
651 |
# our current schedule step
|
652 |
r = scalar_to_batch_tensor(
|
653 |
(i + 1) / sampling_steps,
|
@@ -664,39 +671,19 @@ class VampNet(at.ml.BaseModel):
|
|
664 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
665 |
logits = self.forward(latents, r) # b, prob, seq
|
666 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
667 |
-
|
668 |
-
typical_filter(logits,
|
669 |
-
typical_mass=typical_mass,
|
670 |
-
typical_min_tokens=typical_min_tokens
|
671 |
-
)
|
672 |
-
|
673 |
|
674 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
675 |
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
|
677 |
-
# logits2probs
|
678 |
-
probs = torch.softmax(logits, dim=-1)
|
679 |
-
logging.debug(f"computed probs with shape: {probs.shape}")
|
680 |
-
|
681 |
-
|
682 |
-
# sample from logits with multinomial sampling
|
683 |
-
b = probs.shape[0]
|
684 |
-
probs = rearrange(probs, "b seq prob -> (b seq) prob")
|
685 |
-
|
686 |
-
sampled_z = torch.multinomial(probs, 1).squeeze(-1)
|
687 |
-
|
688 |
-
sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
|
689 |
-
probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
|
690 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
691 |
|
692 |
-
# get the confidences: which tokens did we sample?
|
693 |
-
selected_probs = (
|
694 |
-
torch.take_along_dim(
|
695 |
-
probs, sampled_z.long().unsqueeze(-1),
|
696 |
-
dim=-1
|
697 |
-
).squeeze(-1)
|
698 |
-
)
|
699 |
-
|
700 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
701 |
# we'll unflatten them at the end of the loop for the next forward pass
|
702 |
# remove conditioning codebooks, we'll add them back at the end
|
@@ -733,7 +720,7 @@ class VampNet(at.ml.BaseModel):
|
|
733 |
|
734 |
# get our new mask
|
735 |
mask = mask_by_random_topk(
|
736 |
-
num_to_mask, selected_probs,
|
737 |
)
|
738 |
|
739 |
# update the mask
|
@@ -766,6 +753,91 @@ class VampNet(at.ml.BaseModel):
|
|
766 |
else:
|
767 |
return sampled_z
|
768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
769 |
|
770 |
def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
|
771 |
"""
|
|
|
367 |
|
368 |
return x, position_bias, encoder_decoder_position_bias
|
369 |
|
370 |
+
def t_schedule(n_steps, max_temp=1.0, min_temp=0.0, k=1.0):
|
371 |
+
x = np.linspace(0, 1, n_steps)
|
372 |
+
a = (0.5 - min_temp) / (max_temp - min_temp)
|
373 |
+
|
374 |
+
x = (x * 12) - 6
|
375 |
+
x0 = np.log((1 / a - 1) + 1e-5) / k
|
376 |
+
y = (1 / (1 + np.exp(- k *(x-x0))))[::-1]
|
377 |
+
|
378 |
+
return y
|
379 |
|
380 |
class TransformerStack(nn.Module):
|
381 |
def __init__(
|
|
|
589 |
time_steps: int = 300,
|
590 |
sampling_steps: int = 24,
|
591 |
start_tokens: Optional[torch.Tensor] = None,
|
592 |
+
sampling_temperature: float = 1.0,
|
593 |
mask: Optional[torch.Tensor] = None,
|
594 |
+
mask_temperature: float = 20.5,
|
595 |
typical_filtering=False,
|
596 |
typical_mass=0.2,
|
597 |
typical_min_tokens=1,
|
598 |
+
top_p=None,
|
599 |
return_signal=True,
|
600 |
+
seed: int = None
|
601 |
):
|
602 |
+
if seed is not None:
|
603 |
+
at.util.seed(seed)
|
604 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
605 |
|
|
|
|
|
|
|
|
|
|
|
606 |
|
607 |
|
608 |
#####################
|
|
|
650 |
#################
|
651 |
# begin sampling #
|
652 |
#################
|
653 |
+
t_sched = t_schedule(sampling_steps, max_temp=sampling_temperature)
|
654 |
|
655 |
for i in range(sampling_steps):
|
656 |
logging.debug(f"step {i} of {sampling_steps}")
|
657 |
|
|
|
|
|
|
|
658 |
# our current schedule step
|
659 |
r = scalar_to_batch_tensor(
|
660 |
(i + 1) / sampling_steps,
|
|
|
671 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
672 |
logits = self.forward(latents, r) # b, prob, seq
|
673 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
674 |
+
b = logits.shape[0]
|
|
|
|
|
|
|
|
|
|
|
675 |
|
676 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
677 |
|
678 |
+
sampled_z, selected_probs = sample_from_logits(
|
679 |
+
logits, sample=True, temperature=t_sched[i],
|
680 |
+
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
681 |
+
typical_min_tokens=typical_min_tokens,
|
682 |
+
top_k=None, top_p=top_p, return_probs=True
|
683 |
+
)
|
684 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
685 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
686 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
687 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
688 |
# we'll unflatten them at the end of the loop for the next forward pass
|
689 |
# remove conditioning codebooks, we'll add them back at the end
|
|
|
720 |
|
721 |
# get our new mask
|
722 |
mask = mask_by_random_topk(
|
723 |
+
num_to_mask, selected_probs, mask_temperature * (1-r)
|
724 |
)
|
725 |
|
726 |
# update the mask
|
|
|
753 |
else:
|
754 |
return sampled_z
|
755 |
|
756 |
+
def sample_from_logits(
|
757 |
+
logits,
|
758 |
+
sample: bool = True,
|
759 |
+
temperature: float = 1.0,
|
760 |
+
top_k: int = None,
|
761 |
+
top_p: float = None,
|
762 |
+
typical_filtering: bool = False,
|
763 |
+
typical_mass: float = 0.2,
|
764 |
+
typical_min_tokens: int = 1,
|
765 |
+
return_probs: bool = False
|
766 |
+
):
|
767 |
+
"""Convenience function to sample from a categorial distribution with input as
|
768 |
+
unnormalized logits.
|
769 |
+
|
770 |
+
Parameters
|
771 |
+
----------
|
772 |
+
logits : Tensor[..., vocab_size]
|
773 |
+
config: SamplingConfig
|
774 |
+
The set of hyperparameters to be used for sampling
|
775 |
+
sample : bool, optional
|
776 |
+
Whether to perform multinomial sampling, by default True
|
777 |
+
temperature : float, optional
|
778 |
+
Scaling parameter when multinomial samping, by default 1.0
|
779 |
+
top_k : int, optional
|
780 |
+
Restricts sampling to only `top_k` values acc. to probability,
|
781 |
+
by default None
|
782 |
+
top_p : float, optional
|
783 |
+
Restricts sampling to only those values with cumulative
|
784 |
+
probability = `top_p`, by default None
|
785 |
+
|
786 |
+
Returns
|
787 |
+
-------
|
788 |
+
Tensor[...]
|
789 |
+
Sampled tokens
|
790 |
+
"""
|
791 |
+
shp = logits.shape[:-1]
|
792 |
+
|
793 |
+
if typical_filtering:
|
794 |
+
typical_filter(logits,
|
795 |
+
typical_mass=typical_mass,
|
796 |
+
typical_min_tokens=typical_min_tokens
|
797 |
+
)
|
798 |
+
|
799 |
+
# Apply top_k sampling
|
800 |
+
if top_k is not None:
|
801 |
+
v, _ = logits.topk(top_k)
|
802 |
+
logits[logits < v[..., [-1]]] = -float("inf")
|
803 |
+
|
804 |
+
# Apply top_p (nucleus) sampling
|
805 |
+
if top_p is not None and top_p < 1.0:
|
806 |
+
v, sorted_indices = logits.sort(descending=True)
|
807 |
+
cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
|
808 |
+
|
809 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
810 |
+
# Right shift indices_to_remove to keep 1st token over threshold
|
811 |
+
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
|
812 |
+
..., :-1
|
813 |
+
]
|
814 |
+
|
815 |
+
# Compute indices_to_remove in unsorted array
|
816 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
817 |
+
-1, sorted_indices, sorted_indices_to_remove
|
818 |
+
)
|
819 |
+
|
820 |
+
logits[indices_to_remove] = -float("inf")
|
821 |
+
|
822 |
+
# Perform multinomial sampling after normalizing logits
|
823 |
+
probs = (
|
824 |
+
F.softmax(logits / temperature, dim=-1)
|
825 |
+
if temperature > 0
|
826 |
+
else logits.softmax(dim=-1)
|
827 |
+
)
|
828 |
+
token = (
|
829 |
+
probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
|
830 |
+
if sample
|
831 |
+
else logits.argmax(-1)
|
832 |
+
)
|
833 |
+
|
834 |
+
if return_probs:
|
835 |
+
token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
|
836 |
+
return token, token_probs
|
837 |
+
else:
|
838 |
+
return token
|
839 |
+
|
840 |
+
|
841 |
|
842 |
def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
|
843 |
"""
|