Update README of branch dev_triton.

#11
by Cheshire94 - opened
Files changed (6) hide show
  1. README.md +28 -1
  2. assets/wechat.png +0 -0
  3. config.json +1 -0
  4. configuration_qwen.py +2 -0
  5. modeling_qwen.py +36 -8
  6. triton_kernels.py +125 -0
README.md CHANGED
@@ -18,7 +18,7 @@ inference: false
18
  <p align="center">
19
  🤗 <a href="https://huggingface.co/Qwen">Hugging Face</a>&nbsp&nbsp | &nbsp&nbsp🤖 <a href="https://modelscope.cn/organization/qwen">ModelScope</a>&nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://arxiv.org/abs/2309.16609">Paper</a> &nbsp&nbsp | &nbsp&nbsp🖥️ <a href="https://modelscope.cn/studios/qwen/Qwen-7B-Chat-Demo/summary">Demo</a>
20
  <br>
21
- <a href="assets/wechat.png">WeChat (微信)</a>&nbsp&nbsp | &nbsp&nbsp<a href="https://discord.gg/z3GAxXZ9Ce">Discord</a>&nbsp&nbsp | &nbsp&nbsp<a href="https://dashscope.aliyun.com">API</a>
22
  </p>
23
  <br>
24
 
@@ -67,6 +67,14 @@ cd flash-attention && pip install .
67
  # pip install csrc/layer_norm
68
  # pip install csrc/rotary
69
  ```
 
 
 
 
 
 
 
 
70
  <br>
71
 
72
 
@@ -140,6 +148,25 @@ In detail, the setting of profiling is generating 8192 new tokens with 1 context
140
 
141
  Note: The generation speed of the Int4/Int8 models mentioned above is provided by the autogptq library. The current speed of the model loaded using "AutoModelForCausalLM.from_pretrained" will be approximately 20% slower. We have reported this issue to the HuggingFace team and will update it promptly if a solution is available.
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  ### 显存使用 (GPU Memory Usage)
144
 
145
  我们还测算了不同模型精度编码2048个token及生成8192个token的峰值显存占用情况。(显存消耗在是否使用FlashAttn的情况下均类似。)结果如下所示:
 
18
  <p align="center">
19
  🤗 <a href="https://huggingface.co/Qwen">Hugging Face</a>&nbsp&nbsp | &nbsp&nbsp🤖 <a href="https://modelscope.cn/organization/qwen">ModelScope</a>&nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://arxiv.org/abs/2309.16609">Paper</a> &nbsp&nbsp | &nbsp&nbsp🖥️ <a href="https://modelscope.cn/studios/qwen/Qwen-7B-Chat-Demo/summary">Demo</a>
20
  <br>
21
+ <a href="https://github.com/QwenLM/Qwen/blob/main/assets/wechat.png">WeChat (微信)</a>&nbsp&nbsp | &nbsp&nbsp<a href="https://discord.gg/z3GAxXZ9Ce">Discord</a>&nbsp&nbsp | &nbsp&nbsp<a href="https://dashscope.aliyun.com">API</a>
22
  </p>
23
  <br>
24
 
 
67
  # pip install csrc/layer_norm
68
  # pip install csrc/rotary
69
  ```
70
+
71
+ 如果您有更高推理性能方面的需求,但上述可选加速项`layer_norm`及`rotary`未能安装成功,或是您所使用的GPU不满足`flash-attention`库所要求的NVIDIA Ampere/Ada/Hopper架构,您可以尝试使用该分支下基于Triton进行实现的推理加速方案。该方案适用于更宽范围的GPU产品,且无需安装。您可以通过将config.json里的`use_triton`设置为true来进行启用。
72
+
73
+ **(在dev_triton分支下`use_triton`默认设置为auto,由于pytorch 2.0及以上版本已默认安装了Triton,因此上述优化方案无需其它安装与配置操作即可直接启用。如果您不想开启该优化,请将config.json里的`use_triton`设置为false)**
74
+
75
+ If you require higher inference performance yet encounter some problems when installing the optional acceleration features (i.e., `layer_norm` and `rotary`) or if the GPU you are using does not meet the NVIDIA Ampere/Ada/Hopper architecture required by the `flash-attention` library, you may consider trying the inference acceleration solution implemented with Triton in this branch. This solution adapts to a wider range of GPU products and does not require installation. You can enable this acceleration feature by setting the `use_triton` option to true in the config.json file.
76
+
77
+ **(In the dev_triton branch, `use_triton` is set to 'auto' by default. As Triton is pre-installed with pytorch version 2.0 and above, this acceleration solution can be enabled directly without additional installation or configuration. If you prefer not to activate this optimization, please set `use_triton` to false in the config.json file.)**
78
  <br>
