🐋⚡ SpikeWhale-SNN-216M
Every neuron of the 4 stacked LIF layers, coloured by its dominant gene program and
sized by its learned recurrent wiring. Rendered straight from the checkpoint with
src/scripts/visualize_architecture.py — no
training data required.
A ~216M-parameter spiking neural network (SNN) language model, built entirely from scratch.
This is the original SpikeWhale project — the one that sparked all the other
SpikeWhale projects. Every spiking primitive here is hand-written in plain
PyTorch: the leaky integrate-and-fire (LIF) neuron dynamics, the fast-sigmoid
surrogate gradient, and the backprop-through-time training loop. No snntorch,
no spikingjelly, no norse, no bindsnet — the network is a genuine
from-scratch SNN.
On top of the spiking core it carries two unusual ideas:
- A biological "bridge". Each neuron's excitability (firing threshold,
membrane decay
beta, input gain) is modulated by a frozen virtual-cell gene-program model. Every hidden neuron is assigned a gene knockdown; the cell model predicts that perturbation's effect on six gene programs, and that identity vector sets the neuron's LIF parameters. The mapping directions are fixed (from biology); the strengths are learned. This is a deliberate, ablatable experiment — not a claim that gene expression literally sets neural thresholds. See The virtual cell below for the full story. - The SpikeWhale / Byrne "family" traits. Engram (N-gram hash memory + DERF gate), HRM iterative refinement, multi-token prediction, spiking linear attention, an explicit program-semantic block, Kuramoto coupled-oscillator coupling, and a JEPA representation-prediction head — all wired in at the continuous points so the binary LIF spike dynamics stay intact, and each a near no-op at init so it wakes up only as it helps.
A live demo is hosted at the companion Space SpikeWhale-SNN v1, which pulls these weights at runtime.
Model summary
| Parameters | 215,901,779 (~216M) |
| Architecture | Recurrent LIF spiking LM (SpikingCharLM) |
| Depth × width | 4 stacked LIF layers × 1488 hidden |
| Readout | All-layer pre-reset membrane → linear head (per-position logits) |
| Weight tying | Embedding ⇄ output table tied |
| Tokenizer | Custom SpikeWhale byte-level "length-max" tokenizer, vocab 16,512 |
| Neuron model | Leaky integrate-and-fire + fast-sigmoid surrogate gradient (Neftci et al. 2019) |
| Modulation | program — neurons modulated by the frozen virtual-cell bridge |
| Training data | Streamed 75% FineWeb-Edu + 25% FineMath |
| Training step (this ckpt) | 31,824 |
| Val bits/token | 3.977 (uniform baseline ≈ log₂16512 ≈ 14.0) |
| Mean firing rate | ≈ 41% (healthy: not dead, not saturated) |
| Precision | float32 |
Active "family" traits in this checkpoint
Engram · HRM · MTP · SpikeAttn · ProgSem · Kuramoto · JEPA (MoE off), plus an
anti-death firing regulariser (firing_target=0.2, weight 50).
How a token is processed
token id ─► embedding (input current)
─► LIF layer 0 ─► spikes ─► LIF layer 1 ─► … ─► LIF layer 3
│
(per layer: mem[t] = beta·mem[t-1] + input; spike if mem ≥ threshold; soft reset)
│
all layers' pre-reset membranes ─► tanh ─► linear readout ─► next-token logits
"Time" is token position: one LIF time-step per token, with membrane
potentials and recurrent spikes carrying memory across the sequence. Because the
forward spike is a step function, gradients flow through a surrogate
(1/(1+slope·|mem−thresh|)²), which also passes gradient to each neuron's
threshold — that is what lets the gene-program bridge train end-to-end.
The virtual cell (what the "bridge" connects to)
The "virtual cell" is a small, frozen perturbation-prediction model trained separately, before the SNN. It answers one question: if you knock down gene X, how does the cell's gene expression change? The SNN then borrows those predicted biological effects to configure its neurons.
The cell model
src/ml/VirtualCell.py — a deliberately tiny MLP that
predicts a delta from a control expression profile:
pred_expression = control_expression + delta_net(gene_embedding)
- Input: a per-gene ESM2-650M protein-language-model embedding (1280-d,
mean-pooled over the gene's canonical human UniProt protein sequence; see
src/scripts/compute_esm2_embeddings.py). delta_net:Linear(1280→64) → LayerNorm → GELU → Linear(64→2177). The final layer is zero-initialised, so at init the model predicts "no change" (control) and learns perturbation effects from there.- Output: a 2177-gene expression profile (≈2000 highly-variable genes plus every program gene whitelisted in).
The data it's trained on
The bundled model was trained on the real Virtual Cell Challenge (VCC) dataset
(source="real") — a Perturb-seq screen in the H1 human embryonic stem cell
(hESC) line, a fast-dividing, pluripotent cell that is the canonical starting
point for neuronal differentiation. Training (src/scripts/train_programs.py):
expression is total-count normalised + log1p, reduced to per-perturbation
pseudobulk profiles, and split by held-out perturbation gene. The loss
up-weights program genes and additionally matches per-program Δ-scores; validation
reports the Pearson correlation of predicted vs. true Δ-program score. The frozen
result ships here as models/programs_real_20260628_022009.pth.
The six gene programs
src/ml/gene_programs.py groups genes into named,
biologically-meaningful programs (curated marker sets, e.g. Tirosh et al. 2016
S / G2-M cell-cycle genes, pan-neuronal markers, pluripotency factors):
| Program | Meaning | High score means… |
|---|---|---|
cell_division |
S + G2/M cell-cycle / mitosis genes | actively dividing |
neuron_identity |
pan-neuronal markers (MAP2, RBFOX3, SYN1…) | neuron-like |
pluripotency |
hESC ground-state factors (POU5F1, NANOG, SOX2…) | stem-cell-like |
endocrine |
hormones / receptors / steroidogenesis | endocrine (odd for hESC — deliberate) |
metabolism |
glycolysis, TCA, lipid/sterol, OxPhos | high metabolic drive |
signaling_identity |
RTK→RAS/MAPK/PI3K, Wnt, TGFβ, JAK/STAT | proliferative signalling |
The original motivating idea was the "cell-dividing neuron" — the biologically
unusual corner where cell_division and neuron_identity are both high.
The bridge itself
src/snn/cell_modulation.py turns those biological
predictions into LIF parameters:
Assign each of the 1488 hidden neurons a gene knockdown drawn from the dataset's perturbation pool.
Predict — run the frozen cell model on
(mean control, gene ESM2 embedding)and score the predicted profile's Δ against each of the 6 programs → a per-neuron identity vector. This checkpoint uses theconcatstrategy over 3 seed draws, so each neuron carries its identity under 3 independent gene assignments: the stored program matrixPis [1488 × 18] (6 programs × 3 seeds), z-scored.Map identity → LIF dynamics, with learnable per-program strengths:
threshold = base_thr · exp(P · thr_gain) clamp [0.25, 3.0] beta = sigmoid(logit(base_beta) + P·beta_gain) clamp [0.50, 0.97] gain = exp(P · in_gain) clamp [0.30, 3.0]The directions are seeded from biology (neuron-identity → lower threshold, longer memory, stronger drive; all other programs start neutral); the strengths (
thr_gain,beta_gain,in_gain) arenn.Parameters the SNN trains end-to-end — so the network learns how much to lean on the biology. Because gradient flows to the neuron threshold through the surrogate, this is fully differentiable.
The ablation: --modulation none zeroes the bridge, giving homogeneous LIF
neurons — the control that tests whether the biological coupling actually helps.
Nothing about the cell model is needed at inference: the computed matrix P and
the learned gains are baked into the checkpoint.
Usage
The checkpoint records its own tokenizer, model shape, and modulation matrix, so loading rebuilds the exact architecture automatically.
pip install torch numpy scipy transformers huggingface_hub
import torch, sys
from huggingface_hub import snapshot_download
repo = snapshot_download("Quazim0t0/SpikeWhale-SNN-216M")
sys.path.insert(0, f"{repo}/src")
from snn.text import SpikeVocab
from snn.model import SpikingCharLM
from snn.cell_modulation import NeuronModulation
ckpt = torch.load(f"{repo}/snn_stream_program.pth", map_location="cpu", weights_only=False)
vocab = SpikeVocab((ckpt["vocab"] or {}).get("tokenizer_json"))
m = ckpt["modulation"]
mod = NeuronModulation(m["P"] if m["enabled"] else None, m["hidden"],
base_threshold=m["base_threshold"], base_beta=m["base_beta"],
enabled=m["enabled"])
cfg = ckpt["config"]
model = SpikingCharLM(cfg["vocab_size"], cfg["emb_dim"], cfg["hidden"], mod,
readout_decay=cfg["readout_decay"], num_layers=cfg["num_layers"],
tie_embeddings=cfg["tie_embeddings"], **cfg.get("family", {}))
model.load_state_dict(ckpt["model_state_dict"]); model.eval()
ids = model.generate(vocab.encode("The ").tolist(), 80, temperature=0.7,
device=torch.device("cpu"), top_k=10)
print(vocab.decode(ids))
Or, from the bundled scripts:
PYTHONPATH=src python src/scripts/generate_snn.py \
--checkpoint snn_stream_program.pth --prompt "The " --length 120
⚠️ This is a recurrent net: one LIF step per token, so CPU sampling is slow. A GPU is strongly recommended for both training and long generations.
Reproduce it from scratch
This repo is a complete, self-contained training package — clone it and run
the same project end-to-end. See START_HERE.md and
SNN.md for the full write-up.
pip install -r requirements.txt
pip install -e . --no-deps # makes `snn` and `ml` importable
# train the ~216M spiking LM on the streamed FineWeb-Edu + FineMath blend,
# neurons bridged to the frozen virtual-cell program model:
python src/scripts/train_snn_stream.py --modulation program --data-dir data
# ablation — homogeneous neurons, no bridge (needs no cell data):
python src/scripts/train_snn_stream.py --modulation none --no-family
The datasets stream on the fly (nothing large is stored locally). Watch the
firing rate printed each step: deep spiking stacks can go silent in the upper
layers — the input projection of the deeper layers is scaled (spike_input_gain)
to keep spikes propagating through all 4 layers.
What's in the repo
src/snn/ lif.py model.py cell_modulation.py family.py text.py muon.py
src/ml/ the frozen virtual-cell gene-program model (read-only, for the bridge)
src/scripts/ train_snn_stream.py train_snn.py generate_snn.py train_programs.py …
spike_tokenizer.py special_tokens.py tokenizer.json # the SpikeWhale tokenizer
data/ models/ # the cell-bridge bundle + frozen program model
research/ # analysis / health-check scripts
snn_stream_program.pth latest.txt # the trained weights (this card)
Limitations & honest caveats
- The gene-program bridge is experimental. It is a working, ablatable mechanism, not a validated claim about biology or a guaranteed quality win.
- Recurrent SNN inference is slow (one time-step per token) and this model is large for an SNN; expect slow CPU generation.
- Trained on English FineWeb-Edu + FineMath only; it inherits those corpora's biases and is a research artifact, not a production/aligned assistant.
- Reported bits/token is the LM cross-entropy only, so it stays comparable across runs even when auxiliary family losses are active.
License
Apache-2.0.
Citation
@software{byrne_spikewhale_snn_2026,
title = {SpikeWhale-SNN-216M: a from-scratch spiking neural network language
model bridged to a virtual-cell gene-program model},
author = {Byrne, Dean (Quazim0t0)},
year = {2026},
url = {https://huggingface.co/Quazim0t0/SpikeWhale-SNN-216M}
}
