Shard-1 / code /config.py
Crownelius's picture
Initial release: Shard-40m-v1 (54.5M dense transformer, anneal final)
025878f verified
"""Config dataclass for the toy 50M LM.
Scaled up from the toy_1m_gemma4_dsv4 baseline. Architectural levers stay the
same (alternating SLIDE/GLOBAL Gemma 4 attention, optional Muon, optional
512-slot Engram, full v2 stabilisation), only the shape numbers change.
Two architectural variants are flag-gated:
attention_pattern:
"all_global" -- every layer is full causal attention (baseline).
"gemma4" -- alternating SLIDE/GLOBAL across layers; last layer is GLOBAL.
optimizer:
"adamw" -- AdamW for everything (baseline).
"muon" -- Muon for params with .dim() >= 2; AdamW for embeddings + 1D.
engram_enabled: optional 512-slot external memory bank with zero-init gate.
When attention_pattern == "all_global" and optimizer == "adamw" and engram_enabled
is False, training math is bit-identical to a plain causal transformer baseline.
Defaults
--------
* vocab=8192 (up from 4096): fresh BPE on a larger FineWeb-edu sample.
* dim=512, n_layers=12, n_heads=8, head_dim=64.
* mlp_hidden=2048 (4x dim, SwiGLU).
* max_seq_len=8192 (up from 4096).
* sliding_window=1024 ("larger model" Gemma 4 tier; 1M used 512).
* All v2 stabilisers ON: lm_head_logit_cap=30.0, z_loss_weight=1e-4, lr_schedule="wsd".
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
AttentionPattern = Literal["all_global", "gemma4"]
OptimizerName = Literal["adamw", "muon"]
LRSchedule = Literal["cosine", "wsd"]
@dataclass
class Config:
# ---------- model shape ----------
vocab_size: int = 8192
dim: int = 512
n_layers: int = 12
n_heads: int = 8
head_dim: int = 64 # n_heads * head_dim must equal dim
mlp_hidden: int = 2048
max_seq_len: int = 8192
# ---------- gemma4 SWA ----------
attention_pattern: AttentionPattern = "gemma4"
sliding_window: int = 1024
# ---------- engram (off by default) ----------
engram_enabled: bool = False
engram_slots: int = 512
engram_inject_layer: int = 6 # mid-stack for the 12-layer build
# ---------- training ----------
optimizer: OptimizerName = "muon"
rope_base: float = 10000.0
norm_eps: float = 1e-5
dropout: float = 0.0
tie_embeddings: bool = True
# ---------- CE stabilisation (Gemma-2 logit cap + PaLM z-loss) ----------
# ON by default at 50M scale -- the 1M project added these as a v2 bolt-on
# but at 50M with bf16 they're standard practice (DeepSeek V2/3, Gemma 2/3,
# PaLM). Bit-identical to the un-stabilised path when both knobs are 0/None.
lm_head_logit_cap: float | None = 30.0
z_loss_weight: float = 1e-4
# ---------- LR schedule ----------
# WSD by default at 50M (per Apr 2026 small-LM research; lets the head
# decay over the last 20 % of post-warmup, much smoother than cosine).
lr_schedule: LRSchedule = "wsd"
wsd_decay_frac: float = 0.2
# ---------- bookkeeping ----------
init_std: float = 0.02
def __post_init__(self) -> None:
assert self.n_heads * self.head_dim == self.dim, (
f"n_heads*head_dim={self.n_heads * self.head_dim} != dim={self.dim}"
)
assert self.attention_pattern in ("all_global", "gemma4")
assert self.optimizer in ("adamw", "muon")
assert self.lr_schedule in ("cosine", "wsd")
assert 0.0 <= self.wsd_decay_frac <= 1.0
assert self.z_loss_weight >= 0.0
assert self.lm_head_logit_cap is None or self.lm_head_logit_cap > 0
# Last layer must be GLOBAL when using gemma4 (canonical invariant).
# Concretely: layer i is GLOBAL iff (i % 2 == 1) for i in [0, n_layers).
# n_layers must be even, last index n_layers-1 must be odd.
if self.attention_pattern == "gemma4":
assert self.n_layers % 2 == 0 and self.n_layers >= 2, (
"gemma4 pattern requires even n_layers >= 2 so the last layer is GLOBAL"
)
def attention_kind(self, layer_idx: int) -> Literal["slide", "global"]:
"""Return whether `layer_idx` is a sliding-window or global-attention layer."""
if self.attention_pattern == "all_global":
return "global"
# gemma4: even idx = SLIDE, odd idx = GLOBAL. Last layer (n_layers-1) is odd
# for any even n_layers, so it is GLOBAL.
return "global" if (layer_idx % 2 == 1) else "slide"