File size: 3,626 Bytes
ac3a9cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import torch.nn as nn

from .modules import STU
from .modules import MLP
from .modules import Attention
try:
    from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP
    triton_mlp = True
except ImportError as e:
    print(
        f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead."
    )
    triton_mlp = False

try:
    from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
    triton_norm = True
except ImportError as e:
    print(
        f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
    )
    from torch.nn import RMSNorm
    triton_norm = False


class STULayer(nn.Module):
    def __init__(self, config, phi, n):
        super(STULayer, self).__init__()
        if isinstance(config.torch_dtype, str):
            torch_dtype = getattr(torch, config.torch_dtype)
        else:
            torch_dtype = config.torch_dtype
        self.stu_norm = (
            TritonNorm(config.n_embd)
            if triton_norm
            else RMSNorm(config.n_embd, dtype=torch_dtype)
        )
        self.stu = STU(config, phi, n)
        self.stu = self.stu.to(dtype=torch_dtype)
        self.mlp_norm = (
            TritonNorm(config.n_embd)
            if triton_norm
            else RMSNorm(config.n_embd, dtype=torch_dtype)
        )
        self.mlp = (
            TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
        )

        # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
        self.stu_norm = self.stu_norm.to(dtype=torch_dtype)
        self.mlp = self.mlp.to(dtype=torch_dtype)
        self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Debug dtype
    
        # Normalize and apply STU
        x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype)  # Match dtype for STU
        x_stu = self.stu(x_normed).to(dtype=x.dtype)  # Ensure output matches `x`'s dtype
        x = x + x_stu
    
        # Normalize and apply MLP
        x_normed_mlp = self.mlp_norm(x).to(dtype=self.mlp.gate_proj.weight.dtype)  # Match dtype for MLP
        x_mlp = self.mlp(x_normed_mlp).to(dtype=x.dtype)  # Ensure output matches `x`'s dtype
        x = x + x_mlp
    
        return x

class AttentionLayer(nn.Module):
    def __init__(self, config) -> None:
        super(AttentionLayer, self).__init__()
        if isinstance(config.torch_dtype, str):
            torch_dtype = getattr(torch, config.torch_dtype)
        else:
            torch_dtype = config.torch_dtype
        self.attn_norm = (
            TritonNorm(config.n_embd)
            if triton_norm
            else RMSNorm(config.n_embd, dtype=torch_dtype)
        )
        self.attn = Attention(config)
        self.attn = self.attn.to(dtype=torch_dtype)
        self.mlp_norm = (
            TritonNorm(config.n_embd)
            if triton_norm
            else RMSNorm(config.n_embd, dtype=torch_dtype)
        )
        self.mlp = (
            TritonMLP(config) if triton_mlp else MLP(config, dtype=torch_dtype)
        )
        self.mlp = self.mlp.to(dtype=torch_dtype)

        # TODO: Write Issue in Liger-Kernel repo to support user-defined dtype for MLP
        self.attn_norm = self.attn_norm.to(dtype=torch_dtype)
        self.mlp = self.mlp.to(dtype=torch_dtype)
        self.mlp_norm = self.mlp_norm.to(dtype=torch_dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.attn_norm(x))
        x = x + self.mlp(self.mlp_norm(x))
        return x