File size: 5,955 Bytes
15b5ad2
 
 
 
35c01e4
d3a2364
9eb4614
35c01e4
15b5ad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4350fdb
15b5ad2
7d20d96
15b5ad2
 
 
 
 
 
da94f40
9eb4614
eaf9a45
 
 
81d7d68
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
from transformers import PreTrainedModel
from .configuration_custom_mbz_test import CustomConfig
from transformers.modeling_outputs import CausalLMOutput


class RotaryPositionalEncoding(nn.Module):
    """
    Rotary Position Embeddings (RoPE) - efficient implementation
    """
    def __init__(self, d_head, max_seq_len=8192, base=10000.0):
        super().__init__()
        self.d_head = d_head
        self.max_seq_len = max_seq_len
        self.base = base

        # Precompute inverse frequencies
        inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head))
        self.register_buffer('inv_freq', inv_freq, persistent=False)

        # Precompute cos and sin for maximum sequence length
        self._precompute_freqs(max_seq_len)

    def _precompute_freqs(self, seq_len):
        """Precompute cos and sin values for positions"""
        t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
        freqs = torch.outer(t, self.inv_freq)  # (seq_len, d_head/2)

        # Create cos and sin embeddings
        freqs_cos = torch.cos(freqs)
        freqs_sin = torch.sin(freqs)

        # Interleave to match the dimension (seq_len, d_head)
        self.register_buffer('freqs_cos', freqs_cos.repeat_interleave(2, dim=-1), persistent=False)
        self.register_buffer('freqs_sin', freqs_sin.repeat_interleave(2, dim=-1), persistent=False)

    def rotate_half(self, x):
        """Rotate half the hidden dims of the input"""
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
        return torch.stack([-x2, x1], dim=-1).flatten(-2)

    def forward(self, q, k, start_pos=0):
        """
        Apply rotary embeddings to query and key tensors
        Args:
            q: (batch_size, n_heads, seq_len, d_head)
            k: (batch_size, n_heads, seq_len, d_head)
            start_pos: starting position for caching scenarios
        Returns:
            q_rot, k_rot with rotary embeddings applied
        """
        seq_len = q.shape[2]

        # Get the precomputed frequencies for this sequence length
        freqs_cos = self.freqs_cos[start_pos:start_pos + seq_len]
        freqs_sin = self.freqs_sin[start_pos:start_pos + seq_len]

        # Apply rotary embeddings
        q_rot = q * freqs_cos + self.rotate_half(q) * freqs_sin
        k_rot = k * freqs_cos + self.rotate_half(k) * freqs_sin

        return q_rot, k_rot

class Attention(nn.Module):
    def __init__(self, d_model, n_heads, d_head):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_head

        self.Wq = nn.Linear(d_model, n_heads * d_head, bias=False)
        self.Wk = nn.Linear(d_model, n_heads * d_head, bias=False)
        self.Wv = nn.Linear(d_model, n_heads * d_head, bias=False)
        self.Wo = nn.Linear(n_heads * d_head, d_model, bias=False)

        # Initialize RoPE
        self.rope = RotaryPositionalEncoding(d_head)

    def forward(self, x):
        # x is shape batch_size, seq_len, d_model
        batch_size, seq_len, d_model = x.shape
        q = self.Wq(x) # q is shape batch_size, seq_len, n_heads * d_head
        k = self.Wk(x)
        v = self.Wv(x)

        # reshape to batch_size, n_heads, seq_len, d_head
        q = q.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
        k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
        v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)

        q, k = self.rope(q, k)
        with sdpa_kernel(SDPBackend.FLASH_ATTENTION): # ensure use flash attention
            a = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)# a is (batch_size, n_heads, seq_len, d_head)
        a = a.transpose(1,2) # change a to (batch_size, seq_len, n_heads, d_head)
        a = a.reshape(batch_size, seq_len, self.n_heads * self.d_head)
        out = self.Wo(a) # out is (batch_size, seq_len, d_model)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_head):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_head

        self.attn = Attention(d_model, n_heads, d_head)
        self.mlp = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.ReLU(), nn.Linear(4*d_model, d_model))

        self.norm1 = nn.RMSNorm(d_model)
        self.norm2 = nn.RMSNorm(d_model)

    def forward(self, x):
        x = self.attn(self.norm1(x)) + x
        x = self.mlp(self.norm2(x)) + x
        return x

class GPT(nn.Module):
    def __init__(self, d_model, n_heads, d_head, n_vocab, n_layers):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_head
        self.n_vocab = n_vocab

        self.embed = nn.Embedding(n_vocab, d_model)

        self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_head) for _ in range(n_layers)])

        self.norm = nn.RMSNorm(d_model)
        self.out_head = nn.Linear(d_model, n_vocab)

    def forward(self, x):
        x = self.embed(x)
        for block in self.blocks:
            x = block(x)
        x = self.out_head(self.norm(x))
        return x

class CustomModelForCausalLM(PreTrainedModel):
    config_class = CustomConfig
    _supports_attention_backend = True

    def __init__(self, config):
        super().__init__(config)
        self.model = GPT(config.d_model, config.n_heads, config.d_head, config.n_vocab, config.n_layers)

    def forward(self, tensor):
        with torch.autocast('cuda', dtype=torch.bfloat16):
            logits = self.model(tensor)
            return CausalLMOutput(logits=logits)

    def get_input_embeddings(self):
        return self.model.embed

    def set_input_embeddings(self, x):
        self.model.embed = x