Workaround for issue
Browse filesWorkaround for issue: get_imports failing to respect conditionals on imports
https://github.com/huggingface/transformers/issues/28459
This should allow this code to work without flash2 module installed -- and allow the code to run on a CPU.
- modelling_walsh.py +14 -7
modelling_walsh.py
CHANGED
@@ -27,6 +27,13 @@ from transformers.utils import (
|
|
27 |
is_flash_attn_greater_or_equal_2_10,
|
28 |
)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
if is_flash_attn_2_available():
|
31 |
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
32 |
|
@@ -825,7 +832,7 @@ class CausalSelfAttention(nn.Module):
|
|
825 |
init.constant_(self.output_linear.bias, 0.)
|
826 |
|
827 |
# Project QKV input through input matrices, reshape to (batch_size, n_heads, seq_len, d_model), and apply cache.
|
828 |
-
def
|
829 |
batch_size, seq_len, d_embed = qkv.shape
|
830 |
proj = self.in_proj(qkv)
|
831 |
query, key, value = proj.chunk(chunks=3, dim=-1)
|
@@ -857,15 +864,15 @@ class CausalSelfAttention(nn.Module):
|
|
857 |
|
858 |
if attn_type == "flash2":
|
859 |
if use_cache is None or use_cache == False:
|
860 |
-
return self.
|
861 |
else:
|
862 |
-
return self.
|
863 |
|
864 |
# qkv: (batch_size, seq_len, d_embed)
|
865 |
batch_size, seq_len, d_embed = qkv.shape
|
866 |
|
867 |
# Feed the inputs through the K, Q, V matrices.
|
868 |
-
query, key, value = self.
|
869 |
kv_seq_len = key.shape[-2]
|
870 |
|
871 |
# Default to returning empty attention weights.
|
@@ -922,7 +929,7 @@ class CausalSelfAttention(nn.Module):
|
|
922 |
)
|
923 |
|
924 |
# No cache support, but faster
|
925 |
-
def
|
926 |
self,
|
927 |
qkv,
|
928 |
):
|
@@ -961,7 +968,7 @@ class CausalSelfAttention(nn.Module):
|
|
961 |
|
962 |
# See https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py
|
963 |
#https://huggingface.co/docs/transformers/internal/generation_utils
|
964 |
-
def
|
965 |
self,
|
966 |
qkv,
|
967 |
past_key_values,
|
@@ -969,7 +976,7 @@ class CausalSelfAttention(nn.Module):
|
|
969 |
batch_size, seq_len, d_embed = qkv.shape
|
970 |
|
971 |
# Feed the inputs through the K, Q, V matrices.
|
972 |
-
query, key, value = self.
|
973 |
query, key, value = self._downcast_to_float16(query, key, value)
|
974 |
|
975 |
# Expected inputs to flash2:
|
|
|
27 |
is_flash_attn_greater_or_equal_2_10,
|
28 |
)
|
29 |
|
30 |
+
# Workaround for https://github.com/huggingface/transformers/issues/28459
|
31 |
+
if is_flash_attn_2_available():
|
32 |
+
try:
|
33 |
+
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
34 |
+
except:
|
35 |
+
print("Could not import flash2")
|
36 |
+
|
37 |
if is_flash_attn_2_available():
|
38 |
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
39 |
|
|
|
832 |
init.constant_(self.output_linear.bias, 0.)
|
833 |
|
834 |
# Project QKV input through input matrices, reshape to (batch_size, n_heads, seq_len, d_model), and apply cache.
|
835 |
+
def _project_input(self, qkv, past_key_values):
|
836 |
batch_size, seq_len, d_embed = qkv.shape
|
837 |
proj = self.in_proj(qkv)
|
838 |
query, key, value = proj.chunk(chunks=3, dim=-1)
|
|
|
864 |
|
865 |
if attn_type == "flash2":
|
866 |
if use_cache is None or use_cache == False:
|
867 |
+
return self._flash2_forward(qkv)
|
868 |
else:
|
869 |
+
return self._flash2_forward_cached(qkv, past_key_values)
|
870 |
|
871 |
# qkv: (batch_size, seq_len, d_embed)
|
872 |
batch_size, seq_len, d_embed = qkv.shape
|
873 |
|
874 |
# Feed the inputs through the K, Q, V matrices.
|
875 |
+
query, key, value = self._project_input(qkv, past_key_values)
|
876 |
kv_seq_len = key.shape[-2]
|
877 |
|
878 |
# Default to returning empty attention weights.
|
|
|
929 |
)
|
930 |
|
931 |
# No cache support, but faster
|
932 |
+
def _flash2_forward(
|
933 |
self,
|
934 |
qkv,
|
935 |
):
|
|
|
968 |
|
969 |
# See https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py
|
970 |
#https://huggingface.co/docs/transformers/internal/generation_utils
|
971 |
+
def _flash2_forward_cached(
|
972 |
self,
|
973 |
qkv,
|
974 |
past_key_values,
|
|
|
976 |
batch_size, seq_len, d_embed = qkv.shape
|
977 |
|
978 |
# Feed the inputs through the K, Q, V matrices.
|
979 |
+
query, key, value = self._project_input(qkv, past_key_values)
|
980 |
query, key, value = self._downcast_to_float16(query, key, value)
|
981 |
|
982 |
# Expected inputs to flash2:
|