Initial release: Shard-40m-v1 (54.5M dense transformer, anneal final)
Browse files- README.md +131 -0
- code/config.py +109 -0
- code/model.py +373 -0
- code/muon.py +198 -0
- code/tokenizer.py +109 -0
- models/model.pt +3 -0
- models/pretrain.pt +3 -0
- models/tokenizer.json +0 -0
README.md
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
tags:
|
| 6 |
+
- small-lm
|
| 7 |
+
- gemma4-attention
|
| 8 |
+
- muon
|
| 9 |
+
- swiglu
|
| 10 |
+
- experimental
|
| 11 |
+
library_name: pytorch
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Shard-40m-v1
|
| 15 |
+
|
| 16 |
+
A 54.5M parameter dense transformer trained on consumer-grade compute (Thunder Compute pretrain + Colab anneal). Released as a research artifact and pipeline-validation reference. Not a deployable model.
|
| 17 |
+
|
| 18 |
+
This is the first checkpoint in the Shard series of small experimental transformers.
|
| 19 |
+
|
| 20 |
+
## Architecture
|
| 21 |
+
|
| 22 |
+
```
|
| 23 |
+
Total params: 54,538,752 (~54.5M)
|
| 24 |
+
Hidden dim: 512
|
| 25 |
+
Layers: 12
|
| 26 |
+
Attention heads: 8 (MHA, no GQA)
|
| 27 |
+
Head dim: 64
|
| 28 |
+
MLP intermediate: 2048 (SwiGLU)
|
| 29 |
+
Vocab size: 8192
|
| 30 |
+
Max sequence: 8192
|
| 31 |
+
Attention pattern: Gemma 4 alternating sliding window (window=1024) + global, last layer global
|
| 32 |
+
Norm: RMSNorm, pre-norm
|
| 33 |
+
Position encoding: RoPE on Q and K
|
| 34 |
+
Embeddings: tied input/output
|
| 35 |
+
Activation: SwiGLU
|
| 36 |
+
MoE: none
|
| 37 |
+
Engram: none
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Training
|
| 41 |
+
|
| 42 |
+
```
|
| 43 |
+
Phase 1 (pretrain):
|
| 44 |
+
Compute: Thunder Compute single GPU
|
| 45 |
+
Steps: 48,220 of a 100,000 step target (paused early)
|
| 46 |
+
Throughput: 86,800 tokens per second
|
| 47 |
+
Optimizer: Muon for hidden 2D weights, AdamW for embeddings and norms
|
| 48 |
+
LR schedule: WSD (warmup-stable-decay)
|
| 49 |
+
Stabilizers: lm_head logit cap 30, z-loss coefficient 1e-4
|
| 50 |
+
|
| 51 |
+
Phase 2 (anneal):
|
| 52 |
+
Compute: Colab A100
|
| 53 |
+
Steps: 20,000 (full anneal complete)
|
| 54 |
+
Final cross-entropy: 3.27
|
| 55 |
+
Mix: OpenWebMath, FineWeb-Edu carryover, NuminaMath, MetaMathQA, ArXiv, Cosmopedia, AI2 ARC
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
## Files
|
| 59 |
+
|
| 60 |
+
- `models/model.pt` — anneal final checkpoint (model state only, 105 MB bf16)
|
| 61 |
+
- `models/pretrain.pt` — pretrain step 47,500 (with optimizer state, 217 MB)
|
| 62 |
+
- `models/tokenizer.json` — custom 8192-vocab BPE
|
| 63 |
+
- `code/` — minimum loading code (model.py, config.py, tokenizer.py, muon.py)
|
| 64 |
+
|
| 65 |
+
## How to load
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
import sys, torch
|
| 69 |
+
sys.path.insert(0, 'code')
|
| 70 |
+
from config import Config
|
| 71 |
+
from model import ToyLM
|
| 72 |
+
from tokenizer import load_tokenizer
|
| 73 |
+
|
| 74 |
+
ck = torch.load('models/model.pt', map_location='cpu', weights_only=False)
|
| 75 |
+
cfg = Config(**ck['cfg']) if isinstance(ck['cfg'], dict) else ck['cfg']
|
| 76 |
+
model = ToyLM(cfg).cuda().to(torch.bfloat16)
|
| 77 |
+
model.load_state_dict(ck['model'])
|
| 78 |
+
model.eval()
|
| 79 |
+
|
| 80 |
+
tok = load_tokenizer('models/tokenizer.json')
|
| 81 |
+
ids = torch.tensor([tok.encode('The capital of France is').ids], device='cuda')
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
for _ in range(40):
|
| 84 |
+
logits, _ = model(ids)
|
| 85 |
+
nxt = logits[:, -1].argmax(-1, keepdim=True)
|
| 86 |
+
ids = torch.cat([ids, nxt], 1)
|
| 87 |
+
print(tok.decode(ids[0].tolist()))
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## Benchmark
|
| 91 |
+
|
| 92 |
+
Greedy decode at 47 tokens per second on a single CUDA GPU. Model footprint 109 MB in bf16, 16 MB peak inference memory.
|
| 93 |
+
|
| 94 |
+
Sampled outputs at temperature 0.7, top_p 0.9:
|
| 95 |
+
|
| 96 |
+
| Prompt | Output |
|
| 97 |
+
|---|---|
|
| 98 |
+
| `The capital of France is` | `"covered by the Crown" (for example, the Great Seal of France...)` |
|
| 99 |
+
| `To compute 12 plus 7, we can` | `now use the first 6 as a reversible input...` |
|
| 100 |
+
| `Question: What is 23 + 19? Answer:` | `The answer is 23. Answer: 23. Answer: 23` (loops) |
|
| 101 |
+
| `def fibonacci(n):` | `// Appendix A. - S. B. V. Shanker. - S. M. P. Gerber...` |
|
| 102 |
+
| `Once upon a time, in a small village,` | `a woman is a gentleman in a village with an infinite wealth...` |
|
| 103 |
+
| `Solve: 17 * 23 = ?` | `?????\n*****` (breakdown) |
|
| 104 |
+
|
| 105 |
+
## What this artifact proves
|
| 106 |
+
|
| 107 |
+
The training pipeline runs end to end on consumer-grade hardware. Muon + AdamW dual optimizer, WSD schedule, Gemma 4 alternating attention, anneal phase mixing math, code, and prose all stable. Loss decreases monotonically through pretrain. No NaN events, no divergence, no rank loss flagged by the Muon min-singular-value sentinel.
|
| 108 |
+
|
| 109 |
+
## What this artifact cannot do
|
| 110 |
+
|
| 111 |
+
Math (broken, hallucinates digits or loops). Code generation (gibberish). Factual grounding (hallucinates with grammatical confidence). Long-context retrieval (max sequence 8192 with sliding window 1024 means effective context is much shorter for non-global layers).
|
| 112 |
+
|
| 113 |
+
## Why release it
|
| 114 |
+
|
| 115 |
+
To document a reproducible recipe at this scale. The next iteration in this line moves to a 412M MoE with 3 routed experts, vocabulary 262144, distillation pretraining from frontier teachers, and a token budget that crosses the Chinchilla line. This artifact is the baseline against which that next model will be measured.
|
| 116 |
+
|
| 117 |
+
## License
|
| 118 |
+
|
| 119 |
+
Apache 2.0. Use freely. Attribution appreciated but not required.
|
| 120 |
+
|
| 121 |
+
## Citation
|
| 122 |
+
|
| 123 |
+
```
|
| 124 |
+
@misc{shard40mv1,
|
| 125 |
+
author = {Shane (Crownelius)},
|
| 126 |
+
title = {Shard-40m-v1: a 54.5M dense transformer trained on consumer compute},
|
| 127 |
+
year = {2026},
|
| 128 |
+
publisher = {HuggingFace},
|
| 129 |
+
url = {https://huggingface.co/CompactAI-O/Shard-40m-v1}
|
| 130 |
+
}
|
| 131 |
+
```
|
code/config.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Config dataclass for the toy 50M LM.
|
| 2 |
+
|
| 3 |
+
Scaled up from the toy_1m_gemma4_dsv4 baseline. Architectural levers stay the
|
| 4 |
+
same (alternating SLIDE/GLOBAL Gemma 4 attention, optional Muon, optional
|
| 5 |
+
512-slot Engram, full v2 stabilisation), only the shape numbers change.
|
| 6 |
+
|
| 7 |
+
Two architectural variants are flag-gated:
|
| 8 |
+
|
| 9 |
+
attention_pattern:
|
| 10 |
+
"all_global" -- every layer is full causal attention (baseline).
|
| 11 |
+
"gemma4" -- alternating SLIDE/GLOBAL across layers; last layer is GLOBAL.
|
| 12 |
+
|
| 13 |
+
optimizer:
|
| 14 |
+
"adamw" -- AdamW for everything (baseline).
|
| 15 |
+
"muon" -- Muon for params with .dim() >= 2; AdamW for embeddings + 1D.
|
| 16 |
+
|
| 17 |
+
engram_enabled: optional 512-slot external memory bank with zero-init gate.
|
| 18 |
+
|
| 19 |
+
When attention_pattern == "all_global" and optimizer == "adamw" and engram_enabled
|
| 20 |
+
is False, training math is bit-identical to a plain causal transformer baseline.
|
| 21 |
+
|
| 22 |
+
Defaults
|
| 23 |
+
--------
|
| 24 |
+
* vocab=8192 (up from 4096): fresh BPE on a larger FineWeb-edu sample.
|
| 25 |
+
* dim=512, n_layers=12, n_heads=8, head_dim=64.
|
| 26 |
+
* mlp_hidden=2048 (4x dim, SwiGLU).
|
| 27 |
+
* max_seq_len=8192 (up from 4096).
|
| 28 |
+
* sliding_window=1024 ("larger model" Gemma 4 tier; 1M used 512).
|
| 29 |
+
* All v2 stabilisers ON: lm_head_logit_cap=30.0, z_loss_weight=1e-4, lr_schedule="wsd".
|
| 30 |
+
"""
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
from dataclasses import dataclass
|
| 34 |
+
from typing import Literal
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
AttentionPattern = Literal["all_global", "gemma4"]
|
| 38 |
+
OptimizerName = Literal["adamw", "muon"]
|
| 39 |
+
LRSchedule = Literal["cosine", "wsd"]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class Config:
|
| 44 |
+
# ---------- model shape ----------
|
| 45 |
+
vocab_size: int = 8192
|
| 46 |
+
dim: int = 512
|
| 47 |
+
n_layers: int = 12
|
| 48 |
+
n_heads: int = 8
|
| 49 |
+
head_dim: int = 64 # n_heads * head_dim must equal dim
|
| 50 |
+
mlp_hidden: int = 2048
|
| 51 |
+
max_seq_len: int = 8192
|
| 52 |
+
|
| 53 |
+
# ---------- gemma4 SWA ----------
|
| 54 |
+
attention_pattern: AttentionPattern = "gemma4"
|
| 55 |
+
sliding_window: int = 1024
|
| 56 |
+
|
| 57 |
+
# ---------- engram (off by default) ----------
|
| 58 |
+
engram_enabled: bool = False
|
| 59 |
+
engram_slots: int = 512
|
| 60 |
+
engram_inject_layer: int = 6 # mid-stack for the 12-layer build
|
| 61 |
+
|
| 62 |
+
# ---------- training ----------
|
| 63 |
+
optimizer: OptimizerName = "muon"
|
| 64 |
+
rope_base: float = 10000.0
|
| 65 |
+
norm_eps: float = 1e-5
|
| 66 |
+
dropout: float = 0.0
|
| 67 |
+
tie_embeddings: bool = True
|
| 68 |
+
|
| 69 |
+
# ---------- CE stabilisation (Gemma-2 logit cap + PaLM z-loss) ----------
|
| 70 |
+
# ON by default at 50M scale -- the 1M project added these as a v2 bolt-on
|
| 71 |
+
# but at 50M with bf16 they're standard practice (DeepSeek V2/3, Gemma 2/3,
|
| 72 |
+
# PaLM). Bit-identical to the un-stabilised path when both knobs are 0/None.
|
| 73 |
+
lm_head_logit_cap: float | None = 30.0
|
| 74 |
+
z_loss_weight: float = 1e-4
|
| 75 |
+
|
| 76 |
+
# ---------- LR schedule ----------
|
| 77 |
+
# WSD by default at 50M (per Apr 2026 small-LM research; lets the head
|
| 78 |
+
# decay over the last 20 % of post-warmup, much smoother than cosine).
|
| 79 |
+
lr_schedule: LRSchedule = "wsd"
|
| 80 |
+
wsd_decay_frac: float = 0.2
|
| 81 |
+
|
| 82 |
+
# ---------- bookkeeping ----------
|
| 83 |
+
init_std: float = 0.02
|
| 84 |
+
|
| 85 |
+
def __post_init__(self) -> None:
|
| 86 |
+
assert self.n_heads * self.head_dim == self.dim, (
|
| 87 |
+
f"n_heads*head_dim={self.n_heads * self.head_dim} != dim={self.dim}"
|
| 88 |
+
)
|
| 89 |
+
assert self.attention_pattern in ("all_global", "gemma4")
|
| 90 |
+
assert self.optimizer in ("adamw", "muon")
|
| 91 |
+
assert self.lr_schedule in ("cosine", "wsd")
|
| 92 |
+
assert 0.0 <= self.wsd_decay_frac <= 1.0
|
| 93 |
+
assert self.z_loss_weight >= 0.0
|
| 94 |
+
assert self.lm_head_logit_cap is None or self.lm_head_logit_cap > 0
|
| 95 |
+
# Last layer must be GLOBAL when using gemma4 (canonical invariant).
|
| 96 |
+
# Concretely: layer i is GLOBAL iff (i % 2 == 1) for i in [0, n_layers).
|
| 97 |
+
# n_layers must be even, last index n_layers-1 must be odd.
|
| 98 |
+
if self.attention_pattern == "gemma4":
|
| 99 |
+
assert self.n_layers % 2 == 0 and self.n_layers >= 2, (
|
| 100 |
+
"gemma4 pattern requires even n_layers >= 2 so the last layer is GLOBAL"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def attention_kind(self, layer_idx: int) -> Literal["slide", "global"]:
|
| 104 |
+
"""Return whether `layer_idx` is a sliding-window or global-attention layer."""
|
| 105 |
+
if self.attention_pattern == "all_global":
|
| 106 |
+
return "global"
|
| 107 |
+
# gemma4: even idx = SLIDE, odd idx = GLOBAL. Last layer (n_layers-1) is odd
|
| 108 |
+
# for any even n_layers, so it is GLOBAL.
|
| 109 |
+
return "global" if (layer_idx % 2 == 1) else "slide"
|
code/model.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Toy 1M-param transformer with Gemma 4 alternating SWA + optional engram memory.
|
| 2 |
+
|
| 3 |
+
Design notes
|
| 4 |
+
------------
|
| 5 |
+
* RMSNorm pre-norm, SwiGLU MLP, tied embedding/output (standard Llama-ish base).
|
| 6 |
+
* Causal mask is precomputed; sliding-window layers use the same code path with
|
| 7 |
+
an additional window-restricted mask (purely a mask difference -- no kernel split).
|
| 8 |
+
* RoPE is applied to Q/K only (standard, no Gemma 4 dual-RoPE).
|
| 9 |
+
* Engram is an optional 512-slot static memory bank attended-to from one layer's
|
| 10 |
+
output; injected via a sigmoid gate that is zero-initialised so it's a no-op
|
| 11 |
+
at training start. Bit-identical to no-engram when `cfg.engram_enabled=False`.
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from config import Config
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
# RMSNorm
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
class RMSNorm(nn.Module):
|
| 29 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 32 |
+
self.eps = eps
|
| 33 |
+
|
| 34 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
# Compute in float32 for stability; cast back to input dtype.
|
| 36 |
+
dtype = x.dtype
|
| 37 |
+
xf = x.float()
|
| 38 |
+
rms = xf.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
|
| 39 |
+
return (xf * rms).to(dtype) * self.weight
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
# RoPE
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
def _build_rope_cache(seq_len: int, head_dim: int, base: float, device, dtype) -> tuple[torch.Tensor, torch.Tensor]:
|
| 46 |
+
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
|
| 47 |
+
half = head_dim // 2
|
| 48 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, half, device=device, dtype=torch.float32) / half))
|
| 49 |
+
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
| 50 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq) # (T, half)
|
| 51 |
+
cos = freqs.cos().to(dtype)
|
| 52 |
+
sin = freqs.sin().to(dtype)
|
| 53 |
+
return cos, sin # each (T, half)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
# x: (B, n_h, T, head_dim). cos/sin: (T, head_dim/2).
|
| 58 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 59 |
+
cos_b = cos[None, None, :, :]
|
| 60 |
+
sin_b = sin[None, None, :, :]
|
| 61 |
+
rotated_x1 = x1 * cos_b - x2 * sin_b
|
| 62 |
+
rotated_x2 = x1 * sin_b + x2 * cos_b
|
| 63 |
+
return torch.cat([rotated_x1, rotated_x2], dim=-1)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
# Attention
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
class Attention(nn.Module):
|
| 70 |
+
"""MHA with RoPE and configurable causal-or-sliding mask.
|
| 71 |
+
|
| 72 |
+
`kind == 'global'`: full causal attention.
|
| 73 |
+
`kind == 'slide'` : causal attention restricted to the last `window` tokens.
|
| 74 |
+
|
| 75 |
+
Both code paths use F.scaled_dot_product_attention for speed; the only
|
| 76 |
+
difference is the additive mask. When kind=='global' we pass `is_causal=True`
|
| 77 |
+
and skip building an explicit mask. When kind=='slide' we build a banded
|
| 78 |
+
mask that is bit-identical to the global path with appropriate -inf entries
|
| 79 |
+
outside the window.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(self, cfg: Config, kind: str):
|
| 83 |
+
super().__init__()
|
| 84 |
+
assert kind in ("global", "slide")
|
| 85 |
+
self.cfg = cfg
|
| 86 |
+
self.kind = kind
|
| 87 |
+
self.n_heads = cfg.n_heads
|
| 88 |
+
self.head_dim = cfg.head_dim
|
| 89 |
+
self.scale = self.head_dim**-0.5
|
| 90 |
+
|
| 91 |
+
self.W_q = nn.Linear(cfg.dim, cfg.dim, bias=False)
|
| 92 |
+
self.W_k = nn.Linear(cfg.dim, cfg.dim, bias=False)
|
| 93 |
+
self.W_v = nn.Linear(cfg.dim, cfg.dim, bias=False)
|
| 94 |
+
self.W_o = nn.Linear(cfg.dim, cfg.dim, bias=False)
|
| 95 |
+
|
| 96 |
+
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 97 |
+
B, T, D = x.shape
|
| 98 |
+
q = self.W_q(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B, H, T, Dh)
|
| 99 |
+
k = self.W_k(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 100 |
+
v = self.W_v(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 101 |
+
|
| 102 |
+
q = _apply_rope(q, cos, sin)
|
| 103 |
+
k = _apply_rope(k, cos, sin)
|
| 104 |
+
|
| 105 |
+
if self.kind == "global":
|
| 106 |
+
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
| 107 |
+
else:
|
| 108 |
+
# Banded causal mask: token t may attend to tokens in [max(0, t-window+1), t].
|
| 109 |
+
mask = _sliding_causal_mask(T, self.cfg.sliding_window, x.device, x.dtype)
|
| 110 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=False)
|
| 111 |
+
|
| 112 |
+
out = out.transpose(1, 2).contiguous().view(B, T, D)
|
| 113 |
+
return self.W_o(out)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _sliding_causal_mask(T: int, window: int, device, dtype) -> torch.Tensor:
|
| 117 |
+
"""(T, T) additive mask: 0 inside window+causal, -inf outside.
|
| 118 |
+
|
| 119 |
+
Token i attends to j iff j <= i and (i - j) < window.
|
| 120 |
+
"""
|
| 121 |
+
i = torch.arange(T, device=device).unsqueeze(1) # (T,1)
|
| 122 |
+
j = torch.arange(T, device=device).unsqueeze(0) # (1,T)
|
| 123 |
+
causal = j <= i
|
| 124 |
+
in_window = (i - j) < window
|
| 125 |
+
keep = causal & in_window
|
| 126 |
+
mask = torch.zeros((T, T), device=device, dtype=dtype)
|
| 127 |
+
mask = mask.masked_fill(~keep, float("-inf"))
|
| 128 |
+
# SDPA expects (..., T, T) broadcast over batch/heads.
|
| 129 |
+
return mask
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# ---------------------------------------------------------------------------
|
| 133 |
+
# MLP (SwiGLU)
|
| 134 |
+
# ---------------------------------------------------------------------------
|
| 135 |
+
class SwiGLU(nn.Module):
|
| 136 |
+
def __init__(self, dim: int, hidden: int):
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.w_gate = nn.Linear(dim, hidden, bias=False)
|
| 139 |
+
self.w_up = nn.Linear(dim, hidden, bias=False)
|
| 140 |
+
self.w_down = nn.Linear(hidden, dim, bias=False)
|
| 141 |
+
|
| 142 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 143 |
+
return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ---------------------------------------------------------------------------
|
| 147 |
+
# Block
|
| 148 |
+
# ---------------------------------------------------------------------------
|
| 149 |
+
class Block(nn.Module):
|
| 150 |
+
def __init__(self, cfg: Config, layer_idx: int):
|
| 151 |
+
super().__init__()
|
| 152 |
+
kind = cfg.attention_kind(layer_idx)
|
| 153 |
+
self.norm1 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
|
| 154 |
+
self.attn = Attention(cfg, kind=kind)
|
| 155 |
+
self.norm2 = RMSNorm(cfg.dim, eps=cfg.norm_eps)
|
| 156 |
+
self.mlp = SwiGLU(cfg.dim, cfg.mlp_hidden)
|
| 157 |
+
self.kind = kind
|
| 158 |
+
|
| 159 |
+
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 160 |
+
x = x + self.attn(self.norm1(x), cos, sin)
|
| 161 |
+
x = x + self.mlp(self.norm2(x))
|
| 162 |
+
return x
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
# Engram external memory
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
class Engram(nn.Module):
|
| 169 |
+
"""Static memory bank with single-head attention readout + zero-init gate.
|
| 170 |
+
|
| 171 |
+
Bit-identical to no-engram at init (gate sigmoid is zero so injection is 0).
|
| 172 |
+
Becomes non-trivial only after the gate is trained away from zero.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self, cfg: Config):
|
| 176 |
+
super().__init__()
|
| 177 |
+
self.cfg = cfg
|
| 178 |
+
# Slot rows are normalised by RMSNorm at read time.
|
| 179 |
+
self.slots = nn.Parameter(torch.randn(cfg.engram_slots, cfg.dim) * cfg.init_std)
|
| 180 |
+
self.q_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
|
| 181 |
+
self.k_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
|
| 182 |
+
self.v_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
|
| 183 |
+
self.o_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
|
| 184 |
+
self.norm = RMSNorm(cfg.dim, eps=cfg.norm_eps)
|
| 185 |
+
# Zero-init gate scalar -> sigmoid(0) = 0.5? No, we want exact no-op at init.
|
| 186 |
+
# Use a *raw* gate that we multiply rather than sigmoid; init to 0.
|
| 187 |
+
self.gate = nn.Parameter(torch.zeros(cfg.dim))
|
| 188 |
+
|
| 189 |
+
def forward(self, h: torch.Tensor) -> torch.Tensor:
|
| 190 |
+
# h: (B, T, D). Read from memory.
|
| 191 |
+
h_n = self.norm(h)
|
| 192 |
+
q = self.q_proj(h_n) # (B, T, D)
|
| 193 |
+
k = self.k_proj(self.slots) # (S, D)
|
| 194 |
+
v = self.v_proj(self.slots) # (S, D)
|
| 195 |
+
scale = q.shape[-1] ** -0.5
|
| 196 |
+
attn = torch.einsum("btd,sd->bts", q, k) * scale
|
| 197 |
+
w = attn.softmax(dim=-1)
|
| 198 |
+
retrieved = torch.einsum("bts,sd->btd", w, v)
|
| 199 |
+
retrieved = self.o_proj(retrieved)
|
| 200 |
+
# Multiplicative zero-init gate -> exact no-op at init.
|
| 201 |
+
return h + self.gate * retrieved
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ---------------------------------------------------------------------------
|
| 205 |
+
# ToyLM
|
| 206 |
+
# ---------------------------------------------------------------------------
|
| 207 |
+
class ToyLM(nn.Module):
|
| 208 |
+
def __init__(self, cfg: Config):
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.cfg = cfg
|
| 211 |
+
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.dim)
|
| 212 |
+
self.blocks = nn.ModuleList([Block(cfg, i) for i in range(cfg.n_layers)])
|
| 213 |
+
self.norm_f = RMSNorm(cfg.dim, eps=cfg.norm_eps)
|
| 214 |
+
|
| 215 |
+
if cfg.engram_enabled:
|
| 216 |
+
self.engram = Engram(cfg)
|
| 217 |
+
else:
|
| 218 |
+
self.engram = None
|
| 219 |
+
|
| 220 |
+
if not cfg.tie_embeddings:
|
| 221 |
+
self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
|
| 222 |
+
else:
|
| 223 |
+
self.lm_head = None
|
| 224 |
+
|
| 225 |
+
# RoPE cache; rebuilt lazily if the requested seq_len exceeds it.
|
| 226 |
+
cos, sin = _build_rope_cache(cfg.max_seq_len, cfg.head_dim, cfg.rope_base, device="cpu", dtype=torch.float32)
|
| 227 |
+
self.register_buffer("rope_cos", cos, persistent=False)
|
| 228 |
+
self.register_buffer("rope_sin", sin, persistent=False)
|
| 229 |
+
|
| 230 |
+
self._init_weights()
|
| 231 |
+
|
| 232 |
+
def _init_weights(self) -> None:
|
| 233 |
+
std = self.cfg.init_std
|
| 234 |
+
for p_name, p in self.named_parameters():
|
| 235 |
+
if p.dim() >= 2:
|
| 236 |
+
nn.init.normal_(p, mean=0.0, std=std)
|
| 237 |
+
elif p_name.endswith(".weight") and "norm" in p_name.lower():
|
| 238 |
+
nn.init.ones_(p)
|
| 239 |
+
elif p_name == "engram.gate":
|
| 240 |
+
nn.init.zeros_(p)
|
| 241 |
+
else:
|
| 242 |
+
nn.init.zeros_(p)
|
| 243 |
+
|
| 244 |
+
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 245 |
+
B, T = idx.shape
|
| 246 |
+
assert T <= self.cfg.max_seq_len, f"seq_len {T} > max {self.cfg.max_seq_len}"
|
| 247 |
+
x = self.tok_emb(idx)
|
| 248 |
+
|
| 249 |
+
cos = self.rope_cos[:T].to(device=x.device, dtype=x.dtype)
|
| 250 |
+
sin = self.rope_sin[:T].to(device=x.device, dtype=x.dtype)
|
| 251 |
+
|
| 252 |
+
for i, blk in enumerate(self.blocks):
|
| 253 |
+
x = blk(x, cos, sin)
|
| 254 |
+
if self.engram is not None and i == self.cfg.engram_inject_layer:
|
| 255 |
+
x = self.engram(x)
|
| 256 |
+
|
| 257 |
+
x = self.norm_f(x)
|
| 258 |
+
|
| 259 |
+
if self.cfg.tie_embeddings:
|
| 260 |
+
logits = F.linear(x, self.tok_emb.weight)
|
| 261 |
+
else:
|
| 262 |
+
logits = self.lm_head(x)
|
| 263 |
+
|
| 264 |
+
# Gemma-2 logit soft-cap (bf16 stability + bounded softmax input).
|
| 265 |
+
if self.cfg.lm_head_logit_cap is not None:
|
| 266 |
+
cap = self.cfg.lm_head_logit_cap
|
| 267 |
+
logits = cap * torch.tanh(logits / cap)
|
| 268 |
+
|
| 269 |
+
loss = None
|
| 270 |
+
if targets is not None:
|
| 271 |
+
loss = F.cross_entropy(
|
| 272 |
+
logits.reshape(-1, logits.size(-1)),
|
| 273 |
+
targets.reshape(-1),
|
| 274 |
+
ignore_index=-100,
|
| 275 |
+
)
|
| 276 |
+
# PaLM-style z-loss: penalises log-partition magnitude. Keeps the
|
| 277 |
+
# softmax denominator from drifting; small weight (~1e-4) costs ~0.
|
| 278 |
+
# Computed only on non-ignored positions so it composes with masked SFT.
|
| 279 |
+
if self.cfg.z_loss_weight > 0:
|
| 280 |
+
lse = torch.logsumexp(logits.float(), dim=-1) # (B, T)
|
| 281 |
+
if targets is not None:
|
| 282 |
+
valid = targets.reshape(*lse.shape) != -100
|
| 283 |
+
if valid.any():
|
| 284 |
+
z = (lse[valid] ** 2).mean()
|
| 285 |
+
else:
|
| 286 |
+
z = lse.new_zeros(())
|
| 287 |
+
else:
|
| 288 |
+
z = (lse ** 2).mean()
|
| 289 |
+
loss = loss + self.cfg.z_loss_weight * z
|
| 290 |
+
return logits, loss
|
| 291 |
+
|
| 292 |
+
@torch.no_grad()
|
| 293 |
+
def generate(
|
| 294 |
+
self,
|
| 295 |
+
idx: torch.Tensor,
|
| 296 |
+
max_new_tokens: int = 80,
|
| 297 |
+
*,
|
| 298 |
+
temperature: float = 0.8,
|
| 299 |
+
top_p: float = 0.9,
|
| 300 |
+
rep_penalty: float = 1.3,
|
| 301 |
+
stop_token_ids: Optional[set[int]] = None,
|
| 302 |
+
) -> torch.Tensor:
|
| 303 |
+
"""Sampling-based decode with top-p + repetition penalty.
|
| 304 |
+
|
| 305 |
+
Defaults are tuned for sub-10M LMs: greedy alone collapses into
|
| 306 |
+
token-level repetition loops at this scale (entropy stays high but
|
| 307 |
+
argmax follows a self-amplifying trajectory). T=0.8 + top-p 0.9 +
|
| 308 |
+
rep_penalty=1.3 reliably breaks the loop without going incoherent.
|
| 309 |
+
Validated 2026-04-29 on the 12k-step toy 1M checkpoint.
|
| 310 |
+
|
| 311 |
+
Pass `temperature=0.0` to recover greedy (without rep_penalty).
|
| 312 |
+
"""
|
| 313 |
+
self.eval()
|
| 314 |
+
for _ in range(max_new_tokens):
|
| 315 |
+
logits, _ = self(idx)
|
| 316 |
+
logits = logits[:, -1].float() # (B, V)
|
| 317 |
+
|
| 318 |
+
if rep_penalty != 1.0:
|
| 319 |
+
# Per-batch element rep penalty over already-emitted tokens.
|
| 320 |
+
for b in range(idx.size(0)):
|
| 321 |
+
seen = torch.unique(idx[b])
|
| 322 |
+
pos = logits[b, seen] > 0
|
| 323 |
+
logits[b, seen] = torch.where(pos,
|
| 324 |
+
logits[b, seen] / rep_penalty,
|
| 325 |
+
logits[b, seen] * rep_penalty)
|
| 326 |
+
|
| 327 |
+
if temperature <= 0.0:
|
| 328 |
+
nxt = logits.argmax(dim=-1, keepdim=True)
|
| 329 |
+
else:
|
| 330 |
+
logits = logits / temperature
|
| 331 |
+
if top_p < 1.0:
|
| 332 |
+
sorted_logits, sorted_idx = logits.sort(descending=True)
|
| 333 |
+
cum = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
|
| 334 |
+
mask = cum > top_p
|
| 335 |
+
mask[..., 1:] = mask[..., :-1].clone()
|
| 336 |
+
mask[..., 0] = False
|
| 337 |
+
logits = logits.scatter(1, sorted_idx,
|
| 338 |
+
sorted_logits.masked_fill(mask, float('-inf')))
|
| 339 |
+
probs = F.softmax(logits, dim=-1)
|
| 340 |
+
nxt = torch.multinomial(probs, num_samples=1)
|
| 341 |
+
|
| 342 |
+
idx = torch.cat([idx, nxt], dim=1)
|
| 343 |
+
|
| 344 |
+
if stop_token_ids is not None and nxt[0, 0].item() in stop_token_ids:
|
| 345 |
+
break
|
| 346 |
+
if idx.size(1) >= self.cfg.max_seq_len:
|
| 347 |
+
break
|
| 348 |
+
|
| 349 |
+
return idx
|
| 350 |
+
|
| 351 |
+
def num_params_breakdown(self) -> dict[str, int]:
|
| 352 |
+
emb = sum(p.numel() for p in self.tok_emb.parameters())
|
| 353 |
+
attn = 0
|
| 354 |
+
mlp = 0
|
| 355 |
+
norms = 0
|
| 356 |
+
for blk in self.blocks:
|
| 357 |
+
attn += sum(p.numel() for p in blk.attn.parameters())
|
| 358 |
+
mlp += sum(p.numel() for p in blk.mlp.parameters())
|
| 359 |
+
norms += sum(p.numel() for p in blk.norm1.parameters())
|
| 360 |
+
norms += sum(p.numel() for p in blk.norm2.parameters())
|
| 361 |
+
norms += sum(p.numel() for p in self.norm_f.parameters())
|
| 362 |
+
engram = sum(p.numel() for p in self.engram.parameters()) if self.engram is not None else 0
|
| 363 |
+
head = sum(p.numel() for p in self.lm_head.parameters()) if self.lm_head is not None else 0
|
| 364 |
+
total = sum(p.numel() for p in self.parameters())
|
| 365 |
+
return {
|
| 366 |
+
"embedding": emb,
|
| 367 |
+
"attention": attn,
|
| 368 |
+
"mlp": mlp,
|
| 369 |
+
"norms": norms,
|
| 370 |
+
"engram": engram,
|
| 371 |
+
"lm_head_extra": head, # 0 when tied
|
| 372 |
+
"total": total,
|
| 373 |
+
}
|
code/muon.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Muon optimizer for 2D matrices.
|
| 2 |
+
|
| 3 |
+
Reference: Keller Jordan, "Muon: An optimizer for hidden layers in neural networks"
|
| 4 |
+
https://kellerjordan.github.io/posts/muon/
|
| 5 |
+
|
| 6 |
+
Algorithm
|
| 7 |
+
---------
|
| 8 |
+
For each 2D parameter W with gradient G:
|
| 9 |
+
1. Maintain momentum buffer M_t = beta * M_{t-1} + G_t
|
| 10 |
+
2. Optionally apply Nesterov: G' = G_t + beta * M_t (or just M_t without Nesterov)
|
| 11 |
+
3. Orthogonalise G' via 5 iterations of Newton-Schulz with the quintic polynomial
|
| 12 |
+
coefficients (3.4445, -4.7750, 2.0315):
|
| 13 |
+
X <- 3.4445 * X - 4.7750 * X X^T X + 2.0315 * (X X^T)^2 X
|
| 14 |
+
after first dividing X by ||X||_F to bring its singular values into [0, ~1.5].
|
| 15 |
+
4. Apply the orthogonalised update: W <- W - lr * adj_factor * O
|
| 16 |
+
where adj_factor = max(1, fan_out / fan_in)**0.5 to scale shorter-dim params.
|
| 17 |
+
|
| 18 |
+
This optimiser is intended ONLY for parameters with .dim() >= 2. The recommended
|
| 19 |
+
recipe uses AdamW for embeddings and 1D tensors (norms, biases). The wrapper
|
| 20 |
+
class `HybridOptimizer` here packages that split.
|
| 21 |
+
|
| 22 |
+
Bit-identical guarantee
|
| 23 |
+
-----------------------
|
| 24 |
+
When the caller selects optimizer="adamw" in Config, the train script never
|
| 25 |
+
constructs Muon -- it builds a single AdamW over all params. The HybridOptimizer
|
| 26 |
+
exists only when optimizer="muon"; it is not a sneaky pass-through. This keeps
|
| 27 |
+
the two paths cleanly separated.
|
| 28 |
+
"""
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
from typing import Iterable
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
from torch.optim import Optimizer
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
# Newton-Schulz orthogonalisation
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
@torch.no_grad()
|
| 41 |
+
def newton_schulz_5(G: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
| 42 |
+
"""Quintic Newton-Schulz, 5 iterations. Returns an approximately-orthogonal
|
| 43 |
+
matrix with the same shape as G.
|
| 44 |
+
|
| 45 |
+
Operates on the *transposed* shape if rows < cols so that XX^T stays the
|
| 46 |
+
smaller matrix-multiply (canonical optimisation in the reference impl).
|
| 47 |
+
"""
|
| 48 |
+
assert G.dim() >= 2
|
| 49 |
+
a, b, c = 3.4445, -4.7750, 2.0315
|
| 50 |
+
X = G.float() # do all NS math in fp32 even if param is bf16
|
| 51 |
+
if X.size(-2) > X.size(-1):
|
| 52 |
+
X = X.transpose(-2, -1)
|
| 53 |
+
transposed = True
|
| 54 |
+
else:
|
| 55 |
+
transposed = False
|
| 56 |
+
|
| 57 |
+
# Normalise so ||X||_op <= ~1.5. Frobenius norm is an upper bound on the
|
| 58 |
+
# spectral norm; dividing by it is safe and the standard choice.
|
| 59 |
+
X = X / (X.norm() + eps)
|
| 60 |
+
|
| 61 |
+
for _ in range(5):
|
| 62 |
+
A = X @ X.transpose(-2, -1)
|
| 63 |
+
B = b * A + c * (A @ A)
|
| 64 |
+
X = a * X + B @ X
|
| 65 |
+
|
| 66 |
+
if transposed:
|
| 67 |
+
X = X.transpose(-2, -1)
|
| 68 |
+
return X.to(G.dtype)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
# Muon
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
class Muon(Optimizer):
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
params: Iterable[torch.Tensor],
|
| 78 |
+
lr: float = 3e-3,
|
| 79 |
+
momentum: float = 0.95,
|
| 80 |
+
nesterov: bool = True,
|
| 81 |
+
weight_decay: float = 0.0,
|
| 82 |
+
):
|
| 83 |
+
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay)
|
| 84 |
+
super().__init__(params, defaults)
|
| 85 |
+
for group in self.param_groups:
|
| 86 |
+
for p in group["params"]:
|
| 87 |
+
assert p.dim() >= 2, (
|
| 88 |
+
f"Muon expects 2D+ params; got shape {tuple(p.shape)}. "
|
| 89 |
+
"Wrap embeddings + 1D tensors with AdamW (use HybridOptimizer)."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
@torch.no_grad()
|
| 93 |
+
def step(self, closure=None):
|
| 94 |
+
loss = closure() if closure is not None else None
|
| 95 |
+
|
| 96 |
+
for group in self.param_groups:
|
| 97 |
+
lr = group["lr"]
|
| 98 |
+
beta = group["momentum"]
|
| 99 |
+
nesterov = group["nesterov"]
|
| 100 |
+
wd = group["weight_decay"]
|
| 101 |
+
|
| 102 |
+
for p in group["params"]:
|
| 103 |
+
if p.grad is None:
|
| 104 |
+
continue
|
| 105 |
+
g = p.grad
|
| 106 |
+
state = self.state[p]
|
| 107 |
+
if "momentum_buffer" not in state:
|
| 108 |
+
state["momentum_buffer"] = torch.zeros_like(p)
|
| 109 |
+
buf = state["momentum_buffer"]
|
| 110 |
+
buf.mul_(beta).add_(g)
|
| 111 |
+
update = g + beta * buf if nesterov else buf
|
| 112 |
+
|
| 113 |
+
# Reshape ND tensors (e.g. conv kernels) into 2D for orthogonalisation.
|
| 114 |
+
# Embeddings are excluded by construction; here we expect Linear weights
|
| 115 |
+
# which are already 2D, but keep the reshape for safety.
|
| 116 |
+
orig_shape = update.shape
|
| 117 |
+
if update.dim() > 2:
|
| 118 |
+
update = update.reshape(update.shape[0], -1)
|
| 119 |
+
|
| 120 |
+
ortho = newton_schulz_5(update)
|
| 121 |
+
|
| 122 |
+
# Scale by sqrt(max(1, fan_out/fan_in)) so updates have sane magnitude
|
| 123 |
+
# across rectangular shapes. fan_out = rows, fan_in = cols.
|
| 124 |
+
fan_out, fan_in = ortho.shape[-2], ortho.shape[-1]
|
| 125 |
+
adj = max(1.0, fan_out / fan_in) ** 0.5
|
| 126 |
+
|
| 127 |
+
if ortho.shape != orig_shape:
|
| 128 |
+
ortho = ortho.reshape(orig_shape)
|
| 129 |
+
|
| 130 |
+
if wd != 0.0:
|
| 131 |
+
p.add_(p, alpha=-lr * wd)
|
| 132 |
+
p.add_(ortho, alpha=-lr * adj)
|
| 133 |
+
|
| 134 |
+
return loss
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ---------------------------------------------------------------------------
|
| 138 |
+
# Hybrid Muon + AdamW wrapper
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
class HybridOptimizer:
|
| 141 |
+
"""Routes 2D+ params to Muon and 1D / embedding params to AdamW.
|
| 142 |
+
|
| 143 |
+
Mimics the torch.optim.Optimizer surface enough for our train loop:
|
| 144 |
+
.step(), .zero_grad(set_to_none=True), .param_groups (for LR scheduling).
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
named_params: Iterable[tuple[str, torch.nn.Parameter]],
|
| 150 |
+
muon_lr: float,
|
| 151 |
+
adamw_lr: float,
|
| 152 |
+
muon_momentum: float = 0.95,
|
| 153 |
+
adamw_betas: tuple[float, float] = (0.9, 0.95),
|
| 154 |
+
weight_decay: float = 0.0,
|
| 155 |
+
):
|
| 156 |
+
muon_params = []
|
| 157 |
+
adamw_params = []
|
| 158 |
+
for name, p in named_params:
|
| 159 |
+
if not p.requires_grad:
|
| 160 |
+
continue
|
| 161 |
+
# Embeddings have dim() == 2 but should still go to AdamW per the recipe.
|
| 162 |
+
is_embedding = "tok_emb" in name or "engram.slots" in name
|
| 163 |
+
if p.dim() >= 2 and not is_embedding:
|
| 164 |
+
muon_params.append(p)
|
| 165 |
+
else:
|
| 166 |
+
adamw_params.append(p)
|
| 167 |
+
|
| 168 |
+
self.muon = Muon(
|
| 169 |
+
muon_params,
|
| 170 |
+
lr=muon_lr,
|
| 171 |
+
momentum=muon_momentum,
|
| 172 |
+
nesterov=True,
|
| 173 |
+
weight_decay=weight_decay,
|
| 174 |
+
)
|
| 175 |
+
self.adamw = torch.optim.AdamW(
|
| 176 |
+
adamw_params,
|
| 177 |
+
lr=adamw_lr,
|
| 178 |
+
betas=adamw_betas,
|
| 179 |
+
weight_decay=weight_decay,
|
| 180 |
+
)
|
| 181 |
+
self.param_groups = self.muon.param_groups + self.adamw.param_groups
|
| 182 |
+
|
| 183 |
+
def step(self, closure=None):
|
| 184 |
+
if closure is not None:
|
| 185 |
+
raise NotImplementedError("HybridOptimizer does not support a closure.")
|
| 186 |
+
self.muon.step()
|
| 187 |
+
self.adamw.step()
|
| 188 |
+
|
| 189 |
+
def zero_grad(self, set_to_none: bool = True):
|
| 190 |
+
self.muon.zero_grad(set_to_none=set_to_none)
|
| 191 |
+
self.adamw.zero_grad(set_to_none=set_to_none)
|
| 192 |
+
|
| 193 |
+
def state_dict(self):
|
| 194 |
+
return {"muon": self.muon.state_dict(), "adamw": self.adamw.state_dict()}
|
| 195 |
+
|
| 196 |
+
def load_state_dict(self, sd):
|
| 197 |
+
self.muon.load_state_dict(sd["muon"])
|
| 198 |
+
self.adamw.load_state_dict(sd["adamw"])
|
code/tokenizer.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Train a fresh 8K BPE on a FineWeb-edu sample.
|
| 2 |
+
|
| 3 |
+
This is the 50M-scale variant of the 1M project's 4K BPE. We bump the default
|
| 4 |
+
vocab to 8192 and the document count to 50000 (was 50000 in 1M, kept the same
|
| 5 |
+
because the 1M doc-count was already saturating BPE merge quality at 4K vocab
|
| 6 |
+
-- doubling vocab needs roughly the same training set, not 2x more).
|
| 7 |
+
|
| 8 |
+
We do NOT reuse any FANT tokenizer here -- the point of this experiment family
|
| 9 |
+
is a clean small recipe with no external dependencies.
|
| 10 |
+
|
| 11 |
+
Output: tokenizer.json in the working dir (or wherever specified).
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import time
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
from tokenizers import Tokenizer
|
| 20 |
+
from tokenizers.models import BPE
|
| 21 |
+
from tokenizers.pre_tokenizers import ByteLevel as BLPre
|
| 22 |
+
from tokenizers.decoders import ByteLevel as BLDec
|
| 23 |
+
from tokenizers.processors import ByteLevel as BLPost
|
| 24 |
+
from tokenizers.trainers import BpeTrainer
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
SPECIAL_TOKENS = [
|
| 28 |
+
"<|pad|>", # 0
|
| 29 |
+
"<|bos|>", # 1
|
| 30 |
+
"<|eos|>", # 2
|
| 31 |
+
"<|unk|>", # 3
|
| 32 |
+
"<|im_start|>", # 4 -- chat role open
|
| 33 |
+
"<|im_end|>", # 5 -- chat role close
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _iter_fineweb(n_docs: int):
|
| 38 |
+
"""Yield up to `n_docs` text strings from the FineWeb-edu streaming feed."""
|
| 39 |
+
from datasets import load_dataset
|
| 40 |
+
|
| 41 |
+
ds = load_dataset(
|
| 42 |
+
"HuggingFaceFW/fineweb-edu",
|
| 43 |
+
name="default",
|
| 44 |
+
split="train",
|
| 45 |
+
streaming=True,
|
| 46 |
+
)
|
| 47 |
+
n = 0
|
| 48 |
+
for ex in ds:
|
| 49 |
+
if n >= n_docs:
|
| 50 |
+
return
|
| 51 |
+
text = ex.get("text", "")
|
| 52 |
+
if isinstance(text, str) and text.strip():
|
| 53 |
+
n += 1
|
| 54 |
+
yield text
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def train_tokenizer(out_path: str = "tokenizer.json", vocab_size: int = 8192, n_docs: int = 50000) -> str:
|
| 58 |
+
tok = Tokenizer(BPE(unk_token="<|unk|>"))
|
| 59 |
+
tok.pre_tokenizer = BLPre(add_prefix_space=False)
|
| 60 |
+
tok.decoder = BLDec()
|
| 61 |
+
tok.post_processor = BLPost(trim_offsets=False)
|
| 62 |
+
|
| 63 |
+
trainer = BpeTrainer(
|
| 64 |
+
vocab_size=vocab_size,
|
| 65 |
+
special_tokens=SPECIAL_TOKENS,
|
| 66 |
+
initial_alphabet=BLPre.alphabet(),
|
| 67 |
+
show_progress=False,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
print(f"[tokenizer] streaming up to {n_docs} FineWeb-edu docs...")
|
| 71 |
+
t0 = time.time()
|
| 72 |
+
docs = list(_iter_fineweb(n_docs))
|
| 73 |
+
print(f"[tokenizer] collected {len(docs)} docs in {time.time() - t0:.1f}s")
|
| 74 |
+
|
| 75 |
+
print(f"[tokenizer] training BPE vocab_size={vocab_size}...")
|
| 76 |
+
t0 = time.time()
|
| 77 |
+
tok.train_from_iterator(docs, trainer=trainer)
|
| 78 |
+
print(f"[tokenizer] trained in {time.time() - t0:.1f}s; vocab={tok.get_vocab_size()}")
|
| 79 |
+
|
| 80 |
+
out_dir = Path(out_path).parent
|
| 81 |
+
if str(out_dir) and not out_dir.exists():
|
| 82 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 83 |
+
tok.save(out_path)
|
| 84 |
+
print(f"[tokenizer] saved to {out_path}")
|
| 85 |
+
return out_path
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def load_tokenizer(path: str = "tokenizer.json") -> Tokenizer:
|
| 89 |
+
return Tokenizer.from_file(path)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Convenience accessors used by data.py / train.py
|
| 93 |
+
def special_token_id(tok: Tokenizer, name: str) -> int:
|
| 94 |
+
tid = tok.token_to_id(name)
|
| 95 |
+
assert tid is not None, f"{name} not in tokenizer"
|
| 96 |
+
return tid
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def main():
|
| 100 |
+
ap = argparse.ArgumentParser()
|
| 101 |
+
ap.add_argument("--out", default="tokenizer.json")
|
| 102 |
+
ap.add_argument("--vocab", type=int, default=8192)
|
| 103 |
+
ap.add_argument("--docs", type=int, default=50000)
|
| 104 |
+
args = ap.parse_args()
|
| 105 |
+
train_tokenizer(args.out, args.vocab, args.docs)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
main()
|
models/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a037d9d847c357510b0a09d4dd6c169cacbd988dd24aba945c416a8f93397e7e
|
| 3 |
+
size 109112123
|
models/pretrain.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:06dfe158490abfe41644ebf7f44942d98af32ebe3602892e25117cb8c623c49a
|
| 3 |
+
size 226660407
|
models/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|