79
 
80
 
 
148
 
149
  Note: The generation speed of the Int4/Int8 models mentioned above is provided by the autogptq library. The current speed of the model loaded using "AutoModelForCausalLM.from_pretrained" will be approximately 20% slower. We have reported this issue to the HuggingFace team and will update it promptly if a solution is available.
150
 
151
+ 另外,我们也测算了在使用不同GPU及推理加速方法时Qwen-7B-Chat-Int4模型生成2048和8192个token的平均推理速度。所有评测均使用PyTorch 2.1.0和CUDA 11.8。
152
+
153
+ In addition, we also measured the average inference speed of generating 2048 and 8192 tokens with different GPU devices and acceleration methods, respectively. All results run with PyTorch 2.1.0 and CUDA 11.8.
154
+
155
+ | GPU Device | Method | Speed (2048 tokens) | Speed (8192 tokens) |
156
+ | :--------: | :----------: | :------------------:| :------------------:|
157
+ | A10 | FlashAttn v2 | 41.28 | 30.78 |
158
+ | A10 | Triton | 49.04 | 29.17 |
159
+ | A10 | Disabled | 39.26 | 26.81 |
160
+ | V100 | FlashAttn v2 | N/A | N/A |
161
+ | V100 | Triton | 37.01 | 27.66 |
162
+ | V100 | Disabled | 24.47 | 20.40 |
163
+ | P100 | FlashAttn v2 | N/A | N/A |
164
+ | P100 | Triton | 29.03 | 13.85 |
165
+ | P100 | Disabled | 20.50 | 12.73 |
166
+ | T4 | FlashAttn v2 | N/A | N/A |
167
+ | T4 | Triton | 27.98 | 15.22 |
168
+ | T4 | Disabled | 23.11 | 14.55 |
169
+
170
  ### 显存使用 (GPU Memory Usage)
171
 
172
  我们还测算了不同模型精度编码2048个token及生成8192个token的峰值显存占用情况。(显存消耗在是否使用FlashAttn的情况下均类似。)结果如下所示:
assets/wechat.png CHANGED
config.json CHANGED
@@ -44,6 +44,7 @@
44
  "use_cache": true,
45
  "use_dynamic_ntk": true,
46
  "use_flash_attn": "auto",
 
47
  "use_logn_attn": true,
48
  "vocab_size": 151936
49
  }
 
44
  "use_cache": true,
45
  "use_dynamic_ntk": true,
46
  "use_flash_attn": "auto",
47
+ "use_triton": "auto",
48
  "use_logn_attn": true,
49
  "vocab_size": 151936
50
  }
configuration_qwen.py CHANGED
@@ -32,6 +32,7 @@ class QWenConfig(PretrainedConfig):
32
  use_dynamic_ntk=True,
33
  use_logn_attn=True,
34
  use_flash_attn="auto",
 
35
  intermediate_size=22016,
36
  no_bias=True,
37
  tie_word_embeddings=False,
@@ -61,6 +62,7 @@ class QWenConfig(PretrainedConfig):
61
  self.use_dynamic_ntk = use_dynamic_ntk
62
  self.use_logn_attn = use_logn_attn
63
  self.use_flash_attn = use_flash_attn
 
64
  self.no_bias = no_bias
65
  self.use_cache_quantization = use_cache_quantization
66
  self.use_cache_kernel = use_cache_kernel
 
32
  use_dynamic_ntk=True,
33
  use_logn_attn=True,
34
  use_flash_attn="auto",
35
+ use_triton="auto",
36
  intermediate_size=22016,
37
  no_bias=True,
