lhallee commited on
Commit
dd70d75
·
verified ·
1 Parent(s): 1ebec01

Upload modeling_ankh.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_ankh.py +72 -12
modeling_ankh.py CHANGED
@@ -722,15 +722,34 @@ def _kernels_flash_forward(
722
  key_states: torch.Tensor,
723
  value_states: torch.Tensor,
724
  causal: bool = False,
 
725
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
726
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
727
  if FLASH_KERNEL_VARIANT == "flash_attn2":
728
- return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0]
 
 
 
729
  if FLASH_KERNEL_VARIANT == "flash_attn3":
730
  try:
731
- output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal)
 
 
 
732
  except TypeError:
733
- output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal)
 
 
 
734
  if isinstance(output, tuple):
735
  return output[0]
736
  return output
@@ -746,14 +765,20 @@ def _kernels_flash_varlen_forward(
746
  max_seqlen_in_batch_q: int,
747
  max_seqlen_in_batch_k: int,
748
  causal: bool = False,
 
749
  ) -> torch.Tensor:
 
 
 
 
 
750
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
751
  if FLASH_KERNEL_VARIANT == "flash_attn2":
752
  return FLASH_KERNEL.varlen_fwd(
753
  q=query_states, k=key_states, v=value_states,
754
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
755
  max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
756
- is_causal=causal,
757
  )[0]
758
  if FLASH_KERNEL_VARIANT == "flash_attn3":
759
  try:
@@ -761,14 +786,14 @@ def _kernels_flash_varlen_forward(
761
  q=query_states, k=key_states, v=value_states,
762
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
763
  max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
764
- causal=causal,
765
  )
766
  except TypeError:
767
  output = FLASH_KERNEL.flash_attn_varlen_func(
768
  query_states, key_states, value_states,
769
  cu_seqlens_q, cu_seqlens_k,
770
  max_seqlen_in_batch_q, max_seqlen_in_batch_k,
771
- 0.0, None, causal,
772
  )
773
  if isinstance(output, tuple):
774
  return output[0]
