abi96062 commited on
Commit
0e3e3d6
·
verified ·
1 Parent(s): dbd2371

Create components.py

Browse files
Files changed (1) hide show
  1. components.py +388 -0
components.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ components.py
3
+ =============
4
+ Architectural components for SmolLM2-135M implementation
5
+
6
+ Components:
7
+ - RMSNorm: Root Mean Square Layer Normalization
8
+ - RotaryEmbedding: Rotary Position Embeddings (RoPE)
9
+ - GroupedQueryAttention: Grouped Query Attention (9 Q heads, 3 KV heads)
10
+ - SwiGLU_FFN: SwiGLU Feed-Forward Network
11
+ - TransformerBlock: Complete transformer block with pre-norm architecture
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import math
18
+
19
+
20
+ class RMSNorm(nn.Module):
21
+ """
22
+ Root Mean Square Layer Normalization
23
+
24
+ Simpler and faster than LayerNorm:
25
+ - No mean centering
26
+ - No bias term
27
+ - 10-15% faster than LayerNorm
28
+
29
+ Formula: output = input * rsqrt(mean(input²) + eps) * weight
30
+ """
31
+
32
+ def __init__(self, hidden_size, eps=1e-5):
33
+ """
34
+ Args:
35
+ hidden_size (int): Dimension of the input
36
+ eps (float): Small constant for numerical stability
37
+ """
38
+ super().__init__()
39
+ self.eps = eps
40
+ self.weight = nn.Parameter(torch.ones(hidden_size))
41
+
42
+ def forward(self, x):
43
+ """
44
+ Args:
45
+ x (torch.Tensor): Input tensor of shape [batch, seq_len, hidden_size]
46
+
47
+ Returns:
48
+ torch.Tensor: Normalized tensor of same shape as input
49
+ """
50
+ # Calculate variance (mean of squares)
51
+ variance = x.pow(2).mean(-1, keepdim=True)
52
+
53
+ # Normalize: x / sqrt(variance + eps)
54
+ x = x * torch.rsqrt(variance + self.eps)
55
+
56
+ # Scale by learned weight
57
+ return self.weight * x
58
+
59
+
60
+ class RotaryEmbedding(nn.Module):
61
+ """
62
+ Rotary Position Embedding (RoPE)
63
+
64
+ Encodes position by rotating Q and K vectors in 2D subspaces.
65
+ Enables relative position encoding and extrapolation to longer sequences.
66
+
67
+ Key properties:
68
+ - Applied only to Q and K, not V
69
+ - Different rotation frequencies for different dimension pairs
70
+ - Enables length extrapolation beyond training sequences
71
+ """
72
+
73
+ def __init__(self, dim, max_position_embeddings=2048, base=10000.0):
74
+ """
75
+ Args:
76
+ dim (int): Dimension of each attention head (typically hidden_size / num_heads)
77
+ max_position_embeddings (int): Maximum sequence length
78
+ base (float): Base for inverse frequency calculation (theta)
79
+ """
80
+ super().__init__()
81
+ self.dim = dim
82
+ self.max_position_embeddings = max_position_embeddings
83
+ self.base = base
84
+
85
+ # Calculate inverse frequencies for rotation
86
+ # inv_freq[i] = 1 / (base^(2i/dim)) for i in [0, dim/2)
87
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
88
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
89
+
90
+ def forward(self, x, position_ids):
91
+ """
92
+ Args:
93
+ x (torch.Tensor): Input tensor (used for device/dtype)
94
+ position_ids (torch.Tensor): Position indices [batch, seq_len] or [seq_len]
95
+
96
+ Returns:
97
+ tuple: (cos, sin) embeddings of shape [batch, seq_len, dim]
98
+ """
99
+ # Ensure position_ids has batch dimension
100
+ if position_ids.dim() == 1:
101
+ position_ids = position_ids.unsqueeze(0)
102
+
103
+ # Calculate rotation angles: position_ids × inv_freq
104
+ # Shape: [batch, seq_len, dim/2]
105
+ freqs = torch.einsum('bi,j->bij', position_ids.float(), self.inv_freq)
106
+
107
+ # Duplicate frequencies for both sin and cos
108
+ # Shape: [batch, seq_len, dim]
109
+ emb = torch.cat((freqs, freqs), dim=-1)
110
+
111
+ # Return cos and sin, preserving input dtype
112
+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
113
+
114
+
115
+ def rotate_half(x):
116
+ """
117
+ Rotate half the hidden dimensions
118
+
119
+ For RoPE, we rotate pairs of dimensions. This function rearranges
120
+ the tensor to prepare for rotation.
121
+
122
+ Args:
123
+ x (torch.Tensor): Input of shape [..., dim]
124
+
125
+ Returns:
126
+ torch.Tensor: Rotated tensor where second half is negated and moved to first
127
+ """
128
+ x1 = x[..., : x.shape[-1] // 2]
129
+ x2 = x[..., x.shape[-1] // 2 :]
130
+ return torch.cat((-x2, x1), dim=-1)
131
+
132
+
133
+ def apply_rotary_pos_emb(q, k, cos, sin):
134
+ """
135
+ Apply rotary position embeddings to queries and keys
136
+
137
+ Rotation formula:
138
+ q_rotated = q * cos + rotate_half(q) * sin
139
+ k_rotated = k * cos + rotate_half(k) * sin
140
+
141
+ Args:
142
+ q (torch.Tensor): Query tensor [batch, num_heads, seq_len, head_dim]
143
+ k (torch.Tensor): Key tensor [batch, num_heads, seq_len, head_dim]
144
+ cos (torch.Tensor): Cosine embeddings [batch, seq_len, head_dim]
145
+ sin (torch.Tensor): Sine embeddings [batch, seq_len, head_dim]
146
+
147
+ Returns:
148
+ tuple: (q_rotated, k_rotated) with rotary embeddings applied
149
+ """
150
+ # Add dimensions for broadcasting
151
+ # cos/sin: [batch, seq_len, dim] -> [batch, 1, seq_len, dim]
152
+ if cos.dim() == 2:
153
+ cos = cos.unsqueeze(0)
154
+ sin = sin.unsqueeze(0)
155
+ if cos.dim() == 3:
156
+ cos = cos.unsqueeze(1)
157
+ sin = sin.unsqueeze(1)
158
+
159
+ # Apply rotation
160
+ q_embed = (q * cos) + (rotate_half(q) * sin)
161
+ k_embed = (k * cos) + (rotate_half(k) * sin)
162
+
163
+ return q_embed, k_embed
164
+
165
+
166
+ class GroupedQueryAttention(nn.Module):
167
+ """
168
+ Grouped Query Attention (GQA)
169
+
170
+ Memory-efficient attention where multiple query heads share KV heads.
171
+ SmolLM2-135M uses 9 query heads and 3 KV heads (3:1 ratio).
172
+
173
+ Benefits:
174
+ - Reduces KV cache memory by 66% vs full MHA
175
+ - Maintains most of multi-head attention's expressiveness
176
+ - Used in Llama 2, Mistral, and other modern LLMs
177
+
178
+ Architecture:
179
+ - 9 query heads (each head_dim=64)
180
+ - 3 KV heads (each head_dim=64)
181
+ - Each KV head is repeated 3 times to serve 3 query heads
182
+ """
183
+
184
+ def __init__(self, config):
185
+ """
186
+ Args:
187
+ config: Model configuration with attributes:
188
+ - hidden_size: Model dimension (576)
189
+ - num_attention_heads: Number of query heads (9)
190
+ - num_key_value_heads: Number of KV heads (3)
191
+ - max_position_embeddings: Max sequence length
192
+ - rope_theta: RoPE base frequency
193
+ """
194
+ super().__init__()
195
+ self.hidden_size = config.hidden_size # 576
196
+ self.num_heads = config.num_attention_heads # 9
197
+ self.num_kv_heads = config.num_key_value_heads # 3
198
+ self.num_kv_groups = self.num_heads // self.num_kv_heads # 3
199
+ self.head_dim = self.hidden_size // self.num_heads # 64
200
+
201
+ assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads"
202
+ assert self.num_heads % self.num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"
203
+
204
+ # Projections (no bias in any linear layers)
205
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
206
+ self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
207
+ self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
208
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
209
+
210
+ # Rotary embeddings
211
+ self.rotary_emb = RotaryEmbedding(
212
+ self.head_dim,
213
+ max_position_embeddings=config.max_position_embeddings,
214
+ base=config.rope_theta
215
+ )
216
+
217
+ def forward(self, hidden_states, attention_mask=None, position_ids=None):
218
+ """
219
+ Forward pass of grouped query attention
220
+
221
+ Args:
222
+ hidden_states (torch.Tensor): Input [batch, seq_len, hidden_size]
223
+ attention_mask (torch.Tensor, optional): Attention mask
224
+ position_ids (torch.Tensor, optional): Position indices
225
+
226
+ Returns:
227
+ torch.Tensor: Output [batch, seq_len, hidden_size]
228
+ """
229
+ batch_size, seq_len, _ = hidden_states.size()
230
+
231
+ # Create position IDs if not provided
232
+ if position_ids is None:
233
+ position_ids = torch.arange(seq_len, device=hidden_states.device)
234
+
235
+ # Q, K, V projections
236
+ query_states = self.q_proj(hidden_states) # [batch, seq_len, 576]
237
+ key_states = self.k_proj(hidden_states) # [batch, seq_len, 192]
238
+ value_states = self.v_proj(hidden_states) # [batch, seq_len, 192]
239
+
240
+ # Reshape to separate heads
241
+ # Q: [batch, seq_len, 9, 64] -> [batch, 9, seq_len, 64]
242
+ query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
243
+ # K, V: [batch, seq_len, 3, 64] -> [batch, 3, seq_len, 64]
244
+ key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
245
+ value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
246
+
247
+ # Apply RoPE to Q and K
248
+ cos, sin = self.rotary_emb(value_states, position_ids)
249
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
250
+
251
+ # Repeat K and V for GQA (3 KV heads -> 9 to match Q heads)
252
+ # Each KV head is repeated 3 times: [batch, 3, seq, 64] -> [batch, 9, seq, 64]
253
+ key_states = key_states.repeat_interleave(self.num_kv_groups, dim=1)
254
+ value_states = value_states.repeat_interleave(self.num_kv_groups, dim=1)
255
+
256
+ # Scaled dot-product attention (PyTorch 2.0+ optimized)
257
+ # Equivalent to ~80% of Flash Attention performance
258
+ attn_output = F.scaled_dot_product_attention(
259
+ query_states,
260
+ key_states,
261
+ value_states,
262
+ attn_mask=attention_mask,
263
+ dropout_p=0.0,
264
+ is_causal=True # Causal masking for autoregressive generation
265
+ )
266
+
267
+ # Reshape back: [batch, 9, seq_len, 64] -> [batch, seq_len, 576]
268
+ attn_output = attn_output.transpose(1, 2).contiguous()
269
+ attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
270
+
271
+ # Output projection
272
+ attn_output = self.o_proj(attn_output)
273
+
274
+ return attn_output
275
+
276
+
277
+ class SwiGLU_FFN(nn.Module):
278
+ """
279
+ SwiGLU Feed-Forward Network
280
+
281
+ Uses Swish-Gated Linear Units instead of standard FFN.
282
+ Formula: FFN(x) = down_proj(SiLU(gate_proj(x)) ⊙ up_proj(x))
283
+
284
+ Key differences from standard FFN:
285
+ - 3 linear projections instead of 2 (gate, up, down)
286
+ - Element-wise gating mechanism (⊙)
287
+ - 50% more parameters but better performance
288
+ - Used in Llama, PaLM, and most modern LLMs
289
+ """
290
+
291
+ def __init__(self, config):
292
+ """
293
+ Args:
294
+ config: Model configuration with attributes:
295
+ - hidden_size: Model dimension (576)
296
+ - intermediate_size: FFN intermediate dimension (1536)
297
+ """
298
+ super().__init__()
299
+ self.hidden_size = config.hidden_size # 576
300
+ self.intermediate_size = config.intermediate_size # 1536
301
+
302
+ # Three projections (no bias)
303
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
304
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
305
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
306
+
307
+ # Swish/SiLU activation
308
+ self.act_fn = nn.SiLU()
309
+
310
+ def forward(self, x):
311
+ """
312
+ Forward pass: down(SiLU(gate) * up)
313
+
314
+ Args:
315
+ x (torch.Tensor): Input [batch, seq_len, hidden_size]
316
+
317
+ Returns:
318
+ torch.Tensor: Output [batch, seq_len, hidden_size]
319
+ """
320
+ # Gate path: apply SiLU activation
321
+ gate = self.act_fn(self.gate_proj(x))
322
+
323
+ # Up path: linear transformation
324
+ up = self.up_proj(x)
325
+
326
+ # Element-wise multiplication (gating)
327
+ gated = gate * up
328
+
329
+ # Down projection
330
+ return self.down_proj(gated)
331
+
332
+
333
+ class TransformerBlock(nn.Module):
334
+ """
335
+ Complete Transformer Block with Pre-Norm Architecture
336
+
337
+ Architecture:
338
+ 1. x -> RMSNorm -> Attention -> Add residual
339
+ 2. x -> RMSNorm -> FFN -> Add residual
340
+
341
+ Pre-norm (norm before sublayer) is standard in modern transformers
342
+ as it provides better gradient flow in deep networks.
343
+ """
344
+
345
+ def __init__(self, config):
346
+ """
347
+ Args:
348
+ config: Model configuration
349
+ """
350
+ super().__init__()
351
+
352
+ # Layer normalization (pre-norm)
353
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
354
+
355
+ # Self-attention
356
+ self.self_attn = GroupedQueryAttention(config)
357
+
358
+ # Post-attention layer norm
359
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
360
+
361
+ # Feed-forward network
362
+ self.mlp = SwiGLU_FFN(config)
363
+
364
+ def forward(self, hidden_states, attention_mask=None, position_ids=None):
365
+ """
366
+ Forward pass through transformer block
367
+
368
+ Args:
369
+ hidden_states (torch.Tensor): Input [batch, seq_len, hidden_size]
370
+ attention_mask (torch.Tensor, optional): Attention mask
371
+ position_ids (torch.Tensor, optional): Position indices
372
+
373
+ Returns:
374
+ torch.Tensor: Output [batch, seq_len, hidden_size]
375
+ """
376
+ # Self-attention with residual connection
377
+ residual = hidden_states
378
+ hidden_states = self.input_layernorm(hidden_states)
379
+ hidden_states = self.self_attn(hidden_states, attention_mask, position_ids)
380
+ hidden_states = residual + hidden_states
381
+
382
+ # FFN with residual connection
383
+ residual = hidden_states
384
+ hidden_states = self.post_attention_layernorm(hidden_states)
385
+ hidden_states = self.mlp(hidden_states)
386
+ hidden_states = residual + hidden_states
387
+
388
+ return hidden_states