Hugo Flores Garcia commited on
Commit
793d060
1 Parent(s): 3a5996b

sampling cutoff trick

Browse files
Files changed (1) hide show
  1. vampnet/modules/transformer.py +7 -3
vampnet/modules/transformer.py CHANGED
@@ -597,7 +597,8 @@ class VampNet(at.ml.BaseModel):
597
  typical_min_tokens=1,
598
  top_p=None,
599
  return_signal=True,
600
- seed: int = None
 
601
  ):
602
  if seed is not None:
603
  at.util.seed(seed)
@@ -676,10 +677,13 @@ class VampNet(at.ml.BaseModel):
676
  logging.debug(f"permuted logits with shape: {logits.shape}")
677
 
678
  sampled_z, selected_probs = sample_from_logits(
679
- logits, sample=True, temperature=t_sched[i],
 
 
 
680
  typical_filtering=typical_filtering, typical_mass=typical_mass,
681
  typical_min_tokens=typical_min_tokens,
682
- top_k=None, top_p=top_p, return_probs=True
683
  )
684
 
685
  logging.debug(f"sampled z with shape: {sampled_z.shape}")
 
597
  typical_min_tokens=1,
598
  top_p=None,
599
  return_signal=True,
600
+ seed: int = None,
601
+ sample_cutoff: float = 1.0
602
  ):
603
  if seed is not None:
604
  at.util.seed(seed)
 
677
  logging.debug(f"permuted logits with shape: {logits.shape}")
678
 
679
  sampled_z, selected_probs = sample_from_logits(
680
+ logits, sample=(
681
+ (i / sampling_steps) <= sample_cutoff
682
+ ),
683
+ temperature=t_sched[i],
684
  typical_filtering=typical_filtering, typical_mass=typical_mass,
685
  typical_min_tokens=typical_min_tokens,
686
+ top_k=None, top_p=top_p, return_probs=True,
687
  )
688
 
689
  logging.debug(f"sampled z with shape: {sampled_z.shape}")