@@ -849,7 +874,21 @@ def kernels_flash_attention_func(
849
  value_states: torch.Tensor,
850
  attention_mask_2d: Optional[torch.Tensor] = None,
851
  causal: bool = False,
 
852
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
853
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
854
  if not causal and attention_mask_2d is not None:
855
  batch_size, q_len = query_states.shape[:2]
@@ -861,11 +900,13 @@ def kernels_flash_attention_func(
861
  query_states=query_states, key_states=key_states, value_states=value_states,
862
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
863
  max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
 
864
  )
865
  return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
866
  else:
867
  return _kernels_flash_forward(
868
- query_states=query_states, key_states=key_states, value_states=value_states, causal=causal,
 
869
  )
870
 
871
 
@@ -948,6 +989,25 @@ def get_attention_mask(
948
  attention_mask_4d = attention_mask_2d[:, None, None, :]
949
  return attention_mask_2d, attention_mask_4d, None
950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
951
  import math
952
 
953
  import torch
@@ -1091,7 +1151,9 @@ class AnkhSelfAttention(nn.Module):
1091
  self.k = nn.Linear(config.d_model, self.inner_dim, bias=False)
1092
  self.v = nn.Linear(config.d_model, self.inner_dim, bias=False)
1093
  self.o = nn.Linear(self.inner_dim, config.d_model, bias=False)
1094
- self.scale = self.d_kv ** -0.5
 
 
1095
 
1096
  if self.has_relative_attention_bias:
1097
  self.relative_attention_bias = nn.Embedding(
@@ -1163,11 +1225,9 @@ class AnkhSelfAttention(nn.Module):
1163
  # Compute position bias on first layer (SDPA/manual only; flex uses score_mod)
1164
  if position_bias is None and self.has_relative_attention_bias and self.attn_backend != AttentionBackend.FLEX:
1165
  position_bias = self.compute_bias(seq_length, seq_length, hidden_states.device)
1166
- # Fold padding mask into position bias so layers don't need separate mask
1167
  if attention_mask_4d is not None:
1168
- position_bias = position_bias + attention_mask_4d.masked_fill(
1169
- attention_mask_4d.logical_not(), float("-inf")
1170
- )
1171
 
1172
  if output_attentions:
1173
  attn_output, attn_weights = self._manual_attn(query_BHLD, key_BHLD, value_BHLD, position_bias)
 
722
  key_states: torch.Tensor,
723
  value_states: torch.Tensor,
724
  causal: bool = False,
725
+ softmax_scale: Optional[float] = None,
726
  ) -> torch.Tensor:
727
+ """Flash-attention forward, optionally overriding the softmax scale.
728
+
729
+ When `softmax_scale is None`, the flash kernel applies its default
730
+ `1 / sqrt(head_dim)`. Pass `softmax_scale=1.0` if the caller has already
731
+ pre-scaled Q (the convention used by ESM2, DPLM, DPLM2, E1, ESMFold).
732
+ Failing to override when Q is pre-scaled produces DOUBLE scaling and
733
+ catastrophic downstream drift -- on DPLM-150M (30 layers) this was observed
734
+ as pooled-embedding cosine ~-0.12 and argmax agreement ~0.27 vs sdpa.
735
+ """
736
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
737
  if FLASH_KERNEL_VARIANT == "flash_attn2":
738
+ return FLASH_KERNEL.fwd(
739
+ q=query_states, k=key_states, v=value_states,
740
+ softmax_scale=softmax_scale, is_causal=causal,
741
+ )[0]
742
  if FLASH_KERNEL_VARIANT == "flash_attn3":
743
  try:
744
+ output = FLASH_KERNEL.flash_attn_func(
745
+ q=query_states, k=key_states, v=value_states,
746
+ softmax_scale=softmax_scale, causal=causal,
747
+ )
748
  except TypeError:
749
+ output = FLASH_KERNEL.flash_attn_func(
750
+ query_states, key_states, value_states,
751
+ 0.0, softmax_scale, causal,
752
+ )
753
  if isinstance(output, tuple):
754
  return output[0]
755
  return output
 
765
  max_seqlen_in_batch_q: int,
766
  max_seqlen_in_batch_k: int,
767
  causal: bool = False,
768
+ softmax_scale: Optional[float] = None,
769
  ) -> torch.Tensor:
770
+ """Varlen flash-attention forward, optionally overriding the softmax scale.
771
+
772
+ See `_kernels_flash_forward` docstring for why `softmax_scale=1.0` must be
773
+ passed when Q has been pre-scaled by the caller.
774
+ """
775
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
776
  if FLASH_KERNEL_VARIANT == "flash_attn2":
777
  return FLASH_KERNEL.varlen_fwd(
778
  q=query_states, k=key_states, v=value_states,
779
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
780
  max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
781
+ softmax_scale=softmax_scale, is_causal=causal,
782
  )[0]
783
  if FLASH_KERNEL_VARIANT == "flash_attn3":
784
  try:
 
786
  q=query_states, k=key_states, v=value_states,
787
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
788
  max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
789
+ softmax_scale=softmax_scale, causal=causal,
790
  )
791
  except TypeError:
792
  output = FLASH_KERNEL.flash_attn_varlen_func(
793
  query_states, key_states, value_states,
794
  cu_seqlens_q, cu_seqlens_k,
795
  max_seqlen_in_batch_q, max_seqlen_in_batch_k,
796
+ 0.0, softmax_scale, causal,
797
  )
798
  if isinstance(output, tuple):
799
  return output[0]
 
874
  value_states: torch.Tensor,
875
  attention_mask_2d: Optional[torch.Tensor] = None,
876
  causal: bool = False,
877
+ softmax_scale: Optional[float] = None,
878
  ) -> torch.Tensor:
