Qwen-7B-Chat-Int4 / triton_kernels.py
Shangming Cai
Add ApplyRoPE and RMSNorm kernels written in OpenAI Triton.
af64202
raw
history blame
3.96 kB
from typing import Any, Callable, Dict, Hashable, Tuple
import torch
import triton
import triton.language as tl
from triton.compiler import CompiledKernel
from triton.runtime import JITFunction
try:
import triton.language.math as tlmath # Triton 2.1
except ImportError:
import triton.language.libdevice as tlmath # Triton 2.0
class TritonKernel:
def __init__(
self,
kernel_fn: JITFunction,
grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]],
) -> None:
self.kernel_fn_ = kernel_fn
self.grid_fn_ = grid_fn
self.kernel_cache_: Dict[Hashable, CompiledKernel] = {}
def run(self, *args, **kwargs):
# Set current device
input_device = args[0].device
prev_dev_idx, cur_dev_idx = -1, torch.cuda.current_device()
if input_device.index != cur_dev_idx:
prev_dev_idx = cur_dev_idx
torch.cuda.set_device(input_device.index)
# Compute grid
grid = self.grid_fn_(args)
# Use cached kernel if possible
kernel_key = (input_device,) + tuple(kwargs.items())
if kernel_key in self.kernel_cache_:
kernel = self.kernel_cache_[kernel_key]
kernel[grid](*args)
else:
# Compile and store new kernel
kernel = self.kernel_fn_[grid](*args, **kwargs)
self.kernel_cache_[kernel_key] = kernel
# Restore previous device
torch.cuda.set_device(prev_dev_idx)
@triton.jit
def _apply_rope_fwd_kernel(X, Cos, Sin, Y, HEAD_DIM: tl.constexpr):
batch_idx, tok_idx, head_idx = tl.program_id(0), tl.program_id(1), tl.program_id(2)
seq_len, num_heads = tl.num_programs(1), tl.num_programs(2)
block_idx = tl.arange(0, HEAD_DIM)
x_base_idx = ((batch_idx * seq_len + tok_idx) * num_heads * 3 + head_idx) * HEAD_DIM
x = tl.load(X + x_base_idx + block_idx)
freq_idx = tok_idx * HEAD_DIM + block_idx
cos = tl.load(Cos + freq_idx)
rot_idx = (HEAD_DIM // 2 + block_idx) % HEAD_DIM
x_rot = tl.load(X + x_base_idx + rot_idx)
x_rot = tl.where(block_idx >= HEAD_DIM // 2, x_rot, -x_rot)
sin = tl.load(Sin + freq_idx)
y_idx = (
(batch_idx * seq_len + tok_idx) * num_heads + head_idx
) * HEAD_DIM + block_idx
y = x * cos + x_rot * sin
tl.store(Y + y_idx, y.to(Y.dtype.element_ty))
apply_rope_fwd_kernel = TritonKernel(
_apply_rope_fwd_kernel, lambda args: tuple(args[0].shape[:3])
)
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
y = torch.empty(x.shape, dtype=x.dtype, device=x.device)
apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1))
return y
@triton.jit
def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr):
tok_idx = tl.program_id(0)
mean_sq = tl.zeros([BLOCK_SIZE], tl.float32)
for offset in range(0, hidden_dim, BLOCK_SIZE):
dim_idx = offset + tl.arange(0, BLOCK_SIZE)
x = tl.load(
X + tok_idx * hidden_dim + dim_idx, mask=dim_idx < hidden_dim, other=0
).to(tl.float32)
mean_sq += x * x / hidden_dim
rrms = tlmath.rsqrt(tl.sum(mean_sq, 0) + eps)
for offset in range(0, hidden_dim, BLOCK_SIZE):
dim_idx = offset + tl.arange(0, BLOCK_SIZE)
dim_mask = dim_idx < hidden_dim
hidden_idx = tok_idx * hidden_dim + dim_idx
x = tl.load(X + hidden_idx, mask=dim_mask, other=0)
w = tl.load(W + dim_idx, mask=dim_mask, other=0)
y = x * rrms * w
tl.store(Y + hidden_idx, y.to(Y.dtype.element_ty), mask=dim_mask)
rms_norm_fwd_kernel = TritonKernel(
_rms_norm_fwd_kernel, lambda args: (args[0].shape[:-1].numel(), 1, 1)
)
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float):
y = torch.empty_like(x)
hidden_dim = x.size(-1)
rms_norm_fwd_kernel.run(
x, weight, y, eps, hidden_dim, BLOCK_SIZE=triton.next_power_of_2(hidden_dim)
)
return y