lemms commited on
Commit
e15a2f9
·
verified ·
1 Parent(s): f1b5f6b

Add OpenLLM model.py source file

Browse files
Files changed (1) hide show
  1. model.py +641 -0
model.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024 Louis Chua Bean Chong
3
+ #
4
+ # This file is part of OpenLLM.
5
+ #
6
+ # OpenLLM is dual-licensed:
7
+ # 1. For open source use: GNU General Public License v3.0
8
+ # 2. For commercial use: Commercial License (contact for details)
9
+ #
10
+ # See LICENSE and docs/LICENSES.md for full license information.
11
+
12
+ """
13
+ GPT-style Language Model Architecture
14
+
15
+ This module implements a standard GPT (Generative Pre-trained Transformer) architecture
16
+ using pure PyTorch. The model is a decoder-only transformer designed for autoregressive
17
+ language modeling (next-token prediction).
18
+
19
+ ARCHITECTURE OVERVIEW:
20
+ - Token Embedding: Maps token IDs to dense vectors
21
+ - Positional Embedding: Adds position information to token embeddings
22
+ - Transformer Blocks: Stack of multi-head attention + feed-forward layers
23
+ - Layer Normalization: Pre-norm placement for training stability
24
+ - Output Head: Linear projection to vocabulary for next-token prediction
25
+
26
+ FEATURES:
27
+ - Configurable model size (small/medium/large)
28
+ - Dropout for regularization
29
+ - Causal (autoregressive) attention masking
30
+ - Compatible with our SentencePiece tokenizer
31
+ - Memory-efficient implementation for training on limited hardware
32
+
33
+ Usage:
34
+ from model import GPTConfig, GPTModel
35
+
36
+ config = GPTConfig(vocab_size=32000, n_layer=12, n_head=12, n_embd=768)
37
+ model = GPTModel(config)
38
+
39
+ # Forward pass
40
+ logits = model(input_ids) # Shape: (batch_size, seq_len, vocab_size)
41
+
42
+ Hardware Requirements:
43
+ - Small Model (25M params): 4-8GB RAM, CPU/integrated GPU
44
+ - Medium Model (117M params): 8-16GB RAM, dedicated GPU recommended
45
+ - Large Model (350M params): 16GB+ RAM, high-end GPU required
46
+
47
+ Author: Louis Chua Bean Chong
48
+ License: GPLv3
49
+ """
50
+
51
+ import math
52
+ import torch
53
+ import torch.nn as nn
54
+ import torch.nn.functional as F
55
+ from dataclasses import dataclass
56
+ from typing import Optional, Tuple
57
+
58
+
59
+ @dataclass
60
+ class GPTConfig:
61
+ """
62
+ Configuration class for GPT model hyperparameters.
63
+
64
+ This class defines all the architectural parameters needed to instantiate
65
+ a GPT model. Use the provided class methods to get pre-configured setups
66
+ for different model sizes.
67
+ """
68
+
69
+ # Model architecture
70
+ vocab_size: int = 32000 # Vocabulary size (from tokenizer)
71
+ n_layer: int = 12 # Number of transformer layers
72
+ n_head: int = 12 # Number of attention heads
73
+ n_embd: int = 768 # Embedding dimension
74
+
75
+ # Sequence and context
76
+ block_size: int = 1024 # Maximum sequence length
77
+
78
+ # Training hyperparameters
79
+ dropout: float = 0.1 # Dropout probability
80
+ bias: bool = True # Use bias in linear layers
81
+
82
+ # Model size identifier
83
+ model_name: str = "gpt-medium" # Human-readable model identifier
84
+
85
+ @classmethod
86
+ def small(cls) -> 'GPTConfig':
87
+ """Small model configuration (~25M parameters) - Good for CPU training"""
88
+ return cls(
89
+ vocab_size=32000,
90
+ n_layer=6,
91
+ n_head=8,
92
+ n_embd=512,
93
+ block_size=1024,
94
+ dropout=0.1,
95
+ model_name="gpt-small"
96
+ )
97
+
98
+ @classmethod
99
+ def medium(cls) -> 'GPTConfig':
100
+ """Medium model configuration (~117M parameters) - Balanced performance"""
101
+ return cls(
102
+ vocab_size=32000,
103
+ n_layer=12,
104
+ n_head=12,
105
+ n_embd=768,
106
+ block_size=2048,
107
+ dropout=0.1,
108
+ model_name="gpt-medium"
109
+ )
110
+
111
+ @classmethod
112
+ def large(cls) -> 'GPTConfig':
113
+ """Large model configuration (~350M parameters) - High performance"""
114
+ return cls(
115
+ vocab_size=32000,
116
+ n_layer=24,
117
+ n_head=16,
118
+ n_embd=1024,
119
+ block_size=2048,
120
+ dropout=0.1,
121
+ model_name="gpt-large"
122
+ )
123
+
124
+ def estimate_parameters(self) -> int:
125
+ """
126
+ Estimate the total number of trainable parameters.
127
+
128
+ Returns:
129
+ int: Estimated parameter count
130
+ """
131
+ # Token embeddings
132
+ token_emb = self.vocab_size * self.n_embd
133
+
134
+ # Position embeddings
135
+ pos_emb = self.block_size * self.n_embd
136
+
137
+ # Transformer layers
138
+ # Each layer: attention (4 * n_embd^2) + mlp (8 * n_embd^2) + layer_norms
139
+ layer_params = self.n_layer * (12 * self.n_embd**2 + 4 * self.n_embd)
140
+
141
+ # Output head
142
+ output_head = self.vocab_size * self.n_embd
143
+
144
+ total = token_emb + pos_emb + layer_params + output_head
145
+ return total
146
+
147
+
148
+ class CausalSelfAttention(nn.Module):
149
+ """
150
+ Multi-head causal self-attention mechanism.
151
+
152
+ This implements the core attention mechanism of the transformer, with causal
153
+ masking to ensure autoregressive behavior (tokens can only attend to previous
154
+ tokens, not future ones).
155
+ """
156
+
157
+ def __init__(self, config: GPTConfig):
158
+ super().__init__()
159
+ assert config.n_embd % config.n_head == 0, "Embedding dim must be divisible by number of heads"
160
+
161
+ self.config = config
162
+ self.n_head = config.n_head
163
+ self.n_embd = config.n_embd
164
+ self.head_dim = self.n_embd // self.n_head
165
+
166
+ # Key, query, value projections for all heads (batched)
167
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
168
+
169
+ # Output projection
170
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
171
+
172
+ # Dropout
173
+ self.attn_dropout = nn.Dropout(config.dropout)
174
+ self.resid_dropout = nn.Dropout(config.dropout)
175
+
176
+ # Causal mask - lower triangular matrix
177
+ self.register_buffer(
178
+ "bias",
179
+ torch.tril(torch.ones(config.block_size, config.block_size))
180
+ .view(1, 1, config.block_size, config.block_size)
181
+ )
182
+
183
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
184
+ """
185
+ Forward pass of causal self-attention.
186
+
187
+ This method implements the scaled dot-product attention mechanism with causal masking.
188
+ The attention mechanism allows each token to attend to all previous tokens in the sequence,
189
+ but not to future tokens, maintaining the autoregressive property essential for language modeling.
190
+
191
+ Mathematical formulation:
192
+ Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
193
+ where Q, K, V are query, key, value matrices derived from input x
194
+
195
+ Implementation details:
196
+ - Uses batch matrix multiplication for efficiency
197
+ - Applies causal mask to prevent future token attention
198
+ - Implements multi-head attention by reshaping and parallel processing
199
+ - Applies dropout for regularization during training
200
+
201
+ Args:
202
+ x: Input tensor of shape (batch_size, seq_len, n_embd)
203
+ Contains embedded token representations from previous layer
204
+
205
+ Returns:
206
+ torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
207
+ """
208
+ # Extract tensor dimensions for clear variable naming and validation
209
+ # B = batch size (number of sequences processed in parallel)
210
+ # T = sequence length (number of tokens in each sequence)
211
+ # C = embedding dimensionality (n_embd from config)
212
+ B, T, C = x.size()
213
+
214
+ # Generate query, key, and value projections for all attention heads
215
+ # The c_attn linear layer outputs 3 * n_embd features, which we split into Q, K, V
216
+ # This batched approach is more efficient than separate linear layers
217
+ # Input shape: (B, T, C) -> Output shape: (B, T, 3*C) -> Split to 3x (B, T, C)
218
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
219
+
220
+ # Reshape tensors for multi-head attention computation
221
+ # Transform from (B, T, C) to (B, nh, T, hs) where:
222
+ # - nh = number of heads (self.n_head)
223
+ # - hs = head size (self.head_dim = C // nh)
224
+ # The transpose(1, 2) moves the head dimension before sequence dimension for efficient computation
225
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
226
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
227
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
228
+
229
+ # Compute scaled dot-product attention scores
230
+ # Matrix multiplication: Q @ K^T gives attention affinities between all token pairs
231
+ # Scaling by 1/sqrt(head_dim) prevents softmax saturation for large embedding dimensions
232
+ # Shape: (B, nh, T, hs) @ (B, nh, hs, T) -> (B, nh, T, T)
233
+ # The resulting (T, T) matrix represents attention weights from each token to every other token
234
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
235
+
236
+ # Apply causal masking to enforce autoregressive property
237
+ # The causal mask ensures that token i can only attend to tokens j where j <= i
238
+ # This prevents the model from "cheating" by looking at future tokens during training
239
+ # We use -inf for masked positions so they become 0 after softmax
240
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
241
+
242
+ # Convert attention scores to probabilities using softmax
243
+ # Each row of the attention matrix now sums to 1, representing a probability distribution
244
+ # over which tokens to attend to for each query position
245
+ att = F.softmax(att, dim=-1)
246
+
247
+ # Apply dropout to attention weights for regularization
248
+ # This randomly zeros some attention connections during training to prevent overfitting
249
+ att = self.attn_dropout(att)
250
+
251
+ # Apply attention weights to value vectors
252
+ # This weighted combination produces the actual output of the attention mechanism
253
+ # Shape: (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
254
+ # Each output position is a weighted sum of all value vectors, with weights from attention
255
+ y = att @ v
256
+
257
+ # Concatenate multi-head outputs back to original embedding dimension
258
+ # Transform from (B, nh, T, hs) back to (B, T, C) where C = nh * hs
259
+ # The transpose moves head dimension back, and contiguous() ensures memory layout efficiency
260
+ # This combines information from all attention heads into a single representation
261
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
262
+
263
+ # Apply final output projection and residual dropout
264
+ # The output projection allows the model to learn how to best combine multi-head information
265
+ # Residual dropout provides additional regularization before the residual connection
266
+ y = self.resid_dropout(self.c_proj(y))
267
+ return y
268
+
269
+
270
+ class MLP(nn.Module):
271
+ """
272
+ Multi-Layer Perceptron (Feed-Forward Network) for Transformer.
273
+
274
+ This implements the position-wise feed-forward network that appears in each transformer layer.
275
+ The MLP provides additional non-linear transformation capacity beyond what attention provides.
276
+
277
+ Architecture:
278
+ Input -> Linear(n_embd -> 4*n_embd) -> GELU -> Linear(4*n_embd -> n_embd) -> Dropout -> Output
279
+
280
+ Design rationale:
281
+ - 4x expansion is standard in transformers (from "Attention Is All You Need")
282
+ - GELU activation provides smoother gradients than ReLU for language modeling
283
+ - Dropout prevents overfitting in the feed-forward layers
284
+ - Two linear layers allow complex non-linear transformations of attention outputs
285
+
286
+ Parameters:
287
+ - First linear layer: n_embd * 4*n_embd parameters (expansion)
288
+ - Second linear layer: 4*n_embd * n_embd parameters (projection back)
289
+ - Total: 8 * n_embd^2 parameters (significant portion of model size)
290
+ """
291
+
292
+ def __init__(self, config: GPTConfig):
293
+ super().__init__()
294
+
295
+ # First linear layer: expand embedding dimension by 4x
296
+ # This expansion gives the network more representational capacity
297
+ # The 4x factor is a standard choice that balances capacity vs efficiency
298
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
299
+
300
+ # GELU (Gaussian Error Linear Unit) activation function
301
+ # GELU provides smoother gradients compared to ReLU and works better for language modeling
302
+ # It's approximately: GELU(x) = x * Φ(x) where Φ is the CDF of standard normal distribution
303
+ self.gelu = nn.GELU()
304
+
305
+ # Second linear layer: project back to original embedding dimension
306
+ # This projection allows the network to combine information from the expanded representation
307
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
308
+
309
+ # Dropout for regularization in the feed-forward network
310
+ # Applied after the final projection to prevent overfitting
311
+ self.dropout = nn.Dropout(config.dropout)
312
+
313
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
314
+ """
315
+ Forward pass of the feed-forward network.
316
+
317
+ This method applies a two-layer MLP with GELU activation to transform
318
+ the attention outputs. The MLP operates independently on each position
319
+ in the sequence, providing position-wise non-linear transformations.
320
+
321
+ Mathematical operation:
322
+ MLP(x) = Dropout(Linear₂(GELU(Linear₁(x))))
323
+ where Linear₁: R^n_embd -> R^4*n_embd and Linear₂: R^4*n_embd -> R^n_embd
324
+
325
+ Args:
326
+ x: Input tensor of shape (batch_size, seq_len, n_embd)
327
+ Contains attended representations from the attention layer
328
+
329
+ Returns:
330
+ torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
331
+ Contains transformed representations ready for residual connection
332
+ """
333
+ # First linear transformation: expand from n_embd to 4*n_embd dimensions
334
+ # This expansion provides the network with a higher-dimensional space for computation
335
+ # Shape: (batch_size, seq_len, n_embd) -> (batch_size, seq_len, 4*n_embd)
336
+ x = self.c_fc(x)
337
+
338
+ # Apply GELU activation function for non-linearity
339
+ # GELU is smoother than ReLU and provides better gradients for language modeling
340
+ # It introduces non-linearity while maintaining differentiability everywhere
341
+ x = self.gelu(x)
342
+
343
+ # Second linear transformation: project back to original n_embd dimensions
344
+ # This projection combines information from the expanded representation
345
+ # Shape: (batch_size, seq_len, 4*n_embd) -> (batch_size, seq_len, n_embd)
346
+ x = self.c_proj(x)
347
+
348
+ # Apply dropout for regularization before residual connection
349
+ # Dropout randomly zeros some neurons during training to prevent overfitting
350
+ # This is particularly important in the feed-forward layers which have many parameters
351
+ x = self.dropout(x)
352
+
353
+ return x
354
+
355
+
356
+ class Block(nn.Module):
357
+ """
358
+ Single Transformer block.
359
+
360
+ Consists of:
361
+ 1. Layer normalization
362
+ 2. Multi-head causal self-attention
363
+ 3. Residual connection
364
+ 4. Layer normalization
365
+ 5. MLP (feed-forward network)
366
+ 6. Residual connection
367
+
368
+ Uses pre-norm architecture for better training stability.
369
+ """
370
+
371
+ def __init__(self, config: GPTConfig):
372
+ super().__init__()
373
+ self.ln_1 = nn.LayerNorm(config.n_embd)
374
+ self.attn = CausalSelfAttention(config)
375
+ self.ln_2 = nn.LayerNorm(config.n_embd)
376
+ self.mlp = MLP(config)
377
+
378
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
379
+ """
380
+ Forward pass of transformer block.
381
+
382
+ Args:
383
+ x: Input tensor of shape (batch_size, seq_len, n_embd)
384
+
385
+ Returns:
386
+ torch.Tensor: Output tensor of shape (batch_size, seq_len, n_embd)
387
+ """
388
+ # Pre-norm attention with residual connection
389
+ x = x + self.attn(self.ln_1(x))
390
+
391
+ # Pre-norm MLP with residual connection
392
+ x = x + self.mlp(self.ln_2(x))
393
+
394
+ return x
395
+
396
+
397
+ class GPTModel(nn.Module):
398
+ """
399
+ Complete GPT Language Model.
400
+
401
+ This is the main model class that combines all components:
402
+ - Token and positional embeddings
403
+ - Stack of transformer blocks
404
+ - Final layer normalization
405
+ - Language modeling head
406
+
407
+ The model can be used for:
408
+ - Training from scratch on text data
409
+ - Fine-tuning on downstream tasks
410
+ - Text generation (inference)
411
+ """
412
+
413
+ def __init__(self, config: GPTConfig):
414
+ super().__init__()
415
+ assert config.vocab_size is not None, "vocab_size must be specified"
416
+ assert config.block_size is not None, "block_size must be specified"
417
+
418
+ self.config = config
419
+
420
+ # Embeddings
421
+ self.transformer = nn.ModuleDict(dict(
422
+ wte = nn.Embedding(config.vocab_size, config.n_embd), # Token embeddings
423
+ wpe = nn.Embedding(config.block_size, config.n_embd), # Position embeddings
424
+ drop = nn.Dropout(config.dropout),
425
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # Transformer blocks
426
+ ln_f = nn.LayerNorm(config.n_embd), # Final layer norm
427
+ ))
428
+
429
+ # Language modeling head (maps hidden states to vocabulary)
430
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
431
+
432
+ # Tie weights between token embeddings and output head (common practice)
433
+ self.transformer.wte.weight = self.lm_head.weight
434
+
435
+ # Initialize weights
436
+ self.apply(self._init_weights)
437
+
438
+ # Report parameter count
439
+ print(f"Model initialized: {self.config.model_name}")
440
+ print(f"Parameters: {self.get_num_params():,}")
441
+ print(f"Estimated: {self.config.estimate_parameters():,}")
442
+
443
+ def _init_weights(self, module):
444
+ """Initialize model weights using standard practices."""
445
+ if isinstance(module, nn.Linear):
446
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
447
+ if module.bias is not None:
448
+ torch.nn.init.zeros_(module.bias)
449
+ elif isinstance(module, nn.Embedding):
450
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
451
+
452
+ def get_num_params(self, non_embedding: bool = False) -> int:
453
+ """
454
+ Count the number of parameters in the model.
455
+
456
+ Args:
457
+ non_embedding: If True, subtract embedding parameters
458
+
459
+ Returns:
460
+ int: Number of parameters
461
+ """
462
+ n_params = sum(p.numel() for p in self.parameters())
463
+ if non_embedding:
464
+ n_params -= self.transformer.wpe.weight.numel()
465
+ n_params -= self.transformer.wte.weight.numel()
466
+ return n_params
467
+
468
+ def forward(
469
+ self,
470
+ idx: torch.Tensor,
471
+ targets: Optional[torch.Tensor] = None
472
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
473
+ """
474
+ Forward pass of the GPT model.
475
+
476
+ Args:
477
+ idx: Input token indices of shape (batch_size, seq_len)
478
+ targets: Optional target tokens for loss calculation (batch_size, seq_len)
479
+
480
+ Returns:
481
+ Tuple containing:
482
+ - logits: Output logits of shape (batch_size, seq_len, vocab_size)
483
+ - loss: Cross-entropy loss if targets provided, None otherwise
484
+ """
485
+ device = idx.device
486
+ b, t = idx.size()
487
+ assert t <= self.config.block_size, f"Sequence length {t} exceeds block size {self.config.block_size}"
488
+
489
+ # Token embeddings
490
+ tok_emb = self.transformer.wte(idx) # (b, t, n_embd)
491
+
492
+ # Position embeddings
493
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # (t,)
494
+ pos_emb = self.transformer.wpe(pos) # (t, n_embd)
495
+
496
+ # Combine embeddings and apply dropout
497
+ x = self.transformer.drop(tok_emb + pos_emb)
498
+
499
+ # Pass through transformer blocks
500
+ for block in self.transformer.h:
501
+ x = block(x)
502
+
503
+ # Final layer normalization
504
+ x = self.transformer.ln_f(x)
505
+
506
+ # Language modeling head
507
+ if targets is not None:
508
+ # If we have targets, compute loss
509
+ logits = self.lm_head(x)
510
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
511
+ else:
512
+ # If no targets, only compute logits for the last token (more efficient for generation)
513
+ logits = self.lm_head(x[:, [-1], :]) # Note: using list [-1] to preserve the time dim
514
+ loss = None
515
+
516
+ return logits, loss
517
+
518
+ def generate(
519
+ self,
520
+ idx: torch.Tensor,
521
+ max_new_tokens: int = 100,
522
+ temperature: float = 1.0,
523
+ top_k: Optional[int] = None
524
+ ) -> torch.Tensor:
525
+ """
526
+ Generate new tokens autoregressively.
527
+
528
+ Args:
529
+ idx: Starting token indices (batch_size, seq_len)
530
+ max_new_tokens: Maximum number of new tokens to generate
531
+ temperature: Sampling temperature (higher = more random)
532
+ top_k: If set, only sample from top-k most likely tokens
533
+
534
+ Returns:
535
+ torch.Tensor: Generated sequence (batch_size, seq_len + max_new_tokens)
536
+ """
537
+ self.eval()
538
+ with torch.no_grad():
539
+ for _ in range(max_new_tokens):
540
+ # Crop sequence if it exceeds block size
541
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
542
+
543
+ # Forward pass
544
+ logits, _ = self(idx_cond)
545
+
546
+ # Get logits for the last token and apply temperature
547
+ logits = logits[:, -1, :] / temperature
548
+
549
+ # Optionally crop to top-k most likely tokens
550
+ if top_k is not None:
551
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
552
+ logits[logits < v[:, [-1]]] = -float('Inf')
553
+
554
+ # Apply softmax and sample
555
+ probs = F.softmax(logits, dim=-1)
556
+ idx_next = torch.multinomial(probs, num_samples=1)
557
+
558
+ # Append to sequence
559
+ idx = torch.cat((idx, idx_next), dim=1)
560
+
561
+ self.train() # Return to training mode
562
+ return idx
563
+
564
+ def estimate_memory_usage(self, batch_size: int = 1, seq_len: int = None) -> dict:
565
+ """
566
+ Estimate memory usage for training and inference.
567
+
568
+ Args:
569
+ batch_size: Batch size for estimation
570
+ seq_len: Sequence length (defaults to block_size)
571
+
572
+ Returns:
573
+ dict: Memory usage estimates in MB
574
+ """
575
+ if seq_len is None:
576
+ seq_len = self.config.block_size
577
+
578
+ # Model parameters (weights)
579
+ param_memory = self.get_num_params() * 4 / (1024**2) # 4 bytes per float32
580
+
581
+ # Activations (rough estimate)
582
+ activation_memory = (
583
+ batch_size * seq_len * self.config.n_embd * self.config.n_layer * 8 # Rough estimate
584
+ ) / (1024**2)
585
+
586
+ # Gradients (same size as parameters during training)
587
+ gradient_memory = param_memory
588
+
589
+ return {
590
+ "parameters_mb": param_memory,
591
+ "activations_mb": activation_memory,
592
+ "gradients_mb": gradient_memory,
593
+ "total_training_mb": param_memory + activation_memory + gradient_memory,
594
+ "total_inference_mb": param_memory + activation_memory * 0.5, # No gradients needed
595
+ }
596
+
597
+
598
+ def create_model(model_size: str = "medium") -> GPTModel:
599
+ """
600
+ Factory function to create a GPT model with predefined configurations.
601
+
602
+ Args:
603
+ model_size: Size of model to create ("small", "medium", "large")
604
+
605
+ Returns:
606
+ GPTModel: Initialized model
607
+ """
608
+ configs = {
609
+ "small": GPTConfig.small(),
610
+ "medium": GPTConfig.medium(),
611
+ "large": GPTConfig.large(),
612
+ }
613
+
614
+ if model_size not in configs:
615
+ raise ValueError(f"Unknown model size: {model_size}. Choose from {list(configs.keys())}")
616
+
617
+ config = configs[model_size]
618
+ model = GPTModel(config)
619
+
620
+ return model
621
+
622
+
623
+ if __name__ == "__main__":
624
+ # Example usage
625
+ print("🧠 GPT Model Architecture")
626
+ print("=" * 50)
627
+
628
+ # Create models of different sizes
629
+ for size in ["small", "medium", "large"]:
630
+ print(f"\n{size.upper()} MODEL:")
631
+ model = create_model(size)
632
+
633
+ # Show memory estimates
634
+ memory = model.estimate_memory_usage(batch_size=4, seq_len=512)
635
+ print(f"Memory (4 batch, 512 seq): {memory['total_training_mb']:.1f}MB training, {memory['total_inference_mb']:.1f}MB inference")
636
+
637
+ # Test forward pass
638
+ x = torch.randint(0, 32000, (2, 64)) # Batch size 2, sequence length 64
639
+ with torch.no_grad():
640
+ logits, _ = model(x)
641
+ print(f"Test forward pass: {x.shape} -> {logits.shape} ✓")