Hugo Flores Garcia commited on
Commit
4d0cbfe
1 Parent(s): 85e8a86

tiny sampling refactor

Browse files
Files changed (1) hide show
  1. vampnet/modules/transformer.py +55 -54
vampnet/modules/transformer.py CHANGED
@@ -724,7 +724,7 @@ class VampNet(at.ml.BaseModel):
724
 
725
  logits = torch.log(probs)
726
 
727
- z_inferred = self.sample_from_logits(
728
  logits=logits,
729
  top_k=top_k,
730
  temperature=tmpt,
@@ -742,61 +742,60 @@ class VampNet(at.ml.BaseModel):
742
  else:
743
  return z
744
 
745
- def sample_from_logits(
746
- self,
747
- logits,
748
- top_k: int = None,
749
- temperature: float = 1.0,
750
- sample: str = "multinomial",
751
- typical_filtering=False,
752
- typical_mass=0.2,
753
- typical_min_tokens=1,
754
- ):
755
- # add temperature
756
- logits = logits / temperature
757
-
758
- # add topk
759
- if top_k is not None:
760
- v, topk_idx = logits.topk(top_k)
761
- logits[logits < v[..., [-1]]] = -float("inf")
762
-
763
- if typical_filtering:
764
- assert top_k is None
765
- nb, nt, _ = logits.shape
766
- x_flat = rearrange(logits, "b t l -> (b t ) l")
767
- x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
768
- x_flat_norm_p = torch.exp(x_flat_norm)
769
- entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
770
-
771
- c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
772
- c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
773
- x_flat_cumsum = (
774
- x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
775
- )
776
 
777
- last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
778
- sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
779
- 1, last_ind.view(-1, 1)
780
- )
781
- if typical_min_tokens > 1:
782
- sorted_indices_to_remove[..., :typical_min_tokens] = 0
783
- indices_to_remove = sorted_indices_to_remove.scatter(
784
- 1, x_flat_indices, sorted_indices_to_remove
785
- )
786
- x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
787
- logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
788
-
789
- if sample == "multinomial":
790
- probs = torch.softmax(logits, dim=-1)
791
- inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
792
- elif sample == "argmax":
793
- inferred = torch.softmax(logits, dim=-1).argmax(dim=-1)
794
- elif sample == "gumbel":
795
- inferred = gumbel_sample(logits, dim=-1)
796
- else:
797
- raise ValueError(f"invalid sampling method: {sample}")
798
 
799
- return inferred
 
 
 
 
 
 
 
 
 
 
800
 
801
 
802
 
@@ -833,3 +832,5 @@ if __name__ == "__main__":
833
  args = argbind.parse_args()
834
  with argbind.scope(args):
835
  try_model()
 
 
 
724
 
725
  logits = torch.log(probs)
726
 
727
+ z_inferred = sample_from_logits(
728
  logits=logits,
729
  top_k=top_k,
730
  temperature=tmpt,
 
742
  else:
743
  return z
744
 
745
+ def sample_from_logits(
746
+ logits,
747
+ top_k: int = None,
748
+ temperature: float = 1.0,
749
+ sample: str = "multinomial",
750
+ typical_filtering=False,
751
+ typical_mass=0.2,
752
+ typical_min_tokens=1,
753
+ ):
754
+ # add temperature
755
+ logits = logits / temperature
756
+
757
+ # add topk
758
+ if top_k is not None and typical_filtering == False:
759
+ v, topk_idx = logits.topk(top_k)
760
+ logits[logits < v[..., [-1]]] = -float("inf")
761
+
762
+ if typical_filtering:
763
+ assert top_k is None
764
+ nb, nt, _ = logits.shape
765
+ x_flat = rearrange(logits, "b t l -> (b t ) l")
766
+ x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
767
+ x_flat_norm_p = torch.exp(x_flat_norm)
768
+ entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
769
+
770
+ c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
771
+ c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
772
+ x_flat_cumsum = (
773
+ x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
774
+ )
 
775
 
776
+ last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
777
+ sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
778
+ 1, last_ind.view(-1, 1)
779
+ )
780
+ if typical_min_tokens > 1:
781
+ sorted_indices_to_remove[..., :typical_min_tokens] = 0
782
+ indices_to_remove = sorted_indices_to_remove.scatter(
783
+ 1, x_flat_indices, sorted_indices_to_remove
784
+ )
785
+ x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
786
+ logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
 
 
 
 
 
 
 
 
 
 
787
 
788
+ if sample == "multinomial":
789
+ probs = torch.softmax(logits, dim=-1)
790
+ inferred = torch.stack([pr.multinomial(1).squeeze(-1) for pr in probs])
791
+ elif sample == "argmax":
792
+ inferred = torch.softmax(logits, dim=-1).argmax(dim=-1)
793
+ elif sample == "gumbel":
794
+ inferred = gumbel_sample(logits, dim=-1)
795
+ else:
796
+ raise ValueError(f"invalid sampling method: {sample}")
797
+
798
+ return inferred
799
 
800
 
801
 
 
832
  args = argbind.parse_args()
833
  with argbind.scope(args):
834
  try_model()
835
+
836
+