File size: 7,348 Bytes
4f41cdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
# Copyright 2024 CATIE. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modifications to the orignal file
# - add weights gradients
# - remove the mask if size is a power of 2
# - support for torch.compile

import triton
import triton.language as tl
import torch


MAX_FUSED_SIZE = 65536
next_power_of_2 = triton.next_power_of_2

def calculate_settings(n):
    BLOCK_SIZE = next_power_of_2(n)
    if BLOCK_SIZE > MAX_FUSED_SIZE:
        raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
                           f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
    num_warps = 4
    if   BLOCK_SIZE >= 32768: num_warps = 32
    elif BLOCK_SIZE >=  8192: num_warps = 16
    elif BLOCK_SIZE >=  2048: num_warps = 8
    return BLOCK_SIZE, num_warps


@triton.jit
def _rms_layernorm_forward(
    Y, Y_row_stride,
    X, X_row_stride,
    W, W_row_stride,
    r, r_row_stride,
    n_cols, eps,
    BLOCK_SIZE : tl.constexpr,
    IS_EVEN_X: tl.constexpr
):
    """
        Fast RMS Layernorm kernel
        Inspiration from a Triton tutorial:
        https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
    """
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    Y += row_idx * Y_row_stride
    X += row_idx * X_row_stride
    r += row_idx * r_row_stride

    if IS_EVEN_X:
        X_row = tl.load(X + col_offsets).to(tl.float32)
        W_row = tl.load(W + col_offsets)
    else:
        X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
        W_row = tl.load(W + col_offsets, mask=mask, other=0)

    row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
    inv_var = tl.math.rsqrt(row_var + eps)
    tl.store(r, inv_var)
    normed = X_row * inv_var
    normed = normed.to(W_row.dtype) # Exact copy from HF
    output = normed * W_row

    if IS_EVEN_X:
        tl.store(Y + col_offsets, output)
    else:
        tl.store(Y + col_offsets, output, mask=mask)

@triton.jit
def _rms_layernorm_backward(
    dY, dY_row_stride,
    X,   X_row_stride,
    W,   W_row_stride,
    r,   r_row_stride,
    dW, dW_row_stride,
    dX, dX_row_stride,
    n_cols, eps,
    BLOCK_SIZE : tl.constexpr,
    IS_EVEN_X: tl.constexpr
):
    """
        Fast RMS Layernorm kernel for the backward pass
        Inspiration from a Triton tutorial:
        https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
    """
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    dY += row_idx * dY_row_stride
    X  += row_idx *  X_row_stride
    r  += row_idx *  r_row_stride
    dW += row_idx * dW_row_stride
    dX += row_idx * dX_row_stride

    if IS_EVEN_X:
        dY_row = tl.load(dY + col_offsets).to(tl.float32)
        X_row  = tl.load(X  + col_offsets).to(tl.float32)
        W_row  = tl.load(W  + col_offsets).to(tl.float32)
    else:
        dY_row = tl.load(dY + col_offsets, mask=mask, other=0).to(tl.float32)
        X_row  = tl.load(X  + col_offsets, mask=mask, other=0).to(tl.float32)
        W_row  = tl.load(W  + col_offsets, mask=mask, other=0).to(tl.float32)

    # Get saved row variance
    inv_var = tl.load(r).to(tl.float32)
    normed = X_row * inv_var
    dW_row = dY_row * normed

    dY_W = dY_row * W_row
    rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
    output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)

    if IS_EVEN_X:
        tl.store(dW + col_offsets, dW_row)
        tl.store(dX + col_offsets, output)
    else:
        tl.store(dW + col_offsets, dW_row, mask=mask)
        tl.store(dX + col_offsets, output, mask=mask)


# Wrapper for triton kernel for torch.compile - should be unecessary for PyTorch 2.3 ?
torch.library.define("flasht5::rmsnorm_triton_fwd", "(Tensor X, Tensor W, float eps, int n_cols, int n_rows, int BLOCK_SIZE, int num_warps) -> (Tensor, Tensor)")

@torch.library.impl("flasht5::rmsnorm_triton_fwd", "default")
def rmsnorm_triton_fwd(X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps):
    Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device="cuda")
    r = torch.empty(n_rows, dtype=torch.float32, device="cuda")

    _rms_layernorm_forward[(n_rows,)](
        Y, Y.stride(0),
        X, X.stride(0),
        W, W.stride(0),
        r, r.stride(0),
        n_cols, eps,
        BLOCK_SIZE=BLOCK_SIZE,
        IS_EVEN_X=((n_cols % BLOCK_SIZE) == 0),
        num_warps=num_warps
    )

    return Y, r


@torch.library.impl_abstract("flasht5::rmsnorm_triton_fwd", rmsnorm_triton_fwd)
def rmsnorm_triton_fwd_abstract(X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps):
    Y = X.new_empty((n_rows, n_cols))
    r = X.new_empty((n_rows))
    return Y, r

torch.library.define("flasht5::rmsnorm_triton_bwd", "(Tensor dY, Tensor r, Tensor X, Tensor W, float eps, int n_cols, int n_rows, int BLOCK_SIZE, int num_warps) -> (Tensor, Tensor)")

@torch.library.impl("flasht5::rmsnorm_triton_bwd", "default")
def rmsnorm_triton_bwd(dY, r, X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps):

    dX = torch.empty_like(dY)
    dW = torch.empty_like(dY)

    _rms_layernorm_backward[(n_rows,)](
        dY, dY.stride(0),
        X,  X.stride(0),
        W,  1,
        r,  1,
        dW, dW.stride(0),
        dX, dX.stride(0),
        n_cols, eps,
        BLOCK_SIZE=BLOCK_SIZE,
        IS_EVEN_X=((n_cols % BLOCK_SIZE) == 0),
        num_warps=num_warps,
    )

    return dX, dW


@torch.library.impl_abstract("flasht5::rmsnorm_triton_bwd", rmsnorm_triton_bwd)
def rmsnorm_triton_bwd_abstract(dY, r, X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps):
    return torch.empty_like(dY), torch.empty_like(dY)


class Fast_RMS_Layernorm(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, W, eps):
        shape = X.shape
        dim = shape[-1]
        X = X.view(-1, dim)
        n_rows, n_cols = X.shape
        BLOCK_SIZE, num_warps = calculate_settings(n_cols)

        Y, r = torch.ops.flasht5.rmsnorm_triton_fwd(X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps)

        ctx.eps = eps
        ctx.BLOCK_SIZE = BLOCK_SIZE
        ctx.num_warps  = num_warps
        ctx.save_for_backward(X, W, r)
        return Y.view(*shape)

    @staticmethod
    def backward(ctx, dY):
        shape = dY.shape
        dim = shape[-1]
        dY = dY.view(-1, dim)
        X, W, r = ctx.saved_tensors
        n_rows, n_cols = dY.shape
        dX = torch.empty_like(dY)
        dW = torch.empty_like(dY)

        dW, dX = torch.ops.flasht5.rmsnorm_triton_bwd(dY, r, X, W, ctx.eps, n_cols, n_rows, ctx.BLOCK_SIZE, ctx.num_warps)

        dX = dX.view(*shape)
        return dX, dW.sum(0), None

def fast_rms_layernorm(X, W, eps):
    out = Fast_RMS_Layernorm.apply(X, W, eps)
    return out