Transformers
Safetensors
dplm2
custom_code
lhallee commited on
Commit
99890bd
·
verified ·
1 Parent(s): b904f01

Upload modeling_dplm2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm2.py +70 -7
modeling_dplm2.py CHANGED
@@ -722,15 +722,34 @@ def _kernels_flash_forward(
722
  key_states: torch.Tensor,
723
  value_states: torch.Tensor,
724
  causal: bool = False,
 
725
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
726
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
727
  if FLASH_KERNEL_VARIANT == "flash_attn2":
728
- return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0]
 
 
 
729
  if FLASH_KERNEL_VARIANT == "flash_attn3":
730
  try:
731
- output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal)
 
 
 
732
  except TypeError:
733
- output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal)
 
 
 
734
  if isinstance(output, tuple):
735
  return output[0]
736
  return output
@@ -746,14 +765,20 @@ def _kernels_flash_varlen_forward(
746
  max_seqlen_in_batch_q: int,
747
  max_seqlen_in_batch_k: int,
748
  causal: bool = False,
 
749
  ) -> torch.Tensor:
 
 
 
 
 
750
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
751
  if FLASH_KERNEL_VARIANT == "flash_attn2":
752
  return FLASH_KERNEL.varlen_fwd(
753
  q=query_states, k=key_states, v=value_states,
754
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
755
  max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
756
- is_causal=causal,
757
  )[0]
758
  if FLASH_KERNEL_VARIANT == "flash_attn3":
759
  try:
@@ -761,14 +786,14 @@ def _kernels_flash_varlen_forward(
761
  q=query_states, k=key_states, v=value_states,
762
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
763
  max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
764
- causal=causal,
765
  )
766
  except TypeError:
767
  output = FLASH_KERNEL.flash_attn_varlen_func(
768
  query_states, key_states, value_states,
769
  cu_seqlens_q, cu_seqlens_k,
770
  max_seqlen_in_batch_q, max_seqlen_in_batch_k,
771
- 0.0, None, causal,
772
  )
773
  if isinstance(output, tuple):
774
  return output[0]
@@ -849,7 +874,21 @@ def kernels_flash_attention_func(
849
  value_states: torch.Tensor,
850
  attention_mask_2d: Optional[torch.Tensor] = None,
851
  causal: bool = False,
 
852
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
853
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
854
  if not causal and attention_mask_2d is not None:
855
  batch_size, q_len = query_states.shape[:2]
@@ -861,11 +900,13 @@ def kernels_flash_attention_func(
861
  query_states=query_states, key_states=key_states, value_states=value_states,
862
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
863
  max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
 
864
  )
865
  return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
866
  else:
867
  return _kernels_flash_forward(
868
- query_states=query_states, key_states=key_states, value_states=value_states, causal=causal,
 
869
  )
870
 
871
 
@@ -948,6 +989,25 @@ def get_attention_mask(
948
  attention_mask_4d = attention_mask_2d[:, None, None, :]
949
  return attention_mask_2d, attention_mask_4d, None
950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
951
  """
952
  FastPLMs-compatible DPLM2 implementation.
953
  """
@@ -1289,9 +1349,12 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
1289
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
1290
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
1291
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
 
 
1292
  attn_output = kernels_flash_attention_func(
1293
  query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD,
1294
  attention_mask_2d=attention_mask_2d, causal=False,
 
1295
  )
1296
  return rearrange(attn_output, "b s h d -> b s (h d)"), None
1297
 
 
722
  key_states: torch.Tensor,
723
  value_states: torch.Tensor,
724
  causal: bool = False,
725
+ softmax_scale: Optional[float] = None,
726
  ) -> torch.Tensor:
727
+ """Flash-attention forward, optionally overriding the softmax scale.
728
+
729
+ When `softmax_scale is None`, the flash kernel applies its default
730
+ `1 / sqrt(head_dim)`. Pass `softmax_scale=1.0` if the caller has already
731
+ pre-scaled Q (the convention used by ESM2, DPLM, DPLM2, E1, ESMFold).
732
+ Failing to override when Q is pre-scaled produces DOUBLE scaling and
733
+ catastrophic downstream drift -- on DPLM-150M (30 layers) this was observed
734
+ as pooled-embedding cosine ~-0.12 and argmax agreement ~0.27 vs sdpa.
735
+ """
736
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
737
  if FLASH_KERNEL_VARIANT == "flash_attn2":
738
+ return FLASH_KERNEL.fwd(
739
+ q=query_states, k=key_states, v=value_states,
740
+ softmax_scale=softmax_scale, is_causal=causal,
741
+ )[0]
742
  if FLASH_KERNEL_VARIANT == "flash_attn3":
743
  try:
744
+ output = FLASH_KERNEL.flash_attn_func(
745
+ q=query_states, k=key_states, v=value_states,
746
+ softmax_scale=softmax_scale, causal=causal,
747
+ )
748
  except TypeError:
749
+ output = FLASH_KERNEL.flash_attn_func(
750
+ query_states, key_states, value_states,
751
+ 0.0, softmax_scale, causal,
752
+ )
753
  if isinstance(output, tuple):
754
  return output[0]
755
  return output
 
765
  max_seqlen_in_batch_q: int,
766
  max_seqlen_in_batch_k: int,
767
  causal: bool = False,
768
+ softmax_scale: Optional[float] = None,
769
  ) -> torch.Tensor:
