Update README of branch dev_triton.
#11
by
Cheshire94
- opened
- README.md +28 -1
- assets/wechat.png +0 -0
- config.json +1 -0
- configuration_qwen.py +2 -0
- modeling_qwen.py +36 -8
- triton_kernels.py +125 -0
README.md
CHANGED
@@ -18,7 +18,7 @@ inference: false
|
|
18 |
<p align="center">
|
19 |
🤗 <a href="https://huggingface.co/Qwen">Hugging Face</a>   |   🤖 <a href="https://modelscope.cn/organization/qwen">ModelScope</a>   |    📑 <a href="https://arxiv.org/abs/2309.16609">Paper</a>    |   🖥️ <a href="https://modelscope.cn/studios/qwen/Qwen-7B-Chat-Demo/summary">Demo</a>
|
20 |
<br>
|
21 |
-
<a href="assets/wechat.png">WeChat (微信)</a>   |   <a href="https://discord.gg/z3GAxXZ9Ce">Discord</a>   |   <a href="https://dashscope.aliyun.com">API</a>
|
22 |
</p>
|
23 |
<br>
|
24 |
|
@@ -67,6 +67,14 @@ cd flash-attention && pip install .
|
|
67 |
# pip install csrc/layer_norm
|
68 |
# pip install csrc/rotary
|
69 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
<br>
|
71 |
|
72 |
|
@@ -140,6 +148,25 @@ In detail, the setting of profiling is generating 8192 new tokens with 1 context
|
|
140 |
|
141 |
Note: The generation speed of the Int4/Int8 models mentioned above is provided by the autogptq library. The current speed of the model loaded using "AutoModelForCausalLM.from_pretrained" will be approximately 20% slower. We have reported this issue to the HuggingFace team and will update it promptly if a solution is available.
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
### 显存使用 (GPU Memory Usage)
|
144 |
|
145 |
我们还测算了不同模型精度编码2048个token及生成8192个token的峰值显存占用情况。(显存消耗在是否使用FlashAttn的情况下均类似。)结果如下所示:
|
|
|
18 |
<p align="center">
|
19 |
🤗 <a href="https://huggingface.co/Qwen">Hugging Face</a>   |   🤖 <a href="https://modelscope.cn/organization/qwen">ModelScope</a>   |    📑 <a href="https://arxiv.org/abs/2309.16609">Paper</a>    |   🖥️ <a href="https://modelscope.cn/studios/qwen/Qwen-7B-Chat-Demo/summary">Demo</a>
|
20 |
<br>
|
21 |
+
<a href="https://github.com/QwenLM/Qwen/blob/main/assets/wechat.png">WeChat (微信)</a>   |   <a href="https://discord.gg/z3GAxXZ9Ce">Discord</a>   |   <a href="https://dashscope.aliyun.com">API</a>
|
22 |
</p>
|
23 |
<br>
|
24 |
|
|
|
67 |
# pip install csrc/layer_norm
|
68 |
# pip install csrc/rotary
|
69 |
```
|
70 |
+
|
71 |
+
如果您有更高推理性能方面的需求,但上述可选加速项`layer_norm`及`rotary`未能安装成功,或是您所使用的GPU不满足`flash-attention`库所要求的NVIDIA Ampere/Ada/Hopper架构,您可以尝试使用该分支下基于Triton进行实现的推理加速方案。该方案适用于更宽范围的GPU产品,且无需安装。您可以通过将config.json里的`use_triton`设置为true来进行启用。
|
72 |
+
|
73 |
+
**(在dev_triton分支下`use_triton`默认设置为auto,由于pytorch 2.0及以上版本已默认安装了Triton,因此上述优化方案无需其它安装与配置操作即可直接启用。如果您不想开启该优化,请将config.json里的`use_triton`设置为false)**
|
74 |
+
|
75 |
+
If you require higher inference performance yet encounter some problems when installing the optional acceleration features (i.e., `layer_norm` and `rotary`) or if the GPU you are using does not meet the NVIDIA Ampere/Ada/Hopper architecture required by the `flash-attention` library, you may consider trying the inference acceleration solution implemented with Triton in this branch. This solution adapts to a wider range of GPU products and does not require installation. You can enable this acceleration feature by setting the `use_triton` option to true in the config.json file.
|
76 |
+
|
77 |
+
**(In the dev_triton branch, `use_triton` is set to 'auto' by default. As Triton is pre-installed with pytorch version 2.0 and above, this acceleration solution can be enabled directly without additional installation or configuration. If you prefer not to activate this optimization, please set `use_triton` to false in the config.json file.)**
|
78 |
<br>
|
79 |
|
80 |
|
|
|
148 |
|
149 |
Note: The generation speed of the Int4/Int8 models mentioned above is provided by the autogptq library. The current speed of the model loaded using "AutoModelForCausalLM.from_pretrained" will be approximately 20% slower. We have reported this issue to the HuggingFace team and will update it promptly if a solution is available.
|
150 |
|
151 |
+
另外,我们也测算了在使用不同GPU及推理加速方法时Qwen-7B-Chat-Int4模型生成2048和8192个token的平均推理速度。所有评测均使用PyTorch 2.1.0和CUDA 11.8。
|
152 |
+
|
153 |
+
In addition, we also measured the average inference speed of generating 2048 and 8192 tokens with different GPU devices and acceleration methods, respectively. All results run with PyTorch 2.1.0 and CUDA 11.8.
|
154 |
+
|
155 |
+
| GPU Device | Method | Speed (2048 tokens) | Speed (8192 tokens) |
|
156 |
+
| :--------: | :----------: | :------------------:| :------------------:|
|
157 |
+
| A10 | FlashAttn v2 | 41.28 | 30.78 |
|
158 |
+
| A10 | Triton | 49.04 | 29.17 |
|
159 |
+
| A10 | Disabled | 39.26 | 26.81 |
|
160 |
+
| V100 | FlashAttn v2 | N/A | N/A |
|
161 |
+
| V100 | Triton | 37.01 | 27.66 |
|
162 |
+
| V100 | Disabled | 24.47 | 20.40 |
|
163 |
+
| P100 | FlashAttn v2 | N/A | N/A |
|
164 |
+
| P100 | Triton | 29.03 | 13.85 |
|
165 |
+
| P100 | Disabled | 20.50 | 12.73 |
|
166 |
+
| T4 | FlashAttn v2 | N/A | N/A |
|
167 |
+
| T4 | Triton | 27.98 | 15.22 |
|
168 |
+
| T4 | Disabled | 23.11 | 14.55 |
|
169 |
+
|
170 |
### 显存使用 (GPU Memory Usage)
|
171 |
|
172 |
我们还测算了不同模型精度编码2048个token及生成8192个token的峰值显存占用情况。(显存消耗在是否使用FlashAttn的情况下均类似。)结果如下所示:
|
assets/wechat.png
CHANGED
config.json
CHANGED
@@ -44,6 +44,7 @@
|
|
44 |
"use_cache": true,
|
45 |
"use_dynamic_ntk": true,
|
46 |
"use_flash_attn": "auto",
|
|
|
47 |
"use_logn_attn": true,
|
48 |
"vocab_size": 151936
|
49 |
}
|
|
|
44 |
"use_cache": true,
|
45 |
"use_dynamic_ntk": true,
|
46 |
"use_flash_attn": "auto",
|
47 |
+
"use_triton": "auto",
|
48 |
"use_logn_attn": true,
|
49 |
"vocab_size": 151936
|
50 |
}
|
configuration_qwen.py
CHANGED
@@ -32,6 +32,7 @@ class QWenConfig(PretrainedConfig):
|
|
32 |
use_dynamic_ntk=True,
|
33 |
use_logn_attn=True,
|
34 |
use_flash_attn="auto",
|
|
|
35 |
intermediate_size=22016,
|
36 |
no_bias=True,
|
37 |
tie_word_embeddings=False,
|
@@ -61,6 +62,7 @@ class QWenConfig(PretrainedConfig):
|
|
61 |
self.use_dynamic_ntk = use_dynamic_ntk
|
62 |
self.use_logn_attn = use_logn_attn
|
63 |
self.use_flash_attn = use_flash_attn
|
|
|
64 |
self.no_bias = no_bias
|
65 |
self.use_cache_quantization = use_cache_quantization
|
66 |
self.use_cache_kernel = use_cache_kernel
|
|
|
32 |
use_dynamic_ntk=True,
|
33 |
use_logn_attn=True,
|
34 |
use_flash_attn="auto",
|
35 |
+
use_triton="auto",
|
36 |
intermediate_size=22016,
|
37 |
no_bias=True,
|
38 |
tie_word_embeddings=False,
|
|
|
62 |
self.use_dynamic_ntk = use_dynamic_ntk
|
63 |
self.use_logn_attn = use_logn_attn
|
64 |
self.use_flash_attn = use_flash_attn
|
65 |
+
self.use_triton = use_triton
|
66 |
self.no_bias = no_bias
|
67 |
self.use_cache_quantization = use_cache_quantization
|
68 |
self.use_cache_kernel = use_cache_kernel
|
modeling_qwen.py
CHANGED
@@ -35,7 +35,7 @@ except ImportError:
|
|
35 |
from torch import nn
|
36 |
|
37 |
SUPPORT_CUDA = torch.cuda.is_available()
|
38 |
-
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.
|
39 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
40 |
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
|
41 |
|
@@ -76,7 +76,9 @@ We detect you have activated flash attention support, but running model computat
|
|
76 |
"""
|
77 |
|
78 |
apply_rotary_emb_func = None
|
|
|
79 |
rms_norm = None
|
|
|
80 |
flash_attn_unpadded_func = None
|
81 |
flash_attn_func = None
|
82 |
|
@@ -120,6 +122,24 @@ def _import_flash_attn():
|
|
120 |
"https://github.com/Dao-AILab/flash-attention"
|
121 |
)
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
def quantize_cache_v(fdata, bits, qmax, qmin):
|
124 |
# b, s, head, h-dim->b, head, s, h-dim
|
125 |
qtype = torch.uint8
|
@@ -520,11 +540,9 @@ class QWenAttention(nn.Module):
|
|
520 |
|
521 |
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
522 |
if attention_mask is not None:
|
523 |
-
attention_mask = attention_mask.expand(
|
524 |
-
-1, -1, causal_mask.size(2), -1
|
525 |
-
)
|
526 |
if causal_mask is not None:
|
527 |
-
attention_mask.
|
528 |
else:
|
529 |
attention_mask = causal_mask
|
530 |
attn_output = F.scaled_dot_product_attention(
|
@@ -978,6 +996,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
978 |
if config.use_flash_attn:
|
979 |
_import_flash_attn()
|
980 |
|
|
|
|
|
|
|
|
|
|
|
|
|
981 |
self.transformer = QWenModel(config)
|
982 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
983 |
|
@@ -1330,12 +1354,14 @@ def apply_rotary_pos_emb(t, freqs):
|
|
1330 |
t (tensor(batch_size, seq_len, n_head, head_dim)):
|
1331 |
the input embedding/hidden states
|
1332 |
freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
|
1333 |
-
the cached cos/sin position embeddings
|
1334 |
"""
|
1335 |
rot_dim = freqs[0].shape[-1]
|
1336 |
cos, sin = freqs
|
1337 |
t_float = t.float()
|
1338 |
-
if
|
|
|
|
|
1339 |
# apply_rotary_emb in flash_attn requires cos/sin to be of
|
1340 |
# shape (seqlen, rotary_dim / 2) and apply rotary embedding
|
1341 |
# to the first rotary_dim of the input
|
@@ -1358,7 +1384,9 @@ class RMSNorm(torch.nn.Module):
|
|
1358 |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
1359 |
|
1360 |
def forward(self, x):
|
1361 |
-
if
|
|
|
|
|
1362 |
return rms_norm(x, self.weight, self.eps)
|
1363 |
else:
|
1364 |
output = self._norm(x.float()).type_as(x)
|
|
|
35 |
from torch import nn
|
36 |
|
37 |
SUPPORT_CUDA = torch.cuda.is_available()
|
38 |
+
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 8
|
39 |
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
40 |
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
|
41 |
|
|
|
76 |
"""
|
77 |
|
78 |
apply_rotary_emb_func = None
|
79 |
+
apply_rotary_emb_func_triton = None
|
80 |
rms_norm = None
|
81 |
+
rms_norm_triton = None
|
82 |
flash_attn_unpadded_func = None
|
83 |
flash_attn_func = None
|
84 |
|
|
|
122 |
"https://github.com/Dao-AILab/flash-attention"
|
123 |
)
|
124 |
|
125 |
+
def _import_triton():
|
126 |
+
global apply_rotary_emb_func_triton, rms_norm_triton
|
127 |
+
try:
|
128 |
+
from .triton_kernels import apply_rotary_emb as __apply_rotary_emb, rms_norm as __rms_norm
|
129 |
+
if apply_rotary_emb_func is not None:
|
130 |
+
logger.warn(
|
131 |
+
"Using Triton rotary kernel instead of flash_attn for inference."
|
132 |
+
)
|
133 |
+
apply_rotary_emb_func_triton = __apply_rotary_emb
|
134 |
+
if rms_norm is not None:
|
135 |
+
logger.warn(
|
136 |
+
"Using Triton rms_norm kernel instead of flash_attn for inference."
|
137 |
+
)
|
138 |
+
rms_norm_triton = __rms_norm
|
139 |
+
except ImportError:
|
140 |
+
logger.warn("Warning: Failed to import Triton kernels.")
|
141 |
+
return
|
142 |
+
|
143 |
def quantize_cache_v(fdata, bits, qmax, qmin):
|
144 |
# b, s, head, h-dim->b, head, s, h-dim
|
145 |
qtype = torch.uint8
|
|
|
540 |
|
541 |
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
542 |
if attention_mask is not None:
|
543 |
+
attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
|
|
|
|
|
544 |
if causal_mask is not None:
|
545 |
+
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
546 |
else:
|
547 |
attention_mask = causal_mask
|
548 |
attn_output = F.scaled_dot_product_attention(
|
|
|
996 |
if config.use_flash_attn:
|
997 |
_import_flash_attn()
|
998 |
|
999 |
+
if config.use_triton == "auto":
|
1000 |
+
logger.warn("Try importing Triton kernels for faster inference...")
|
1001 |
+
config.use_triton = SUPPORT_TORCH2
|
1002 |
+
if config.use_triton:
|
1003 |
+
_import_triton()
|
1004 |
+
|
1005 |
self.transformer = QWenModel(config)
|
1006 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1007 |
|
|
|
1354 |
t (tensor(batch_size, seq_len, n_head, head_dim)):
|
1355 |
the input embedding/hidden states
|
1356 |
freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
|
1357 |
+
the cached cos/sin position embeddings
|
1358 |
"""
|
1359 |
rot_dim = freqs[0].shape[-1]
|
1360 |
cos, sin = freqs
|
1361 |
t_float = t.float()
|
1362 |
+
if apply_rotary_emb_func_triton is not None and t.is_cuda and (not t.requires_grad):
|
1363 |
+
return apply_rotary_emb_func_triton(t, cos, sin)
|
1364 |
+
elif apply_rotary_emb_func is not None and t.is_cuda:
|
1365 |
# apply_rotary_emb in flash_attn requires cos/sin to be of
|
1366 |
# shape (seqlen, rotary_dim / 2) and apply rotary embedding
|
1367 |
# to the first rotary_dim of the input
|
|
|
1384 |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
1385 |
|
1386 |
def forward(self, x):
|
1387 |
+
if rms_norm_triton is not None and x.is_cuda and (not x.requires_grad):
|
1388 |
+
return rms_norm_triton(x, self.weight, self.eps)
|
1389 |
+
elif rms_norm is not None and x.is_cuda:
|
1390 |
return rms_norm(x, self.weight, self.eps)
|
1391 |
else:
|
1392 |
output = self._norm(x.float()).type_as(x)
|
triton_kernels.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba Cloud.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# This module provides ApplyRoPE and RMSNorm kernels written in OpenAI Triton.
|
7 |
+
# Feel free to contact the contributors if you have any questions or issues regarding this code.
|
8 |
+
# Contributors: Shangming Cai, Zihan Wang
|
9 |
+
# Contacts: csmthu@gmail.com, wzh1999_frog@126.com
|
10 |
+
|
11 |
+
from typing import Any, Callable, Dict, Hashable, Tuple
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import triton
|
15 |
+
import triton.language as tl
|
16 |
+
from triton.compiler import CompiledKernel
|
17 |
+
from triton.runtime import JITFunction
|
18 |
+
|
19 |
+
try:
|
20 |
+
import triton.language.math as tlmath # Triton 2.1
|
21 |
+
except ImportError:
|
22 |
+
import triton.language.libdevice as tlmath # Triton 2.0
|
23 |
+
|
24 |
+
|
25 |
+
class TritonKernel:
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
kernel_fn: JITFunction,
|
29 |
+
grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]],
|
30 |
+
) -> None:
|
31 |
+
self.kernel_fn_ = kernel_fn
|
32 |
+
self.grid_fn_ = grid_fn
|
33 |
+
self.kernel_cache_: Dict[Hashable, CompiledKernel] = {}
|
34 |
+
|
35 |
+
def run(self, *args, **kwargs):
|
36 |
+
# Set current device
|
37 |
+
input_device = args[0].device
|
38 |
+
prev_dev_idx, cur_dev_idx = -1, torch.cuda.current_device()
|
39 |
+
if input_device.index != cur_dev_idx:
|
40 |
+
prev_dev_idx = cur_dev_idx
|
41 |
+
torch.cuda.set_device(input_device.index)
|
42 |
+
|
43 |
+
# Compute grid
|
44 |
+
grid = self.grid_fn_(args)
|
45 |
+
|
46 |
+
# Use cached kernel if possible
|
47 |
+
kernel_key = (input_device,) + tuple(kwargs.items())
|
48 |
+
if kernel_key in self.kernel_cache_:
|
49 |
+
kernel = self.kernel_cache_[kernel_key]
|
50 |
+
kernel[grid](*args)
|
51 |
+
else:
|
52 |
+
# Compile and store new kernel
|
53 |
+
kernel = self.kernel_fn_[grid](*args, **kwargs)
|
54 |
+
self.kernel_cache_[kernel_key] = kernel
|
55 |
+
|
56 |
+
# Restore previous device
|
57 |
+
torch.cuda.set_device(prev_dev_idx)
|
58 |
+
|
59 |
+
|
60 |
+
@triton.jit
|
61 |
+
def _apply_rope_fwd_kernel(X, Cos, Sin, Y, HEAD_DIM: tl.constexpr):
|
62 |
+
batch_idx, tok_idx, head_idx = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
63 |
+
seq_len, num_heads = tl.num_programs(1), tl.num_programs(2)
|
64 |
+
block_idx = tl.arange(0, HEAD_DIM)
|
65 |
+
x_base_idx = ((batch_idx * seq_len + tok_idx) * num_heads * 3 + head_idx) * HEAD_DIM
|
66 |
+
x = tl.load(X + x_base_idx + block_idx)
|
67 |
+
freq_idx = tok_idx * HEAD_DIM + block_idx
|
68 |
+
cos = tl.load(Cos + freq_idx)
|
69 |
+
rot_idx = (HEAD_DIM // 2 + block_idx) % HEAD_DIM
|
70 |
+
x_rot = tl.load(X + x_base_idx + rot_idx)
|
71 |
+
x_rot = tl.where(block_idx >= HEAD_DIM // 2, x_rot, -x_rot)
|
72 |
+
sin = tl.load(Sin + freq_idx)
|
73 |
+
y_idx = (
|
74 |
+
(batch_idx * seq_len + tok_idx) * num_heads + head_idx
|
75 |
+
) * HEAD_DIM + block_idx
|
76 |
+
y = x * cos + x_rot * sin
|
77 |
+
tl.store(Y + y_idx, y.to(Y.dtype.element_ty))
|
78 |
+
|
79 |
+
|
80 |
+
apply_rope_fwd_kernel = TritonKernel(
|
81 |
+
_apply_rope_fwd_kernel, lambda args: tuple(args[0].shape[:3])
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
86 |
+
y = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
87 |
+
apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1))
|
88 |
+
return y
|
89 |
+
|
90 |
+
|
91 |
+
@triton.jit
|
92 |
+
def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr):
|
93 |
+
tok_idx = tl.program_id(0)
|
94 |
+
|
95 |
+
mean_sq = tl.zeros([BLOCK_SIZE], tl.float32)
|
96 |
+
for offset in range(0, hidden_dim, BLOCK_SIZE):
|
97 |
+
dim_idx = offset + tl.arange(0, BLOCK_SIZE)
|
98 |
+
x = tl.load(
|
99 |
+
X + tok_idx * hidden_dim + dim_idx, mask=dim_idx < hidden_dim, other=0
|
100 |
+
).to(tl.float32)
|
101 |
+
mean_sq += x * x / hidden_dim
|
102 |
+
rrms = tlmath.rsqrt(tl.sum(mean_sq, 0) + eps)
|
103 |
+
|
104 |
+
for offset in range(0, hidden_dim, BLOCK_SIZE):
|
105 |
+
dim_idx = offset + tl.arange(0, BLOCK_SIZE)
|
106 |
+
dim_mask = dim_idx < hidden_dim
|
107 |
+
hidden_idx = tok_idx * hidden_dim + dim_idx
|
108 |
+
x = tl.load(X + hidden_idx, mask=dim_mask, other=0)
|
109 |
+
w = tl.load(W + dim_idx, mask=dim_mask, other=0)
|
110 |
+
y = x * rrms * w
|
111 |
+
tl.store(Y + hidden_idx, y.to(Y.dtype.element_ty), mask=dim_mask)
|
112 |
+
|
113 |
+
|
114 |
+
rms_norm_fwd_kernel = TritonKernel(
|
115 |
+
_rms_norm_fwd_kernel, lambda args: (args[0].shape[:-1].numel(), 1, 1)
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float):
|
120 |
+
y = torch.empty_like(x)
|
121 |
+
hidden_dim = x.size(-1)
|
122 |
+
rms_norm_fwd_kernel.run(
|
123 |
+
x, weight, y, eps, hidden_dim, BLOCK_SIZE=triton.next_power_of_2(hidden_dim)
|
124 |
+
)
|
125 |
+
return y
|