yangapku commited on
Commit
405556d
1 Parent(s): 1a2571e

fix flash-attention usage

Browse files
Files changed (3) hide show
  1. README.md +1 -2
  2. config.json +1 -1
  3. modeling_qwen.py +19 -12
README.md CHANGED
@@ -16,7 +16,7 @@ inference: false
16
  <br>
17
 
18
  <p align="center">
19
- Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖 <a> | <a href="https://huggingface.co/Qwen/Qwen-7B">🤗</a>&nbsp | Qwen-7B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary">🤖 <a>| <a href="https://huggingface.co/Qwen/Qwen-7B-Chat">🤗</a>&nbsp | &nbsp<a href="https://modelscope.cn/studios/qwen/Qwen-7B-Chat-Demo/summary">Demo</a>&nbsp | &nbsp<a href="https://github.com/QwenLM/Qwen-7B/blob/main/tech_memo.md">Report</a>
20
  </p>
21
  <br>
22
 
@@ -319,4 +319,3 @@ Our code and checkpoints are open to research purpose, and they are allowed for
319
  如果你想给我们的研发团队和产品团队留言,请通过邮件(qianwen_opensource@alibabacloud.com)联系我们。
320
 
321
  If you are interested to leave a message to either our research team or product team, feel free to send an email to qianwen_opensource@alibabacloud.com.
322
-
 
16
  <br>
17
 
18
  <p align="center">
19
+ Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖 </a> | <a href="https://huggingface.co/Qwen/Qwen-7B">🤗</a>&nbsp | Qwen-7B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary">🤖 </a>| <a href="https://huggingface.co/Qwen/Qwen-7B-Chat">🤗</a>&nbsp | &nbsp<a href="https://modelscope.cn/studios/qwen/Qwen-7B-Chat-Demo/summary">Demo</a>&nbsp | &nbsp<a href="https://github.com/QwenLM/Qwen-7B/blob/main/tech_memo.md">Report</a>
20
  </p>
21
  <br>
22
 
 
319
  如果你想给我们的研发团队和产品团队留言,请通过邮件(qianwen_opensource@alibabacloud.com)联系我们。
320
 
321
  If you are interested to leave a message to either our research team or product team, feel free to send an email to qianwen_opensource@alibabacloud.com.
 
config.json CHANGED
@@ -38,7 +38,7 @@
38
  "tokenizer_type": "QWenTokenizer",
39
  "transformers_version": "4.31.0",
40
  "use_cache": true,
41
- "use_flash_attn": false,
42
  "vocab_size": 151936,
43
  "use_dynamic_ntk": false,
44
  "use_logn_attn": false
 
38
  "tokenizer_type": "QWenTokenizer",
39
  "transformers_version": "4.31.0",
40
  "use_cache": true,
41
+ "use_flash_attn": true,
42
  "vocab_size": 151936,
43
  "use_dynamic_ntk": false,
44
  "use_logn_attn": false
modeling_qwen.py CHANGED
@@ -36,18 +36,17 @@ try:
36
  from einops import rearrange
37
 
38
  use_flash_rotary = True
39
- print("use flash_attn rotary")
40
  except ImportError:
41
  use_flash_rotary = False
42
- print("import flash_attn rotary fail")
 
43
 
44
  try:
45
  from flash_attn.ops.rms_norm import rms_norm
46
-
47
- print("use flash_attn rms_norm")
48
  except ImportError:
49
  rms_norm = None
50
- print("import flash_attn rms_norm fail")
 
51
 
52
  from .configuration_qwen import QWenConfig
53
  from .qwen_generation_utils import (
@@ -70,6 +69,8 @@ try:
70
  from flash_attn.flash_attn_interface import flash_attn_unpadded_func
71
  except ImportError:
72
  flash_attn_unpadded_func = None
 
 
73
 
74
 
75
  class FlashSelfAttention(torch.nn.Module):
@@ -176,7 +177,7 @@ class QWenAttention(nn.Module):
176
  config.hidden_size, self.projection_size, bias=not config.no_bias
177
  )
178
 
179
- if self.use_flash_attn:
180
  self.core_attention_flash = FlashSelfAttention(
181
  causal=True, attention_dropout=config.attn_pdrop
182
  )
@@ -333,7 +334,7 @@ class QWenAttention(nn.Module):
333
  if layer_past:
334
  # layer past[0] shape: bs * seq_len * head_num * dim
335
  kv_seq_len += layer_past[0].shape[1]
336
- if self.use_dynamic_ntk and kv_seq_len == hidden_states.size()[1]:
337
  context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
338
  ntk_alpha = 2 ** math.ceil(context_value) - 1
339
  ntk_alpha = max(ntk_alpha, 1)