879
+ """Public flash-attention entry point with optional padding handling.
880
+
881
+ `softmax_scale`:
882
+ None -> kernel applies its default `1 / sqrt(head_dim)`.
883
+ float -> kernel uses the given scale (pass 1.0 when Q is pre-scaled
884
+ by the caller).
885
+
886
+ IMPORTANT: if your family multiplies Q by `1/sqrt(head_dim)` before calling
887
+ this function (as ESM2, DPLM, DPLM2, E1, and ESMFold do) you MUST pass
888
+ `softmax_scale=1.0`. Otherwise the kernel applies its default scale ON TOP
889
+ of the caller's, producing effective scale `1/head_dim` and catastrophic
890
+ downstream drift that compounds across layers.
891
+ """
892
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
893
  if not causal and attention_mask_2d is not None:
894
  batch_size, q_len = query_states.shape[:2]
 
900
  query_states=query_states, key_states=key_states, value_states=value_states,
901
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
902
  max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
903
+ softmax_scale=softmax_scale,
904
  )
905
  return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
906
  else:
907
  return _kernels_flash_forward(
908
+ query_states=query_states, key_states=key_states, value_states=value_states,
909
+ causal=causal, softmax_scale=softmax_scale,
910
  )
911
 
912
 
 
989
  attention_mask_4d = attention_mask_2d[:, None, None, :]
990
  return attention_mask_2d, attention_mask_4d, None
991
 
992
+
993
+ def bool_to_additive_mask(
994
+ bool_mask: torch.Tensor,
995
+ dtype: torch.dtype,
996
+ ) -> torch.Tensor:
997
+ """Convert a bool mask (True = valid) to a float additive mask (0.0 valid, -inf invalid).
998
+
999
+ Why this exists: calling `bool_mask.masked_fill(bool_mask.logical_not(), float('-inf'))`
1000
+ directly on a bool tensor returns a bool tensor -- because `-inf` casts to `True` -- and
1001
+ silently drops the mask entirely. Always allocate a float tensor first, then fill it.
1002
+ This helper is the sanctioned way to build an SDPA additive mask from a bool validity mask.
1003
+ """
1004
+ assert bool_mask.dtype == torch.bool, (
1005
+ f"bool_to_additive_mask requires a bool tensor, got dtype={bool_mask.dtype}"
1006
+ )
1007
+ additive = torch.zeros_like(bool_mask, dtype=dtype)
1008
+ additive.masked_fill_(bool_mask.logical_not(), float("-inf"))
1009
+ return additive
1010
+
1011
  import math
1012
 
1013
  import torch
 
1151
  self.k = nn.Linear(config.d_model, self.inner_dim, bias=False)
1152
  self.v = nn.Linear(config.d_model, self.inner_dim, bias=False)
1153
  self.o = nn.Linear(self.inner_dim, config.d_model, bias=False)
1154
+ # T5/ANKH attention is unscaled: scores = Q K^T (no 1/sqrt(d_kv)).
1155
+ # The learned relative position bias absorbs any temperature.
1156
+ self.scale = 1.0
1157
 
1158
  if self.has_relative_attention_bias:
1159
  self.relative_attention_bias = nn.Embedding(
 
1225
  # Compute position bias on first layer (SDPA/manual only; flex uses score_mod)
1226
  if position_bias is None and self.has_relative_attention_bias and self.attn_backend != AttentionBackend.FLEX:
1227
  position_bias = self.compute_bias(seq_length, seq_length, hidden_states.device)
1228
+ # Fold padding mask into position bias so layers don't need separate mask.
1229
  if attention_mask_4d is not None:
1230
+ position_bias = position_bias + bool_to_additive_mask(attention_mask_4d, position_bias.dtype)
 
 
1231
 
1232
  if output_attentions:
1233
  attn_output, attn_weights = self._manual_attn(query_BHLD, key_BHLD, value_BHLD, position_bias)