Upload modeling_ankh.py with huggingface_hub
Browse files- modeling_ankh.py +72 -12
modeling_ankh.py
CHANGED
|
@@ -722,15 +722,34 @@ def _kernels_flash_forward(
|
|
| 722 |
key_states: torch.Tensor,
|
| 723 |
value_states: torch.Tensor,
|
| 724 |
causal: bool = False,
|
|
|
|
| 725 |
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| 727 |
if FLASH_KERNEL_VARIANT == "flash_attn2":
|
| 728 |
-
return FLASH_KERNEL.fwd(
|
|
|
|
|
|
|
|
|
|
| 729 |
if FLASH_KERNEL_VARIANT == "flash_attn3":
|
| 730 |
try:
|
| 731 |
-
output = FLASH_KERNEL.flash_attn_func(
|
|
|
|
|
|
|
|
|
|
| 732 |
except TypeError:
|
| 733 |
-
output = FLASH_KERNEL.flash_attn_func(
|
|
|
|
|
|
|
|
|
|
| 734 |
if isinstance(output, tuple):
|
| 735 |
return output[0]
|
| 736 |
return output
|
|
@@ -746,14 +765,20 @@ def _kernels_flash_varlen_forward(
|
|
| 746 |
max_seqlen_in_batch_q: int,
|
| 747 |
max_seqlen_in_batch_k: int,
|
| 748 |
causal: bool = False,
|
|
|
|
| 749 |
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| 751 |
if FLASH_KERNEL_VARIANT == "flash_attn2":
|
| 752 |
return FLASH_KERNEL.varlen_fwd(
|
| 753 |
q=query_states, k=key_states, v=value_states,
|
| 754 |
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 755 |
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
|
| 756 |
-
is_causal=causal,
|
| 757 |
)[0]
|
| 758 |
if FLASH_KERNEL_VARIANT == "flash_attn3":
|
| 759 |
try:
|
|
@@ -761,14 +786,14 @@ def _kernels_flash_varlen_forward(
|
|
| 761 |
q=query_states, k=key_states, v=value_states,
|
| 762 |
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 763 |
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
|
| 764 |
-
causal=causal,
|
| 765 |
)
|
| 766 |
except TypeError:
|
| 767 |
output = FLASH_KERNEL.flash_attn_varlen_func(
|
| 768 |
query_states, key_states, value_states,
|
| 769 |
cu_seqlens_q, cu_seqlens_k,
|
| 770 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k,
|
| 771 |
-
0.0,
|
| 772 |
)
|
| 773 |
if isinstance(output, tuple):
|
| 774 |
return output[0]
|
|
@@ -849,7 +874,21 @@ def kernels_flash_attention_func(
|
|
| 849 |
value_states: torch.Tensor,
|
| 850 |
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 851 |
causal: bool = False,
|
|
|
|
| 852 |
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| 854 |
if not causal and attention_mask_2d is not None:
|
| 855 |
batch_size, q_len = query_states.shape[:2]
|
|
@@ -861,11 +900,13 @@ def kernels_flash_attention_func(
|
|
| 861 |
query_states=query_states, key_states=key_states, value_states=value_states,
|
| 862 |
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 863 |
max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
|
|
|
|
| 864 |
)
|
| 865 |
return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
|
| 866 |
else:
|
| 867 |
return _kernels_flash_forward(
|
| 868 |
-
query_states=query_states, key_states=key_states, value_states=value_states,
|
|
|
|
| 869 |
)
|
| 870 |
|
| 871 |
|
|
@@ -948,6 +989,25 @@ def get_attention_mask(
|
|
| 948 |
attention_mask_4d = attention_mask_2d[:, None, None, :]
|
| 949 |
return attention_mask_2d, attention_mask_4d, None
|
| 950 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 951 |
import math
|
| 952 |
|
| 953 |
import torch
|
|
@@ -1091,7 +1151,9 @@ class AnkhSelfAttention(nn.Module):
|
|
| 1091 |
self.k = nn.Linear(config.d_model, self.inner_dim, bias=False)
|
| 1092 |
self.v = nn.Linear(config.d_model, self.inner_dim, bias=False)
|
| 1093 |
self.o = nn.Linear(self.inner_dim, config.d_model, bias=False)
|
| 1094 |
-
|
|
|
|
|
|
|
| 1095 |
|
| 1096 |
if self.has_relative_attention_bias:
|
| 1097 |
self.relative_attention_bias = nn.Embedding(
|
|
@@ -1163,11 +1225,9 @@ class AnkhSelfAttention(nn.Module):
|
|
| 1163 |
# Compute position bias on first layer (SDPA/manual only; flex uses score_mod)
|
| 1164 |
if position_bias is None and self.has_relative_attention_bias and self.attn_backend != AttentionBackend.FLEX:
|
| 1165 |
position_bias = self.compute_bias(seq_length, seq_length, hidden_states.device)
|
| 1166 |
-
# Fold padding mask into position bias so layers don't need separate mask
|
| 1167 |
if attention_mask_4d is not None:
|
| 1168 |
-
position_bias = position_bias + attention_mask_4d.
|
| 1169 |
-
attention_mask_4d.logical_not(), float("-inf")
|
| 1170 |
-
)
|
| 1171 |
|
| 1172 |
if output_attentions:
|
| 1173 |
attn_output, attn_weights = self._manual_attn(query_BHLD, key_BHLD, value_BHLD, position_bias)
|
|
|
|
| 722 |
key_states: torch.Tensor,
|
| 723 |
value_states: torch.Tensor,
|
| 724 |
causal: bool = False,
|
| 725 |
+
softmax_scale: Optional[float] = None,
|
| 726 |
) -> torch.Tensor:
|
| 727 |
+
"""Flash-attention forward, optionally overriding the softmax scale.
|
| 728 |
+
|
| 729 |
+
When `softmax_scale is None`, the flash kernel applies its default
|
| 730 |
+
`1 / sqrt(head_dim)`. Pass `softmax_scale=1.0` if the caller has already
|
| 731 |
+
pre-scaled Q (the convention used by ESM2, DPLM, DPLM2, E1, ESMFold).
|
| 732 |
+
Failing to override when Q is pre-scaled produces DOUBLE scaling and
|
| 733 |
+
catastrophic downstream drift -- on DPLM-150M (30 layers) this was observed
|
| 734 |
+
as pooled-embedding cosine ~-0.12 and argmax agreement ~0.27 vs sdpa.
|
| 735 |
+
"""
|
| 736 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| 737 |
if FLASH_KERNEL_VARIANT == "flash_attn2":
|
| 738 |
+
return FLASH_KERNEL.fwd(
|
| 739 |
+
q=query_states, k=key_states, v=value_states,
|
| 740 |
+
softmax_scale=softmax_scale, is_causal=causal,
|
| 741 |
+
)[0]
|
| 742 |
if FLASH_KERNEL_VARIANT == "flash_attn3":
|
| 743 |
try:
|
| 744 |
+
output = FLASH_KERNEL.flash_attn_func(
|
| 745 |
+
q=query_states, k=key_states, v=value_states,
|
| 746 |
+
softmax_scale=softmax_scale, causal=causal,
|
| 747 |
+
)
|
| 748 |
except TypeError:
|
| 749 |
+
output = FLASH_KERNEL.flash_attn_func(
|
| 750 |
+
query_states, key_states, value_states,
|
| 751 |
+
0.0, softmax_scale, causal,
|
| 752 |
+
)
|
| 753 |
if isinstance(output, tuple):
|
| 754 |
return output[0]
|
| 755 |
return output
|
|
|
|
| 765 |
max_seqlen_in_batch_q: int,
|
| 766 |
max_seqlen_in_batch_k: int,
|
| 767 |
causal: bool = False,
|
| 768 |
+
softmax_scale: Optional[float] = None,
|
| 769 |
) -> torch.Tensor:
|
| 770 |
+
"""Varlen flash-attention forward, optionally overriding the softmax scale.
|
| 771 |
+
|
| 772 |
+
See `_kernels_flash_forward` docstring for why `softmax_scale=1.0` must be
|
| 773 |
+
passed when Q has been pre-scaled by the caller.
|
| 774 |
+
"""
|
| 775 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| 776 |
if FLASH_KERNEL_VARIANT == "flash_attn2":
|
| 777 |
return FLASH_KERNEL.varlen_fwd(
|
| 778 |
q=query_states, k=key_states, v=value_states,
|
| 779 |
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 780 |
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
|
| 781 |
+
softmax_scale=softmax_scale, is_causal=causal,
|
| 782 |
)[0]
|
| 783 |
if FLASH_KERNEL_VARIANT == "flash_attn3":
|
| 784 |
try:
|
|
|
|
| 786 |
q=query_states, k=key_states, v=value_states,
|
| 787 |
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 788 |
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
|
| 789 |
+
softmax_scale=softmax_scale, causal=causal,
|
| 790 |
)
|
| 791 |
except TypeError:
|
| 792 |
output = FLASH_KERNEL.flash_attn_varlen_func(
|
| 793 |
query_states, key_states, value_states,
|
| 794 |
cu_seqlens_q, cu_seqlens_k,
|
| 795 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k,
|
| 796 |
+
0.0, softmax_scale, causal,
|
| 797 |
)
|
| 798 |
if isinstance(output, tuple):
|
| 799 |
return output[0]
|
|
|
|
| 874 |
value_states: torch.Tensor,
|
| 875 |
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 876 |
causal: bool = False,
|
| 877 |
+
softmax_scale: Optional[float] = None,
|
| 878 |
) -> torch.Tensor:
|
| 879 |
+
"""Public flash-attention entry point with optional padding handling.
|
| 880 |
+
|
| 881 |
+
`softmax_scale`:
|
| 882 |
+
None -> kernel applies its default `1 / sqrt(head_dim)`.
|
| 883 |
+
float -> kernel uses the given scale (pass 1.0 when Q is pre-scaled
|
| 884 |
+
by the caller).
|
| 885 |
+
|
| 886 |
+
IMPORTANT: if your family multiplies Q by `1/sqrt(head_dim)` before calling
|
| 887 |
+
this function (as ESM2, DPLM, DPLM2, E1, and ESMFold do) you MUST pass
|
| 888 |
+
`softmax_scale=1.0`. Otherwise the kernel applies its default scale ON TOP
|
| 889 |
+
of the caller's, producing effective scale `1/head_dim` and catastrophic
|
| 890 |
+
downstream drift that compounds across layers.
|
| 891 |
+
"""
|
| 892 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
| 893 |
if not causal and attention_mask_2d is not None:
|
| 894 |
batch_size, q_len = query_states.shape[:2]
|
|
|
|
| 900 |
query_states=query_states, key_states=key_states, value_states=value_states,
|
| 901 |
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 902 |
max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
|
| 903 |
+
softmax_scale=softmax_scale,
|
| 904 |
)
|
| 905 |
return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
|
| 906 |
else:
|
| 907 |
return _kernels_flash_forward(
|
| 908 |
+
query_states=query_states, key_states=key_states, value_states=value_states,
|
| 909 |
+
causal=causal, softmax_scale=softmax_scale,
|
| 910 |
)
|
| 911 |
|
| 912 |
|
|
|
|
| 989 |
attention_mask_4d = attention_mask_2d[:, None, None, :]
|
| 990 |
return attention_mask_2d, attention_mask_4d, None
|
| 991 |
|
| 992 |
+
|
| 993 |
+
def bool_to_additive_mask(
|
| 994 |
+
bool_mask: torch.Tensor,
|
| 995 |
+
dtype: torch.dtype,
|
| 996 |
+
) -> torch.Tensor:
|
| 997 |
+
"""Convert a bool mask (True = valid) to a float additive mask (0.0 valid, -inf invalid).
|
| 998 |
+
|
| 999 |
+
Why this exists: calling `bool_mask.masked_fill(bool_mask.logical_not(), float('-inf'))`
|
| 1000 |
+
directly on a bool tensor returns a bool tensor -- because `-inf` casts to `True` -- and
|
| 1001 |
+
silently drops the mask entirely. Always allocate a float tensor first, then fill it.
|
| 1002 |
+
This helper is the sanctioned way to build an SDPA additive mask from a bool validity mask.
|
| 1003 |
+
"""
|
| 1004 |
+
assert bool_mask.dtype == torch.bool, (
|
| 1005 |
+
f"bool_to_additive_mask requires a bool tensor, got dtype={bool_mask.dtype}"
|
| 1006 |
+
)
|
| 1007 |
+
additive = torch.zeros_like(bool_mask, dtype=dtype)
|
| 1008 |
+
additive.masked_fill_(bool_mask.logical_not(), float("-inf"))
|
| 1009 |
+
return additive
|
| 1010 |
+
|
| 1011 |
import math
|
| 1012 |
|
| 1013 |
import torch
|
|
|
|
| 1151 |
self.k = nn.Linear(config.d_model, self.inner_dim, bias=False)
|
| 1152 |
self.v = nn.Linear(config.d_model, self.inner_dim, bias=False)
|
| 1153 |
self.o = nn.Linear(self.inner_dim, config.d_model, bias=False)
|
| 1154 |
+
# T5/ANKH attention is unscaled: scores = Q K^T (no 1/sqrt(d_kv)).
|
| 1155 |
+
# The learned relative position bias absorbs any temperature.
|
| 1156 |
+
self.scale = 1.0
|
| 1157 |
|
| 1158 |
if self.has_relative_attention_bias:
|
| 1159 |
self.relative_attention_bias = nn.Embedding(
|
|
|
|
| 1225 |
# Compute position bias on first layer (SDPA/manual only; flex uses score_mod)
|
| 1226 |
if position_bias is None and self.has_relative_attention_bias and self.attn_backend != AttentionBackend.FLEX:
|
| 1227 |
position_bias = self.compute_bias(seq_length, seq_length, hidden_states.device)
|
| 1228 |
+
# Fold padding mask into position bias so layers don't need separate mask.
|
| 1229 |
if attention_mask_4d is not None:
|
| 1230 |
+
position_bias = position_bias + bool_to_additive_mask(attention_mask_4d, position_bias.dtype)
|
|
|
|
|
|
|
| 1231 |
|
| 1232 |
if output_attentions:
|
| 1233 |
attn_output, attn_weights = self._manual_attn(query_BHLD, key_BHLD, value_BHLD, position_bias)
|