Use try-except for flash_attn
#5
by
LiangliangMa
- opened
- modeling_deepseek.py +3 -3
modeling_deepseek.py
CHANGED
@@ -48,7 +48,6 @@ from transformers.pytorch_utils import (
|
|
48 |
from transformers.utils import (
|
49 |
add_start_docstrings,
|
50 |
add_start_docstrings_to_model_forward,
|
51 |
-
is_flash_attn_2_available,
|
52 |
is_flash_attn_greater_or_equal_2_10,
|
53 |
logging,
|
54 |
replace_return_docstrings,
|
@@ -58,10 +57,11 @@ from .configuration_deepseek import DeepseekV2Config
|
|
58 |
import torch.distributed as dist
|
59 |
import numpy as np
|
60 |
|
61 |
-
|
62 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
63 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
64 |
-
|
|
|
65 |
|
66 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
67 |
# It means that the function will not be traced through and simply appear as a node in the graph.
|
|
|
48 |
from transformers.utils import (
|
49 |
add_start_docstrings,
|
50 |
add_start_docstrings_to_model_forward,
|
|
|
51 |
is_flash_attn_greater_or_equal_2_10,
|
52 |
logging,
|
53 |
replace_return_docstrings,
|
|
|
57 |
import torch.distributed as dist
|
58 |
import numpy as np
|
59 |
|
60 |
+
try:
|
61 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
62 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
63 |
+
except ImportError:
|
64 |
+
pass
|
65 |
|
66 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
67 |
# It means that the function will not be traced through and simply appear as a node in the graph.
|