File size: 1,312 Bytes
cbda9b7
 
 
a2fbb2f
cbda9b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2fbb2f
cbda9b7
a2fbb2f
 
 
 
 
cbda9b7
a2fbb2f
 
cbda9b7
a2fbb2f
 
cbda9b7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn

from .attn import FlexAttention
from .modules import MLP
from .modules import Attention
try:
    from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP
    triton_mlp = True
except ImportError as e:
    print(
        f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead."
    )
    triton_mlp = False

try:
    from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
    triton_norm = True
except ImportError as e:
    print(
        f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
    )
    from torch.nn import RMSNorm
    triton_norm = False

class AttentionLayer(nn.Module):
    def __init__(self, config, mask_mod, score_mod=None) -> None:
        super(AttentionLayer, self).__init__()
        self.attn_norm = nn.RMSNorm(config.dim)
        self.attn = FlexAttention(
            config=config,
            mask_mod=mask_mod, 
            score_mod=score_mod,
        )
        self.mlp_norm = nn.RMSNorm(config.dim)
        self.mlp = MLP(config)

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor:
        x = x + self.attn(self.attn_norm(x), freqs_cis=freqs_cis)
        x = x + self.mlp(self.mlp_norm(x))
        return x