x54-729 commited on
Commit
2d0920c
1 Parent(s): 2750ce8

fix flash attention import

Browse files
configuration_internlm2.py CHANGED
@@ -169,5 +169,12 @@ class InternLM2Config(PretrainedConfig):
169
  raise ValueError(
170
  f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
171
  )
172
- if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
173
- raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
 
 
 
 
 
 
 
 
169
  raise ValueError(
170
  f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
171
  )
172
+ if (
173
+ rope_scaling_factor is None
174
+ or not isinstance(rope_scaling_factor, (float, int))
175
+ or rope_scaling_factor < 1.0
176
+ ):
177
+ raise ValueError(
178
+ f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
179
+ f"of type {type(rope_scaling_factor)}"
180
+ )
modeling_internlm2.py CHANGED
@@ -40,7 +40,6 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
40
  from transformers.utils import (
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
43
- is_flash_attn_2_available,
44
  is_flash_attn_greater_or_equal_2_10,
45
  logging,
46
  replace_return_docstrings,
@@ -53,9 +52,12 @@ except Exception:
53
 
54
  from .configuration_internlm2 import InternLM2Config
55
 
56
- if is_flash_attn_2_available():
 
57
  from flash_attn import flash_attn_func, flash_attn_varlen_func
58
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
 
 
59
 
60
 
61
  logger = logging.get_logger(__name__)
 
40
  from transformers.utils import (
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
 
43
  is_flash_attn_greater_or_equal_2_10,
44
  logging,
45
  replace_return_docstrings,
 
52
 
53
  from .configuration_internlm2 import InternLM2Config
54
 
55
+
56
+ try:
57
  from flash_attn import flash_attn_func, flash_attn_varlen_func
58
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
59
+ except:
60
+ pass
61
 
62
 
63
  logger = logging.get_logger(__name__)