jonmabe commited on
Commit
e70d699
·
verified ·
1 Parent(s): 443ebb1

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +193 -0
model.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tiny Transformer with modern components:
3
+ - RoPE (Rotary Position Embeddings)
4
+ - RMSNorm
5
+ - SwiGLU activation
6
+ - Weight tying
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import math
12
+
13
+ class RMSNorm(nn.Module):
14
+ def __init__(self, dim: int, eps: float = 1e-6):
15
+ super().__init__()
16
+ self.eps = eps
17
+ self.weight = nn.Parameter(torch.ones(dim))
18
+
19
+ def forward(self, x):
20
+ norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
21
+ return x * norm * self.weight
22
+
23
+
24
+ class RotaryEmbedding(nn.Module):
25
+ def __init__(self, dim: int, max_seq_len: int = 512, base: int = 10000):
26
+ super().__init__()
27
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
28
+ self.register_buffer("inv_freq", inv_freq)
29
+ self.max_seq_len = max_seq_len
30
+
31
+ def forward(self, x, seq_len: int):
32
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
33
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
34
+ emb = torch.cat((freqs, freqs), dim=-1)
35
+ return emb.cos(), emb.sin()
36
+
37
+
38
+ def rotate_half(x):
39
+ x1, x2 = x.chunk(2, dim=-1)
40
+ return torch.cat((-x2, x1), dim=-1)
41
+
42
+
43
+ def apply_rotary_pos_emb(q, k, cos, sin):
44
+ cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim]
45
+ sin = sin.unsqueeze(0).unsqueeze(0)
46
+ q_embed = (q * cos) + (rotate_half(q) * sin)
47
+ k_embed = (k * cos) + (rotate_half(k) * sin)
48
+ return q_embed, k_embed
49
+
50
+
51
+ class SwiGLU(nn.Module):
52
+ def __init__(self, hidden_size: int, intermediate_size: int):
53
+ super().__init__()
54
+ self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
55
+ self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
56
+ self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
57
+
58
+ def forward(self, x):
59
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
60
+
61
+
62
+ class Attention(nn.Module):
63
+ def __init__(self, hidden_size: int, num_heads: int, dropout: float = 0.0):
64
+ super().__init__()
65
+ self.num_heads = num_heads
66
+ self.head_dim = hidden_size // num_heads
67
+
68
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
69
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
70
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
71
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
72
+
73
+ self.rotary = RotaryEmbedding(self.head_dim)
74
+ self.dropout = nn.Dropout(dropout)
75
+
76
+ def forward(self, x, mask=None):
77
+ B, T, C = x.shape
78
+
79
+ q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
80
+ k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
81
+ v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
82
+
83
+ cos, sin = self.rotary(x, T)
84
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
85
+
86
+ # Scaled dot-product attention
87
+ scale = 1.0 / math.sqrt(self.head_dim)
88
+ attn = torch.matmul(q, k.transpose(-2, -1)) * scale
89
+
90
+ if mask is not None:
91
+ attn = attn.masked_fill(mask == 0, float('-inf'))
92
+
93
+ attn = F.softmax(attn, dim=-1)
94
+ attn = self.dropout(attn)
95
+
96
+ out = torch.matmul(attn, v)
97
+ out = out.transpose(1, 2).contiguous().view(B, T, C)
98
+ return self.o_proj(out)
99
+
100
+
101
+ class TransformerBlock(nn.Module):
102
+ def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int, dropout: float = 0.0):
103
+ super().__init__()
104
+ self.norm1 = RMSNorm(hidden_size)
105
+ self.attn = Attention(hidden_size, num_heads, dropout)
106
+ self.norm2 = RMSNorm(hidden_size)
107
+ self.ffn = SwiGLU(hidden_size, intermediate_size)
108
+
109
+ def forward(self, x, mask=None):
110
+ x = x + self.attn(self.norm1(x), mask)
111
+ x = x + self.ffn(self.norm2(x))
112
+ return x
113
+
114
+
115
+ class TinyLLM(nn.Module):
116
+ def __init__(
117
+ self,
118
+ vocab_size: int = 32000,
119
+ hidden_size: int = 512,
120
+ num_layers: int = 12,
121
+ num_heads: int = 8,
122
+ intermediate_size: int = 1408,
123
+ max_position_embeddings: int = 512,
124
+ dropout: float = 0.0,
125
+ tie_weights: bool = True,
126
+ ):
127
+ super().__init__()
128
+ self.vocab_size = vocab_size
129
+ self.hidden_size = hidden_size
130
+
131
+ self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
132
+ self.layers = nn.ModuleList([
133
+ TransformerBlock(hidden_size, num_heads, intermediate_size, dropout)
134
+ for _ in range(num_layers)
135
+ ])
136
+ self.norm = RMSNorm(hidden_size)
137
+ self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
138
+
139
+ if tie_weights:
140
+ self.lm_head.weight = self.embed_tokens.weight
141
+
142
+ # Causal mask
143
+ self.register_buffer(
144
+ "causal_mask",
145
+ torch.tril(torch.ones(max_position_embeddings, max_position_embeddings))
146
+ )
147
+
148
+ self._init_weights()
149
+
150
+ def _init_weights(self):
151
+ for module in self.modules():
152
+ if isinstance(module, nn.Linear):
153
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
154
+ elif isinstance(module, nn.Embedding):
155
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
156
+
157
+ def forward(self, input_ids, labels=None):
158
+ B, T = input_ids.shape
159
+
160
+ x = self.embed_tokens(input_ids)
161
+ mask = self.causal_mask[:T, :T]
162
+
163
+ for layer in self.layers:
164
+ x = layer(x, mask)
165
+
166
+ x = self.norm(x)
167
+ logits = self.lm_head(x)
168
+
169
+ loss = None
170
+ if labels is not None:
171
+ shift_logits = logits[..., :-1, :].contiguous()
172
+ shift_labels = labels[..., 1:].contiguous()
173
+ loss = F.cross_entropy(
174
+ shift_logits.view(-1, self.vocab_size),
175
+ shift_labels.view(-1),
176
+ ignore_index=-100
177
+ )
178
+
179
+ return {"loss": loss, "logits": logits}
180
+
181
+ def count_parameters(self):
182
+ return sum(p.numel() for p in self.parameters())
183
+
184
+
185
+ if __name__ == "__main__":
186
+ # Test model
187
+ model = TinyLLM()
188
+ print(f"Parameters: {model.count_parameters() / 1e6:.2f}M")
189
+
190
+ x = torch.randint(0, 32000, (2, 128))
191
+ out = model(x, labels=x)
192
+ print(f"Loss: {out['loss'].item():.4f}")
193
+ print(f"Logits shape: {out['logits'].shape}")