Crownelius commited on
Commit
025878f
·
verified ·
1 Parent(s): 1e9d33d

Initial release: Shard-40m-v1 (54.5M dense transformer, anneal final)

Browse files
Files changed (8) hide show
  1. README.md +131 -0
  2. code/config.py +109 -0
  3. code/model.py +373 -0
  4. code/muon.py +198 -0
  5. code/tokenizer.py +109 -0
  6. models/model.pt +3 -0
  7. models/pretrain.pt +3 -0
  8. models/tokenizer.json +0 -0
README.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - small-lm
7
+ - gemma4-attention
8
+ - muon
9
+ - swiglu
10
+ - experimental
11
+ library_name: pytorch
12
+ ---
13
+
14
+ # Shard-40m-v1
15
+
16
+ A 54.5M parameter dense transformer trained on consumer-grade compute (Thunder Compute pretrain + Colab anneal). Released as a research artifact and pipeline-validation reference. Not a deployable model.
17
+
18
+ This is the first checkpoint in the Shard series of small experimental transformers.
19
+
20
+ ## Architecture
21
+
22
+ ```
23
+ Total params: 54,538,752 (~54.5M)
24
+ Hidden dim: 512
25
+ Layers: 12
26
+ Attention heads: 8 (MHA, no GQA)
27
+ Head dim: 64
28
+ MLP intermediate: 2048 (SwiGLU)
29
+ Vocab size: 8192
30
+ Max sequence: 8192
31
+ Attention pattern: Gemma 4 alternating sliding window (window=1024) + global, last layer global
32
+ Norm: RMSNorm, pre-norm
33
+ Position encoding: RoPE on Q and K
34
+ Embeddings: tied input/output
35
+ Activation: SwiGLU
36
+ MoE: none
37
+ Engram: none
38
+ ```
39
+
40
+ ## Training
41
+
42
+ ```
43
+ Phase 1 (pretrain):
44
+ Compute: Thunder Compute single GPU
45
+ Steps: 48,220 of a 100,000 step target (paused early)
46
+ Throughput: 86,800 tokens per second
47
+ Optimizer: Muon for hidden 2D weights, AdamW for embeddings and norms
48
+ LR schedule: WSD (warmup-stable-decay)
49
+ Stabilizers: lm_head logit cap 30, z-loss coefficient 1e-4
50
+
51
+ Phase 2 (anneal):
52
+ Compute: Colab A100
53
+ Steps: 20,000 (full anneal complete)
54
+ Final cross-entropy: 3.27
55
+ Mix: OpenWebMath, FineWeb-Edu carryover, NuminaMath, MetaMathQA, ArXiv, Cosmopedia, AI2 ARC
56
+ ```
57
+
58
+ ## Files
59
+
60
+ - `models/model.pt` — anneal final checkpoint (model state only, 105 MB bf16)
61
+ - `models/pretrain.pt` — pretrain step 47,500 (with optimizer state, 217 MB)
62
+ - `models/tokenizer.json` — custom 8192-vocab BPE
63
+ - `code/` — minimum loading code (model.py, config.py, tokenizer.py, muon.py)
64
+
65
+ ## How to load
66
+
67
+ ```python
68
+ import sys, torch
69
+ sys.path.insert(0, 'code')
70
+ from config import Config
71
+ from model import ToyLM
72
+ from tokenizer import load_tokenizer
73
+
74
+ ck = torch.load('models/model.pt', map_location='cpu', weights_only=False)
75
+ cfg = Config(**ck['cfg']) if isinstance(ck['cfg'], dict) else ck['cfg']
76
+ model = ToyLM(cfg).cuda().to(torch.bfloat16)
77
+ model.load_state_dict(ck['model'])
78
+ model.eval()
79
+
80
+ tok = load_tokenizer('models/tokenizer.json')
81
+ ids = torch.tensor([tok.encode('The capital of France is').ids], device='cuda')
82
+ with torch.no_grad():
83
+ for _ in range(40):
84
+ logits, _ = model(ids)
85
+ nxt = logits[:, -1].argmax(-1, keepdim=True)
86
+ ids = torch.cat([ids, nxt], 1)
87
+ print(tok.decode(ids[0].tolist()))
88
+ ```
89
+
90
+ ## Benchmark
91
+
92
+ Greedy decode at 47 tokens per second on a single CUDA GPU. Model footprint 109 MB in bf16, 16 MB peak inference memory.
93
+
94
+ Sampled outputs at temperature 0.7, top_p 0.9:
95
+
96
+ | Prompt | Output |
97
+ |---|---|
98
+ | `The capital of France is` | `"covered by the Crown" (for example, the Great Seal of France...)` |
99
+ | `To compute 12 plus 7, we can` | `now use the first 6 as a reversible input...` |
100
+ | `Question: What is 23 + 19? Answer:` | `The answer is 23. Answer: 23. Answer: 23` (loops) |
101
+ | `def fibonacci(n):` | `// Appendix A. - S. B. V. Shanker. - S. M. P. Gerber...` |
102
+ | `Once upon a time, in a small village,` | `a woman is a gentleman in a village with an infinite wealth...` |
103
+ | `Solve: 17 * 23 = ?` | `?????\n*****` (breakdown) |
104
+
105
+ ## What this artifact proves
106
+
107
+ The training pipeline runs end to end on consumer-grade hardware. Muon + AdamW dual optimizer, WSD schedule, Gemma 4 alternating attention, anneal phase mixing math, code, and prose all stable. Loss decreases monotonically through pretrain. No NaN events, no divergence, no rank loss flagged by the Muon min-singular-value sentinel.
108
+
109
+ ## What this artifact cannot do
110
+
111
+ Math (broken, hallucinates digits or loops). Code generation (gibberish). Factual grounding (hallucinates with grammatical confidence). Long-context retrieval (max sequence 8192 with sliding window 1024 means effective context is much shorter for non-global layers).
112
+
113
+ ## Why release it
114
+
115
+ To document a reproducible recipe at this scale. The next iteration in this line moves to a 412M MoE with 3 routed experts, vocabulary 262144, distillation pretraining from frontier teachers, and a token budget that crosses the Chinchilla line. This artifact is the baseline against which that next model will be measured.
116
+
117
+ ## License
118
+
119
+ Apache 2.0. Use freely. Attribution appreciated but not required.
120
+
121
+ ## Citation
122
+
123
+ ```
124
+ @misc{shard40mv1,
125
+ author = {Shane (Crownelius)},
126
+ title = {Shard-40m-v1: a 54.5M dense transformer trained on consumer compute},
127
+ year = {2026},
128
+ publisher = {HuggingFace},
129
+ url = {https://huggingface.co/CompactAI-O/Shard-40m-v1}
130
+ }
131
+ ```
code/config.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Config dataclass for the toy 50M LM.
2
+
3
+ Scaled up from the toy_1m_gemma4_dsv4 baseline. Architectural levers stay the
4
+ same (alternating SLIDE/GLOBAL Gemma 4 attention, optional Muon, optional
5
+ 512-slot Engram, full v2 stabilisation), only the shape numbers change.
6
+
7
+ Two architectural variants are flag-gated:
8
+
9
+ attention_pattern:
10
+ "all_global" -- every layer is full causal attention (baseline).
11
+ "gemma4" -- alternating SLIDE/GLOBAL across layers; last layer is GLOBAL.
12
+
13
+ optimizer:
14
+ "adamw" -- AdamW for everything (baseline).
15
+ "muon" -- Muon for params with .dim() >= 2; AdamW for embeddings + 1D.
16
+
17
+ engram_enabled: optional 512-slot external memory bank with zero-init gate.
18
+
19
+ When attention_pattern == "all_global" and optimizer == "adamw" and engram_enabled
20
+ is False, training math is bit-identical to a plain causal transformer baseline.
21
+
22
+ Defaults
23
+ --------
24
+ * vocab=8192 (up from 4096): fresh BPE on a larger FineWeb-edu sample.
25
+ * dim=512, n_layers=12, n_heads=8, head_dim=64.
26
+ * mlp_hidden=2048 (4x dim, SwiGLU).
27
+ * max_seq_len=8192 (up from 4096).
28
+ * sliding_window=1024 ("larger model" Gemma 4 tier; 1M used 512).
29
+ * All v2 stabilisers ON: lm_head_logit_cap=30.0, z_loss_weight=1e-4, lr_schedule="wsd".
30
+ """
31
+ from __future__ import annotations
32
+
33
+ from dataclasses import dataclass
34
+ from typing import Literal
35
+
36
+
37
+ AttentionPattern = Literal["all_global", "gemma4"]
38
+ OptimizerName = Literal["adamw", "muon"]
39
+ LRSchedule = Literal["cosine", "wsd"]
40
+
41
+
42
+ @dataclass
43
+ class Config:
44
+ # ---------- model shape ----------
45
+ vocab_size: int = 8192
46
+ dim: int = 512
47
+ n_layers: int = 12
48
+ n_heads: int = 8
49
+ head_dim: int = 64 # n_heads * head_dim must equal dim
50
+ mlp_hidden: int = 2048
51
+ max_seq_len: int = 8192
52
+
53
+ # ---------- gemma4 SWA ----------
54
+ attention_pattern: AttentionPattern = "gemma4"
55
+ sliding_window: int = 1024
56
+
57
+ # ---------- engram (off by default) ----------
58
+ engram_enabled: bool = False
59
+ engram_slots: int = 512
60
+ engram_inject_layer: int = 6 # mid-stack for the 12-layer build
61
+
62
+ # ---------- training ----------
63
+ optimizer: OptimizerName = "muon"
64
+ rope_base: float = 10000.0
65
+ norm_eps: float = 1e-5
66
+ dropout: float = 0.0
67
+ tie_embeddings: bool = True
68
+
69
+ # ---------- CE stabilisation (Gemma-2 logit cap + PaLM z-loss) ----------
70
+ # ON by default at 50M scale -- the 1M project added these as a v2 bolt-on
71
+ # but at 50M with bf16 they're standard practice (DeepSeek V2/3, Gemma 2/3,
72
+ # PaLM). Bit-identical to the un-stabilised path when both knobs are 0/None.
73
+ lm_head_logit_cap: float | None = 30.0
74
+ z_loss_weight: float = 1e-4
75
+
76
+ # ---------- LR schedule ----------
77
+ # WSD by default at 50M (per Apr 2026 small-LM research; lets the head
78
+ # decay over the last 20 % of post-warmup, much smoother than cosine).
79
+ lr_schedule: LRSchedule = "wsd"
80
+ wsd_decay_frac: float = 0.2
81
+
82
+ # ---------- bookkeeping ----------
83
+ init_std: float = 0.02
84
+
85
+ def __post_init__(self) -> None:
86
+ assert self.n_heads * self.head_dim == self.dim, (
87
+ f"n_heads*head_dim={self.n_heads * self.head_dim} != dim={self.dim}"
88
+ )
89
+ assert self.attention_pattern in ("all_global", "gemma4")
90
+ assert self.optimizer in ("adamw", "muon")
91
+ assert self.lr_schedule in ("cosine", "wsd")
92
+ assert 0.0 <= self.wsd_decay_frac <= 1.0
93
+ assert self.z_loss_weight >= 0.0
94
+ assert self.lm_head_logit_cap is None or self.lm_head_logit_cap > 0
95
+ # Last layer must be GLOBAL when using gemma4 (canonical invariant).
96
+ # Concretely: layer i is GLOBAL iff (i % 2 == 1) for i in [0, n_layers).
97
+ # n_layers must be even, last index n_layers-1 must be odd.
98
+ if self.attention_pattern == "gemma4":
99
+ assert self.n_layers % 2 == 0 and self.n_layers >= 2, (
100
+ "gemma4 pattern requires even n_layers >= 2 so the last layer is GLOBAL"
101
+ )
102
+
103
+ def attention_kind(self, layer_idx: int) -> Literal["slide", "global"]:
104
+ """Return whether `layer_idx` is a sliding-window or global-attention layer."""
105
+ if self.attention_pattern == "all_global":
106
+ return "global"
107
+ # gemma4: even idx = SLIDE, odd idx = GLOBAL. Last layer (n_layers-1) is odd
108
+ # for any even n_layers, so it is GLOBAL.
109
+ return "global" if (layer_idx % 2 == 1) else "slide"
code/model.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Toy 1M-param transformer with Gemma 4 alternating SWA + optional engram memory.
2
+
3
+ Design notes
4
+ ------------
5
+ * RMSNorm pre-norm, SwiGLU MLP, tied embedding/output (standard Llama-ish base).
6
+ * Causal mask is precomputed; sliding-window layers use the same code path with
7
+ an additional window-restricted mask (purely a mask difference -- no kernel split).
8
+ * RoPE is applied to Q/K only (standard, no Gemma 4 dual-RoPE).
9
+ * Engram is an optional 512-slot static memory bank attended-to from one layer's
10
+ output; injected via a sigmoid gate that is zero-initialised so it's a no-op
11
+ at training start. Bit-identical to no-engram when `cfg.engram_enabled=False`.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import math
16
+ from typing import Optional
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from config import Config
23
+
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # RMSNorm
27
+ # ---------------------------------------------------------------------------
28
+ class RMSNorm(nn.Module):
29
+ def __init__(self, dim: int, eps: float = 1e-5):
30
+ super().__init__()
31
+ self.weight = nn.Parameter(torch.ones(dim))
32
+ self.eps = eps
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ # Compute in float32 for stability; cast back to input dtype.
36
+ dtype = x.dtype
37
+ xf = x.float()
38
+ rms = xf.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
39
+ return (xf * rms).to(dtype) * self.weight
40
+
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # RoPE
44
+ # ---------------------------------------------------------------------------
45
+ def _build_rope_cache(seq_len: int, head_dim: int, base: float, device, dtype) -> tuple[torch.Tensor, torch.Tensor]:
46
+ assert head_dim % 2 == 0, "head_dim must be even for RoPE"
47
+ half = head_dim // 2
48
+ inv_freq = 1.0 / (base ** (torch.arange(0, half, device=device, dtype=torch.float32) / half))
49
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
50
+ freqs = torch.einsum("i,j->ij", t, inv_freq) # (T, half)
51
+ cos = freqs.cos().to(dtype)
52
+ sin = freqs.sin().to(dtype)
53
+ return cos, sin # each (T, half)
54
+
55
+
56
+ def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
57
+ # x: (B, n_h, T, head_dim). cos/sin: (T, head_dim/2).
58
+ x1, x2 = x.chunk(2, dim=-1)
59
+ cos_b = cos[None, None, :, :]
60
+ sin_b = sin[None, None, :, :]
61
+ rotated_x1 = x1 * cos_b - x2 * sin_b
62
+ rotated_x2 = x1 * sin_b + x2 * cos_b
63
+ return torch.cat([rotated_x1, rotated_x2], dim=-1)
64
+
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # Attention
68
+ # ---------------------------------------------------------------------------
69
+ class Attention(nn.Module):
70
+ """MHA with RoPE and configurable causal-or-sliding mask.
71
+
72
+ `kind == 'global'`: full causal attention.
73
+ `kind == 'slide'` : causal attention restricted to the last `window` tokens.
74
+
75
+ Both code paths use F.scaled_dot_product_attention for speed; the only
76
+ difference is the additive mask. When kind=='global' we pass `is_causal=True`
77
+ and skip building an explicit mask. When kind=='slide' we build a banded
78
+ mask that is bit-identical to the global path with appropriate -inf entries
79
+ outside the window.
80
+ """
81
+
82
+ def __init__(self, cfg: Config, kind: str):
83
+ super().__init__()
84
+ assert kind in ("global", "slide")
85
+ self.cfg = cfg
86
+ self.kind = kind
87
+ self.n_heads = cfg.n_heads
88
+ self.head_dim = cfg.head_dim
89
+ self.scale = self.head_dim**-0.5
90
+
91
+ self.W_q = nn.Linear(cfg.dim, cfg.dim, bias=False)
92
+ self.W_k = nn.Linear(cfg.dim, cfg.dim, bias=False)
93
+ self.W_v = nn.Linear(cfg.dim, cfg.dim, bias=False)
94
+ self.W_o = nn.Linear(cfg.dim, cfg.dim, bias=False)
95
+
96
+ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
97
+ B, T, D = x.shape
98
+ q = self.W_q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, H, T, Dh)
99
+ k = self.W_k(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
100
+ v = self.W_v(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
101
+
102
+ q = _apply_rope(q, cos, sin)
103
+ k = _apply_rope(k, cos, sin)
104
+
105
+ if self.kind == "global":
106
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
107
+ else:
108
+ # Banded causal mask: token t may attend to tokens in [max(0, t-window+1), t].
109
+ mask = _sliding_causal_mask(T, self.cfg.sliding_window, x.device, x.dtype)
110
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=False)
111
+
112
+ out = out.transpose(1, 2).contiguous().view(B, T, D)
113
+ return self.W_o(out)
114
+
115
+
116
+ def _sliding_causal_mask(T: int, window: int, device, dtype) -> torch.Tensor:
117
+ """(T, T) additive mask: 0 inside window+causal, -inf outside.
118
+
119
+ Token i attends to j iff j <= i and (i - j) < window.
120
+ """
121
+ i = torch.arange(T, device=device).unsqueeze(1) # (T,1)
122
+ j = torch.arange(T, device=device).unsqueeze(0) # (1,T)
123
+ causal = j <= i
124
+ in_window = (i - j) < window
125
+ keep = causal & in_window
126
+ mask = torch.zeros((T, T), device=device, dtype=dtype)
127
+ mask = mask.masked_fill(~keep, float("-inf"))
128
+ # SDPA expects (..., T, T) broadcast over batch/heads.
129
+ return mask
130
+
131
+
132
+ # ---------------------------------------------------------------------------
133
+ # MLP (SwiGLU)
134
+ # ---------------------------------------------------------------------------
135
+ class SwiGLU(nn.Module):
136
+ def __init__(self, dim: int, hidden: int):
137
+ super().__init__()
138
+ self.w_gate = nn.Linear(dim, hidden, bias=False)
139
+ self.w_up = nn.Linear(dim, hidden, bias=False)
140
+ self.w_down = nn.Linear(hidden, dim, bias=False)
141
+
142
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
143
+ return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
144
+
145
+
146
+ # ---------------------------------------------------------------------------
147
+ # Block
148
+ # ---------------------------------------------------------------------------
149
+ class Block(nn.Module):
150
+ def __init__(self, cfg: Config, layer_idx: int):
151
+ super().__init__()
152
+ kind = cfg.attention_kind(layer_idx)
153
+ self.norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
154
+ self.attn = Attention(cfg, kind=kind)
155
+ self.norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
156
+ self.mlp = SwiGLU(cfg.dim, cfg.mlp_hidden)
157
+ self.kind = kind
158
+
159
+ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
160
+ x = x + self.attn(self.norm1(x), cos, sin)
161
+ x = x + self.mlp(self.norm2(x))
162
+ return x
163
+
164
+
165
+ # ---------------------------------------------------------------------------
166
+ # Engram external memory
167
+ # ---------------------------------------------------------------------------
168
+ class Engram(nn.Module):
169
+ """Static memory bank with single-head attention readout + zero-init gate.
170
+
171
+ Bit-identical to no-engram at init (gate sigmoid is zero so injection is 0).
172
+ Becomes non-trivial only after the gate is trained away from zero.
173
+ """
174
+
175
+ def __init__(self, cfg: Config):
176
+ super().__init__()
177
+ self.cfg = cfg
178
+ # Slot rows are normalised by RMSNorm at read time.
179
+ self.slots = nn.Parameter(torch.randn(cfg.engram_slots, cfg.dim) * cfg.init_std)
180
+ self.q_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
181
+ self.k_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
182
+ self.v_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
183
+ self.o_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
184
+ self.norm = RMSNorm(cfg.dim, eps=cfg.norm_eps)
185
+ # Zero-init gate scalar -> sigmoid(0) = 0.5? No, we want exact no-op at init.
186
+ # Use a *raw* gate that we multiply rather than sigmoid; init to 0.
187
+ self.gate = nn.Parameter(torch.zeros(cfg.dim))
188
+
189
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
190
+ # h: (B, T, D). Read from memory.
191
+ h_n = self.norm(h)
192
+ q = self.q_proj(h_n) # (B, T, D)
193
+ k = self.k_proj(self.slots) # (S, D)
194
+ v = self.v_proj(self.slots) # (S, D)
195
+ scale = q.shape[-1] ** -0.5
196
+ attn = torch.einsum("btd,sd->bts", q, k) * scale
197
+ w = attn.softmax(dim=-1)
198
+ retrieved = torch.einsum("bts,sd->btd", w, v)
199
+ retrieved = self.o_proj(retrieved)
200
+ # Multiplicative zero-init gate -> exact no-op at init.
201
+ return h + self.gate * retrieved
202
+
203
+
204
+ # ---------------------------------------------------------------------------
205
+ # ToyLM
206
+ # ---------------------------------------------------------------------------
207
+ class ToyLM(nn.Module):
208
+ def __init__(self, cfg: Config):
209
+ super().__init__()
210
+ self.cfg = cfg
211
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.dim)
212
+ self.blocks = nn.ModuleList([Block(cfg, i) for i in range(cfg.n_layers)])
213
+ self.norm_f = RMSNorm(cfg.dim, eps=cfg.norm_eps)
214
+
215
+ if cfg.engram_enabled:
216
+ self.engram = Engram(cfg)
217
+ else:
218
+ self.engram = None
219
+
220
+ if not cfg.tie_embeddings:
221
+ self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
222
+ else:
223
+ self.lm_head = None
224
+
225
+ # RoPE cache; rebuilt lazily if the requested seq_len exceeds it.
226
+ cos, sin = _build_rope_cache(cfg.max_seq_len, cfg.head_dim, cfg.rope_base, device="cpu", dtype=torch.float32)
227
+ self.register_buffer("rope_cos", cos, persistent=False)
228
+ self.register_buffer("rope_sin", sin, persistent=False)
229
+
230
+ self._init_weights()
231
+
232
+ def _init_weights(self) -> None:
233
+ std = self.cfg.init_std
234
+ for p_name, p in self.named_parameters():
235
+ if p.dim() >= 2:
236
+ nn.init.normal_(p, mean=0.0, std=std)
237
+ elif p_name.endswith(".weight") and "norm" in p_name.lower():
238
+ nn.init.ones_(p)
239
+ elif p_name == "engram.gate":
240
+ nn.init.zeros_(p)
241
+ else:
242
+ nn.init.zeros_(p)
243
+
244
+ def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
245
+ B, T = idx.shape
246
+ assert T <= self.cfg.max_seq_len, f"seq_len {T} > max {self.cfg.max_seq_len}"
247
+ x = self.tok_emb(idx)
248
+
249
+ cos = self.rope_cos[:T].to(device=x.device, dtype=x.dtype)
250
+ sin = self.rope_sin[:T].to(device=x.device, dtype=x.dtype)
251
+
252
+ for i, blk in enumerate(self.blocks):
253
+ x = blk(x, cos, sin)
254
+ if self.engram is not None and i == self.cfg.engram_inject_layer:
255
+ x = self.engram(x)
256
+
257
+ x = self.norm_f(x)
258
+
259
+ if self.cfg.tie_embeddings:
260
+ logits = F.linear(x, self.tok_emb.weight)
261
+ else:
262
+ logits = self.lm_head(x)
263
+
264
+ # Gemma-2 logit soft-cap (bf16 stability + bounded softmax input).
265
+ if self.cfg.lm_head_logit_cap is not None:
266
+ cap = self.cfg.lm_head_logit_cap
267
+ logits = cap * torch.tanh(logits / cap)
268
+
269
+ loss = None
270
+ if targets is not None:
271
+ loss = F.cross_entropy(
272
+ logits.reshape(-1, logits.size(-1)),
273
+ targets.reshape(-1),
274
+ ignore_index=-100,
275
+ )
276
+ # PaLM-style z-loss: penalises log-partition magnitude. Keeps the
277
+ # softmax denominator from drifting; small weight (~1e-4) costs ~0.
278
+ # Computed only on non-ignored positions so it composes with masked SFT.
279
+ if self.cfg.z_loss_weight > 0:
280
+ lse = torch.logsumexp(logits.float(), dim=-1) # (B, T)
281
+ if targets is not None:
282
+ valid = targets.reshape(*lse.shape) != -100
283
+ if valid.any():
284
+ z = (lse[valid] ** 2).mean()
285
+ else:
286
+ z = lse.new_zeros(())
287
+ else:
288
+ z = (lse ** 2).mean()
289
+ loss = loss + self.cfg.z_loss_weight * z
290
+ return logits, loss
291
+
292
+ @torch.no_grad()
293
+ def generate(
294
+ self,
295
+ idx: torch.Tensor,
296
+ max_new_tokens: int = 80,
297
+ *,
298
+ temperature: float = 0.8,
299
+ top_p: float = 0.9,
300
+ rep_penalty: float = 1.3,
301
+ stop_token_ids: Optional[set[int]] = None,
302
+ ) -> torch.Tensor:
303
+ """Sampling-based decode with top-p + repetition penalty.
304
+
305
+ Defaults are tuned for sub-10M LMs: greedy alone collapses into
306
+ token-level repetition loops at this scale (entropy stays high but
307
+ argmax follows a self-amplifying trajectory). T=0.8 + top-p 0.9 +
308
+ rep_penalty=1.3 reliably breaks the loop without going incoherent.
309
+ Validated 2026-04-29 on the 12k-step toy 1M checkpoint.
310
+
311
+ Pass `temperature=0.0` to recover greedy (without rep_penalty).
312
+ """
313
+ self.eval()
314
+ for _ in range(max_new_tokens):
315
+ logits, _ = self(idx)
316
+ logits = logits[:, -1].float() # (B, V)
317
+
318
+ if rep_penalty != 1.0:
319
+ # Per-batch element rep penalty over already-emitted tokens.
320
+ for b in range(idx.size(0)):
321
+ seen = torch.unique(idx[b])
322
+ pos = logits[b, seen] > 0
323
+ logits[b, seen] = torch.where(pos,
324
+ logits[b, seen] / rep_penalty,
325
+ logits[b, seen] * rep_penalty)
326
+
327
+ if temperature <= 0.0:
328
+ nxt = logits.argmax(dim=-1, keepdim=True)
329
+ else:
330
+ logits = logits / temperature
331
+ if top_p < 1.0:
332
+ sorted_logits, sorted_idx = logits.sort(descending=True)
333
+ cum = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
334
+ mask = cum > top_p
335
+ mask[..., 1:] = mask[..., :-1].clone()
336
+ mask[..., 0] = False
337
+ logits = logits.scatter(1, sorted_idx,
338
+ sorted_logits.masked_fill(mask, float('-inf')))
339
+ probs = F.softmax(logits, dim=-1)
340
+ nxt = torch.multinomial(probs, num_samples=1)
341
+
342
+ idx = torch.cat([idx, nxt], dim=1)
343
+
344
+ if stop_token_ids is not None and nxt[0, 0].item() in stop_token_ids:
345
+ break
346
+ if idx.size(1) >= self.cfg.max_seq_len:
347
+ break
348
+
349
+ return idx
350
+
351
+ def num_params_breakdown(self) -> dict[str, int]:
352
+ emb = sum(p.numel() for p in self.tok_emb.parameters())
353
+ attn = 0
354
+ mlp = 0
355
+ norms = 0
356
+ for blk in self.blocks:
357
+ attn += sum(p.numel() for p in blk.attn.parameters())
358
+ mlp += sum(p.numel() for p in blk.mlp.parameters())
359
+ norms += sum(p.numel() for p in blk.norm1.parameters())
360
+ norms += sum(p.numel() for p in blk.norm2.parameters())
361
+ norms += sum(p.numel() for p in self.norm_f.parameters())
362
+ engram = sum(p.numel() for p in self.engram.parameters()) if self.engram is not None else 0
363
+ head = sum(p.numel() for p in self.lm_head.parameters()) if self.lm_head is not None else 0
364
+ total = sum(p.numel() for p in self.parameters())
365
+ return {
366
+ "embedding": emb,
367
+ "attention": attn,
368
+ "mlp": mlp,
369
+ "norms": norms,
370
+ "engram": engram,
371
+ "lm_head_extra": head, # 0 when tied
372
+ "total": total,
373
+ }
code/muon.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Muon optimizer for 2D matrices.
2
+
3
+ Reference: Keller Jordan, "Muon: An optimizer for hidden layers in neural networks"
4
+ https://kellerjordan.github.io/posts/muon/
5
+
6
+ Algorithm
7
+ ---------
8
+ For each 2D parameter W with gradient G:
9
+ 1. Maintain momentum buffer M_t = beta * M_{t-1} + G_t
10
+ 2. Optionally apply Nesterov: G' = G_t + beta * M_t (or just M_t without Nesterov)
11
+ 3. Orthogonalise G' via 5 iterations of Newton-Schulz with the quintic polynomial
12
+ coefficients (3.4445, -4.7750, 2.0315):
13
+ X <- 3.4445 * X - 4.7750 * X X^T X + 2.0315 * (X X^T)^2 X
14
+ after first dividing X by ||X||_F to bring its singular values into [0, ~1.5].
15
+ 4. Apply the orthogonalised update: W <- W - lr * adj_factor * O
16
+ where adj_factor = max(1, fan_out / fan_in)**0.5 to scale shorter-dim params.
17
+
18
+ This optimiser is intended ONLY for parameters with .dim() >= 2. The recommended
19
+ recipe uses AdamW for embeddings and 1D tensors (norms, biases). The wrapper
20
+ class `HybridOptimizer` here packages that split.
21
+
22
+ Bit-identical guarantee
23
+ -----------------------
24
+ When the caller selects optimizer="adamw" in Config, the train script never
25
+ constructs Muon -- it builds a single AdamW over all params. The HybridOptimizer
26
+ exists only when optimizer="muon"; it is not a sneaky pass-through. This keeps
27
+ the two paths cleanly separated.
28
+ """
29
+ from __future__ import annotations
30
+
31
+ from typing import Iterable
32
+
33
+ import torch
34
+ from torch.optim import Optimizer
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Newton-Schulz orthogonalisation
39
+ # ---------------------------------------------------------------------------
40
+ @torch.no_grad()
41
+ def newton_schulz_5(G: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
42
+ """Quintic Newton-Schulz, 5 iterations. Returns an approximately-orthogonal
43
+ matrix with the same shape as G.
44
+
45
+ Operates on the *transposed* shape if rows < cols so that XX^T stays the
46
+ smaller matrix-multiply (canonical optimisation in the reference impl).
47
+ """
48
+ assert G.dim() >= 2
49
+ a, b, c = 3.4445, -4.7750, 2.0315
50
+ X = G.float() # do all NS math in fp32 even if param is bf16
51
+ if X.size(-2) > X.size(-1):
52
+ X = X.transpose(-2, -1)
53
+ transposed = True
54
+ else:
55
+ transposed = False
56
+
57
+ # Normalise so ||X||_op <= ~1.5. Frobenius norm is an upper bound on the
58
+ # spectral norm; dividing by it is safe and the standard choice.
59
+ X = X / (X.norm() + eps)
60
+
61
+ for _ in range(5):
62
+ A = X @ X.transpose(-2, -1)
63
+ B = b * A + c * (A @ A)
64
+ X = a * X + B @ X
65
+
66
+ if transposed:
67
+ X = X.transpose(-2, -1)
68
+ return X.to(G.dtype)
69
+
70
+
71
+ # ---------------------------------------------------------------------------
72
+ # Muon
73
+ # ---------------------------------------------------------------------------
74
+ class Muon(Optimizer):
75
+ def __init__(
76
+ self,
77
+ params: Iterable[torch.Tensor],
78
+ lr: float = 3e-3,
79
+ momentum: float = 0.95,
80
+ nesterov: bool = True,
81
+ weight_decay: float = 0.0,
82
+ ):
83
+ defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay)
84
+ super().__init__(params, defaults)
85
+ for group in self.param_groups:
86
+ for p in group["params"]:
87
+ assert p.dim() >= 2, (
88
+ f"Muon expects 2D+ params; got shape {tuple(p.shape)}. "
89
+ "Wrap embeddings + 1D tensors with AdamW (use HybridOptimizer)."
90
+ )
91
+
92
+ @torch.no_grad()
93
+ def step(self, closure=None):
94
+ loss = closure() if closure is not None else None
95
+
96
+ for group in self.param_groups:
97
+ lr = group["lr"]
98
+ beta = group["momentum"]
99
+ nesterov = group["nesterov"]
100
+ wd = group["weight_decay"]
101
+
102
+ for p in group["params"]:
103
+ if p.grad is None:
104
+ continue
105
+ g = p.grad
106
+ state = self.state[p]
107
+ if "momentum_buffer" not in state:
108
+ state["momentum_buffer"] = torch.zeros_like(p)
109
+ buf = state["momentum_buffer"]
110
+ buf.mul_(beta).add_(g)
111
+ update = g + beta * buf if nesterov else buf
112
+
113
+ # Reshape ND tensors (e.g. conv kernels) into 2D for orthogonalisation.
114
+ # Embeddings are excluded by construction; here we expect Linear weights
115
+ # which are already 2D, but keep the reshape for safety.
116
+ orig_shape = update.shape
117
+ if update.dim() > 2:
118
+ update = update.reshape(update.shape[0], -1)
119
+
120
+ ortho = newton_schulz_5(update)
121
+
122
+ # Scale by sqrt(max(1, fan_out/fan_in)) so updates have sane magnitude
123
+ # across rectangular shapes. fan_out = rows, fan_in = cols.
124
+ fan_out, fan_in = ortho.shape[-2], ortho.shape[-1]
125
+ adj = max(1.0, fan_out / fan_in) ** 0.5
126
+
127
+ if ortho.shape != orig_shape:
128
+ ortho = ortho.reshape(orig_shape)
129
+
130
+ if wd != 0.0:
131
+ p.add_(p, alpha=-lr * wd)
132
+ p.add_(ortho, alpha=-lr * adj)
133
+
134
+ return loss
135
+
136
+
137
+ # ---------------------------------------------------------------------------
138
+ # Hybrid Muon + AdamW wrapper
139
+ # ---------------------------------------------------------------------------
140
+ class HybridOptimizer:
141
+ """Routes 2D+ params to Muon and 1D / embedding params to AdamW.
142
+
143
+ Mimics the torch.optim.Optimizer surface enough for our train loop:
144
+ .step(), .zero_grad(set_to_none=True), .param_groups (for LR scheduling).
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ named_params: Iterable[tuple[str, torch.nn.Parameter]],
150
+ muon_lr: float,
151
+ adamw_lr: float,
152
+ muon_momentum: float = 0.95,
153
+ adamw_betas: tuple[float, float] = (0.9, 0.95),
154
+ weight_decay: float = 0.0,
155
+ ):
156
+ muon_params = []
157
+ adamw_params = []
158
+ for name, p in named_params:
159
+ if not p.requires_grad:
160
+ continue
161
+ # Embeddings have dim() == 2 but should still go to AdamW per the recipe.
162
+ is_embedding = "tok_emb" in name or "engram.slots" in name
163
+ if p.dim() >= 2 and not is_embedding:
164
+ muon_params.append(p)
165
+ else:
166
+ adamw_params.append(p)
167
+
168
+ self.muon = Muon(
169
+ muon_params,
170
+ lr=muon_lr,
171
+ momentum=muon_momentum,
172
+ nesterov=True,
173
+ weight_decay=weight_decay,
174
+ )
175
+ self.adamw = torch.optim.AdamW(
176
+ adamw_params,
177
+ lr=adamw_lr,
178
+ betas=adamw_betas,
179
+ weight_decay=weight_decay,
180
+ )
181
+ self.param_groups = self.muon.param_groups + self.adamw.param_groups
182
+
183
+ def step(self, closure=None):
184
+ if closure is not None:
185
+ raise NotImplementedError("HybridOptimizer does not support a closure.")
186
+ self.muon.step()
187
+ self.adamw.step()
188
+
189
+ def zero_grad(self, set_to_none: bool = True):
190
+ self.muon.zero_grad(set_to_none=set_to_none)
191
+ self.adamw.zero_grad(set_to_none=set_to_none)
192
+
193
+ def state_dict(self):
194
+ return {"muon": self.muon.state_dict(), "adamw": self.adamw.state_dict()}
195
+
196
+ def load_state_dict(self, sd):
197
+ self.muon.load_state_dict(sd["muon"])
198
+ self.adamw.load_state_dict(sd["adamw"])
code/tokenizer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train a fresh 8K BPE on a FineWeb-edu sample.
2
+
3
+ This is the 50M-scale variant of the 1M project's 4K BPE. We bump the default
4
+ vocab to 8192 and the document count to 50000 (was 50000 in 1M, kept the same
5
+ because the 1M doc-count was already saturating BPE merge quality at 4K vocab
6
+ -- doubling vocab needs roughly the same training set, not 2x more).
7
+
8
+ We do NOT reuse any FANT tokenizer here -- the point of this experiment family
9
+ is a clean small recipe with no external dependencies.
10
+
11
+ Output: tokenizer.json in the working dir (or wherever specified).
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import time
17
+ from pathlib import Path
18
+
19
+ from tokenizers import Tokenizer
20
+ from tokenizers.models import BPE
21
+ from tokenizers.pre_tokenizers import ByteLevel as BLPre
22
+ from tokenizers.decoders import ByteLevel as BLDec
23
+ from tokenizers.processors import ByteLevel as BLPost
24
+ from tokenizers.trainers import BpeTrainer
25
+
26
+
27
+ SPECIAL_TOKENS = [
28
+ "<|pad|>", # 0
29
+ "<|bos|>", # 1
30
+ "<|eos|>", # 2
31
+ "<|unk|>", # 3
32
+ "<|im_start|>", # 4 -- chat role open
33
+ "<|im_end|>", # 5 -- chat role close
34
+ ]
35
+
36
+
37
+ def _iter_fineweb(n_docs: int):
38
+ """Yield up to `n_docs` text strings from the FineWeb-edu streaming feed."""
39
+ from datasets import load_dataset
40
+
41
+ ds = load_dataset(
42
+ "HuggingFaceFW/fineweb-edu",
43
+ name="default",
44
+ split="train",
45
+ streaming=True,
46
+ )
47
+ n = 0
48
+ for ex in ds:
49
+ if n >= n_docs:
50
+ return
51
+ text = ex.get("text", "")
52
+ if isinstance(text, str) and text.strip():
53
+ n += 1
54
+ yield text
55
+
56
+
57
+ def train_tokenizer(out_path: str = "tokenizer.json", vocab_size: int = 8192, n_docs: int = 50000) -> str:
58
+ tok = Tokenizer(BPE(unk_token="<|unk|>"))
59
+ tok.pre_tokenizer = BLPre(add_prefix_space=False)
60
+ tok.decoder = BLDec()
61
+ tok.post_processor = BLPost(trim_offsets=False)
62
+
63
+ trainer = BpeTrainer(
64
+ vocab_size=vocab_size,
65
+ special_tokens=SPECIAL_TOKENS,
66
+ initial_alphabet=BLPre.alphabet(),
67
+ show_progress=False,
68
+ )
69
+
70
+ print(f"[tokenizer] streaming up to {n_docs} FineWeb-edu docs...")
71
+ t0 = time.time()
72
+ docs = list(_iter_fineweb(n_docs))
73
+ print(f"[tokenizer] collected {len(docs)} docs in {time.time() - t0:.1f}s")
74
+
75
+ print(f"[tokenizer] training BPE vocab_size={vocab_size}...")
76
+ t0 = time.time()
77
+ tok.train_from_iterator(docs, trainer=trainer)
78
+ print(f"[tokenizer] trained in {time.time() - t0:.1f}s; vocab={tok.get_vocab_size()}")
79
+
80
+ out_dir = Path(out_path).parent
81
+ if str(out_dir) and not out_dir.exists():
82
+ out_dir.mkdir(parents=True, exist_ok=True)
83
+ tok.save(out_path)
84
+ print(f"[tokenizer] saved to {out_path}")
85
+ return out_path
86
+
87
+
88
+ def load_tokenizer(path: str = "tokenizer.json") -> Tokenizer:
89
+ return Tokenizer.from_file(path)
90
+
91
+
92
+ # Convenience accessors used by data.py / train.py
93
+ def special_token_id(tok: Tokenizer, name: str) -> int:
94
+ tid = tok.token_to_id(name)
95
+ assert tid is not None, f"{name} not in tokenizer"
96
+ return tid
97
+
98
+
99
+ def main():
100
+ ap = argparse.ArgumentParser()
101
+ ap.add_argument("--out", default="tokenizer.json")
102
+ ap.add_argument("--vocab", type=int, default=8192)
103
+ ap.add_argument("--docs", type=int, default=50000)
104
+ args = ap.parse_args()
105
+ train_tokenizer(args.out, args.vocab, args.docs)
106
+
107
+
108
+ if __name__ == "__main__":
109
+ main()
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a037d9d847c357510b0a09d4dd6c169cacbd988dd24aba945c416a8f93397e7e
3
+ size 109112123
models/pretrain.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06dfe158490abfe41644ebf7f44942d98af32ebe3602892e25117cb8c623c49a
3
+ size 226660407
models/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff