BAAI
/

Update modeling_bunny_qwen2.py

#1
Files changed (1) hide show
  1. modeling_bunny_qwen2.py +8 -3
modeling_bunny_qwen2.py CHANGED
@@ -858,10 +858,15 @@ from .configuration_bunny_qwen2 import Qwen2Config
858
 
859
 
860
  if is_flash_attn_2_available():
861
- from flash_attn import flash_attn_func, flash_attn_varlen_func
862
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
863
 
864
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
 
 
 
 
865
 
866
 
867
  logger = logging.get_logger(__name__)
 
858
 
859
 
860
  if is_flash_attn_2_available():
861
+ try:
862
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
863
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
864
 
865
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
866
+
867
+ except:
868
+ _flash_supports_window_size,flash_attn_func, flash_attn_varlen_func,\
869
+ index_first_axis, pad_input, unpad_input= None, None, None, None, None, None
870
 
871
 
872
  logger = logging.get_logger(__name__)