File size: 1,547 Bytes
e5e2eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
a1e5ca8
 
e5e2eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from common.diff_engine import DiffCase

import activation


class FusedAddRMSNorm(torch.nn.Module):

    def __init__(self, d, eps=1e-6, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(d, dtype=dtype))
        self.eps = eps

    def forward(self, x, residual):
        h = x + residual
        return activation.rms_norm(h, self.weight, self.eps), h


class AddRMS(DiffCase):

    def build_inputs(self, bs, sl, hidden, dtype, eps):
        return {
            "x":
            torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
            "residual":
            torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
            "weight":
            torch.ones(hidden, dtype=dtype),
            "dim":
            hidden,
            "eps":
            eps,
            "dtype":
            dtype,
        }

    def make_naive(self, I):
        m = FusedAddRMSNorm(I["dim"], I["eps"], dtype=I["dtype"])
        m.weight = torch.nn.Parameter(I["weight"].detach().clone())
        return m

    def make_cuda(self, I):
        m = activation.layers.FusedAddRMSNorm(I["dim"],
                                              I["eps"],
                                              dtype=I["dtype"])
        m.weight = torch.nn.Parameter(I["weight"].detach().clone())
        return m

    def forward(self, obj, I):
        return obj(I["x"], I["residual"])

    def grad_inputs(self, I):
        return [I["x"], I["residual"]]


CASE = AddRMS()