|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({}, num_warps=1), |
|
triton.Config({}, num_warps=2), |
|
triton.Config({}, num_warps=4), |
|
triton.Config({}, num_warps=8), |
|
triton.Config({}, num_warps=16), |
|
triton.Config({}, num_warps=32) |
|
], |
|
key=['D'] |
|
) |
|
@triton.jit |
|
def softmax_fwd_kernel( |
|
x, |
|
p, |
|
D: tl.constexpr, |
|
B: tl.constexpr |
|
): |
|
i_n = tl.program_id(0) |
|
o_d = tl.arange(0, B) |
|
m_d = o_d < D |
|
|
|
b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) |
|
b_m = tl.max(b_x, 0) |
|
b_x = tl.exp(b_x - b_m) |
|
b_p = b_x / tl.sum(b_x, 0) |
|
|
|
tl.store(p + i_n * D + o_d, b_p.to(p.dtype.element_ty), mask=m_d) |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({}, num_warps=1), |
|
triton.Config({}, num_warps=2), |
|
triton.Config({}, num_warps=4), |
|
triton.Config({}, num_warps=8), |
|
triton.Config({}, num_warps=16), |
|
triton.Config({}, num_warps=32) |
|
], |
|
key=['D'] |
|
) |
|
@triton.jit |
|
def softmax_bwd_kernel( |
|
p, |
|
dp, |
|
ds, |
|
D: tl.constexpr, |
|
B: tl.constexpr |
|
): |
|
i_n = tl.program_id(0) |
|
o_d = tl.arange(0, B) |
|
m_d = o_d < D |
|
|
|
b_p = tl.load(p + i_n * D + o_d, mask=m_d, other=0.) |
|
b_dp = tl.load(dp + i_n * D + o_d, mask=m_d, other=0.) |
|
b_pp = tl.sum(b_p * b_dp, 0) |
|
b_ds = b_p * b_dp - b_p * b_pp |
|
tl.store(ds + i_n * D + o_d, b_ds.to(ds.dtype.element_ty), mask=m_d) |
|
|
|
|
|
def softmax_fwd( |
|
x: torch.Tensor, |
|
dtype: Optional[torch.dtype] = torch.float |
|
) -> torch.Tensor: |
|
shape = x.shape |
|
x = x.view(-1, x.shape[-1]) |
|
|
|
N, D = x.shape |
|
B = triton.next_power_of_2(D) |
|
|
|
p = torch.empty_like(x, dtype=dtype) |
|
softmax_fwd_kernel[(N,)]( |
|
x=x, |
|
p=p, |
|
D=D, |
|
B=B |
|
) |
|
return p.view(*shape) |
|
|
|
|
|
def softmax_bwd( |
|
p: torch.Tensor, |
|
dp: torch.Tensor, |
|
dtype: Optional[torch.dtype] = torch.float |
|
) -> torch.Tensor: |
|
shape = p.shape |
|
p = p.view(-1, p.shape[-1]) |
|
ds = torch.empty_like(p, dtype=dtype) |
|
|
|
N, D = p.shape |
|
B = triton.next_power_of_2(D) |
|
softmax_bwd_kernel[(N,)]( |
|
p=p, |
|
dp=dp, |
|
ds=ds, |
|
D=D, |
|
B=B |
|
) |
|
return ds.view(*shape) |
|
|