38
  tie_word_embeddings=False,
 
62
  self.use_dynamic_ntk = use_dynamic_ntk
63
  self.use_logn_attn = use_logn_attn
64
  self.use_flash_attn = use_flash_attn
65
+ self.use_triton = use_triton
66
  self.no_bias = no_bias
67
  self.use_cache_quantization = use_cache_quantization
68
  self.use_cache_kernel = use_cache_kernel
modeling_qwen.py CHANGED
@@ -35,7 +35,7 @@ except ImportError:
35
  from torch import nn
36
 
37
  SUPPORT_CUDA = torch.cuda.is_available()
38
- SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
39
  SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
40
  SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
41
 
@@ -76,7 +76,9 @@ We detect you have activated flash attention support, but running model computat
76
  """
77
 
78
  apply_rotary_emb_func = None
 
79
  rms_norm = None
 
80
  flash_attn_unpadded_func = None
81
  flash_attn_func = None
82
 
@@ -120,6 +122,24 @@ def _import_flash_attn():
120
  "https://github.com/Dao-AILab/flash-attention"
121
  )
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def quantize_cache_v(fdata, bits, qmax, qmin):
124
  # b, s, head, h-dim->b, head, s, h-dim
125
  qtype = torch.uint8
@@ -520,11 +540,9 @@ class QWenAttention(nn.Module):
520
 
521
  if not self.use_cache_quantization and SUPPORT_TORCH2:
522
  if attention_mask is not None:
523
- attention_mask = attention_mask.expand(
524
- -1, -1, causal_mask.size(2), -1
525
- )
526
  if causal_mask is not None:
527
- attention_mask.masked_fill_(~causal_mask, torch.finfo(query.dtype).min)
528
  else:
529
  attention_mask = causal_mask
530
  attn_output = F.scaled_dot_product_attention(
@@ -978,6 +996,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
978
  if config.use_flash_attn:
979
  _import_flash_attn()
980
 
 
 
 
 
 
 
981
  self.transformer = QWenModel(config)
982
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
983
 
@@ -1330,12 +1354,14 @@ def apply_rotary_pos_emb(t, freqs):
1330
  t (tensor(batch_size, seq_len, n_head, head_dim)):
1331
  the input embedding/hidden states
1332
  freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
1333
- the cached cos/sin position embeddings
1334
  """
1335
  rot_dim = freqs[0].shape[-1]
1336
  cos, sin = freqs
1337
  t_float = t.float()
1338
- if apply_rotary_emb_func is not None and t.is_cuda:
 
 
1339
  # apply_rotary_emb in flash_attn requires cos/sin to be of
1340
  # shape (seqlen, rotary_dim / 2) and apply rotary embedding
1341
  # to the first rotary_dim of the input
@@ -1358,7 +1384,9 @@ class RMSNorm(torch.nn.Module):
1358
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1359
 
1360
  def forward(self, x):
1361
- if rms_norm is not None and x.is_cuda:
 
 
1362
  return rms_norm(x, self.weight, self.eps)
1363
  else:
1364
  output = self._norm(x.float()).type_as(x)
 
35
  from torch import nn
36
 
37
  SUPPORT_CUDA = torch.cuda.is_available()
38
+ SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 8
39
  SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
40
  SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
