Upload modeling_dplm2.py with huggingface_hub
Browse files- modeling_dplm2.py +70 -7
modeling_dplm2.py
CHANGED
|
@@ -722,15 +722,34 @@ def _kernels_flash_forward(
|
|
| 722 |
key_states: torch.Tensor,
|
| 723 |
value_states: torch.Tensor,
|
| 724 |
causal: bool = False,
|
|
|
|
| 725 |
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| 727 |
if FLASH_KERNEL_VARIANT == "flash_attn2":
|
| 728 |
-
return FLASH_KERNEL.fwd(
|
|
|
|
|
|
|
|
|
|
| 729 |
if FLASH_KERNEL_VARIANT == "flash_attn3":
|
| 730 |
try:
|
| 731 |
-
output = FLASH_KERNEL.flash_attn_func(
|
|
|
|
|
|
|
|
|
|
| 732 |
except TypeError:
|
| 733 |
-
output = FLASH_KERNEL.flash_attn_func(
|
|
|
|
|
|
|
|
|
|
| 734 |
if isinstance(output, tuple):
|
| 735 |
return output[0]
|
| 736 |
return output
|
|
@@ -746,14 +765,20 @@ def _kernels_flash_varlen_forward(
|
|
| 746 |
max_seqlen_in_batch_q: int,
|
| 747 |
max_seqlen_in_batch_k: int,
|
| 748 |
causal: bool = False,
|
|
|
|
| 749 |
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| 751 |
if FLASH_KERNEL_VARIANT == "flash_attn2":
|
| 752 |
return FLASH_KERNEL.varlen_fwd(
|
| 753 |
q=query_states, k=key_states, v=value_states,
|
| 754 |
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 755 |
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
|
| 756 |
-
is_causal=causal,
|
| 757 |
)[0]
|
| 758 |
if FLASH_KERNEL_VARIANT == "flash_attn3":
|
| 759 |
try:
|
|
@@ -761,14 +786,14 @@ def _kernels_flash_varlen_forward(
|
|
| 761 |
q=query_states, k=key_states, v=value_states,
|
| 762 |
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 763 |
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
|
| 764 |
-
causal=causal,
|
| 765 |
)
|
| 766 |
except TypeError:
|
| 767 |
output = FLASH_KERNEL.flash_attn_varlen_func(
|
| 768 |
query_states, key_states, value_states,
|
| 769 |
cu_seqlens_q, cu_seqlens_k,
|
| 770 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k,
|
| 771 |
-
0.0,
|
| 772 |
)
|
| 773 |
if isinstance(output, tuple):
|
| 774 |
return output[0]
|
|
@@ -849,7 +874,21 @@ def kernels_flash_attention_func(
|
|
| 849 |
value_states: torch.Tensor,
|
| 850 |
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 851 |
causal: bool = False,
|
|
|
|
| 852 |
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| 854 |
if not causal and attention_mask_2d is not None:
|
| 855 |
batch_size, q_len = query_states.shape[:2]
|
|
@@ -861,11 +900,13 @@ def kernels_flash_attention_func(
|
|
| 861 |
query_states=query_states, key_states=key_states, value_states=value_states,
|
| 862 |
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 863 |
max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
|
|
|
|
| 864 |
)
|
| 865 |
return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
|
| 866 |
else:
|
| 867 |
return _kernels_flash_forward(
|
| 868 |
-
query_states=query_states, key_states=key_states, value_states=value_states,
|
|
|
|
| 869 |
)
|
| 870 |
|
| 871 |
|
|
@@ -948,6 +989,25 @@ def get_attention_mask(
|
|
| 948 |
attention_mask_4d = attention_mask_2d[:, None, None, :]
|
| 949 |
return attention_mask_2d, attention_mask_4d, None
|
| 950 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 951 |
"""
|
| 952 |
FastPLMs-compatible DPLM2 implementation.
|
| 953 |
"""
|
|
@@ -1289,9 +1349,12 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 1289 |
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
|
| 1290 |
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
|
| 1291 |
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
|
|
|
|
|
|
|
| 1292 |
attn_output = kernels_flash_attention_func(
|
| 1293 |
query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD,
|
| 1294 |
attention_mask_2d=attention_mask_2d, causal=False,
|
|
|
|
| 1295 |
)
|
| 1296 |
return rearrange(attn_output, "b s h d -> b s (h d)"), None
|
| 1297 |
|
|
|
|
| 722 |
key_states: torch.Tensor,
|
| 723 |
value_states: torch.Tensor,
|
| 724 |
causal: bool = False,
|
| 725 |
+
softmax_scale: Optional[float] = None,
|
| 726 |
) -> torch.Tensor:
|
| 727 |
+
"""Flash-attention forward, optionally overriding the softmax scale.
|
| 728 |
+
|
| 729 |
+
When `softmax_scale is None`, the flash kernel applies its default
|
| 730 |
+
`1 / sqrt(head_dim)`. Pass `softmax_scale=1.0` if the caller has already
|
| 731 |
+
pre-scaled Q (the convention used by ESM2, DPLM, DPLM2, E1, ESMFold).
|
| 732 |
+
Failing to override when Q is pre-scaled produces DOUBLE scaling and
|
| 733 |
+
catastrophic downstream drift -- on DPLM-150M (30 layers) this was observed
|
| 734 |
+
as pooled-embedding cosine ~-0.12 and argmax agreement ~0.27 vs sdpa.
|
| 735 |
+
"""
|
| 736 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| 737 |
if FLASH_KERNEL_VARIANT == "flash_attn2":
|
| 738 |
+
return FLASH_KERNEL.fwd(
|
| 739 |
+
q=query_states, k=key_states, v=value_states,
|
| 740 |
+
softmax_scale=softmax_scale, is_causal=causal,
|
| 741 |
+
)[0]
|
| 742 |
if FLASH_KERNEL_VARIANT == "flash_attn3":
|
| 743 |
try:
|
| 744 |
+
output = FLASH_KERNEL.flash_attn_func(
|
| 745 |
+
q=query_states, k=key_states, v=value_states,
|
| 746 |
+
softmax_scale=softmax_scale, causal=causal,
|
| 747 |
+
)
|
| 748 |
except TypeError:
|
| 749 |
+
output = FLASH_KERNEL.flash_attn_func(
|
| 750 |
+
query_states, key_states, value_states,
|
| 751 |
+
0.0, softmax_scale, causal,
|
| 752 |
+
)
|
| 753 |
if isinstance(output, tuple):
|
| 754 |
return output[0]
|
| 755 |
return output
|
|
|
|
| 765 |
max_seqlen_in_batch_q: int,
|
| 766 |
max_seqlen_in_batch_k: int,
|
| 767 |
causal: bool = False,
|
| 768 |
+
softmax_scale: Optional[float] = None,
|
| 769 |
) -> torch.Tensor:
|
| 770 |
+
"""Varlen flash-attention forward, optionally overriding the softmax scale.
|
| 771 |
+
|
| 772 |
+
See `_kernels_flash_forward` docstring for why `softmax_scale=1.0` must be
|
| 773 |
+
passed when Q has been pre-scaled by the caller.
|
| 774 |
+
"""
|
| 775 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| 776 |
if FLASH_KERNEL_VARIANT == "flash_attn2":
|
| 777 |
return FLASH_KERNEL.varlen_fwd(
|
| 778 |
q=query_states, k=key_states, v=value_states,
|
| 779 |
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 780 |
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
|
| 781 |
+
softmax_scale=softmax_scale, is_causal=causal,
|
| 782 |
)[0]
|
| 783 |
if FLASH_KERNEL_VARIANT == "flash_attn3":
|
| 784 |
try:
|
|
|
|
| 786 |
q=query_states, k=key_states, v=value_states,
|
| 787 |
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 788 |
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
|
| 789 |
+
softmax_scale=softmax_scale, causal=causal,
|
| 790 |
)
|
| 791 |
except TypeError:
|
| 792 |
output = FLASH_KERNEL.flash_attn_varlen_func(
|
| 793 |
query_states, key_states, value_states,
|
| 794 |
cu_seqlens_q, cu_seqlens_k,
|
| 795 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k,
|
| 796 |
+
0.0, softmax_scale, causal,
|
| 797 |
)
|
| 798 |
if isinstance(output, tuple):
|
| 799 |
return output[0]
|
|
|
|
| 874 |
value_states: torch.Tensor,
|
| 875 |
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 876 |
causal: bool = False,
|
| 877 |
+
softmax_scale: Optional[float] = None,
|
| 878 |
) -> torch.Tensor:
|
| 879 |
+
"""Public flash-attention entry point with optional padding handling.
|
| 880 |
+
|
| 881 |
+
`softmax_scale`:
|
| 882 |
+
None -> kernel applies its default `1 / sqrt(head_dim)`.
|
| 883 |
+
float -> kernel uses the given scale (pass 1.0 when Q is pre-scaled
|
| 884 |
+
by the caller).
|
| 885 |
+
|
| 886 |
+
IMPORTANT: if your family multiplies Q by `1/sqrt(head_dim)` before calling
|
| 887 |
+
this function (as ESM2, DPLM, DPLM2, E1, and ESMFold do) you MUST pass
|
| 888 |
+
`softmax_scale=1.0`. Otherwise the kernel applies its default scale ON TOP
|
| 889 |
+
of the caller's, producing effective scale `1/head_dim` and catastrophic
|
| 890 |
+
downstream drift that compounds across layers.
|
| 891 |
+
"""
|
| 892 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| 893 |
if not causal and attention_mask_2d is not None:
|
| 894 |
batch_size, q_len = query_states.shape[:2]
|
|
|
|
| 900 |
query_states=query_states, key_states=key_states, value_states=value_states,
|
| 901 |
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 902 |
max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
|
| 903 |
+
softmax_scale=softmax_scale,
|
| 904 |
)
|
| 905 |
return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
|
| 906 |
else:
|
| 907 |
return _kernels_flash_forward(
|
| 908 |
+
query_states=query_states, key_states=key_states, value_states=value_states,
|
| 909 |
+
causal=causal, softmax_scale=softmax_scale,
|
| 910 |
)
|
| 911 |
|
| 912 |
|
|
|
|
| 989 |
attention_mask_4d = attention_mask_2d[:, None, None, :]
|
| 990 |
return attention_mask_2d, attention_mask_4d, None
|
| 991 |
|
| 992 |
+
|
| 993 |
+
def bool_to_additive_mask(
|
| 994 |
+
bool_mask: torch.Tensor,
|
| 995 |
+
dtype: torch.dtype,
|
| 996 |
+
) -> torch.Tensor:
|
| 997 |
+
"""Convert a bool mask (True = valid) to a float additive mask (0.0 valid, -inf invalid).
|
| 998 |
+
|
| 999 |
+
Why this exists: calling `bool_mask.masked_fill(bool_mask.logical_not(), float('-inf'))`
|
| 1000 |
+
directly on a bool tensor returns a bool tensor -- because `-inf` casts to `True` -- and
|
| 1001 |
+
silently drops the mask entirely. Always allocate a float tensor first, then fill it.
|
| 1002 |
+
This helper is the sanctioned way to build an SDPA additive mask from a bool validity mask.
|
| 1003 |
+
"""
|
| 1004 |
+
assert bool_mask.dtype == torch.bool, (
|
| 1005 |
+
f"bool_to_additive_mask requires a bool tensor, got dtype={bool_mask.dtype}"
|
| 1006 |
+
)
|
| 1007 |
+
additive = torch.zeros_like(bool_mask, dtype=dtype)
|
| 1008 |
+
additive.masked_fill_(bool_mask.logical_not(), float("-inf"))
|
| 1009 |
+
return additive
|
| 1010 |
+
|
| 1011 |
"""
|
| 1012 |
FastPLMs-compatible DPLM2 implementation.
|
| 1013 |
"""
|
|
|
|
| 1349 |
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
|
| 1350 |
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
|
| 1351 |
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
|
| 1352 |
+
# Q is pre-scaled by self.scale in forward() -- pass softmax_scale=1.0
|
| 1353 |
+
# to prevent the kernel from applying its default 1/sqrt(head_dim).
|
| 1354 |
attn_output = kernels_flash_attention_func(
|
| 1355 |
query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD,
|
| 1356 |
attention_mask_2d=attention_mask_2d, causal=False,
|
| 1357 |
+
softmax_scale=1.0,
|
| 1358 |
)
|
| 1359 |
return rearrange(attn_output, "b s h d -> b s (h d)"), None
|
| 1360 |
|