Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
·
793d060
1
Parent(s):
3a5996b
sampling cutoff trick
Browse files
vampnet/modules/transformer.py
CHANGED
@@ -597,7 +597,8 @@ class VampNet(at.ml.BaseModel):
|
|
597 |
typical_min_tokens=1,
|
598 |
top_p=None,
|
599 |
return_signal=True,
|
600 |
-
seed: int = None
|
|
|
601 |
):
|
602 |
if seed is not None:
|
603 |
at.util.seed(seed)
|
@@ -676,10 +677,13 @@ class VampNet(at.ml.BaseModel):
|
|
676 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
677 |
|
678 |
sampled_z, selected_probs = sample_from_logits(
|
679 |
-
logits, sample=
|
|
|
|
|
|
|
680 |
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
681 |
typical_min_tokens=typical_min_tokens,
|
682 |
-
top_k=None, top_p=top_p, return_probs=True
|
683 |
)
|
684 |
|
685 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
|
|
597 |
typical_min_tokens=1,
|
598 |
top_p=None,
|
599 |
return_signal=True,
|
600 |
+
seed: int = None,
|
601 |
+
sample_cutoff: float = 1.0
|
602 |
):
|
603 |
if seed is not None:
|
604 |
at.util.seed(seed)
|
|
|
677 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
678 |
|
679 |
sampled_z, selected_probs = sample_from_logits(
|
680 |
+
logits, sample=(
|
681 |
+
(i / sampling_steps) <= sample_cutoff
|
682 |
+
),
|
683 |
+
temperature=t_sched[i],
|
684 |
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
685 |
typical_min_tokens=typical_min_tokens,
|
686 |
+
top_k=None, top_p=top_p, return_probs=True,
|
687 |
)
|
688 |
|
689 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|