770
+ """Varlen flash-attention forward, optionally overriding the softmax scale.
771
+
772
+ See `_kernels_flash_forward` docstring for why `softmax_scale=1.0` must be
773
+ passed when Q has been pre-scaled by the caller.
774
+ """
775
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
776
  if FLASH_KERNEL_VARIANT == "flash_attn2":
777
  return FLASH_KERNEL.varlen_fwd(
778
  q=query_states, k=key_states, v=value_states,
779
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
780
  max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
781
+ softmax_scale=softmax_scale, is_causal=causal,
782
  )[0]
783
  if FLASH_KERNEL_VARIANT == "flash_attn3":
784
  try:
 
786
  q=query_states, k=key_states, v=value_states,
787
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
788
  max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
789
+ softmax_scale=softmax_scale, causal=causal,
790
  )
791
  except TypeError:
792
  output = FLASH_KERNEL.flash_attn_varlen_func(
793
  query_states, key_states, value_states,
794
  cu_seqlens_q, cu_seqlens_k,
795
  max_seqlen_in_batch_q, max_seqlen_in_batch_k,
796
+ 0.0, softmax_scale, causal,
797
  )
798
  if isinstance(output, tuple):
799
  return output[0]
 
874
  value_states: torch.Tensor,
875
  attention_mask_2d: Optional[torch.Tensor] = None,
876
  causal: bool = False,
877
+ softmax_scale: Optional[float] = None,
878
  ) -> torch.Tensor:
879
+ """Public flash-attention entry point with optional padding handling.
880
+
881
+ `softmax_scale`:
882
+ None -> kernel applies its default `1 / sqrt(head_dim)`.
883
+ float -> kernel uses the given scale (pass 1.0 when Q is pre-scaled
884
+ by the caller).
885
+
886
+ IMPORTANT: if your family multiplies Q by `1/sqrt(head_dim)` before calling
887
+ this function (as ESM2, DPLM, DPLM2, E1, and ESMFold do) you MUST pass
888
+ `softmax_scale=1.0`. Otherwise the kernel applies its default scale ON TOP
889
+ of the caller's, producing effective scale `1/head_dim` and catastrophic
890
+ downstream drift that compounds across layers.
891
+ """
892
  assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
893
  if not causal and attention_mask_2d is not None:
894
  batch_size, q_len = query_states.shape[:2]
 
900
  query_states=query_states, key_states=key_states, value_states=value_states,
901
  cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
902
  max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
903
+ softmax_scale=softmax_scale,
904
  )
905
  return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
906
  else:
907
  return _kernels_flash_forward(
908
+ query_states=query_states, key_states=key_states, value_states=value_states,
909
+ causal=causal, softmax_scale=softmax_scale,
910
  )
911
 
912
 
 
989
  attention_mask_4d = attention_mask_2d[:, None, None, :]
990
  return attention_mask_2d, attention_mask_4d, None
991
 
992
+
993
+ def bool_to_additive_mask(
994
+ bool_mask: torch.Tensor,
995
+ dtype: torch.dtype,
996
+ ) -> torch.Tensor:
997
+ """Convert a bool mask (True = valid) to a float additive mask (0.0 valid, -inf invalid).
998
+
999
+ Why this exists: calling `bool_mask.masked_fill(bool_mask.logical_not(), float('-inf'))`
1000
+ directly on a bool tensor returns a bool tensor -- because `-inf` casts to `True` -- and
1001
+ silently drops the mask entirely. Always allocate a float tensor first, then fill it.
1002
+ This helper is the sanctioned way to build an SDPA additive mask from a bool validity mask.
1003
+ """
1004
+ assert bool_mask.dtype == torch.bool, (
1005
+ f"bool_to_additive_mask requires a bool tensor, got dtype={bool_mask.dtype}"
1006
+ )
1007
+ additive = torch.zeros_like(bool_mask, dtype=dtype)
1008
+ additive.masked_fill_(bool_mask.logical_not(), float("-inf"))
1009
+ return additive
1010
+
1011
  """
1012
  FastPLMs-compatible DPLM2 implementation.
1013
  """
 
1349
  query_BLHD = query_BHLD.transpose(1, 2).contiguous()
1350
  key_BLHD = key_BHLD.transpose(1, 2).contiguous()
1351
  value_BLHD = value_BHLD.transpose(1, 2).contiguous()
1352
+ # Q is pre-scaled by self.scale in forward() -- pass softmax_scale=1.0
1353
+ # to prevent the kernel from applying its default 1/sqrt(head_dim).
1354
  attn_output = kernels_flash_attention_func(
1355
  query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD,
1356
  attention_mask_2d=attention_mask_2d, causal=False,
1357
+ softmax_scale=1.0,
1358
  )
1359
  return rearrange(attn_output, "b s h d -> b s (h d)"), None
1360