41
 
 
76
  """
77
 
78
  apply_rotary_emb_func = None
79
+ apply_rotary_emb_func_triton = None
80
  rms_norm = None
81
+ rms_norm_triton = None
82
  flash_attn_unpadded_func = None
83
  flash_attn_func = None
84
 
 
122
  "https://github.com/Dao-AILab/flash-attention"
123
  )
124
 
125
+ def _import_triton():
126
+ global apply_rotary_emb_func_triton, rms_norm_triton
127
+ try:
128
+ from .triton_kernels import apply_rotary_emb as __apply_rotary_emb, rms_norm as __rms_norm
129
+ if apply_rotary_emb_func is not None:
130
+ logger.warn(
131
+ "Using Triton rotary kernel instead of flash_attn for inference."
132
+ )
133
+ apply_rotary_emb_func_triton = __apply_rotary_emb
134
+ if rms_norm is not None:
135
+ logger.warn(
136
+ "Using Triton rms_norm kernel instead of flash_attn for inference."
137
+ )
138
+ rms_norm_triton = __rms_norm
139
+ except ImportError:
140
+ logger.warn("Warning: Failed to import Triton kernels.")
141
+ return
142
+
143
  def quantize_cache_v(fdata, bits, qmax, qmin):
144
  # b, s, head, h-dim->b, head, s, h-dim
145
  qtype = torch.uint8
 
540
 
541
  if not self.use_cache_quantization and SUPPORT_TORCH2:
542
  if attention_mask is not None:
543
+ attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
 
 
544
  if causal_mask is not None:
545
+ attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
546
  else:
547
  attention_mask = causal_mask
548
  attn_output = F.scaled_dot_product_attention(
 
996
  if config.use_flash_attn:
997
  _import_flash_attn()
998
 
999
+ if config.use_triton == "auto":
1000
+ logger.warn("Try importing Triton kernels for faster inference...")
1001
+ config.use_triton = SUPPORT_TORCH2
1002
+ if config.use_triton:
1003
+ _import_triton()
1004
+
1005
  self.transformer = QWenModel(config)
1006
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1007
 
 
1354
  t (tensor(batch_size, seq_len, n_head, head_dim)):
1355
  the input embedding/hidden states
1356
  freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
1357
+ the cached cos/sin position embeddings
1358
  """
1359
  rot_dim = freqs[0].shape[-1]
1360
  cos, sin = freqs
1361
  t_float = t.float()
1362
+ if apply_rotary_emb_func_triton is not None and t.is_cuda and (not t.requires_grad):
1363
+ return apply_rotary_emb_func_triton(t, cos, sin)
1364
+ elif apply_rotary_emb_func is not None and t.is_cuda:
1365
  # apply_rotary_emb in flash_attn requires cos/sin to be of
1366
  # shape (seqlen, rotary_dim / 2) and apply rotary embedding
1367
  # to the first rotary_dim of the input
 
1384
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1385
 
1386
  def forward(self, x):
1387
+ if rms_norm_triton is not None and x.is_cuda and (not x.requires_grad):
1388
+ return rms_norm_triton(x, self.weight, self.eps)
1389
+ elif rms_norm is not None and x.is_cuda:
1390
  return rms_norm(x, self.weight, self.eps)
1391
  else:
1392
  output = self._norm(x.float()).type_as(x)
