|
from typing import Any, Callable, Dict, Hashable, Tuple |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
from triton.compiler import CompiledKernel |
|
from triton.runtime import JITFunction |
|
|
|
try: |
|
import triton.language.math as tlmath |
|
except ImportError: |
|
import triton.language.libdevice as tlmath |
|
|
|
|
|
class TritonKernel: |
|
def __init__( |
|
self, |
|
kernel_fn: JITFunction, |
|
grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]], |
|
) -> None: |
|
self.kernel_fn_ = kernel_fn |
|
self.grid_fn_ = grid_fn |
|
self.kernel_cache_: Dict[Hashable, CompiledKernel] = {} |
|
|
|
def run(self, *args, **kwargs): |
|
|
|
input_device = args[0].device |
|
prev_dev_idx, cur_dev_idx = -1, torch.cuda.current_device() |
|
if input_device.index != cur_dev_idx: |
|
prev_dev_idx = cur_dev_idx |
|
torch.cuda.set_device(input_device.index) |
|
|
|
|
|
grid = self.grid_fn_(args) |
|
|
|
|
|
kernel_key = (input_device,) + tuple(kwargs.items()) |
|
if kernel_key in self.kernel_cache_: |
|
kernel = self.kernel_cache_[kernel_key] |
|
kernel[grid](*args) |
|
else: |
|
|
|
kernel = self.kernel_fn_[grid](*args, **kwargs) |
|
self.kernel_cache_[kernel_key] = kernel |
|
|
|
|
|
torch.cuda.set_device(prev_dev_idx) |
|
|
|
|
|
@triton.jit |
|
def _apply_rope_fwd_kernel(X, Cos, Sin, Y, HEAD_DIM: tl.constexpr): |
|
batch_idx, tok_idx, head_idx = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
seq_len, num_heads = tl.num_programs(1), tl.num_programs(2) |
|
block_idx = tl.arange(0, HEAD_DIM) |
|
x_base_idx = ((batch_idx * seq_len + tok_idx) * num_heads * 3 + head_idx) * HEAD_DIM |
|
x = tl.load(X + x_base_idx + block_idx) |
|
freq_idx = tok_idx * HEAD_DIM + block_idx |
|
cos = tl.load(Cos + freq_idx) |
|
rot_idx = (HEAD_DIM // 2 + block_idx) % HEAD_DIM |
|
x_rot = tl.load(X + x_base_idx + rot_idx) |
|
x_rot = tl.where(block_idx >= HEAD_DIM // 2, x_rot, -x_rot) |
|
sin = tl.load(Sin + freq_idx) |
|
y_idx = ( |
|
(batch_idx * seq_len + tok_idx) * num_heads + head_idx |
|
) * HEAD_DIM + block_idx |
|
y = x * cos + x_rot * sin |
|
tl.store(Y + y_idx, y.to(Y.dtype.element_ty)) |
|
|
|
|
|
apply_rope_fwd_kernel = TritonKernel( |
|
_apply_rope_fwd_kernel, lambda args: tuple(args[0].shape[:3]) |
|
) |
|
|
|
|
|
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): |
|
y = torch.empty(x.shape, dtype=x.dtype, device=x.device) |
|
apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1)) |
|
return y |
|
|
|
|
|
@triton.jit |
|
def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr): |
|
tok_idx = tl.program_id(0) |
|
|
|
mean_sq = tl.zeros([BLOCK_SIZE], tl.float32) |
|
for offset in range(0, hidden_dim, BLOCK_SIZE): |
|
dim_idx = offset + tl.arange(0, BLOCK_SIZE) |
|
x = tl.load( |
|
X + tok_idx * hidden_dim + dim_idx, mask=dim_idx < hidden_dim, other=0 |
|
).to(tl.float32) |
|
mean_sq += x * x / hidden_dim |
|
rrms = tlmath.rsqrt(tl.sum(mean_sq, 0) + eps) |
|
|
|
for offset in range(0, hidden_dim, BLOCK_SIZE): |
|
dim_idx = offset + tl.arange(0, BLOCK_SIZE) |
|
dim_mask = dim_idx < hidden_dim |
|
hidden_idx = tok_idx * hidden_dim + dim_idx |
|
x = tl.load(X + hidden_idx, mask=dim_mask, other=0) |
|
w = tl.load(W + dim_idx, mask=dim_mask, other=0) |
|
y = x * rrms * w |
|
tl.store(Y + hidden_idx, y.to(Y.dtype.element_ty), mask=dim_mask) |
|
|
|
|
|
rms_norm_fwd_kernel = TritonKernel( |
|
_rms_norm_fwd_kernel, lambda args: (args[0].shape[:-1].numel(), 1, 1) |
|
) |
|
|
|
|
|
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float): |
|
y = torch.empty_like(x) |
|
hidden_dim = x.size(-1) |
|
rms_norm_fwd_kernel.run( |
|
x, weight, y, eps, hidden_dim, BLOCK_SIZE=triton.next_power_of_2(hidden_dim) |
|
) |
|
return y |
|
|