| |
| """scripts/train.py — real Tilelli/Vanilla trainer on TinyStories. |
| |
| Replaces the smoke ``train_demo.py``. Adds the things a serious run needs: |
| |
| * train/val split (separate ``.bin`` files produced by ``prepare_tinystories.py``) |
| * AdamW + cosine LR with warmup |
| * gradient clipping |
| * periodic eval-loss against val |
| * periodic checkpointing + resume from last |
| * deterministic seed |
| * a per-run directory under ``runs/`` with config.json + log.jsonl |
| |
| Models supported via ``--model``: |
| |
| * ``tilelli-fp32`` — TilelliLM with quantize=False (architecture, FP32 weights) |
| * ``tilelli-ternary`` — TilelliLM with quantize=True (the default Tilelli model) |
| * ``vanilla-fp32`` — pre-norm Transformer baseline at the same param budget |
| |
| The three are param-matched at ~10 M each via the configs in |
| ``scripts/configs.py``. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import random |
| import sys |
| import time |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
| from typing import Iterator |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| import numpy as np |
| import torch |
| from torch import Tensor |
|
|
| from tilelli.baselines.vanilla import VanillaLM |
| from tilelli.core.tilelli_lm import TilelliLM |
|
|
|
|
| def _make_tilelli_lite(cfg, max_seq_len): |
| from tilelli.core.tilelli_lite import TilelliLiteLM |
| n_heads = getattr(cfg, "n_heads", 8) or 8 |
| return TilelliLiteLM( |
| vocab_size=256, |
| d_model=cfg.d_model, |
| n_layers=cfg.n_layers, |
| n_heads=n_heads, |
| top_k=cfg.top_k or 16, |
| ffn_expand=cfg.dense_expand or 4, |
| max_seq_len=max_seq_len, |
| quantize=cfg.quantize, |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class ModelCfg: |
| name: str |
| builder: str |
| quantize: bool |
| d_model: int |
| n_layers: int |
| d_head: int |
| top_k: int |
| n_heads: int |
| expand: int |
| n_banks: int = 1 |
| per_row: bool = False |
| hadamard: bool = False |
| lsq: bool = False |
| dense_expand: int = 2 |
| fp_attention: bool = False |
| top_k_routing: int = 0 |
|
|
|
|
| MODEL_CFGS: dict[str, ModelCfg] = { |
| "tilelli-fp32": ModelCfg( |
| name="tilelli-fp32", |
| builder="tilelli", |
| quantize=False, |
| d_model=512, |
| n_layers=7, |
| d_head=64, |
| top_k=8, |
| n_heads=0, |
| expand=0, |
| ), |
| "tilelli-ternary": ModelCfg( |
| name="tilelli-ternary", |
| builder="tilelli", |
| quantize=True, |
| d_model=512, |
| n_layers=7, |
| d_head=64, |
| top_k=8, |
| n_heads=0, |
| expand=0, |
| ), |
| "vanilla-fp32": ModelCfg( |
| name="vanilla-fp32", |
| builder="vanilla", |
| quantize=False, |
| d_model=320, |
| n_layers=8, |
| d_head=40, |
| top_k=0, |
| n_heads=8, |
| expand=4, |
| ), |
| |
| "tilelli-lite-fp32": ModelCfg( |
| name="tilelli-lite-fp32", |
| builder="tilelli_lite", |
| quantize=False, |
| d_model=256, n_layers=8, d_head=32, top_k=16, |
| n_heads=8, expand=0, dense_expand=4, |
| ), |
| "tilelli-lite-ternary": ModelCfg( |
| name="tilelli-lite-ternary", |
| builder="tilelli_lite", |
| quantize=True, |
| d_model=256, n_layers=8, d_head=32, top_k=16, |
| n_heads=8, expand=0, dense_expand=4, |
| ), |
| } |
|
|
|
|
|
|
| def build_model(cfg: ModelCfg, max_seq_len: int) -> torch.nn.Module: |
| if cfg.builder == "tilelli": |
| return TilelliLM( |
| vocab_size=256, |
| d_model=cfg.d_model, |
| n_layers=cfg.n_layers, |
| d_head=cfg.d_head, |
| top_k=cfg.top_k, |
| max_seq_len=max_seq_len, |
| quantize=cfg.quantize, |
| n_banks=cfg.n_banks, |
| per_row=cfg.per_row, |
| hadamard=cfg.hadamard, |
| lsq=cfg.lsq, |
| dense_expand=cfg.dense_expand, |
| fp_attention=cfg.fp_attention, |
| top_k_routing=cfg.top_k_routing, |
| ) |
| if cfg.builder == "vanilla": |
| return VanillaLM( |
| vocab_size=256, |
| d_model=cfg.d_model, |
| n_layers=cfg.n_layers, |
| n_heads=cfg.n_heads, |
| expand=cfg.expand, |
| max_seq_len=max_seq_len, |
| ) |
| if cfg.builder == "tilelli_lite": |
| return _make_tilelli_lite(cfg, max_seq_len) |
| raise ValueError(f"unknown builder {cfg.builder!r}") |
|
|
|
|
| |
| |
| |
|
|
|
|
| class ByteShard: |
| """Read-only memmap of a packed uint8 token shard.""" |
|
|
| def __init__(self, path: Path) -> None: |
| self.path = path |
| self.data = np.memmap(path, dtype=np.uint8, mode="r") |
| self.n = int(self.data.size) |
|
|
| def sample_batch(self, batch_size: int, seq_len: int, rng: np.random.Generator) -> Tensor: |
| |
| max_start = self.n - (seq_len + 1) |
| starts = rng.integers(0, max_start, size=batch_size) |
| out = np.empty((batch_size, seq_len + 1), dtype=np.uint8) |
| for i, s in enumerate(starts): |
| out[i] = self.data[s : s + seq_len + 1] |
| return torch.from_numpy(out.astype(np.int64)) |
|
|
| def iter_eval_batches( |
| self, batch_size: int, seq_len: int, n_batches: int, rng: np.random.Generator |
| ) -> Iterator[Tensor]: |
| for _ in range(n_batches): |
| yield self.sample_batch(batch_size, seq_len, rng) |
|
|
|
|
| class InductionStream: |
| """In-memory generator that emits synthetic induction-heads sequences. |
| |
| Wire-compatible with ByteShard (same .sample_batch / .iter_eval_batches |
| interface). Each batch is freshly generated from |
| `tilelli.sherlock.induction_heads.make_induction_batch` — so a "step" of |
| training sees a fresh patch of (random body) + (planted KEY-VALUE |
| pattern). The model is trained to do next-token prediction on the whole |
| sequence; the planted pattern provides a non-trivial signal that only |
| a model with working in-context recall can exploit. |
| |
| `n` here is a notional "shard size" so the loss-per-token reporting |
| in the main train loop has a sane denominator; for the streaming |
| source it's just the per-sample token count. |
| """ |
|
|
| def __init__(self, vocab_size: int = 256, min_gap: int = 8) -> None: |
| self.vocab_size = vocab_size |
| self.min_gap = min_gap |
| self.n = 1_000_000 |
|
|
| def sample_batch(self, batch_size: int, seq_len: int, rng: np.random.Generator) -> Tensor: |
| |
| |
| |
| |
| from tilelli.sherlock.induction_heads import make_dense_induction_batch |
| seed = int(rng.integers(0, 2**31 - 1)) |
| tgen = torch.Generator() |
| tgen.manual_seed(seed) |
| ids = make_dense_induction_batch( |
| batch_size=batch_size, seq_len=seq_len + 1, |
| rng=tgen, vocab_size=self.vocab_size, n_keys=16, |
| min_gap=self.min_gap, |
| ) |
| return ids |
|
|
| def iter_eval_batches( |
| self, batch_size: int, seq_len: int, n_batches: int, rng: np.random.Generator |
| ) -> Iterator[Tensor]: |
| for _ in range(n_batches): |
| yield self.sample_batch(batch_size, seq_len, rng) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class _MultiOptim: |
| """Forwards zero_grad / step / state_dict / load_state_dict to a list of |
| underlying optimizers. Exposes a concatenated param_groups, with each group |
| annotated with its own peak_lr so the cosine schedule can scale them |
| proportionally (Muon's effective LR is ~60× AdamW's). |
| """ |
|
|
| def __init__(self, optims, peak_lrs): |
| assert len(optims) == len(peak_lrs) |
| self._optims = list(optims) |
| for opt, peak in zip(self._optims, peak_lrs): |
| for g in opt.param_groups: |
| g["peak_lr"] = peak |
|
|
| @property |
| def param_groups(self): |
| groups = [] |
| for opt in self._optims: |
| groups.extend(opt.param_groups) |
| return groups |
|
|
| def zero_grad(self, set_to_none=True): |
| for opt in self._optims: |
| opt.zero_grad(set_to_none=set_to_none) |
|
|
| def step(self, closure=None): |
| for opt in self._optims: |
| opt.step() |
|
|
| def state_dict(self): |
| return {"optims": [opt.state_dict() for opt in self._optims]} |
|
|
| def load_state_dict(self, sd): |
| for opt, s in zip(self._optims, sd["optims"]): |
| opt.load_state_dict(s) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def lr_at(step: int, total_steps: int, peak_lr: float, warmup: int, min_ratio: float) -> float: |
| if step < warmup: |
| return peak_lr * (step + 1) / max(1, warmup) |
| progress = (step - warmup) / max(1, total_steps - warmup) |
| progress = min(1.0, max(0.0, progress)) |
| cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| return peak_lr * (min_ratio + (1.0 - min_ratio) * cosine) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def evaluate( |
| model: torch.nn.Module, |
| val: ByteShard, |
| batch_size: int, |
| seq_len: int, |
| n_batches: int, |
| rng: np.random.Generator, |
| device: torch.device, |
| autocast_dtype=None, |
| ) -> float: |
| model.eval() |
| losses: list[float] = [] |
| with torch.no_grad(): |
| for chunk in val.iter_eval_batches(batch_size, seq_len, n_batches, rng): |
| chunk = chunk.to(device, non_blocking=True) |
| if autocast_dtype is not None: |
| with torch.amp.autocast(device.type, dtype=autocast_dtype): |
| loss = model.loss(chunk[:, :-1], chunk[:, 1:]) |
| else: |
| loss = model.loss(chunk[:, :-1], chunk[:, 1:]) |
| losses.append(float(loss.item())) |
| model.train() |
| return float(np.mean(losses)) |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--model", required=True, choices=list(MODEL_CFGS.keys())) |
| ap.add_argument("--data-dir", type=Path, default=Path("data/tinystories")) |
| ap.add_argument("--steps", type=int, default=50_000) |
| ap.add_argument("--seq-len", type=int, default=256) |
| ap.add_argument("--batch-size", type=int, default=16) |
| ap.add_argument("--peak-lr", type=float, default=3e-4) |
| ap.add_argument("--min-lr-ratio", type=float, default=0.01) |
| ap.add_argument("--warmup", type=int, default=500) |
| ap.add_argument("--weight-decay", type=float, default=0.01) |
| ap.add_argument("--grad-clip", type=float, default=1.0) |
| ap.add_argument("--eval-every", type=int, default=1000) |
| ap.add_argument("--eval-batches", type=int, default=20) |
| ap.add_argument("--ckpt-every", type=int, default=2000) |
| ap.add_argument("--log-every", type=int, default=50) |
| ap.add_argument("--seed", type=int, default=1234) |
| ap.add_argument("--threads", type=int, default=8) |
| ap.add_argument("--device", default="auto", |
| help="auto | cuda | cpu | cuda:0 etc.") |
| ap.add_argument("--autocast", default="off", |
| choices=["off", "bf16", "fp16"], |
| help="Mixed-precision autocast for forward+backward (CUDA only)") |
| ap.add_argument("--run-dir", type=Path, default=None, |
| help="Directory for this run. Defaults to runs/<model>_<timestamp>.") |
| ap.add_argument("--resume", action="store_true", |
| help="Resume from runs/<run-dir>/last.pt if present.") |
| ap.add_argument("--optimizer", default="adamw", choices=["adamw", "muon"], |
| help="adamw (default) | muon (Muon for 2D+, AdamW for 1D)") |
| ap.add_argument("--muon-lr-mult", type=float, default=60.0, |
| help="Muon LR multiplier vs AdamW peak_lr; per Keller Jordan ~60×") |
| ap.add_argument("--data-source", default="bin", |
| choices=["bin", "induction"], |
| help="bin: memmap train.bin/valid.bin (default). " |
| "induction: generate synthetic induction-heads sequences " |
| "on the fly (no data-dir needed).") |
| args = ap.parse_args() |
|
|
| if args.device == "auto": |
| args.device = "cuda" if torch.cuda.is_available() else "cpu" |
| device = torch.device(args.device) |
| if device.type == "cpu": |
| torch.set_num_threads(args.threads) |
| if device.type == "cuda": |
| torch.set_float32_matmul_precision("high") |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| autocast_dtype = {"off": None, "bf16": torch.bfloat16, "fp16": torch.float16}[args.autocast] |
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
| random.seed(args.seed) |
|
|
| |
| if args.run_dir is None: |
| ts = time.strftime("%Y-%m-%d_%H-%M-%S") |
| args.run_dir = Path("runs") / f"{args.model}_{ts}" |
| args.run_dir.mkdir(parents=True, exist_ok=True) |
| log_path = args.run_dir / "log.jsonl" |
| cfg_path = args.run_dir / "config.json" |
| last_ckpt = args.run_dir / "last.pt" |
| best_ckpt = args.run_dir / "best.pt" |
|
|
| |
| if args.data_source == "induction": |
| |
| |
| |
| train = InductionStream(vocab_size=256, min_gap=8) |
| val = InductionStream(vocab_size=256, min_gap=8) |
| print(f"data: induction-heads (synthetic, vocab=256, min_gap=8)") |
| else: |
| train = ByteShard(args.data_dir / "train.bin") |
| val = ByteShard(args.data_dir / "valid.bin") |
| print(f"train: {train.n:,} tokens val: {val.n:,} tokens") |
|
|
| |
| cfg = MODEL_CFGS[args.model] |
| model = build_model(cfg, max_seq_len=args.seq_len).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"model {cfg.name}: {n_params:,} params ({n_params/1e6:.2f}M) on {device}") |
|
|
| if args.optimizer == "muon": |
| from tilelli.optimisers import Muon, split_params_for_muon |
| muon_params, adamw_params = split_params_for_muon(model) |
| muon_peak_lr = args.peak_lr * args.muon_lr_mult |
| optim_muon = Muon( |
| muon_params, lr=muon_peak_lr, momentum=0.95, |
| weight_decay=args.weight_decay, nesterov=True, ns_steps=5, |
| ) |
| optim_adamw = torch.optim.AdamW( |
| adamw_params, lr=args.peak_lr, |
| weight_decay=args.weight_decay, betas=(0.9, 0.95), |
| ) |
| optim = _MultiOptim([optim_muon, optim_adamw], peak_lrs=[muon_peak_lr, args.peak_lr]) |
| print(f"optimizer: muon ({len(muon_params)} 2D params, lr {muon_peak_lr:.1e}) + adamw ({len(adamw_params)} 1D params, lr {args.peak_lr:.1e})") |
| else: |
| optim = torch.optim.AdamW( |
| model.parameters(), |
| lr=args.peak_lr, |
| weight_decay=args.weight_decay, |
| betas=(0.9, 0.95), |
| ) |
|
|
| |
| start_step = 0 |
| best_val = float("inf") |
| if args.resume and last_ckpt.exists(): |
| sd = torch.load(last_ckpt, map_location="cpu") |
| model.load_state_dict(sd["model"]) |
| optim.load_state_dict(sd["optim"]) |
| start_step = int(sd.get("step", 0)) |
| best_val = float(sd.get("best_val", float("inf"))) |
| print(f"resumed from {last_ckpt} at step {start_step}, best_val {best_val:.4f}") |
|
|
| |
| cfg_path.write_text(json.dumps({ |
| "model_cfg": asdict(cfg), |
| "args": {k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()}, |
| "n_params": n_params, |
| }, indent=2)) |
|
|
| log = log_path.open("a", buffering=1) |
| rng_train = np.random.default_rng(args.seed + 1) |
| rng_eval = np.random.default_rng(args.seed + 2) |
|
|
| model.train() |
| t0 = time.time() |
| last_log_t = t0 |
| running_loss = 0.0 |
| running_n = 0 |
| for step in range(start_step, args.steps): |
| |
| lr = lr_at(step, args.steps, args.peak_lr, args.warmup, args.min_lr_ratio) |
| for g in optim.param_groups: |
| peak = g.get("peak_lr", args.peak_lr) |
| g["lr"] = lr_at(step, args.steps, peak, args.warmup, args.min_lr_ratio) |
|
|
| chunk = train.sample_batch(args.batch_size, args.seq_len, rng_train).to(device, non_blocking=True) |
| optim.zero_grad() |
| if autocast_dtype is not None: |
| with torch.amp.autocast(device.type, dtype=autocast_dtype): |
| loss = model.loss(chunk[:, :-1], chunk[:, 1:]) |
| else: |
| loss = model.loss(chunk[:, :-1], chunk[:, 1:]) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| optim.step() |
|
|
| running_loss += float(loss.item()) |
| running_n += 1 |
|
|
| if (step + 1) % args.log_every == 0: |
| now = time.time() |
| ms = (now - last_log_t) / args.log_every * 1000 |
| avg = running_loss / max(1, running_n) |
| print(f"step {step+1:>6d}/{args.steps} loss {avg:.4f} lr {lr:.2e} {ms:.0f} ms/step") |
| log.write(json.dumps({"event": "train", "step": step+1, "loss": avg, "lr": lr, "ms_per_step": ms}) + "\n") |
| running_loss = 0.0 |
| running_n = 0 |
| last_log_t = now |
|
|
| if (step + 1) % args.eval_every == 0: |
| v = evaluate(model, val, args.batch_size, args.seq_len, args.eval_batches, rng_eval, device, autocast_dtype) |
| print(f" val loss {v:.4f} best {min(best_val, v):.4f}") |
| log.write(json.dumps({"event": "val", "step": step+1, "val_loss": v, "best_val": min(best_val, v)}) + "\n") |
| if v < best_val: |
| best_val = v |
| torch.save({ |
| "model": model.state_dict(), |
| "step": step + 1, |
| "best_val": best_val, |
| "model_cfg": asdict(cfg), |
| }, best_ckpt) |
|
|
| if (step + 1) % args.ckpt_every == 0: |
| torch.save({ |
| "model": model.state_dict(), |
| "optim": optim.state_dict(), |
| "step": step + 1, |
| "best_val": best_val, |
| "model_cfg": asdict(cfg), |
| }, last_ckpt) |
|
|
| |
| v_final = evaluate(model, val, args.batch_size, args.seq_len, args.eval_batches, rng_eval, device, autocast_dtype) |
| log.write(json.dumps({"event": "final", "step": args.steps, "val_loss": v_final, "best_val": min(best_val, v_final), "wall_seconds": time.time()-t0}) + "\n") |
| torch.save({ |
| "model": model.state_dict(), |
| "optim": optim.state_dict(), |
| "step": args.steps, |
| "best_val": min(best_val, v_final), |
| "model_cfg": asdict(cfg), |
| }, last_ckpt) |
| log.close() |
| print(f"done. final val {v_final:.4f} best val {min(best_val, v_final):.4f} wall {(time.time()-t0)/3600:.2f}h") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|