triton_kernels.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This module provides ApplyRoPE and RMSNorm kernels written in OpenAI Triton.
7
+ # Feel free to contact the contributors if you have any questions or issues regarding this code.
8
+ # Contributors: Shangming Cai, Zihan Wang
9
+ # Contacts: csmthu@gmail.com, wzh1999_frog@126.com
10
+
11
+ from typing import Any, Callable, Dict, Hashable, Tuple
12
+
13
+ import torch
14
+ import triton
15
+ import triton.language as tl
16
+ from triton.compiler import CompiledKernel
17
+ from triton.runtime import JITFunction
18
+
19
+ try:
20
+ import triton.language.math as tlmath # Triton 2.1
21
+ except ImportError:
22
+ import triton.language.libdevice as tlmath # Triton 2.0
23
+
24
+
25
+ class TritonKernel:
26
+ def __init__(
27
+ self,
28
+ kernel_fn: JITFunction,
29
+ grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]],
30
+ ) -> None:
31
+ self.kernel_fn_ = kernel_fn
32
+ self.grid_fn_ = grid_fn
33
+ self.kernel_cache_: Dict[Hashable, CompiledKernel] = {}
34
+
35
+ def run(self, *args, **kwargs):
36
+ # Set current device
37
+ input_device = args[0].device
38
+ prev_dev_idx, cur_dev_idx = -1, torch.cuda.current_device()
39
+ if input_device.index != cur_dev_idx:
40
+ prev_dev_idx = cur_dev_idx
41
+ torch.cuda.set_device(input_device.index)
42
+
43
+ # Compute grid
44
+ grid = self.grid_fn_(args)
45
+
46
+ # Use cached kernel if possible
47
+ kernel_key = (input_device,) + tuple(kwargs.items())
48
+ if kernel_key in self.kernel_cache_:
49
+ kernel = self.kernel_cache_[kernel_key]
50
+ kernel[grid](*args)
51
+ else:
52
+ # Compile and store new kernel
53
+ kernel = self.kernel_fn_[grid](*args, **kwargs)
54
+ self.kernel_cache_[kernel_key] = kernel
55
+
56
+ # Restore previous device
57
+ torch.cuda.set_device(prev_dev_idx)
58
+
59
+
60
+ @triton.jit
61
+ def _apply_rope_fwd_kernel(X, Cos, Sin, Y, HEAD_DIM: tl.constexpr):
62
+ batch_idx, tok_idx, head_idx = tl.program_id(0), tl.program_id(1), tl.program_id(2)
63
+ seq_len, num_heads = tl.num_programs(1), tl.num_programs(2)
64
+ block_idx = tl.arange(0, HEAD_DIM)
65
+ x_base_idx = ((batch_idx * seq_len + tok_idx) * num_heads * 3 + head_idx) * HEAD_DIM
66
+ x = tl.load(X + x_base_idx + block_idx)
67
+ freq_idx = tok_idx * HEAD_DIM + block_idx
68
+ cos = tl.load(Cos + freq_idx)
69
+ rot_idx = (HEAD_DIM // 2 + block_idx) % HEAD_DIM
70
+ x_rot = tl.load(X + x_base_idx + rot_idx)
71
+ x_rot = tl.where(block_idx >= HEAD_DIM // 2, x_rot, -x_rot)
72
+ sin = tl.load(Sin + freq_idx)
73
+ y_idx = (
74
+ (batch_idx * seq_len + tok_idx) * num_heads + head_idx
75
+ ) * HEAD_DIM + block_idx
76
+ y = x * cos + x_rot * sin
77
+ tl.store(Y + y_idx, y.to(Y.dtype.element_ty))
78
+
79
+
80
+ apply_rope_fwd_kernel = TritonKernel(
81
+ _apply_rope_fwd_kernel, lambda args: tuple(args[0].shape[:3])
82
+ )
83
+
84
+
85
+ def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
86
+ y = torch.empty(x.shape, dtype=x.dtype, device=x.device)
87
+ apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1))
88
+ return y
89
+
90
+
91
+ @triton.jit
92
+ def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr):
93
+ tok_idx = tl.program_id(0)
94
+
95
+ mean_sq = tl.zeros([BLOCK_SIZE], tl.float32)
96
+ for offset in range(0, hidden_dim, BLOCK_SIZE):
97
+ dim_idx = offset + tl.arange(0, BLOCK_SIZE)
98
+ x = tl.load(
99
+ X + tok_idx * hidden_dim + dim_idx, mask=dim_idx < hidden_dim, other=0
100
+ ).to(tl.float32)
101
+ mean_sq += x * x / hidden_dim
102
+ rrms = tlmath.rsqrt(tl.sum(mean_sq, 0) + eps)
103
+
104
+ for offset in range(0, hidden_dim, BLOCK_SIZE):
105
+ dim_idx = offset + tl.arange(0, BLOCK_SIZE)
106
+ dim_mask = dim_idx < hidden_dim
107
+ hidden_idx = tok_idx * hidden_dim + dim_idx
108
+ x = tl.load(X + hidden_idx, mask=dim_mask, other=0)
109
+ w = tl.load(W + dim_idx, mask=dim_mask, other=0)
110
+ y = x * rrms * w
111
+ tl.store(Y + hidden_idx, y.to(Y.dtype.element_ty), mask=dim_mask)
112
+
113
+
114
+ rms_norm_fwd_kernel = TritonKernel(
115
+ _rms_norm_fwd_kernel, lambda args: (args[0].shape[:-1].numel(), 1, 1)
116
+ )
117
+
118
+
119
+ def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float):
120
+ y = torch.empty_like(x)
121
+ hidden_dim = x.size(-1)
122
+ rms_norm_fwd_kernel.run(
123
+ x, weight, y, eps, hidden_dim, BLOCK_SIZE=triton.next_power_of_2(hidden_dim)
124
+ )
125
+ return y