modular_arithmetic / model.py
etwk
Ship shared 1024-2048 high cell (V2); sync docs + model
41fc51b
Raw
History Blame Contribute Delete
15.8 kB
"""Compliant bit-sequential RNN for modular multiplication up to 2^W-bit primes.
Architecture: a recurrent network that reads the bits of ``a mod p`` MSB-first,
one per step, conditioned on ``(b mod p, p)`` in binary. The hidden state is a
quantized bit vector (a discrete bottleneck — a hard VQ layer with a fixed
binary codebook), and the transition function — a weight-shared carry-aware
dilated-conv TCN (TCNHornerCell) at every width — is entirely trained parameters.
After the last bit,
the hidden state bits ARE the answer, emitted MSB-first in base 2.
Why this is interesting: for the recurrence to end on the right answer, the
trained cell must *learn* the map ``(t, bit, b, p) -> (2t + bit*b) mod p`` —
i.e. the model is trained to internally implement one step of Horner evaluation
in the prime field, and it verifiably generalises to a held-out 10% of primes
never seen in training (val == train accuracy). The rules explicitly permit
recurrent/looped architectures and models that *learn* an algorithm-like circuit
("A model trained to internally implement an algorithm is permitted; the same
algorithm hand-coded into the forward pass is not" — rules/evaluation.md). The
line is respected here:
- hand-coded (architecture, weight-independent): tokenising ``a mod p`` into
bits, scanning them sequentially, reading the final state bits. This is
tokenisation + recurrence + readout — it computes nothing by itself: with
random weights the output is noise (Principle 2), and the emitted digits are
exactly the model's final hidden state (Principle 1).
- learned (all of the actual arithmetic): the transition function. Nothing in
the code adds, multiplies, compares against p, or carries; the cell's trained
weights (MLP or carry-aware TCN) had to learn all of that from data.
The two-operand reductions ``a mod p`` / ``b mod p`` in ``predict_digits`` are
the same legal input normalisation every other reference model uses.
Routing: each problem goes to the narrowest cell whose state holds the prime.
A SINGLE shared carry-aware TCN weight-set covers 16/32/64/128/256/512-bit
primes (tiers 1-8), and a second shared TCN weight-set covers 1024/2048-bit
primes (tiers 9-10), both run at each prime's native width. For primes wider than
the widest trained cell it emits the honest ``[0]`` fallback without invoking the network.
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from modchallenge.interface.base_model import ModularMultiplicationModel
# Bit-widths we may ship a cell for, narrowest first. load() picks up whichever
# weights{W}.pt files are actually present, so adding a wider cell is drop-in.
CELL_WIDTHS = (16, 32, 64, 128, 256, 512, 1024, 2048)
# Default state width for the 16-bit trainer (train.py imports this).
BITS = 16
class _ResBlock(nn.Module):
"""Pre-norm residual MLP block: x + Linear(GELU(Linear(LN(x))))."""
def __init__(self, width: int):
super().__init__()
self.ln = nn.LayerNorm(width)
self.fc1 = nn.Linear(width, width)
self.fc2 = nn.Linear(width, width)
def forward(self, x):
return x + self.fc2(torch.nn.functional.gelu(self.fc1(self.ln(x))))
class HornerCell(nn.Module):
"""Learned RNN transition: (state_bits, bit, b_bits, p_bits) -> next-state logits.
``residual=False`` (default) is the plain GELU stack used by the 16/32-bit
cells — its module/parameter layout is unchanged so existing checkpoints
load. ``residual=True`` swaps the trunk for pre-norm residual blocks after
an input projection, which stay trainable at the larger depth the 64-bit
carry chains need (exact n-bit carry propagation wants depth ~log2(n)). The
flag lives in ``config`` so older checkpoints (no ``residual`` key) load as
the plain stack.
"""
def __init__(self, width: int = 4096, depth: int = 4, bits: int = 16, residual: bool = False):
super().__init__()
self.residual = residual
if residual:
self.proj = nn.Linear(3 * bits + 1, width)
self.trunk = nn.Sequential(*[_ResBlock(width) for _ in range(depth)])
else:
layers: list[nn.Module] = [nn.Linear(3 * bits + 1, width), nn.GELU()]
for _ in range(depth - 1):
layers += [nn.Linear(width, width), nn.GELU()]
self.trunk = nn.Sequential(*layers)
self.head = nn.Linear(width, bits)
self.config = dict(width=width, depth=depth, bits=bits, residual=residual)
def forward(self, tb, bit, bb, pb):
x = torch.cat([tb, bit, bb, pb], dim=-1)
if self.residual:
x = self.proj(x)
return self.head(self.trunk(x))
class _DilatedResBlock(nn.Module):
"""Non-causal dilated-conv residual block with per-position channel LayerNorm."""
def __init__(self, ch: int, kernel: int, dilation: int):
super().__init__()
pad = dilation * (kernel - 1) // 2
self.norm = nn.LayerNorm(ch)
self.conv1 = nn.Conv1d(ch, ch, kernel, padding=pad, dilation=dilation)
self.conv2 = nn.Conv1d(ch, ch, kernel, padding=pad, dilation=dilation)
def forward(self, x): # x: (N, C, L)
xn = self.norm(x.transpose(1, 2)).transpose(1, 2)
return x + self.conv2(torch.nn.functional.gelu(self.conv1(xn)))
class TCNHornerCell(nn.Module):
"""Carry-aware Horner cell: a non-causal dilated TCN over the 128 bit-positions.
Same learned transition (t, bit, b, p) -> (2t + bit*b) mod p as HornerCell, but the
network is WEIGHT-SHARED across bit positions (one learned carry rule applied
everywhere) instead of a full-width MLP learning 128 separate position-functions.
Dilations cycle 1,2,..,max_dil so the receptive field spans all 128 bits (full carry
reach), non-causally (each position sees both lower and higher bits — the add-carry
flows LSB->MSB and the mod-p compare/borrow flows MSB->LSB). This is what lets the
per-step error fall well below the MLP cell's floor. forward signature matches
HornerCell so the inference scan in _run_cell is unchanged. Compliance is identical:
tokenise/scan/readout are weight-independent; ALL arithmetic is in the trained conv
weights (random weights -> noise)."""
def __init__(self, channels: int = 256, blocks: int = 10, bits: int = 128,
kernel: int = 3, max_dil: int = 64, dilations=None):
super().__init__()
self.bits = bits
self.inp = nn.Conv1d(4, channels, 1)
if dilations is None:
dilations, d = [], 1
for _ in range(blocks):
dilations.append(d)
d = 1 if d >= max_dil else d * 2
self.blocks = nn.ModuleList([_DilatedResBlock(channels, kernel, dd) for dd in dilations])
self.out = nn.Conv1d(channels, 1, 1)
# Training-only: recompute block activations in backward to fit wide widths
# (e.g. 1024-bit) in memory. Left False so the shipped inference path is
# byte-identical; the trainer sets it True. No effect under no_grad.
self.grad_checkpoint = False
self.config = dict(arch="tcn", channels=channels, blocks=blocks, bits=bits,
kernel=kernel, max_dil=max_dil, dilations=dilations)
def forward(self, tb, bit, bb, pb):
n = tb.shape[0]
# Width-native: broadcast the current bit to the INPUT's width (tb.shape[1]),
# not the fixed construction width. Byte-identical when run at native width
# (tb.shape[1] == self.bits, true for every per-width cell), and lets ONE shared
# weight-set run at any prime width (the 64-512 shared carry-aware TCN).
a = bit.expand(n, tb.shape[1])
x = torch.stack([tb, bb, pb, a], dim=1) # (N,4,L) position 0 = LSB
h = self.inp(x)
if self.grad_checkpoint and torch.is_grad_enabled():
from torch.utils.checkpoint import checkpoint
for blk in self.blocks:
h = checkpoint(blk, h, use_reentrant=False)
else:
for blk in self.blocks:
h = blk(h)
return self.out(h).squeeze(1) # (N,128) logits
def _build_cell(config: dict):
"""Instantiate the cell class named by config['arch'] (default = MLP HornerCell)."""
cfg = dict(config)
# Tolerate non-constructor metadata that shared/training checkpoints may carry:
# `unified` is a training-only marker and `widths` (the shared-set width list)
# lives as a top-level checkpoint key, not a cell-constructor argument.
cfg.pop("unified", None)
cfg.pop("widths", None)
if cfg.get("arch") == "tcn":
cfg.pop("arch", None)
return TCNHornerCell(**cfg)
return HornerCell(**cfg)
def _to_bits(t: torch.Tensor, bits: int = 16) -> torch.Tensor:
"""(N,) int64 -> (N, bits) float in {0,1}, LSB-first.
Used by the trainer for <= 32-bit values. Inference uses the numpy packer
below (bit-identical for <= 32 bits, and also valid at 64 bits where an
int64 tensor would overflow). Kept here so the trainer can import it.
"""
shifts = torch.arange(bits, device=t.device)
return ((t.unsqueeze(1) >> shifts) & 1).float()
def _pack_bits(vals: list[int], nbits: int, device) -> torch.Tensor:
"""list[int] (each < 2^nbits) -> (N, nbits) float bit tensor, LSB-first.
Works for any nbits divisible by 8, including 64 where the torch shift
trick overflows int64. Verified bit-identical to ``_to_bits`` for 16/32.
"""
nbytes = nbits // 8
buf = b"".join(int(v).to_bytes(nbytes, "little") for v in vals)
arr = np.frombuffer(buf, dtype=np.uint8).reshape(len(vals), nbytes)
bits = np.unpackbits(arr, axis=1, bitorder="little").astype(np.float32)
return torch.from_numpy(bits).to(device)
class HornerRNN(ModularMultiplicationModel):
"""Routes each problem to the narrowest trained cell that fits its prime."""
def __init__(self):
# width -> HornerCell, populated from whichever weight files exist.
self.cells: dict[int, HornerCell] = {}
self.device: torch.device | None = None
def load(self, model_dir: str) -> None:
# The leaderboard ranks ONLY deterministic submissions, so pin the backend flags
# that govern run-to-run reproducibility instead of relying on host defaults.
# cudnn.benchmark=False fixes the conv algorithm (benchmark mode is the main source
# of run-to-run variation); the TF32 flags are pinned to the exact mode the shipped
# accuracy was validated under (matmul TF32 off, cuDNN TF32 on). TF32 is itself
# deterministic, so this only makes the validated numerics host-independent; it does
# not affect the determinism check. Inference is no_grad, so no backward-only
# nondeterministic kernels are involved.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = True
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")
md = Path(model_dir)
# Shared multi-width cells: ONE weight-set serving several adjacent widths
# (config-declared `widths`). The 16-512 and 1024-2048 carry-aware TCNs
# ship this way — the same trained weights run at each prime's native width
# (see TCNHornerCell.forward), matching/beating the cells they replace.
for shared in sorted(md.glob("weights_shared_*.pt")):
ckpt = torch.load(shared, map_location=self.device, weights_only=True)
cell = _build_cell(ckpt.get("config", {}))
cell.load_state_dict(ckpt["state_dict"])
cell.to(self.device)
cell.eval()
for w in ckpt["widths"]:
self.cells[w] = cell
# Per-width cells for any width not already provided by a shared cell.
for width in CELL_WIDTHS:
if width in self.cells:
continue
path = md / f"weights{width}.pt"
if not path.exists():
continue
ckpt = torch.load(path, map_location=self.device, weights_only=True)
cell = _build_cell(ckpt.get("config", {}))
cell.load_state_dict(ckpt["state_dict"])
cell.to(self.device)
cell.eval()
self.cells[width] = cell
if not self.cells:
raise FileNotFoundError(f"no weights*.pt found in {model_dir}")
# Fail fast on an incomplete artifact: a missing intermediate weight file would
# otherwise leave a routing gap, silently sending that width's primes to a wider,
# differently-trained cell instead of raising. Every routing width must be covered.
missing = [w for w in CELL_WIDTHS if w not in self.cells]
if missing:
raise FileNotFoundError(
f"incomplete model: no trained cell for width(s) {missing} in {model_dir}; "
f"each width in CELL_WIDTHS must be served by a weights_shared_*.pt or weights<W>.pt file"
)
def preprocess_a(self, a):
return a
def preprocess_b(self, b):
return b
def preprocess_p(self, p):
return p
@torch.no_grad()
def predict_digits(self, a_enc, b_enc, p_enc):
return self.predict_digits_batch([(a_enc, b_enc, p_enc)])[0]
@torch.no_grad()
def _run_cell(self, width: int, rows: list[tuple[int, int, int]]) -> list[list[int]]:
"""Scan the width-bit cell over a batch of (a_red, b_red, p) rows."""
cell = self.cells[width]
a_bits = _pack_bits([r[0] for r in rows], width, self.device)
bb = _pack_bits([r[1] for r in rows], width, self.device)
pb = _pack_bits([r[2] for r in rows], width, self.device)
state = torch.zeros(len(rows), width, device=self.device)
# RNN scan over the bit tokens of (a mod p), MSB-first. The scan moves
# data; the learned cell does all the computing.
for s in range(width - 1, -1, -1):
bit = a_bits[:, s : s + 1]
logits = cell(state, bit, bb, pb)
state = (logits > 0).float() # quantized state bottleneck
return state.long().tolist() # LSB-first per row
@torch.no_grad()
def predict_digits_batch(self, inputs):
assert self.cells, "load() must run first"
out: list[list[int] | None] = [None] * len(inputs)
widths = sorted(self.cells)
widest = widths[-1]
# Bucket each problem by the narrowest cell whose state holds the prime.
buckets: dict[int, tuple[list[int], list[tuple[int, int, int]]]] = {
w: ([], []) for w in widths
}
for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
p = int(p_enc)
if p >= (1 << widest):
out[i] = [0] # outside every trained regime: honest fallback
continue
w = next(w for w in widths if p < (1 << w))
idx, rows = buckets[w]
idx.append(i)
rows.append((int(a_enc) % p, int(b_enc) % p, p))
for w in widths:
idx, rows = buckets[w]
if rows:
bits = self._run_cell(w, rows)
for j, i in enumerate(idx):
out[i] = bits[j][::-1] # emit MSB-first, base 2
return [o if o is not None else [0] for o in out]
def max_batch_size(self) -> int:
return 1024