dinalt commited on
Commit
7820c23
·
verified ·
1 Parent(s): 0b1deee

Workaround for issue

Browse files

Workaround for issue: get_imports failing to respect conditionals on imports
https://github.com/huggingface/transformers/issues/28459

This should allow this code to work without flash2 module installed -- and allow the code to run on a CPU.

Files changed (1) hide show
  1. modelling_walsh.py +14 -7
modelling_walsh.py CHANGED
@@ -27,6 +27,13 @@ from transformers.utils import (
27
  is_flash_attn_greater_or_equal_2_10,
28
  )
29
 
 
 
 
 
 
 
 
30
  if is_flash_attn_2_available():
31
  from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
32
 
@@ -825,7 +832,7 @@ class CausalSelfAttention(nn.Module):
825
  init.constant_(self.output_linear.bias, 0.)
826
 
827
  # Project QKV input through input matrices, reshape to (batch_size, n_heads, seq_len, d_model), and apply cache.
828
- def project_input(self, qkv, past_key_values):
829
  batch_size, seq_len, d_embed = qkv.shape
830
  proj = self.in_proj(qkv)
831
  query, key, value = proj.chunk(chunks=3, dim=-1)
@@ -857,15 +864,15 @@ class CausalSelfAttention(nn.Module):
857
 
858
  if attn_type == "flash2":
859
  if use_cache is None or use_cache == False:
860
- return self.flash2_forward(qkv)
861
  else:
862
- return self.flash2_forward_cached(qkv, past_key_values)
863
 
864
  # qkv: (batch_size, seq_len, d_embed)
865
  batch_size, seq_len, d_embed = qkv.shape
866
 
867
  # Feed the inputs through the K, Q, V matrices.
868
- query, key, value = self.project_input(qkv, past_key_values)
869
  kv_seq_len = key.shape[-2]
870
 
871
  # Default to returning empty attention weights.
@@ -922,7 +929,7 @@ class CausalSelfAttention(nn.Module):
922
  )
923
 
924
  # No cache support, but faster
925
- def flash2_forward(
926
  self,
927
  qkv,
928
  ):
@@ -961,7 +968,7 @@ class CausalSelfAttention(nn.Module):
961
 
962
  # See https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py
963
  #https://huggingface.co/docs/transformers/internal/generation_utils
964
- def flash2_forward_cached(
965
  self,
966
  qkv,
967
  past_key_values,
@@ -969,7 +976,7 @@ class CausalSelfAttention(nn.Module):
969
  batch_size, seq_len, d_embed = qkv.shape
970
 
971
  # Feed the inputs through the K, Q, V matrices.
972
- query, key, value = self.project_input(qkv, past_key_values)
973
  query, key, value = self._downcast_to_float16(query, key, value)
974
 
975
  # Expected inputs to flash2:
 
27
  is_flash_attn_greater_or_equal_2_10,
28
  )
29
 
30
+ # Workaround for https://github.com/huggingface/transformers/issues/28459
31
+ if is_flash_attn_2_available():
32
+ try:
33
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
34
+ except:
35
+ print("Could not import flash2")
36
+
37
  if is_flash_attn_2_available():
38
  from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
39
 
 
832
  init.constant_(self.output_linear.bias, 0.)
833
 
834
  # Project QKV input through input matrices, reshape to (batch_size, n_heads, seq_len, d_model), and apply cache.
835
+ def _project_input(self, qkv, past_key_values):
836
  batch_size, seq_len, d_embed = qkv.shape
837
  proj = self.in_proj(qkv)
838
  query, key, value = proj.chunk(chunks=3, dim=-1)
 
864
 
865
  if attn_type == "flash2":
866
  if use_cache is None or use_cache == False:
867
+ return self._flash2_forward(qkv)
868
  else:
869
+ return self._flash2_forward_cached(qkv, past_key_values)
870
 
871
  # qkv: (batch_size, seq_len, d_embed)
872
  batch_size, seq_len, d_embed = qkv.shape
873
 
874
  # Feed the inputs through the K, Q, V matrices.
875
+ query, key, value = self._project_input(qkv, past_key_values)
876
  kv_seq_len = key.shape[-2]
877
 
878
  # Default to returning empty attention weights.
 
929
  )
930
 
931
  # No cache support, but faster
932
+ def _flash2_forward(
933
  self,
934
  qkv,
935
  ):
 
968
 
969
  # See https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py
970
  #https://huggingface.co/docs/transformers/internal/generation_utils
971
+ def _flash2_forward_cached(
972
  self,
973
  qkv,
974
  past_key_values,
 
976
  batch_size, seq_len, d_embed = qkv.shape
977
 
978
  # Feed the inputs through the K, Q, V matrices.
979
+ query, key, value = self._project_input(qkv, past_key_values)
980
  query, key, value = self._downcast_to_float16(query, key, value)
981
 
982
  # Expected inputs to flash2: