Add ApplyRoPE and RMSNorm kernels written in OpenAI Triton to `dev_triton` branch

#9
Files changed (3) hide show
  1. README.md +1 -8
  2. assets/wechat.png +0 -0
  3. modeling_qwen.py +6 -4
README.md CHANGED
@@ -6,9 +6,6 @@ tags:
6
  - qwen
7
  pipeline_tag: text-generation
8
  inference: false
9
- license: other
10
- license_name: tongyi-qianwen-license-agreement
11
- license_link: https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT
12
  ---
13
 
14
  # Qwen-7B-Chat-Int4
@@ -21,7 +18,7 @@ license_link: https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENS
21
  <p align="center">
22
  🤗 <a href="https://huggingface.co/Qwen">Hugging Face</a>&nbsp&nbsp | &nbsp&nbsp🤖 <a href="https://modelscope.cn/organization/qwen">ModelScope</a>&nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://arxiv.org/abs/2309.16609">Paper</a> &nbsp&nbsp | &nbsp&nbsp🖥️ <a href="https://modelscope.cn/studios/qwen/Qwen-7B-Chat-Demo/summary">Demo</a>
23
  <br>
24
- <a href="https://github.com/QwenLM/Qwen/blob/main/assets/wechat.png">WeChat (微信)</a>&nbsp&nbsp | &nbsp&nbsp<a href="https://discord.gg/z3GAxXZ9Ce">Discord</a>&nbsp&nbsp | &nbsp&nbsp<a href="https://dashscope.aliyun.com">API</a>
25
  </p>
26
  <br>
27
 
@@ -70,10 +67,6 @@ cd flash-attention && pip install .
70
  # pip install csrc/layer_norm
71
  # pip install csrc/rotary
72
  ```
73
-
74
- 如果您有更高推理性能方面的需求,但上述可选加速项`layer_norm`及`rotary`未能安装成功,或是您所使用的GPU不满足`flash-attention`库所要求的NVIDIA Ampere/Ada/Hopper架构,您可以尝试切换至dev_triton分支,使用该分支下基于Triton实现的推理加速方案。该方案适用于更宽范围的GPU产品,在pytorch 2.0及以上版本原生支持,无需额外安装操作。
75
-
76
- If you require higher inference performance yet encounter some problems when installing the optional acceleration features (i.e., `layer_norm` and `rotary`) or if the GPU you are using does not meet the NVIDIA Ampere/Ada/Hopper architecture required by the `flash-attention` library, you may switch to the dev_triton branch and consider trying the inference acceleration solution implemented with Triton in this branch. This solution adapts to a wider range of GPU products and does not require extra package installation with pytorch version 2.0 and above.
77
  <br>
78
 
79
 
 
6
  - qwen
7
  pipeline_tag: text-generation
8
  inference: false
 
 
 
9
  ---
10
 
11
  # Qwen-7B-Chat-Int4
 
18
  <p align="center">
19
  🤗 <a href="https://huggingface.co/Qwen">Hugging Face</a>&nbsp&nbsp | &nbsp&nbsp🤖 <a href="https://modelscope.cn/organization/qwen">ModelScope</a>&nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://arxiv.org/abs/2309.16609">Paper</a> &nbsp&nbsp | &nbsp&nbsp🖥️ <a href="https://modelscope.cn/studios/qwen/Qwen-7B-Chat-Demo/summary">Demo</a>
20
  <br>
21
+ <a href="assets/wechat.png">WeChat (微信)</a>&nbsp&nbsp | &nbsp&nbsp<a href="https://discord.gg/z3GAxXZ9Ce">Discord</a>&nbsp&nbsp | &nbsp&nbsp<a href="https://dashscope.aliyun.com">API</a>
22
  </p>
23
  <br>
24
 
 
67
  # pip install csrc/layer_norm
68
  # pip install csrc/rotary
69
  ```
 
 
 
 
70
  <br>
71
 
72
 
assets/wechat.png CHANGED
modeling_qwen.py CHANGED
@@ -520,7 +520,9 @@ class QWenAttention(nn.Module):
520
 
521
  if not self.use_cache_quantization and SUPPORT_TORCH2:
522
  if attention_mask is not None:
523
- attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
 
 
524
  if causal_mask is not None:
525
  attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
526
  else:
@@ -1328,14 +1330,14 @@ def apply_rotary_pos_emb(t, freqs):
1328
  t (tensor(batch_size, seq_len, n_head, head_dim)):
1329
  the input embedding/hidden states
1330
  freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
1331
- the cached cos/sin position embeddings
1332
  """
1333
  rot_dim = freqs[0].shape[-1]
1334
  cos, sin = freqs
1335
  t_float = t.float()
1336
  if apply_rotary_emb_func is not None and t.is_cuda:
1337
- # apply_rotary_emb in flash_attn requires cos/sin to be of
1338
- # shape (seqlen, rotary_dim / 2) and apply rotary embedding
1339
  # to the first rotary_dim of the input
1340
  cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
1341
  sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]
 
520
 
521
  if not self.use_cache_quantization and SUPPORT_TORCH2:
522
  if attention_mask is not None:
523
+ attention_mask = attention_mask.expand(
524
+ -1, -1, causal_mask.size(2), -1
525
+ )
526
  if causal_mask is not None:
527
  attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
528
  else:
 
1330
  t (tensor(batch_size, seq_len, n_head, head_dim)):
1331
  the input embedding/hidden states
1332
  freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
1333
+ the cached cos/sin position embeddings
1334
  """
1335
  rot_dim = freqs[0].shape[-1]
1336
  cos, sin = freqs
1337
  t_float = t.float()
1338
  if apply_rotary_emb_func is not None and t.is_cuda:
1339
+ # apply_rotary_emb in flash_attn requires cos/sin to be of
1340
+ # shape (seqlen, rotary_dim / 2) and apply rotary embedding
1341
  # to the first rotary_dim of the input
1342
  cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
1343
  sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]