File size: 6,786 Bytes
adf0368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import torch
import torch.nn as nn
import math

from .optimized_diffattn import MultiheadDiffAttn

# --- Tokenizer Definition ---
# Vocabulary: 256 bytes + IM_START_TOKEN + IM_END_TOKEN + <pad>
IM_START_TOKEN = "<|im_start|>"
IM_END_TOKEN = "<|im_end|>"
PAD_TOKEN = "<pad>"

SPECIAL_TOKENS = [IM_START_TOKEN, IM_END_TOKEN, PAD_TOKEN]
VOCAB_SIZE = 256 + len(SPECIAL_TOKENS)

# Create token to id mapping
token_to_id = {}
id_to_token = {}

for i in range(256):
    token_to_id[bytes([i])] = i
    id_to_token[i] = bytes([i])

for i, token_str in enumerate(SPECIAL_TOKENS):
    token_id = 256 + i
    token_to_id[token_str] = token_id
    id_to_token[token_id] = token_str

PAD_ID = token_to_id[PAD_TOKEN]
IM_START_ID = token_to_id[IM_START_TOKEN]
IM_END_ID = token_to_id[IM_END_TOKEN]


class ByteTokenizer:
    def __init__(self):
        self.token_to_id = token_to_id
        self.id_to_token = id_to_token
        self.vocab_size = VOCAB_SIZE
        self.pad_id = PAD_ID
        self.im_start_id = IM_START_ID
        self.im_end_id = IM_END_ID

    def encode(self, text_bytes: bytes, add_special_tokens=True):
        ids = [self.token_to_id[bytes([b])] for b in text_bytes]
        if add_special_tokens:
            return [self.im_start_id] + ids + [self.im_end_id]
        return ids

    def decode(self, ids: list[int]):
        tokens = []
        for i in ids:
            token = self.id_to_token.get(i)
            if token is None:
                # Handle unknown token ID if necessary, or raise error
                tokens.append(b"?")  # Placeholder for unknown
            elif isinstance(token, bytes):
                tokens.append(token)
            # Ignore special tokens for decoding to raw text, or handle as needed
        return b"".join(tokens)


# --- RoPE Embeddings --- (Reused from previous script)
def get_rotary_embeddings(seq_len, dim_model, theta=10000.0):
    if dim_model % 2 != 0:
        raise ValueError(f"dim_model must be even, got {dim_model}")
    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0, dim_model, 2).float() * -(math.log(theta) / dim_model)
    )
    angles = position * div_term
    cos_emb = torch.cos(angles)
    sin_emb = torch.sin(angles)
    return cos_emb, sin_emb


# --- Model Definition ---
class FeedForward(nn.Module):
    def __init__(self, embed_dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.act = nn.GELU()

    def forward(self, x):
        return self.fc2(self.dropout(self.act(self.fc1(x))))


class DiffTransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, depth, ffn_hidden_dim, dropout=0.1):
        super().__init__()
        self.attn = MultiheadDiffAttn(embed_dim, depth, num_heads, dropout=dropout)
        self.ffn = FeedForward(embed_dim, ffn_hidden_dim, dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, rel_pos, attn_mask=None):
        # Pre-norm
        attn_out = self.attn(self.norm1(x), rel_pos, attn_mask)
        x = x + self.dropout(attn_out)
        ffn_out = self.ffn(self.norm2(x))
        x = x + self.dropout(ffn_out)
        return x


class DiffTransformerLLM(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_dim,
        num_layers,
        num_heads,
        ffn_hidden_dim,
        max_seq_len,
        dropout=0.1,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len

        self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
        # Positional embeddings are handled by RoPE, so no separate nn.Embedding for positions
        self.dropout = nn.Dropout(dropout)

        self.layers = nn.ModuleList(
            [
                DiffTransformerBlock(
                    embed_dim, num_heads, depth, ffn_hidden_dim, dropout
                )
                for depth in range(num_layers)
            ]
        )
        self.norm_out = nn.LayerNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)

        # Tie weights
        self.token_embeddings.weight = self.lm_head.weight

        # RoPE precomputation
        # The head_dim for MultiheadDiffAttn is embed_dim // num_heads // 2
        self.rope_head_dim = embed_dim // num_heads // 2
        cos_emb, sin_emb = get_rotary_embeddings(max_seq_len, self.rope_head_dim)
        self.register_buffer("cos_emb", cos_emb, persistent=False)
        self.register_buffer("sin_emb", sin_emb, persistent=False)

    def forward(self, input_ids, attn_mask=None):
        batch_size, seq_len = input_ids.shape

        x = self.token_embeddings(input_ids) * math.sqrt(self.embed_dim)
        x = self.dropout(x)

        # Ensure RoPE embeddings are on the same device *and* dtype as activations
        rel_pos = (
            self.cos_emb[:seq_len, :].to(x.device, dtype=x.dtype),
            self.sin_emb[:seq_len, :].to(x.device, dtype=x.dtype),
        )

        # Create causal attention mask if not provided
        if attn_mask is None:
            # Standard causal mask for autoregressive decoding
            # MultiheadDiffAttn expects a mask where -inf indicates masked positions
            causal_mask = torch.triu(
                torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
                diagonal=1,
            )
        else:
            # If a custom mask is provided (e.g., for padding), ensure it's correctly formatted
            # For MultiheadDiffAttn, 0 means attend, -inf means mask.
            # Assuming input attn_mask is 1 for attend, 0 for mask (like Hugging Face)
            # We need to convert it: (1 - attn_mask) * -inf
            # However, MultiheadDiffAttn's internal mask logic might be sufficient if it handles padding.
            # For simplicity, let's assume the provided attn_mask is already in the correct format if not None.
            # If it's a padding mask (1 for real tokens, 0 for pad), we need to adapt it.
            # Let's stick to causal mask for now, padding handled by loss_fn ignore_index.
            causal_mask = torch.triu(
                torch.ones(seq_len, seq_len, device=x.device) * float("-inf"),
                diagonal=1,
            )

        for layer in self.layers:
            x = layer(x, rel_pos, attn_mask=causal_mask)

        x = self.norm_out(x)
        logits = self.lm_head(x)
        return logits

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)