@@ -367,7 +368,7 @@ class QWenAttention(nn.Module):
367
  else:
368
  present = None
369
 
370
- if self.use_logn_attn:
371
  if self.logn_tensor.device != query.device:
372
  self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
373
  seq_start = key.size(0) - query.size(0)
@@ -375,7 +376,7 @@ class QWenAttention(nn.Module):
375
  logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
376
  query = query * logn_tensor.expand_as(query)
377
 
378
- if self.use_flash_attn:
379
  q, k, v = query, key, value
380
  context_layer = self.core_attention_flash(q, k, v)
381
 
@@ -396,7 +397,7 @@ class QWenAttention(nn.Module):
396
  attn_output = self.c_proj(context_layer)
397
  outputs = (attn_output, present)
398
  if output_attentions:
399
- if self.use_flash_attn:
400
  raise ValueError("Cannot output attentions while using flash-attn")
401
  else:
402
  outputs += (attn_weight,)
@@ -748,6 +749,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
748
  super().__init__(config)
749
  self.transformer = QWenModel(config)
750
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
 
 
 
 
 
 
 
751
  self.post_init()
752
 
753
  def get_output_embeddings(self):
@@ -957,8 +965,7 @@ class RotaryEmbedding(torch.nn.Module):
957
  super().__init__()
958
  self.dim = dim
959
  self.base = base
960
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
961
- self.register_buffer("inv_freq", inv_freq)
962
  if importlib.util.find_spec("einops") is None:
963
  raise RuntimeError("einops is required for Rotary Embedding")
964
 
 
36
  from einops import rearrange
37
 
38
  use_flash_rotary = True
 
39
  except ImportError:
40
  use_flash_rotary = False
41
+ print("Warning: import flash_attn rotary fail, please install FlashAttention rotary to get better performance "
42
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary")
43
 
44
  try:
45
  from flash_attn.ops.rms_norm import rms_norm
 
 
46
  except ImportError:
47
  rms_norm = None
48
+ print("Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get better performance "
49
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm")
50
 
51
  from .configuration_qwen import QWenConfig
52
  from .qwen_generation_utils import (
 
69
  from flash_attn.flash_attn_interface import flash_attn_unpadded_func
70
  except ImportError:
71
  flash_attn_unpadded_func = None
72
+ print("Warning: import flash_attn fail, please install FlashAttention "
73
+ "https://github.com/Dao-AILab/flash-attention")
74
 
75
 
76
  class FlashSelfAttention(torch.nn.Module):
 
177
  config.hidden_size, self.projection_size, bias=not config.no_bias
178
  )
179
 
180
+ if self.use_flash_attn and flash_attn_unpadded_func is not None:
181
  self.core_attention_flash = FlashSelfAttention(
182
  causal=True, attention_dropout=config.attn_pdrop
183
  )
 
334
  if layer_past:
335
  # layer past[0] shape: bs * seq_len * head_num * dim
336
  kv_seq_len += layer_past[0].shape[1]
337
+ if self.use_dynamic_ntk and kv_seq_len == hidden_states.size()[1] and not self.training:
338
  context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
339
  ntk_alpha = 2 ** math.ceil(context_value) - 1
340
  ntk_alpha = max(ntk_alpha, 1)
 
368
  else:
369
  present = None
370
 
371
+ if self.use_logn_attn and not self.training:
372
  if self.logn_tensor.device != query.device:
373
  self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
374
  seq_start = key.size(0) - query.size(0)
 
376
  logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
377
  query = query * logn_tensor.expand_as(query)
378
 
379
+ if self.use_flash_attn and flash_attn_unpadded_func is not None:
380
  q, k, v = query, key, value
381
  context_layer = self.core_attention_flash(q, k, v)
382
 
 
397
  attn_output = self.c_proj(context_layer)
398
  outputs = (attn_output, present)
399
  if output_attentions:
400
+ if self.use_flash_attn and flash_attn_unpadded_func is not None:
401
  raise ValueError("Cannot output attentions while using flash-attn")
402
  else:
403
  outputs += (attn_weight,)
 
749
  super().__init__(config)
750
  self.transformer = QWenModel(config)
751
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
752
+ assert not(config.bf16 and config.fp16), ("In config, bf16 and fp16 cannot both be true")
753
+ if config.bf16:
754
+ self.transformer.bfloat16()
755
+ self.lm_head.bfloat16()
756
+ if config.fp16:
757
+ self.transformer.half()
758
+ self.lm_head.half()
759
  self.post_init()
760
 
761
  def get_output_embeddings(self):
 
965
  super().__init__()
966
  self.dim = dim
967
  self.base = base
968
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
 
969
  if importlib.util.find_spec("einops") is None:
970
  raise RuntimeError("einops is required for Rotary Embedding")
971