cayley-large-2L-mlp_in-20B

A 1.31B-param GPT (24 layers, 16 heads, d=2048) with a 2-level CayleySAE (L0=2048, L1=32768) inserted at mlp_in in every block, trained on FineWeb-Edu (sample-100BT slice) for 20B tokens.

Best val_loss: 2.8081 (CE). Final-eval headline metrics: pile_ppl 17.65, HellaSwag 0.385, LAMBADA 0.311, CE-Bench interp 32.71.

This is the new canonical cayley-large-2L for the 20B-token campaign, superseding the prior cayley-24L2048-32k-2L-mlp_in-20b checkpoint (val 2.7933) which used RoPE. This run uses learned absolute pos enc per project direction.

Training recipe

n_layer=24, n_head=16, n_embd=2048
cayley levels: "11,16,0;15,32,256"  (L0=2048, L1=32768; per-parent budget)
location: mlp_in
seq_len=1024, pos_encoding=learned
optimizer: Muon + AdamW (decoupled)
muon_lr=8e-3 โ†’ 1e-4, adamw_lr=3e-4 โ†’ 1e-5
schedule: linear_warmdown, warmup=200, wf=0.5
batch_size=32, grad_accum=48 (tokens/iter = 1,572,864)
max_tokens = 20B, max_iters = 12,716
score standardization + forward standardization (zombie fix)

Hardware: 8ร— RTX PRO 6000 Blackwell (96 GB). Throughput ~248k tok/s. Wall ~22 h. The run was paused at iter 2000 to swap data source from FineWeb-Edu-10B (which would have wrapped) to a fresh 25B-token slice of sample-100BT; the same checkpoint resumed, no LR or optimizer state was perturbed.

Val loss progression

iter tokens val_loss
500 0.79B 3.8128
1000 1.57B 3.3703
1500 2.36B 3.2416
2000 3.15B 3.1689
2500 3.93B 3.0948
5000 7.86B ~3.00
7500 11.79B 2.92
10000 15.73B 2.83
12500 19.66B 2.78
12716 20B 2.8081 (best)

Warmdown began at iter ~6358 (wf=0.5).

Comparison

model tokens val_loss pile_ppl hellaswag lambada
this (cayley-large-2L-mlp_in-20B) 20B 2.8081 17.65 0.385 0.311
cayley-24L2048-32k-2L-mlp_in-20b (prior, RoPE) 20B 2.7933 20.32 0.383 0.304
cayley-24L2048-32k-2L-mlp_in-v3 (10B intermediate) 10B 2.9055 โ€” โ€” โ€”
vanilla-24L2048-parity-cold (1.3B baseline) 3.8B 2.7926 18.89 0.379 0.304

Notes: this run has a slightly worse val_loss than the RoPE-using predecessor (ฮ” +0.0148 nats) but materially better LM perplexity (pile 17.65 vs 20.32) and improved downstream accuracy on HellaSwag and LAMBADA. The val/pile divergence likely reflects the data source change โ€” the RoPE run trained on a 10B-wrapped corpus, while this run sees 20B fresh tokens from the larger sample-100BT slice.

Architecture notes

  • CayleySAE is a parameter-free dictionary; only per-feature biases are learned (~0.9M total). The model trains around the fixed algebraic structure.
  • Output is dense d=2048; sparsity lives in the intermediate code, not the output.
  • Positive and negative activations are orthogonal in meaning at every level; most analysis tools treat each feature as two virtual features (+ and -).

Loading

import torch
from sparse_nanogpt.train import load_checkpoint  # via deeptopk package

ckpt = torch.load("ckpt.pt", map_location="cpu", weights_only=False)
# ckpt contains: model, optimizer_states, iter_num, best_val_loss, config, ...

The checkpoint includes optimizer state (Muon momentum + AdamW state), suitable for further training. Inference-only consumers should pull ckpt["model"] and build the model from config.json.

Project context

Part of the Sparse NanoGPT / CayleySAE research program. The goal is to measure whether a sparsity-enforcing CayleySAE bottleneck improves interpretability of residual-stream representations vs a vanilla transformer of the same size and token budget. This model is part of the 20B-token canonical campaign that standardizes four configs (large/small ร— 2L/3L) on a single token budget for cross-model comparison.

Downloads last month
39
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Collection including markhenry/cayley-large-2L-mlp_in-20B