pragsyy1729 commited on
Commit
00b97f1
·
1 Parent(s): b7b55de

Add SmolLM2-135M model and Gradio app

Browse files

- app.py: Gradio interface for text generation
- requirements.txt: Dependencies (torch, tiktoken, gradio)
- README.md: Model architecture and training documentation
- smollm2_135m_final.pt: Trained model weights (~135M params)

Model trained from scratch on Suits TV series scripts for 5,050 steps.
Architecture: RMSNorm, RoPE, GQA (9 query, 3 KV heads), SwiGLU MLP.

Files changed (4) hide show
  1. README.md +323 -9
  2. app.py +401 -0
  3. requirements.txt +3 -0
  4. smollm2_135m_final.pt +3 -0
README.md CHANGED
@@ -1,12 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: SmolLM2 135M From Scratch
3
- emoji: 🐨
4
- colorFrom: yellow
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 6.3.0
8
- app_file: app.py
9
- pinned: false
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SmolLM2-135M Training from Scratch
2
+
3
+ This repository contains a from-scratch implementation of the SmolLM2-135M model architecture, trained on custom text data.
4
+
5
+ ## Table of Contents
6
+ - [Model Architecture](#model-architecture)
7
+ - [Parameter Calculation](#parameter-calculation)
8
+ - [Training Data](#training-data)
9
+ - [Training Details](#training-details)
10
+ - [Speedups Used](#speedups-used)
11
+ - [Results](#results)
12
+ - [Usage](#usage)
13
+
14
+ ---
15
+
16
+ ## Model Architecture
17
+
18
+ SmolLM2-135M is a **Llama-based decoder-only transformer** model. Unlike GPT-2, it incorporates modern architectural improvements that have become standard in recent language models.
19
+
20
+ ### Architecture Overview
21
+
22
+ ```
23
+ Input Tokens
24
+
25
+
26
+ ┌─────────────────┐
27
+ │ Token Embedding │ (No position embeddings - RoPE applied in attention)
28
+ └────────┬────────┘
29
+
30
+
31
+ ┌─────────────────────────────────────┐
32
+ │ Transformer Block x30 │
33
+ │ ┌─────────────────────────────┐ │
34
+ │ │ RMSNorm │ │
35
+ │ │ ↓ │ │
36
+ │ │ Grouped Query Attention │ │
37
+ │ │ (9 query heads, 3 KV heads) │ │
38
+ │ │ + RoPE │ │
39
+ │ │ ↓ │ │
40
+ │ │ Residual Connection │ │
41
+ │ │ ↓ │ │
42
+ │ │ RMSNorm │ │
43
+ │ │ ↓ │ │
44
+ │ │ SwiGLU MLP │ │
45
+ │ │ ↓ │ │
46
+ │ │ Residual Connection │ │
47
+ │ └─────────────────────────────┘ │
48
+ └────────────────┬────────────────────┘
49
+
50
+
51
+ ┌─────────────────┐
52
+ │ Final RMSNorm │
53
+ └────────┬────────┘
54
+
55
+
56
+ ┌─────────────────┐
57
+ │ LM Head │ (Tied with token embeddings)
58
+ └────────┬────────┘
59
+
60
+
61
+ Output Logits
62
+ ```
63
+
64
+ ### Key Components
65
+
66
+ #### 1. RMSNorm (Root Mean Square Normalization)
67
+ Unlike LayerNorm, RMSNorm doesn't center activations (no mean subtraction), making it more computationally efficient.
68
+
69
+ ```python
70
+ # Formula: x * weight / sqrt(mean(x²) + eps)
71
+ def forward(self, x):
72
+ rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
73
+ return x * rms * self.weight
74
+ ```
75
+
76
+ #### 2. Rotary Position Embedding (RoPE)
77
+ RoPE encodes position information by rotating query and key vectors. The dot product of rotated vectors naturally encodes relative positions.
78
+
79
+ - **Advantage**: No learned position embeddings needed
80
+ - **Advantage**: Better extrapolation to longer sequences
81
+ - **theta**: 10,000 (base frequency)
82
+
83
+ #### 3. Grouped Query Attention (GQA)
84
+ GQA reduces memory bandwidth by sharing key-value heads across multiple query heads.
85
+
86
+ | Component | Count |
87
+ |-----------|-------|
88
+ | Query Heads | 9 |
89
+ | Key-Value Heads | 3 |
90
+ | KV Groups | 3 (each KV head shared by 3 query heads) |
91
+ | Head Dimension | 64 |
92
+
93
+ ```python
94
+ # Q: (B, T, 9 heads, 64)
95
+ # K, V: (B, T, 3 heads, 64) → repeated to match 9 query heads
96
+ ```
97
+
98
+ #### 4. SwiGLU MLP
99
+ SwiGLU is a gated linear unit variant using SiLU activation, shown to improve model quality.
100
+
101
+ ```python
102
+ # Formula: down_proj(silu(gate_proj(x)) * up_proj(x))
103
+ def forward(self, x):
104
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
105
+ ```
106
+
107
+ ### Configuration
108
+
109
+ | Parameter | Value |
110
+ |-----------|-------|
111
+ | `vocab_size` | 50,304 (GPT-2 compatible) |
112
+ | `hidden_size` | 576 |
113
+ | `intermediate_size` | 1,536 |
114
+ | `num_hidden_layers` | 30 |
115
+ | `num_attention_heads` | 9 |
116
+ | `num_key_value_heads` | 3 |
117
+ | `max_position_embeddings` | 2,048 |
118
+ | `rms_norm_eps` | 1e-5 |
119
+ | `rope_theta` | 10,000 |
120
+ | `hidden_act` | SiLU |
121
+ | `tie_word_embeddings` | True |
122
+
123
+ ---
124
+
125
+ ## Parameter Calculation
126
+
127
+ ### Component-by-Component Breakdown
128
+
129
+ #### 1. Token Embeddings
130
+ ```
131
+ vocab_size × hidden_size = 50,304 × 576 = 28,975,104 parameters
132
+ ```
133
+
134
+ #### 2. Attention (per layer)
135
+ ```
136
+ Q projection: hidden_size × (num_heads × head_dim) = 576 × 576 = 331,776
137
+ K projection: hidden_size × (num_kv_heads × head_dim) = 576 × 192 = 110,592
138
+ V projection: hidden_size × (num_kv_heads × head_dim) = 576 × 192 = 110,592
139
+ O projection: (num_heads × head_dim) × hidden_size = 576 × 576 = 331,776
140
+ ──────────────────────────────────────────────────────────��──────────────────
141
+ Total per layer: 884,736 parameters
142
+ Total for 30 layers: 884,736 × 30 = 26,542,080 parameters
143
+ ```
144
+
145
+ #### 3. MLP (per layer)
146
+ ```
147
+ gate_proj: hidden_size × intermediate_size = 576 × 1,536 = 884,736
148
+ up_proj: hidden_size × intermediate_size = 576 × 1,536 = 884,736
149
+ down_proj: intermediate_size × hidden_size = 1,536 × 576 = 884,736
150
+ ─────────────────────────────────────────────────────────────────────
151
+ Total per layer: 2,654,208 parameters
152
+ Total for 30 layers: 2,654,208 × 30 = 79,626,240 parameters
153
+ ```
154
+
155
+ #### 4. RMSNorm (per layer + final)
156
+ ```
157
+ input_layernorm: hidden_size = 576
158
+ post_attention_layernorm: hidden_size = 576
159
+ ─────────────────────────────────────────────
160
+ Total per layer: 1,152 parameters
161
+ Total for 30 layers: 1,152 × 30 = 34,560 parameters
162
+ Final norm: 576 parameters
163
+ Total normalization: 35,136 parameters
164
+ ```
165
+
166
+ #### 5. LM Head
167
+ ```
168
+ Tied with token embeddings: 0 additional parameters
169
+ ```
170
+
171
+ ### Total Parameter Count
172
+
173
+ | Component | Parameters | Percentage |
174
+ |-----------|------------|------------|
175
+ | Embedding | 28,975,104 | 21.4% |
176
+ | Attention | 26,542,080 | 19.6% |
177
+ | MLP | 79,626,240 | 58.9% |
178
+ | Normalization | 35,136 | 0.03% |
179
+ | **Total** | **135,178,560** | **100%** |
180
+
181
+ > **Note**: Our implementation uses vocab_size=50,304 instead of the original 49,152, adding ~663K parameters to the embedding layer. Original SmolLM2-135M has ~134.5M parameters.
182
+
183
+ ---
184
+
185
+ ## Training Data
186
+
187
+ ### Dataset Description
188
+ The model was trained on dialogue scripts from the television series **"Suits"**, a legal drama that follows characters working at a fictional New York law firm.
189
+
190
+ ### Characteristics
191
+ - **Content Type**: Television dialogue scripts
192
+ - **Genre**: Legal drama
193
+ - **Language Style**: Professional legal terminology mixed with casual dialogue
194
+ - **Text Format**: Character names followed by dialogue
195
+
196
+ ### Tokenization
197
+ - **Tokenizer**: GPT-2 BPE tokenizer (tiktoken)
198
+ - **Vocabulary Size**: 50,257 tokens (padded to 50,304 for GPU efficiency)
199
+
200
+ ---
201
+
202
+ ## Training Details
203
+
204
+ ### Hyperparameters
205
+
206
+ | Parameter | Value |
207
+ |-----------|-------|
208
+ | Total Steps | 5,000 + 50 (resumed) |
209
+ | Batch Size | 16 (effective) |
210
+ | Micro Batch Size | 4 |
211
+ | Gradient Accumulation | 4 steps |
212
+ | Sequence Length | 1,024 tokens |
213
+ | Learning Rate | 6e-4 (max) |
214
+ | LR Schedule | Cosine with warmup |
215
+ | Warmup Steps | 500 |
216
+ | Weight Decay | 0.1 |
217
+ | Gradient Clipping | 1.0 |
218
+ | Optimizer | AdamW (fused) |
219
+ | Precision | bfloat16 |
220
+
221
+ ### Checkpointing
222
+ - Checkpoints saved every **500 steps**
223
+ - Text generation with fixed prompts at each checkpoint
224
+ - Final checkpoint at step 5,050 (after resume demonstration)
225
+
226
+ ### Fixed Evaluation Prompts
227
+ ```
228
+ 1. "Once upon a time"
229
+ 2. "The meaning of life is"
230
+ 3. "In a galaxy far away"
231
+ ```
232
+
233
  ---
234
+
235
+ ## Speedups Used
236
+
237
+ | Speedup | Implementation |
238
+ |---------|----------------|
239
+ | **Flash Attention** | `F.scaled_dot_product_attention(is_causal=True)` |
240
+ | **Mixed Precision** | `torch.autocast(dtype=torch.bfloat16)` |
241
+ | **torch.compile** | JIT compilation for CUDA |
242
+ | **TF32 Precision** | `torch.set_float32_matmul_precision('high')` |
243
+ | **Gradient Accumulation** | 4 micro-batches per step |
244
+ | **Fused AdamW** | `fused=True` for CUDA |
245
+ | **Power-of-2 Vocab** | 50,304 for efficient GPU memory access |
246
+
247
  ---
248
 
249
+ ## Results
250
+
251
+ ### Training Progress
252
+ - **Initial Loss**: ~10.8 (random initialization)
253
+ - **Final Loss**: Significantly reduced after 5,050 steps
254
+ - **Checkpoint Resume**: Successfully demonstrated loading from step 5,000 and continuing training
255
+
256
+ ### Checkpoints Saved
257
+ ```
258
+ checkpoints/
259
+ ├── checkpoint_step_500.pt
260
+ ├── checkpoint_step_1000.pt
261
+ ├── checkpoint_step_1500.pt
262
+ ├── checkpoint_step_2000.pt
263
+ ├── checkpoint_step_2500.pt
264
+ ├── checkpoint_step_3000.pt
265
+ ├── checkpoint_step_3500.pt
266
+ ├── checkpoint_step_4000.pt
267
+ ├── checkpoint_step_4500.pt
268
+ ├── checkpoint_step_5000.pt
269
+ └── checkpoint_step_5050.pt
270
+ ```
271
+
272
+ ---
273
+
274
+ ## Usage
275
+
276
+ ### Requirements
277
+ ```bash
278
+ pip install torch tiktoken matplotlib
279
+ ```
280
+
281
+ ### Training
282
+ 1. Upload `input.txt` (training data) to the working directory
283
+ 2. Run the notebook cells sequentially
284
+ 3. Checkpoints will be saved to `checkpoints/` directory
285
+
286
+ ### Loading a Checkpoint
287
+ ```python
288
+ checkpoint = torch.load('checkpoints/checkpoint_step_5000.pt')
289
+ model.load_state_dict(checkpoint['model_state_dict'])
290
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
291
+ ```
292
+
293
+ ### Text Generation
294
+ ```python
295
+ generated = generate_text(
296
+ model,
297
+ prompt="Once upon a time",
298
+ max_new_tokens=50,
299
+ temperature=0.8,
300
+ top_k=50
301
+ )
302
+ print(generated)
303
+ ```
304
+
305
+ ---
306
+
307
+ ## References
308
+
309
+ - [SmolLM2 - HuggingFace](https://huggingface.co/HuggingFaceTB/SmolLM2-135M)
310
+ - [RoPE: Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
311
+ - [GQA: Grouped Query Attention](https://arxiv.org/abs/2305.13245)
312
+ - [SwiGLU Activation](https://arxiv.org/abs/2002.05202)
313
+ - [RMSNorm](https://arxiv.org/abs/1910.07467)
314
+
315
+ ---
316
+
317
+ ## License
318
+
319
+ This project is for educational purposes.
320
+
321
+ ---
322
+
323
+ ## Acknowledgments
324
+
325
+ - HuggingFace for the SmolLM2 model and training recipes
326
+ - Andrej Karpathy for nanoGPT inspiration
app.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolLM2-135M Text Generation - Gradio App
3
+ Trained from scratch on Suits TV series scripts
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn import functional as F
9
+ from dataclasses import dataclass
10
+ from typing import Optional, Tuple
11
+ import tiktoken
12
+ import gradio as gr
13
+
14
+ # ============================================================================
15
+ # Model Architecture
16
+ # ============================================================================
17
+
18
+ @dataclass
19
+ class SmolLM2Config:
20
+ """SmolLM2-135M Configuration"""
21
+ vocab_size: int = 50304
22
+ hidden_size: int = 576
23
+ intermediate_size: int = 1536
24
+ num_hidden_layers: int = 30
25
+ num_attention_heads: int = 9
26
+ num_key_value_heads: int = 3
27
+ max_position_embeddings: int = 2048
28
+ rms_norm_eps: float = 1e-5
29
+ rope_theta: float = 10000.0
30
+ hidden_act: str = "silu"
31
+ initializer_range: float = 0.041666666666666664
32
+ tie_word_embeddings: bool = True
33
+ bos_token_id: int = 0
34
+ eos_token_id: int = 0
35
+
36
+ @property
37
+ def head_dim(self) -> int:
38
+ return self.hidden_size // self.num_attention_heads
39
+
40
+
41
+ class RMSNorm(nn.Module):
42
+ """Root Mean Square Layer Normalization"""
43
+ def __init__(self, hidden_size: int, eps: float = 1e-5):
44
+ super().__init__()
45
+ self.weight = nn.Parameter(torch.ones(hidden_size))
46
+ self.eps = eps
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
50
+ return x * rms * self.weight
51
+
52
+
53
+ class RotaryEmbedding(nn.Module):
54
+ """Rotary Position Embedding (RoPE)"""
55
+ def __init__(self, dim: int, max_position_embeddings: int = 2048, theta: float = 10000.0):
56
+ super().__init__()
57
+ self.dim = dim
58
+ self.max_position_embeddings = max_position_embeddings
59
+ self.theta = theta
60
+
61
+ inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float32) / self.dim))
62
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
63
+ self._set_cos_sin_cache(max_position_embeddings)
64
+
65
+ def _set_cos_sin_cache(self, seq_len: int):
66
+ self.max_seq_len_cached = seq_len
67
+ t = torch.arange(seq_len, dtype=torch.float32)
68
+ freqs = torch.outer(t, self.inv_freq)
69
+ emb = torch.cat((freqs, freqs), dim=-1)
70
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
71
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
72
+
73
+ def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
74
+ if seq_len > self.max_seq_len_cached:
75
+ self._set_cos_sin_cache(seq_len)
76
+ return (
77
+ self.cos_cached[:seq_len].to(x.dtype),
78
+ self.sin_cached[:seq_len].to(x.dtype)
79
+ )
80
+
81
+
82
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
83
+ x1 = x[..., : x.shape[-1] // 2]
84
+ x2 = x[..., x.shape[-1] // 2 :]
85
+ return torch.cat((-x2, x1), dim=-1)
86
+
87
+
88
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ cos = cos.unsqueeze(0).unsqueeze(0)
90
+ sin = sin.unsqueeze(0).unsqueeze(0)
91
+ q_embed = (q * cos) + (rotate_half(q) * sin)
92
+ k_embed = (k * cos) + (rotate_half(k) * sin)
93
+ return q_embed, k_embed
94
+
95
+
96
+ class GroupedQueryAttention(nn.Module):
97
+ """Grouped Query Attention (GQA)"""
98
+ def __init__(self, config: SmolLM2Config):
99
+ super().__init__()
100
+ self.config = config
101
+ self.hidden_size = config.hidden_size
102
+ self.num_heads = config.num_attention_heads
103
+ self.num_kv_heads = config.num_key_value_heads
104
+ self.head_dim = config.head_dim
105
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
106
+
107
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
108
+ self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
109
+ self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
110
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
111
+
112
+ self.rotary_emb = RotaryEmbedding(
113
+ self.head_dim,
114
+ max_position_embeddings=config.max_position_embeddings,
115
+ theta=config.rope_theta
116
+ )
117
+
118
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
119
+ B, T, C = x.size()
120
+
121
+ q = self.q_proj(x)
122
+ k = self.k_proj(x)
123
+ v = self.v_proj(x)
124
+
125
+ q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
126
+ k = k.view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
127
+ v = v.view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
128
+
129
+ cos, sin = self.rotary_emb(x, T)
130
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
131
+
132
+ k = k.repeat_interleave(self.num_kv_groups, dim=1)
133
+ v = v.repeat_interleave(self.num_kv_groups, dim=1)
134
+
135
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
136
+ y = y.transpose(1, 2).contiguous().view(B, T, self.hidden_size)
137
+ y = self.o_proj(y)
138
+ return y
139
+
140
+
141
+ class SwiGLUMLP(nn.Module):
142
+ """SwiGLU Feed-Forward Network"""
143
+ def __init__(self, config: SmolLM2Config):
144
+ super().__init__()
145
+ self.hidden_size = config.hidden_size
146
+ self.intermediate_size = config.intermediate_size
147
+
148
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
149
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
150
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
151
+ self.act_fn = nn.SiLU()
152
+
153
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
154
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
155
+
156
+
157
+ class SmolLM2Block(nn.Module):
158
+ """SmolLM2 Transformer Block"""
159
+ def __init__(self, config: SmolLM2Config):
160
+ super().__init__()
161
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
162
+ self.self_attn = GroupedQueryAttention(config)
163
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
164
+ self.mlp = SwiGLUMLP(config)
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ x = x + self.self_attn(self.input_layernorm(x))
168
+ x = x + self.mlp(self.post_attention_layernorm(x))
169
+ return x
170
+
171
+
172
+ class SmolLM2(nn.Module):
173
+ """SmolLM2-135M Model"""
174
+ def __init__(self, config: SmolLM2Config):
175
+ super().__init__()
176
+ self.config = config
177
+
178
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
179
+ self.layers = nn.ModuleList([SmolLM2Block(config) for _ in range(config.num_hidden_layers)])
180
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
181
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
182
+
183
+ if config.tie_word_embeddings:
184
+ self.lm_head.weight = self.embed_tokens.weight
185
+
186
+ def forward(self, input_ids: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
187
+ B, T = input_ids.size()
188
+ x = self.embed_tokens(input_ids)
189
+
190
+ for layer in self.layers:
191
+ x = layer(x)
192
+
193
+ x = self.norm(x)
194
+ logits = self.lm_head(x)
195
+
196
+ loss = None
197
+ if targets is not None:
198
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
199
+
200
+ return logits, loss
201
+
202
+
203
+ # ============================================================================
204
+ # Load Model
205
+ # ============================================================================
206
+
207
+ device = "cuda" if torch.cuda.is_available() else "cpu"
208
+ print(f"Using device: {device}")
209
+
210
+ # Load model
211
+ config = SmolLM2Config()
212
+ model = SmolLM2(config)
213
+
214
+ # Load trained weights
215
+ checkpoint = torch.load("smollm2_135m_final.pt", map_location=device)
216
+ model.load_state_dict(checkpoint['model_state_dict'])
217
+ model.to(device)
218
+ model.eval()
219
+
220
+ print(f"Model loaded successfully! Parameters: {sum(p.numel() for p in model.parameters()):,}")
221
+
222
+ # Load tokenizer
223
+ tokenizer = tiktoken.get_encoding('gpt2')
224
+
225
+
226
+ # ============================================================================
227
+ # Generation Function
228
+ # ============================================================================
229
+
230
+ def generate_text(
231
+ prompt: str,
232
+ max_new_tokens: int = 100,
233
+ temperature: float = 0.8,
234
+ top_k: int = 50,
235
+ top_p: float = 0.9,
236
+ ) -> str:
237
+ """Generate text from a prompt"""
238
+ if not prompt.strip():
239
+ return "Please enter a prompt."
240
+
241
+ model.eval()
242
+
243
+ # Encode prompt
244
+ tokens = tokenizer.encode(prompt)
245
+ tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
246
+
247
+ # Generate tokens
248
+ with torch.no_grad():
249
+ for _ in range(max_new_tokens):
250
+ # Crop to max position embeddings
251
+ idx_cond = tokens[:, -config.max_position_embeddings:]
252
+
253
+ # Get predictions
254
+ logits, _ = model(idx_cond)
255
+ logits = logits[:, -1, :] / temperature
256
+
257
+ # Top-k filtering
258
+ if top_k is not None and top_k > 0:
259
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
260
+ logits[logits < v[:, [-1]]] = float('-inf')
261
+
262
+ # Top-p (nucleus) filtering
263
+ if top_p is not None and top_p < 1.0:
264
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
265
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
266
+ sorted_indices_to_remove = cumulative_probs > top_p
267
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
268
+ sorted_indices_to_remove[..., 0] = 0
269
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
270
+ logits[indices_to_remove] = float('-inf')
271
+
272
+ # Sample
273
+ probs = F.softmax(logits, dim=-1)
274
+ next_token = torch.multinomial(probs, num_samples=1)
275
+ tokens = torch.cat([tokens, next_token], dim=1)
276
+
277
+ # Decode
278
+ generated = tokenizer.decode(tokens[0].tolist())
279
+ return generated
280
+
281
+
282
+ # ============================================================================
283
+ # Gradio Interface
284
+ # ============================================================================
285
+
286
+ title = "SmolLM2-135M Text Generator"
287
+ description = """
288
+ ## About This Model
289
+
290
+ This is a **SmolLM2-135M** model trained from scratch on dialogue scripts from the TV series "Suits".
291
+
292
+ ### Model Architecture
293
+ - **Type**: Llama-based decoder-only transformer
294
+ - **Parameters**: ~135M
295
+ - **Features**: RMSNorm, RoPE, Grouped Query Attention (GQA), SwiGLU MLP
296
+
297
+ ### Training Details
298
+ - Trained for 5,050 steps
299
+ - Sequence length: 1024 tokens
300
+ - Uses GPT-2 tokenizer
301
+
302
+ Enter a prompt below and adjust the generation parameters to see what the model generates!
303
+ """
304
+
305
+ examples = [
306
+ ["Harvey walked into the office and said,"],
307
+ ["The legal case was complicated because"],
308
+ ["Once upon a time"],
309
+ ["In a world where lawyers"],
310
+ ["Mike looked at the contract and noticed"],
311
+ ]
312
+
313
+ # Create interface
314
+ with gr.Blocks(title=title, theme=gr.themes.Soft()) as demo:
315
+ gr.Markdown(f"# {title}")
316
+ gr.Markdown(description)
317
+
318
+ with gr.Row():
319
+ with gr.Column(scale=2):
320
+ prompt_input = gr.Textbox(
321
+ label="Prompt",
322
+ placeholder="Enter your prompt here...",
323
+ lines=3,
324
+ )
325
+
326
+ with gr.Row():
327
+ max_tokens_slider = gr.Slider(
328
+ minimum=10,
329
+ maximum=500,
330
+ value=100,
331
+ step=10,
332
+ label="Max New Tokens",
333
+ )
334
+ temperature_slider = gr.Slider(
335
+ minimum=0.1,
336
+ maximum=2.0,
337
+ value=0.8,
338
+ step=0.1,
339
+ label="Temperature",
340
+ )
341
+
342
+ with gr.Row():
343
+ top_k_slider = gr.Slider(
344
+ minimum=1,
345
+ maximum=100,
346
+ value=50,
347
+ step=1,
348
+ label="Top-K",
349
+ )
350
+ top_p_slider = gr.Slider(
351
+ minimum=0.1,
352
+ maximum=1.0,
353
+ value=0.9,
354
+ step=0.05,
355
+ label="Top-P (Nucleus)",
356
+ )
357
+
358
+ generate_btn = gr.Button("Generate", variant="primary")
359
+
360
+ with gr.Column(scale=2):
361
+ output_text = gr.Textbox(
362
+ label="Generated Text",
363
+ lines=15,
364
+ show_copy_button=True,
365
+ )
366
+
367
+ gr.Markdown("### Example Prompts")
368
+ gr.Examples(
369
+ examples=examples,
370
+ inputs=prompt_input,
371
+ )
372
+
373
+ # Connect the generate button
374
+ generate_btn.click(
375
+ fn=generate_text,
376
+ inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider],
377
+ outputs=output_text,
378
+ )
379
+
380
+ # Also generate on Enter key
381
+ prompt_input.submit(
382
+ fn=generate_text,
383
+ inputs=[prompt_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider],
384
+ outputs=output_text,
385
+ )
386
+
387
+ gr.Markdown("""
388
+ ---
389
+ ### Parameter Guide
390
+ - **Temperature**: Higher = more creative/random, Lower = more focused/deterministic
391
+ - **Top-K**: Only sample from the top K most likely tokens
392
+ - **Top-P**: Only sample from tokens whose cumulative probability is below P
393
+ - **Max New Tokens**: Maximum number of tokens to generate
394
+
395
+ ---
396
+ *Model trained from scratch using PyTorch. Architecture based on SmolLM2-135M (Llama-style).*
397
+ """)
398
+
399
+
400
+ if __name__ == "__main__":
401
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=2.0.0
2
+ tiktoken>=0.5.0
3
+ gradio>=4.0.0
smollm2_135m_final.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e2fbf5fa84c5eaf792ef34b906d27ff195684c490f2b8ad572b1f03b3d3b3ee
3
+ size 540826769