File size: 7,147 Bytes
2f9282b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2024, Songlin Yang, Yu Zhang
# code adapted from
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
from typing import Optional
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BM`) and compilation options (e.g., `num_warps`) to try
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BM': 128, 'BK': 64, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
triton.Config({'BM': 64, 'BK': 32, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
triton.Config({'BM': 128, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
triton.Config({'BM': 128, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
triton.Config({'BM': 64, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
triton.Config({'BM': 128, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4),
triton.Config({'BM': 64, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=5, num_warps=2),
triton.Config({'BM': 32, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=5, num_warps=2),
# Good config for fp8 inputs.
triton.Config({'BM': 128, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8),
triton.Config({'BM': 256, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=3, num_warps=8),
triton.Config({'BM': 256, 'BK': 128, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
triton.Config({'BM': 64, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4),
triton.Config({'BM': 128, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
triton.Config({'BM': 128, 'BK': 64, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4),
triton.Config({'BM': 64, 'BK': 64, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4),
triton.Config({'BM': 128, 'BK': 64, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4)
],
key=['M', 'N', 'K'],
)
@triton.heuristics({
'HAS_INPUT': lambda args: args['input'] is not None,
'HAS_ALPHA': lambda args: args['alpha'] is not None,
'HAS_BETA': lambda args: args['beta'] is not None
})
@triton.jit
def matmul_kernel(
# Pointers to matrices
a,
b,
c,
input,
alpha,
beta,
# Matrix dimensions
M,
N,
K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `s_am` is how much to increase `a`
# by to get the element one row down (A has M rows).
s_am,
s_ak,
s_bk,
s_bn,
s_cm,
s_cn,
# Meta-parameters
BM: tl.constexpr,
BK: tl.constexpr,
BN: tl.constexpr,
G: tl.constexpr,
ACTIVATION: tl.constexpr,
HAS_INPUT: tl.constexpr,
HAS_ALPHA: tl.constexpr,
HAS_BETA: tl.constexpr
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
NM, NN = tl.num_programs(0), tl.num_programs(1)
i_m, i_n = tl.program_id(0), tl.program_id(1)
i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G)
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `p_a` is a block of [BM, BK] pointers
# `p_b` is a block of [BK, BN] pointers
# See above `Pointer Arithmetic` section for details
o_am = (i_m * BM + tl.arange(0, BM)) % M
o_bn = (i_n * BN + tl.arange(0, BN)) % N
o_k = tl.arange(0, BK)
p_a = a + (o_am[:, None] * s_am + o_k[None, :] * s_ak)
p_b = b + (o_k[:, None] * s_bk + o_bn[None, :] * s_bn)
b_acc = tl.zeros((BM, BN), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BK)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0)
b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0)
# We accumulate along the K dimension.
b_acc += tl.dot(b_a, b_b, allow_tf32=False)
# Advance the ptrs to the next K block.
p_a += BK * s_ak
p_b += BK * s_bk
o_cm = i_m * BM + tl.arange(0, BM)
o_cn = i_n * BN + tl.arange(0, BN)
mask = (o_cm[:, None] < M) & (o_cn[None, :] < N)
b_c = b_acc
# You can fuse arbitrary activation functions here
# while the b_acc is still in FP32!
if ACTIVATION == "leaky_relu":
b_c = leaky_relu(b_c)
if HAS_ALPHA:
b_c *= tl.load(alpha)
if HAS_INPUT:
p_i = input + s_cm * o_cm[:, None] + s_cn * o_cn[None, :]
b_i = tl.load(p_i, mask=mask, other=0.0).to(tl.float32)
if HAS_BETA:
b_i *= tl.load(beta)
b_c += b_i
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
p_c = c + s_cm * o_cm[:, None] + s_cn * o_cn[None, :]
tl.store(p_c, b_c.to(c.dtype.element_ty), mask=mask)
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
@triton.jit
def leaky_relu(x):
return tl.where(x >= 0, x, 0.01 * x)
@contiguous
def matmul(a, b, activation=''):
assert a.shape[1] == b.shape[0], 'Incompatible dimensions (A: {}x{}, B: {}x{})'.format(*a.shape, *b.shape)
M, K = a.shape
K, N = b.shape
# Allocates output.
c = a.new_empty(M, N)
# 1D launch kernel where each block gets its own program.
def grid(meta): return (triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
matmul_kernel[grid](
a, b, c, None, None, None,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
ACTIVATION=activation,
)
return c
@contiguous
def addmm(
x: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
alpha: Optional[float] = None,
beta: Optional[float] = None,
inplace: Optional[bool] = False
) -> torch.Tensor:
assert a.shape[1] == b.shape[0], 'Incompatible dimensions (A: {}x{}, B: {}x{})'.format(*a.shape, *b.shape)
M, K = a.shape
K, N = b.shape
# Allocates output.
c = x if inplace else a.new_empty(M, N)
def grid(meta): return (triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN']))
matmul_kernel[grid](
a, b, c, x, alpha, beta,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
ACTIVATION=None,
)
return c
|