File size: 3,614 Bytes
e5e2eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1e5ca8
 
e5e2eeb
 
 
 
 
 
a1e5ca8
 
 
 
 
 
 
 
e5e2eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1e5ca8
e5e2eeb
 
 
 
 
 
 
 
 
a1e5ca8
 
e5e2eeb
 
 
a1e5ca8
e5e2eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import random

import pytest
import torch

import activation

from .utils import assert_close, opcheck

DTYPES = [torch.float, torch.bfloat16, torch.half]
NUM_TOKENS = [7, 83, 256, 2048]  # Arbitrary values for testing
D = [1, 7, 512, 13824]  # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]


def add_rms_norm_all_naive(x: torch.Tensor, residual: torch.Tensor,
                           weight: torch.Tensor, eps: float) -> torch.Tensor:
    h = x + residual
    return torch.nn.functional.rms_norm(h, weight.shape, weight, eps) + h


#use rms_norm kernel
def add_rms_norm_partial_naive(x: torch.Tensor, residual: torch.Tensor,
                               weight: torch.Tensor,
                               eps: float) -> torch.Tensor:
    h = x + residual
    return activation.rms_norm(h, weight, eps) + h


def fused_add_rms_norm(x: torch.Tensor, residual: torch.Tensor,
                       weight: torch.Tensor, eps: float) -> torch.Tensor:
    out, h = activation.fused_add_rms_norm(x, residual, weight, eps)
    return out + h


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_fused_add_rms_norm(
    num_tokens: int,
    d: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    torch.set_default_device(device)

    x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
    residual = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
    weight = torch.randn(d, dtype=dtype, requires_grad=True)
    eps = 1e-05

    x.retain_grad()
    residual.retain_grad()
    weight.retain_grad()
    # To separate gradient computation, clone the inputs

    x_ref = x.detach().clone().requires_grad_(True)
    residual_ref = residual.detach().clone().requires_grad_(True)
    weight_ref = weight.detach().clone().requires_grad_(True)

    x_ref2 = x.detach().clone().requires_grad_(True)
    residual_ref2 = residual.detach().clone().requires_grad_(True)
    weight_ref2 = weight.detach().clone().requires_grad_(True)

    torch_fn = add_rms_norm_all_naive
    torch_fn2 = add_rms_norm_partial_naive

    op = activation.ops.fused_add_rms_norm
    fn = fused_add_rms_norm

    layer = activation.layers.FusedAddRMSNorm(d, eps)
    layer.weight = torch.nn.Parameter(weight)

    out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
    add_out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
    opcheck(op, (out, add_out, x, residual, weight, eps))

    out = fn(x, residual, weight, eps)
    mod_out, mod_a_out = layer(x, residual)
    mod_out = mod_out + mod_a_out
    ref_out = torch_fn(x_ref, residual_ref, weight_ref, eps)
    ref_out2 = torch_fn2(x_ref2, residual_ref2, weight_ref2, eps)

    assert_close(out, ref_out, atol=0.05, rtol=0.05)
    assert_close(out, ref_out2)
    assert_close(mod_out, out, atol=0.0, rtol=0.0)

    # test backward pass
    out_grad = torch.randn_like(out)
    out_grad = out_grad / out_grad.norm()

    ref_out.backward(out_grad)
    ref_out2.backward(out_grad)
    mod_out.backward(out_grad)

    assert_close(x.grad, x_ref.grad)
    assert_close(x.grad, x_ref2.grad)
    assert_close(residual.grad, residual_ref.grad)
    assert_close(residual.grad, residual_ref2.grad)
    assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05)
    assert_close(layer.weight.grad, weight_ref2.grad, rtol=0.05)