hidude562 commited on
Commit
7a2fc07
·
verified ·
1 Parent(s): bc6d73d

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +85 -0
  2. best.pt +3 -0
  3. tokenizer.model +3 -0
  4. tokenizer.vocab +192 -0
  5. train.py +553 -0
README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ tags:
6
+ - 1-bit
7
+ - bitnet
8
+ - tiny
9
+ - language-model
10
+ - tinystories
11
+ datasets:
12
+ - roneneldan/TinyStories
13
+ pipeline_tag: text-generation
14
+ ---
15
+
16
+ # tiny-tiny-stories
17
+
18
+ A **1-bit (ternary {-1, 0, +1}) transformer language model** trained on [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories).
19
+
20
+ ## Specs
21
+
22
+ | | |
23
+ |---|---|
24
+ | **Parameters** | 998,784 (< 1M) |
25
+ | **Weight precision** | 1.58-bit ternary (BitNet b1.58) |
26
+ | **Tokenizer** | SentencePiece unigram, 192 vocab |
27
+ | **Context length** | 512 tokens |
28
+ | **Best val loss** | 1.2087 (perplexity 3.35) |
29
+ | **Training** | 100K steps on 2.1M TinyStories |
30
+ | **Checkpoint size** | 3.9 MB (FP32 latent), ~350 KB quantized |
31
+
32
+ ## Architecture
33
+
34
+ - **d_model**: 128
35
+ - **Heads**: 4 (head_dim=32)
36
+ - **Layers**: 5
37
+ - **FFN**: SwiGLU (d_ff=336)
38
+ - **Position encoding**: RoPE (no learned positional embeddings)
39
+ - **Normalization**: RMSNorm
40
+ - **Embeddings**: Tied input/output, full precision
41
+ - **All linear layers**: BitLinear with ternary quantization + straight-through estimator
42
+
43
+ ## How it works
44
+
45
+ All Q/K/V/O attention projections and SwiGLU FFN matrices use **BitLinear**: weights are quantized to {-1, 0, +1} during the forward pass via `round(W / mean(|W|))`, with gradients flowing through a straight-through estimator to full-precision latent weights during training.
46
+
47
+ ## Usage
48
+
49
+ ```python
50
+ import torch
51
+ import sentencepiece as spm
52
+
53
+ # Load tokenizer and model
54
+ sp = spm.SentencePieceProcessor(model_file='tokenizer.model')
55
+
56
+ # Load model (see train.py for BitLM class definition)
57
+ from train import BitLM, Config
58
+ cfg = Config()
59
+ cfg.vocab_size = 192
60
+ model = BitLM(cfg)
61
+ ckpt = torch.load('best.pt', map_location='cpu', weights_only=True)
62
+ state = ckpt['model']
63
+ if any(k.startswith('_orig_mod.') for k in state):
64
+ state = {k.replace('_orig_mod.', ''): v for k, v in state.items()}
65
+ model.load_state_dict(state)
66
+ model.eval()
67
+
68
+ # Generate
69
+ ids = [sp.bos_id()] + sp.encode("Once upon a time")
70
+ idx = torch.tensor([ids])
71
+ out = model.generate(idx, max_new=200, temp=0.8, top_k=40, eos_id=sp.eos_id())
72
+ print(sp.decode(out[0].tolist()))
73
+ ```
74
+
75
+ ## Sample output
76
+
77
+ > Once upon a time, there was a squirrel. He was very curious and loved to play in the park. One day, he noticed a big tree in the sky. He was already laughing, but he was stronger under his houses. The squirrel was glue of all the trees, exploring the walls...
78
+
79
+ ## Training
80
+
81
+ Trained on 2x RTX 2080 Ti using mixed-precision (FP16) with AdamW optimizer, cosine LR schedule (1.5e-3 peak, 1000 step warmup), and gradient accumulation (effective batch size 384).
82
+
83
+ ```bash
84
+ python train.py --exp-dir ./output --device cuda:0 --compile
85
+ ```
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:554a19e1c1aeebcc99d7cbc97dd06c1bd814da2791255aaeb845a809179bcac0
3
+ size 4076196
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7ceada97ec7b97394c18e546a326ecf0ad019de24d10d5cc308629c95614d36
3
+ size 2435
tokenizer.vocab ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <pad> 0
2
+ <s> 0
3
+ </s> 0
4
+ <unk> 0
5
+ ▁ -2.3513
6
+ e -3.12018
7
+ t -3.18323
8
+ o -3.26053
9
+ a -3.28909
10
+ . -3.35098
11
+ s -3.36613
12
+ i -3.55784
13
+ r -3.69739
14
+ ▁the -3.76187
15
+ d -3.76731
16
+ l -3.84743
17
+ m -3.8612
18
+ u -3.87521
19
+ y -3.89653
20
+ n -3.91297
21
+ ▁and -4.06639
22
+ , -4.09488
23
+ ▁to -4.18126
24
+ ▁a -4.23056
25
+ ed -4.24938
26
+ h -4.33062
27
+ ▁s -4.41898
28
+ p -4.46433
29
+ ▁w -4.6028
30
+ f -4.60994
31
+ re -4.6245
32
+ an -4.63177
33
+ g -4.64491
34
+ ▁was -4.65317
35
+ k -4.65523
36
+ ▁c -4.81843
37
+ c -4.88971
38
+ ar -4.90083
39
+ ing -4.94361
40
+ ▁b -5.01726
41
+ on -5.08089
42
+ ▁p -5.11917
43
+ ll -5.13761
44
+ er -5.14355
45
+ ve -5.24712
46
+ ▁it -5.25316
47
+ ▁He -5.33791
48
+ le -5.38252
49
+ ▁he -5.38734
50
+ ▁her -5.38974
51
+ en -5.39365
52
+ ▁She -5.41582
53
+ ▁" -5.43498
54
+ w -5.46043
55
+ in -5.47842
56
+ b -5.49025
57
+ ' -5.49061
58
+ ▁The -5.49273
59
+ ▁They -5.52583
60
+ ▁so -5.56655
61
+ ▁be -5.58785
62
+ ow -5.5919
63
+ ▁in -5.61707
64
+ th -5.67593
65
+ ▁said -5.69551
66
+ ch -5.71423
67
+ ▁she -5.74618
68
+ ▁with -5.79963
69
+ ▁of -5.84465
70
+ ▁Lily -5.848
71
+ ▁his -5.86868
72
+ ▁you -5.90209
73
+ ck -5.9039
74
+ ▁day -5.9183
75
+ ▁that -5.94745
76
+ " -5.99307
77
+ ▁go -5.99341
78
+ ▁for -6.0249
79
+ ▁had -6.04481
80
+ I -6.04967
81
+ S -6.05925
82
+ ▁play -6.05977
83
+ ▁do -6.07343
84
+ ▁mom -6.11001
85
+ sh -6.21814
86
+ x -6.23546
87
+ ! -6.29597
88
+ ▁very -6.29711
89
+ ▁time -6.30942
90
+ ▁little -6.32547
91
+ ▁di -6.33418
92
+ M -6.34077
93
+ ▁happy -6.37051
94
+ ▁big -6.38665
95
+ T -6.39556
96
+ ▁It -6.40006
97
+ v -6.40084
98
+ ▁but -6.41073
99
+ W -6.43543
100
+ ▁saw -6.43608
101
+ ▁friend -6.44581
102
+ ▁One -6.48876
103
+ ▁Once -6.53329
104
+ ▁were -6.55827
105
+ J -6.56237
106
+ ▁look -6.56312
107
+ ▁like -6.56992
108
+ !" -6.5937
109
+ A -6.60251
110
+ ▁him -6.6133
111
+ ▁upon -6.66722
112
+ ▁girl -6.66883
113
+ ful -6.7123
114
+ ▁gr -6.73809
115
+ ▁could -6.74623
116
+ ▁Tom -6.77214
117
+ L -6.77457
118
+ Y -6.77569
119
+ ▁Ben -6.81974
120
+ z -6.84743
121
+ ▁have -6.84834
122
+ ▁went -6.85576
123
+ ight -6.91741
124
+ ▁But -6.93221
125
+ ▁help -6.94166
126
+ B -7.20363
127
+ D -7.26142
128
+ â -7.27982
129
+ € -7.27982
130
+ H -7.28585
131
+ F -7.37123
132
+ C -7.73309
133
+ N -7.81128
134
+ O -7.83985
135
+ P -8.19519
136
+ E -8.21406
137
+ - -8.22522
138
+ G -8.2278
139
+ œ -8.32156
140
+ ? -8.45174
141
+ R -8.55357
142
+ K -8.63802
143
+ : -8.89354
144
+ V -10.1554
145
+ Z -10.4014
146
+ U -10.4874
147
+ ; -10.9446
148
+ Q -11.5159
149
+ 1 -12.2185
150
+ 0 -12.4038
151
+ 2 -12.8522
152
+ 5 -13.0644
153
+ X -13.1254
154
+ 4 -13.5024
155
+ Ã -13.8668
156
+ 9 -14.1729
157
+ 7 -14.2667
158
+ 6 -14.4457
159
+ Â -14.5274
160
+ 8 -14.5709
161
+ © -14.5709
162
+ / -14.8222
163
+ ) -14.881
164
+ ` -14.881
165
+ * -15.0101
166
+ ( -15.1585
167
+ ‰ -15.5438
168
+ ± -15.5438
169
+   -15.6688
170
+ Š -15.8117
171
+ « -16.1784
172
+ » -16.1784
173
+ ¡ -16.4284
174
+ & -16.4284
175
+ ­ -16.7617
176
+ ´ -16.7617
177
+ + -17.2617
178
+ ] -17.2617
179
+ ‹ -18.2607
180
+ # -18.2608
181
+ $ -18.2609
182
+ ¦ -18.261
183
+ ˜ -18.2611
184
+ ” -18.2612
185
+ “ -18.2613
186
+ 3 -18.2614
187
+ ™ -18.2615
188
+ q -18.2616
189
+ ~ -18.2617
190
+ _ -18.2617
191
+ j -18.2617
192
+ ³ -18.2617
train.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 1-Bit Transformer LM on TinyStories
4
+ < 1M params | < 200 vocab | BitNet b1.58 ternary weights {-1, 0, +1}
5
+
6
+ Architecture: RoPE, RMSNorm, SwiGLU, tied embeddings
7
+ Tokenizer: SentencePiece unigram (192 vocab)
8
+ """
9
+
10
+ import os, json, math, time, random, argparse
11
+ from pathlib import Path
12
+ from dataclasses import dataclass, asdict
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import Dataset, DataLoader
18
+ import sentencepiece as spm
19
+
20
+
21
+ # ================================================================
22
+ # Config
23
+ # ================================================================
24
+ @dataclass
25
+ class Config:
26
+ # Model
27
+ vocab_size: int = 192 # < 200
28
+ d_model: int = 128
29
+ n_heads: int = 4 # head_dim = 32
30
+ n_layers: int = 5
31
+ d_ff: int = 336 # SwiGLU intermediate
32
+ max_seq_len: int = 512
33
+
34
+ # Training
35
+ batch_size: int = 96
36
+ grad_accum: int = 4 # effective batch = 384
37
+ lr: float = 1.5e-3
38
+ min_lr: float = 1e-5
39
+ warmup_steps: int = 800
40
+ max_steps: int = 100_000
41
+ weight_decay: float = 0.1
42
+ grad_clip: float = 1.0
43
+
44
+ # Logging / eval
45
+ eval_interval: int = 1000
46
+ eval_steps: int = 50
47
+ log_interval: int = 100
48
+ gen_interval: int = 5000
49
+ save_interval: int = 5000
50
+
51
+ # Misc
52
+ seed: int = 42
53
+ device: str = "cuda:0"
54
+ compile: bool = False
55
+ num_workers: int = 0
56
+
57
+
58
+ # ================================================================
59
+ # Model
60
+ # ================================================================
61
+ class RMSNorm(nn.Module):
62
+ def __init__(self, dim, eps=1e-6):
63
+ super().__init__()
64
+ self.w = nn.Parameter(torch.ones(dim))
65
+ self.eps = eps
66
+
67
+ def forward(self, x):
68
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.w
69
+
70
+
71
+ class BitLinear(nn.Module):
72
+ """Linear layer with ternary {-1, 0, +1} weight quantization (BitNet b1.58).
73
+ Full-precision latent weights are kept for optimizer updates.
74
+ Forward uses quantized weights via straight-through estimator."""
75
+
76
+ def __init__(self, in_f, out_f):
77
+ super().__init__()
78
+ self.weight = nn.Parameter(torch.empty(out_f, in_f))
79
+ nn.init.normal_(self.weight, std=0.02)
80
+
81
+ def forward(self, x):
82
+ alpha = self.weight.abs().mean().clamp(min=1e-5)
83
+ wq = torch.clamp(torch.round(self.weight / alpha), -1, 1) * alpha
84
+ w = self.weight + (wq - self.weight).detach() # STE
85
+ return F.linear(x, w)
86
+
87
+
88
+ def _rope_freqs(dim, max_len, base=10000.0):
89
+ f = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
90
+ t = torch.arange(max_len, dtype=torch.float32)
91
+ ang = torch.outer(t, f)
92
+ return torch.cos(ang), torch.sin(ang)
93
+
94
+
95
+ def _apply_rope(x, c, s):
96
+ d = x.shape[-1] // 2
97
+ x1, x2 = x[..., :d], x[..., d:]
98
+ return torch.cat([x1 * c - x2 * s, x2 * c + x1 * s], dim=-1)
99
+
100
+
101
+ class Block(nn.Module):
102
+ def __init__(self, d, h, ff):
103
+ super().__init__()
104
+ self.n1 = RMSNorm(d)
105
+ self.n2 = RMSNorm(d)
106
+ # Attention
107
+ self.q = BitLinear(d, d)
108
+ self.k = BitLinear(d, d)
109
+ self.v = BitLinear(d, d)
110
+ self.o = BitLinear(d, d)
111
+ # SwiGLU FFN
112
+ self.gate = BitLinear(d, ff)
113
+ self.up = BitLinear(d, ff)
114
+ self.down = BitLinear(ff, d)
115
+ self.nh = h
116
+ self.hd = d // h
117
+
118
+ def forward(self, x, cos, sin):
119
+ B, T, C = x.shape
120
+ h = self.n1(x)
121
+ q = self.q(h).view(B, T, self.nh, self.hd).transpose(1, 2)
122
+ k = self.k(h).view(B, T, self.nh, self.hd).transpose(1, 2)
123
+ v = self.v(h).view(B, T, self.nh, self.hd).transpose(1, 2)
124
+ q = _apply_rope(q, cos, sin)
125
+ k = _apply_rope(k, cos, sin)
126
+ a = F.scaled_dot_product_attention(q, k, v, is_causal=True)
127
+ x = x + self.o(a.transpose(1, 2).contiguous().view(B, T, C))
128
+ h = self.n2(x)
129
+ x = x + self.down(F.silu(self.gate(h)) * self.up(h))
130
+ return x
131
+
132
+
133
+ class BitLM(nn.Module):
134
+ def __init__(self, cfg: Config):
135
+ super().__init__()
136
+ self.cfg = cfg
137
+ self.emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
138
+ self.layers = nn.ModuleList(
139
+ [Block(cfg.d_model, cfg.n_heads, cfg.d_ff) for _ in range(cfg.n_layers)]
140
+ )
141
+ self.norm = RMSNorm(cfg.d_model)
142
+ self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
143
+ self.head.weight = self.emb.weight # weight tying
144
+
145
+ hd = cfg.d_model // cfg.n_heads
146
+ c, s = _rope_freqs(hd, cfg.max_seq_len)
147
+ self.register_buffer("rc", c)
148
+ self.register_buffer("rs", s)
149
+ nn.init.normal_(self.emb.weight, std=0.02)
150
+
151
+ def forward(self, idx, targets=None):
152
+ B, T = idx.shape
153
+ x = self.emb(idx)
154
+ c = self.rc[:T].unsqueeze(0).unsqueeze(0) # (1,1,T,hd/2)
155
+ s = self.rs[:T].unsqueeze(0).unsqueeze(0)
156
+ for layer in self.layers:
157
+ x = layer(x, c, s)
158
+ logits = self.head(self.norm(x))
159
+ loss = None
160
+ if targets is not None:
161
+ loss = F.cross_entropy(
162
+ logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0
163
+ )
164
+ return logits, loss
165
+
166
+ def param_count(self):
167
+ seen = set()
168
+ total = 0
169
+ for p in self.parameters():
170
+ pid = id(p)
171
+ if pid not in seen:
172
+ seen.add(pid)
173
+ total += p.numel()
174
+ return total
175
+
176
+ @torch.no_grad()
177
+ def generate(self, idx, max_new=200, temp=0.8, top_k=40, eos_id=2):
178
+ for _ in range(max_new):
179
+ ic = idx[:, -self.cfg.max_seq_len:]
180
+ logits, _ = self(ic)
181
+ logits = logits[:, -1] / temp
182
+ if top_k > 0:
183
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
184
+ logits[logits < v[:, [-1]]] = float("-inf")
185
+ probs = F.softmax(logits, dim=-1)
186
+ nxt = torch.multinomial(probs, 1)
187
+ idx = torch.cat([idx, nxt], dim=1)
188
+ if nxt.item() == eos_id:
189
+ break
190
+ return idx
191
+
192
+
193
+ # ================================================================
194
+ # Dataset
195
+ # ================================================================
196
+ class ChunkedDataset(Dataset):
197
+ """Flat token tensor split into fixed-length chunks."""
198
+ def __init__(self, tokens, seq_len):
199
+ self.tokens = tokens
200
+ self.seq_len = seq_len
201
+ self.n = (len(tokens) - 1) // seq_len
202
+
203
+ def __len__(self):
204
+ return self.n
205
+
206
+ def __getitem__(self, i):
207
+ s = i * self.seq_len
208
+ c = self.tokens[s : s + self.seq_len + 1]
209
+ return c[:-1], c[1:]
210
+
211
+
212
+ # ================================================================
213
+ # Tokenizer helpers
214
+ # ================================================================
215
+ def train_tokenizer(texts, exp_dir, vocab_size=192, n_train=100_000):
216
+ """Train SentencePiece unigram tokenizer with <200 vocab."""
217
+ data_file = exp_dir / "sp_train.txt"
218
+ prefix = str(exp_dir / "tokenizer")
219
+
220
+ print(f"Writing {min(n_train, len(texts))} texts for tokenizer training...")
221
+ with open(data_file, "w", encoding="utf-8") as f:
222
+ for t in texts[:n_train]:
223
+ f.write(t.strip().replace("\n", " ") + "\n")
224
+
225
+ print("Training SentencePiece unigram tokenizer...")
226
+ spm.SentencePieceTrainer.train(
227
+ input=str(data_file),
228
+ model_prefix=prefix,
229
+ vocab_size=vocab_size,
230
+ model_type="unigram",
231
+ character_coverage=1.0,
232
+ pad_id=0, bos_id=1, eos_id=2, unk_id=3,
233
+ byte_fallback=False,
234
+ normalization_rule_name="identity",
235
+ max_sentence_length=8192,
236
+ num_threads=os.cpu_count() or 4,
237
+ train_extremely_large_corpus=False,
238
+ )
239
+ data_file.unlink(missing_ok=True)
240
+
241
+ sp = spm.SentencePieceProcessor(model_file=prefix + ".model")
242
+ print(f"Tokenizer ready: {sp.get_piece_size()} tokens")
243
+ return sp
244
+
245
+
246
+ def encode_texts(sp, texts, desc="data"):
247
+ """Encode all texts into a single flat token tensor (BOS story EOS ...)."""
248
+ bos, eos = sp.bos_id(), sp.eos_id()
249
+ all_ids = []
250
+ t0 = time.time()
251
+ for i, t in enumerate(texts):
252
+ all_ids.append(bos)
253
+ all_ids.extend(sp.encode(t))
254
+ all_ids.append(eos)
255
+ if (i + 1) % 500_000 == 0:
256
+ print(f" {desc}: {i+1}/{len(texts)} ({len(all_ids)/1e6:.1f}M tok)")
257
+ elapsed = time.time() - t0
258
+ print(f" {desc}: {len(all_ids)/1e6:.2f}M tokens, {elapsed:.1f}s")
259
+ return torch.tensor(all_ids, dtype=torch.long)
260
+
261
+
262
+ # ================================================================
263
+ # LR schedule
264
+ # ================================================================
265
+ def get_lr(step, cfg):
266
+ if step < cfg.warmup_steps:
267
+ return cfg.lr * step / cfg.warmup_steps
268
+ if step >= cfg.max_steps:
269
+ return cfg.min_lr
270
+ r = (step - cfg.warmup_steps) / (cfg.max_steps - cfg.warmup_steps)
271
+ return cfg.min_lr + 0.5 * (cfg.lr - cfg.min_lr) * (1 + math.cos(math.pi * r))
272
+
273
+
274
+ # ================================================================
275
+ # Eval
276
+ # ================================================================
277
+ @torch.no_grad()
278
+ def evaluate(model, loader, device, steps=50):
279
+ model.eval()
280
+ total, n = 0.0, 0
281
+ for x, y in loader:
282
+ if n >= steps:
283
+ break
284
+ x, y = x.to(device), y.to(device)
285
+ with torch.amp.autocast("cuda", dtype=torch.float16):
286
+ _, loss = model(x, y)
287
+ total += loss.item()
288
+ n += 1
289
+ model.train()
290
+ return total / max(n, 1)
291
+
292
+
293
+ # ================================================================
294
+ # Main
295
+ # ================================================================
296
+ def main():
297
+ parser = argparse.ArgumentParser(description="Train 1-bit Transformer LM")
298
+ parser.add_argument("--exp-dir", default="/root/experiments/1m-model")
299
+ parser.add_argument("--max-steps", type=int, default=100_000)
300
+ parser.add_argument("--batch-size", type=int, default=96)
301
+ parser.add_argument("--lr", type=float, default=1.5e-3)
302
+ parser.add_argument("--device", default="cuda:0")
303
+ parser.add_argument("--compile", action="store_true")
304
+ parser.add_argument("--generate", action="store_true")
305
+ parser.add_argument("--prompt", default="Once upon a time")
306
+ args = parser.parse_args()
307
+
308
+ cfg = Config()
309
+ cfg.batch_size = args.batch_size
310
+ cfg.max_steps = args.max_steps
311
+ cfg.lr = args.lr
312
+ cfg.device = args.device
313
+ cfg.compile = args.compile
314
+
315
+ exp = Path(args.exp_dir)
316
+ exp.mkdir(parents=True, exist_ok=True)
317
+
318
+ torch.manual_seed(cfg.seed)
319
+ random.seed(cfg.seed)
320
+ torch.backends.cudnn.benchmark = True
321
+
322
+ # ---- Tokenizer ----
323
+ tok_model = exp / "tokenizer.model"
324
+ if tok_model.exists():
325
+ print("Loading tokenizer...")
326
+ sp = spm.SentencePieceProcessor(model_file=str(tok_model))
327
+ else:
328
+ from datasets import load_dataset
329
+ print("Loading TinyStories for tokenizer training...")
330
+ ds = load_dataset("roneneldan/TinyStories", split="train")
331
+ subset = [ds[i]["text"] for i in range(min(100_000, len(ds)))]
332
+ sp = train_tokenizer(subset, exp, vocab_size=cfg.vocab_size)
333
+ del subset, ds
334
+
335
+ cfg.vocab_size = sp.get_piece_size()
336
+ print(f"Vocab size: {cfg.vocab_size}")
337
+ assert cfg.vocab_size < 200, f"Tokenizer too large: {cfg.vocab_size}"
338
+
339
+ # ---- Generate mode ----
340
+ if args.generate:
341
+ model = BitLM(cfg).to(cfg.device)
342
+ ckpt = torch.load(exp / "best.pt", map_location=cfg.device, weights_only=True)
343
+ state = ckpt["model"]
344
+ if any(k.startswith("_orig_mod.") for k in state):
345
+ state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
346
+ model.load_state_dict(state)
347
+ model.eval()
348
+ print(f"Loaded best model (step {ckpt['step']}, val_loss={ckpt['val_loss']:.4f})")
349
+
350
+ ids = [sp.bos_id()] + sp.encode(args.prompt)
351
+ idx = torch.tensor([ids], device=cfg.device)
352
+ out = model.generate(idx, max_new=500, temp=0.8, top_k=40, eos_id=sp.eos_id())
353
+ text = sp.decode(out[0].tolist())
354
+ print(f"\n--- Generated ---\n{text}\n")
355
+ return
356
+
357
+ # ---- Data ----
358
+ train_cache = exp / "train_tokens.pt"
359
+ val_cache = exp / "val_tokens.pt"
360
+
361
+ if train_cache.exists() and val_cache.exists():
362
+ print("Loading cached tokens...")
363
+ train_tok = torch.load(train_cache, weights_only=True)
364
+ val_tok = torch.load(val_cache, weights_only=True)
365
+ else:
366
+ from datasets import load_dataset
367
+ print("Loading TinyStories...")
368
+ train_ds = load_dataset("roneneldan/TinyStories", split="train")
369
+ val_ds = load_dataset("roneneldan/TinyStories", split="validation")
370
+
371
+ train_texts = [ex["text"] for ex in train_ds]
372
+ val_texts = [ex["text"] for ex in val_ds]
373
+ print(f"Train: {len(train_texts):,} stories, Val: {len(val_texts):,} stories")
374
+
375
+ train_tok = encode_texts(sp, train_texts, "train")
376
+ val_tok = encode_texts(sp, val_texts, "val")
377
+
378
+ print("Saving cached tokens...")
379
+ torch.save(train_tok, train_cache)
380
+ torch.save(val_tok, val_cache)
381
+ del train_texts, val_texts
382
+
383
+ train_data = ChunkedDataset(train_tok, cfg.max_seq_len)
384
+ val_data = ChunkedDataset(val_tok, cfg.max_seq_len)
385
+ print(f"Train: {len(train_data):,} chunks, Val: {len(val_data):,} chunks")
386
+
387
+ train_loader = DataLoader(
388
+ train_data, batch_size=cfg.batch_size, shuffle=True,
389
+ num_workers=cfg.num_workers, pin_memory=True, drop_last=True,
390
+ )
391
+ val_loader = DataLoader(
392
+ val_data, batch_size=cfg.batch_size, shuffle=False,
393
+ num_workers=cfg.num_workers, pin_memory=True, drop_last=True,
394
+ )
395
+
396
+ # ---- Model ----
397
+ model = BitLM(cfg).to(cfg.device)
398
+ n_params = model.param_count()
399
+ print(f"\nModel: {n_params:,} parameters ({n_params/1e6:.3f}M)")
400
+ print(f" d_model={cfg.d_model}, n_heads={cfg.n_heads}, n_layers={cfg.n_layers}, "
401
+ f"d_ff={cfg.d_ff}, max_seq_len={cfg.max_seq_len}")
402
+ assert n_params < 1_000_000, f"Model too large: {n_params:,} params >= 1M"
403
+
404
+ if cfg.compile:
405
+ print("Compiling model with torch.compile...")
406
+ model = torch.compile(model)
407
+
408
+ # ---- Optimizer ----
409
+ decay_params, nodecay_params = [], []
410
+ for name, p in model.named_parameters():
411
+ if p.requires_grad:
412
+ if "norm" in name or "emb" in name:
413
+ nodecay_params.append(p)
414
+ else:
415
+ decay_params.append(p)
416
+
417
+ opt = torch.optim.AdamW(
418
+ [
419
+ {"params": decay_params, "weight_decay": cfg.weight_decay},
420
+ {"params": nodecay_params, "weight_decay": 0.0},
421
+ ],
422
+ lr=cfg.lr, betas=(0.9, 0.95),
423
+ )
424
+ scaler = torch.amp.GradScaler("cuda")
425
+
426
+ # ---- Resume ----
427
+ step = 0
428
+ best_val = float("inf")
429
+ ckpt_path = exp / "latest.pt"
430
+ if ckpt_path.exists():
431
+ print(f"Resuming from {ckpt_path}...")
432
+ ck = torch.load(ckpt_path, map_location=cfg.device)
433
+ # Handle compiled model keys
434
+ state = ck["model"]
435
+ if any(k.startswith("_orig_mod.") for k in state):
436
+ state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
437
+ model.load_state_dict(state)
438
+ opt.load_state_dict(ck["optimizer"])
439
+ scaler.load_state_dict(ck["scaler"])
440
+ step = ck["step"]
441
+ best_val = ck.get("best_val", float("inf"))
442
+ print(f"Resumed at step {step}, best_val={best_val:.4f}")
443
+
444
+ # ---- Training loop ----
445
+ print(f"\nTraining for {cfg.max_steps:,} steps "
446
+ f"(batch={cfg.batch_size}, accum={cfg.grad_accum}, "
447
+ f"eff_batch={cfg.batch_size * cfg.grad_accum})\n")
448
+
449
+ model.train()
450
+ train_iter = iter(train_loader)
451
+ running_loss = 0.0
452
+ t0 = time.time()
453
+ tokens_since_log = 0
454
+
455
+ while step < cfg.max_steps:
456
+ # Get batch (auto-restart on epoch boundary)
457
+ try:
458
+ x, y = next(train_iter)
459
+ except StopIteration:
460
+ train_iter = iter(train_loader)
461
+ x, y = next(train_iter)
462
+
463
+ x, y = x.to(cfg.device, non_blocking=True), y.to(cfg.device, non_blocking=True)
464
+
465
+ # LR schedule
466
+ lr = get_lr(step, cfg)
467
+ for pg in opt.param_groups:
468
+ pg["lr"] = lr
469
+
470
+ # Forward + backward (mixed precision FP16)
471
+ with torch.amp.autocast("cuda", dtype=torch.float16):
472
+ _, loss = model(x, y)
473
+ scaled_loss = loss / cfg.grad_accum
474
+
475
+ scaler.scale(scaled_loss).backward()
476
+ running_loss += loss.item()
477
+ tokens_since_log += x.numel()
478
+
479
+ # Optimizer step every grad_accum mini-batches
480
+ if (step + 1) % cfg.grad_accum == 0:
481
+ scaler.unscale_(opt)
482
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
483
+ scaler.step(opt)
484
+ scaler.update()
485
+ opt.zero_grad(set_to_none=True)
486
+
487
+ step += 1
488
+
489
+ # ---- Logging ----
490
+ if step % cfg.log_interval == 0:
491
+ avg = running_loss / cfg.log_interval
492
+ elapsed = time.time() - t0
493
+ tps = tokens_since_log / elapsed
494
+ ppl = math.exp(min(avg, 20)) # cap for display
495
+ print(
496
+ f"step {step:>6d}/{cfg.max_steps} | "
497
+ f"loss {avg:.4f} | ppl {ppl:>8.2f} | "
498
+ f"lr {lr:.2e} | {tps/1e3:.0f}K tok/s"
499
+ )
500
+ running_loss = 0.0
501
+ tokens_since_log = 0
502
+ t0 = time.time()
503
+
504
+ # ---- Evaluation ----
505
+ if step % cfg.eval_interval == 0:
506
+ vl = evaluate(model, val_loader, cfg.device, cfg.eval_steps)
507
+ vppl = math.exp(min(vl, 20))
508
+ improved = vl < best_val
509
+ tag = " ** NEW BEST **" if improved else ""
510
+ print(f" >>> val_loss={vl:.4f} val_ppl={vppl:.2f}{tag}")
511
+ if improved:
512
+ best_val = vl
513
+ save_dict = {"model": model.state_dict(), "step": step,
514
+ "val_loss": vl, "config": asdict(cfg)}
515
+ torch.save(save_dict, exp / "best.pt")
516
+ model.train()
517
+
518
+ # ---- Generate samples ----
519
+ if step % cfg.gen_interval == 0:
520
+ model.eval()
521
+ for prompt in ["Once upon a time", "The little dog", "She was very happy"]:
522
+ ids = [sp.bos_id()] + sp.encode(prompt)
523
+ idx = torch.tensor([ids], device=cfg.device)
524
+ out = model.generate(idx, max_new=150, temp=0.8, top_k=40,
525
+ eos_id=sp.eos_id())
526
+ text = sp.decode(out[0].tolist())
527
+ print(f" GEN [{prompt[:20]}] → {text[:250]}")
528
+ model.train()
529
+
530
+ # ---- Checkpoint ----
531
+ if step % cfg.save_interval == 0:
532
+ torch.save(
533
+ {
534
+ "model": model.state_dict(),
535
+ "optimizer": opt.state_dict(),
536
+ "scaler": scaler.state_dict(),
537
+ "step": step,
538
+ "best_val": best_val,
539
+ "config": asdict(cfg),
540
+ },
541
+ ckpt_path,
542
+ )
543
+
544
+ # ---- Final save ----
545
+ torch.save(
546
+ {"model": model.state_dict(), "step": step, "config": asdict(cfg)},
547
+ exp / "final.pt",
548
+ )
549
+ print(f"\nTraining complete! Best val loss: {best_val:.4f} (ppl {math.exp(best_val):.2f})")
550
+
551
+
552
+ if __name__ == "__main__":
553
+ main()