You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

🐋⚡ SpikeWhale-SNN-216M

SpikeWhale-SNN architecture

Every neuron of the 4 stacked LIF layers, coloured by its dominant gene program and sized by its learned recurrent wiring. Rendered straight from the checkpoint with src/scripts/visualize_architecture.py — no training data required.

A ~216M-parameter spiking neural network (SNN) language model, built entirely from scratch.

This is the original SpikeWhale project — the one that sparked all the other SpikeWhale projects. Every spiking primitive here is hand-written in plain PyTorch: the leaky integrate-and-fire (LIF) neuron dynamics, the fast-sigmoid surrogate gradient, and the backprop-through-time training loop. No snntorch, no spikingjelly, no norse, no bindsnet — the network is a genuine from-scratch SNN.

On top of the spiking core it carries two unusual ideas:

  1. A biological "bridge". Each neuron's excitability (firing threshold, membrane decay beta, input gain) is modulated by a frozen virtual-cell gene-program model. Every hidden neuron is assigned a gene knockdown; the cell model predicts that perturbation's effect on six gene programs, and that identity vector sets the neuron's LIF parameters. The mapping directions are fixed (from biology); the strengths are learned. This is a deliberate, ablatable experiment — not a claim that gene expression literally sets neural thresholds. See The virtual cell below for the full story.
  2. The SpikeWhale / Byrne "family" traits. Engram (N-gram hash memory + DERF gate), HRM iterative refinement, multi-token prediction, spiking linear attention, an explicit program-semantic block, Kuramoto coupled-oscillator coupling, and a JEPA representation-prediction head — all wired in at the continuous points so the binary LIF spike dynamics stay intact, and each a near no-op at init so it wakes up only as it helps.

A live demo is hosted at the companion Space SpikeWhale-SNN v1, which pulls these weights at runtime.


Model summary

Parameters 215,901,779 (~216M)
Architecture Recurrent LIF spiking LM (SpikingCharLM)
Depth × width 4 stacked LIF layers × 1488 hidden
Readout All-layer pre-reset membrane → linear head (per-position logits)
Weight tying Embedding ⇄ output table tied
Tokenizer Custom SpikeWhale byte-level "length-max" tokenizer, vocab 16,512
Neuron model Leaky integrate-and-fire + fast-sigmoid surrogate gradient (Neftci et al. 2019)
Modulation program — neurons modulated by the frozen virtual-cell bridge
Training data Streamed 75% FineWeb-Edu + 25% FineMath
Training step (this ckpt) 31,824
Val bits/token 3.977 (uniform baseline ≈ log₂16512 ≈ 14.0)
Mean firing rate 41% (healthy: not dead, not saturated)
Precision float32

Active "family" traits in this checkpoint

Engram · HRM · MTP · SpikeAttn · ProgSem · Kuramoto · JEPA (MoE off), plus an anti-death firing regulariser (firing_target=0.2, weight 50).


How a token is processed

token id ─► embedding (input current)
         ─► LIF layer 0 ─► spikes ─► LIF layer 1 ─► … ─► LIF layer 3
                                                            │
         (per layer: mem[t] = beta·mem[t-1] + input; spike if mem ≥ threshold; soft reset)
                                                            │
   all layers' pre-reset membranes ─► tanh ─► linear readout ─► next-token logits

"Time" is token position: one LIF time-step per token, with membrane potentials and recurrent spikes carrying memory across the sequence. Because the forward spike is a step function, gradients flow through a surrogate (1/(1+slope·|mem−thresh|)²), which also passes gradient to each neuron's threshold — that is what lets the gene-program bridge train end-to-end.


The virtual cell (what the "bridge" connects to)

The "virtual cell" is a small, frozen perturbation-prediction model trained separately, before the SNN. It answers one question: if you knock down gene X, how does the cell's gene expression change? The SNN then borrows those predicted biological effects to configure its neurons.

The cell model

src/ml/VirtualCell.py — a deliberately tiny MLP that predicts a delta from a control expression profile:

pred_expression = control_expression + delta_net(gene_embedding)
  • Input: a per-gene ESM2-650M protein-language-model embedding (1280-d, mean-pooled over the gene's canonical human UniProt protein sequence; see src/scripts/compute_esm2_embeddings.py).
  • delta_net: Linear(1280→64) → LayerNorm → GELU → Linear(64→2177). The final layer is zero-initialised, so at init the model predicts "no change" (control) and learns perturbation effects from there.
  • Output: a 2177-gene expression profile (≈2000 highly-variable genes plus every program gene whitelisted in).

The data it's trained on

The bundled model was trained on the real Virtual Cell Challenge (VCC) dataset (source="real") — a Perturb-seq screen in the H1 human embryonic stem cell (hESC) line, a fast-dividing, pluripotent cell that is the canonical starting point for neuronal differentiation. Training (src/scripts/train_programs.py): expression is total-count normalised + log1p, reduced to per-perturbation pseudobulk profiles, and split by held-out perturbation gene. The loss up-weights program genes and additionally matches per-program Δ-scores; validation reports the Pearson correlation of predicted vs. true Δ-program score. The frozen result ships here as models/programs_real_20260628_022009.pth.

The six gene programs

src/ml/gene_programs.py groups genes into named, biologically-meaningful programs (curated marker sets, e.g. Tirosh et al. 2016 S / G2-M cell-cycle genes, pan-neuronal markers, pluripotency factors):

Program Meaning High score means…
cell_division S + G2/M cell-cycle / mitosis genes actively dividing
neuron_identity pan-neuronal markers (MAP2, RBFOX3, SYN1…) neuron-like
pluripotency hESC ground-state factors (POU5F1, NANOG, SOX2…) stem-cell-like
endocrine hormones / receptors / steroidogenesis endocrine (odd for hESC — deliberate)
metabolism glycolysis, TCA, lipid/sterol, OxPhos high metabolic drive
signaling_identity RTK→RAS/MAPK/PI3K, Wnt, TGFβ, JAK/STAT proliferative signalling

The original motivating idea was the "cell-dividing neuron" — the biologically unusual corner where cell_division and neuron_identity are both high.

The bridge itself

src/snn/cell_modulation.py turns those biological predictions into LIF parameters:

  1. Assign each of the 1488 hidden neurons a gene knockdown drawn from the dataset's perturbation pool.

  2. Predict — run the frozen cell model on (mean control, gene ESM2 embedding) and score the predicted profile's Δ against each of the 6 programs → a per-neuron identity vector. This checkpoint uses the concat strategy over 3 seed draws, so each neuron carries its identity under 3 independent gene assignments: the stored program matrix P is [1488 × 18] (6 programs × 3 seeds), z-scored.

  3. Map identity → LIF dynamics, with learnable per-program strengths:

    threshold = base_thr · exp(P · thr_gain)        clamp [0.25, 3.0]
    beta      = sigmoid(logit(base_beta) + P·beta_gain)  clamp [0.50, 0.97]
    gain      = exp(P · in_gain)                     clamp [0.30, 3.0]
    

    The directions are seeded from biology (neuron-identity → lower threshold, longer memory, stronger drive; all other programs start neutral); the strengths (thr_gain, beta_gain, in_gain) are nn.Parameters the SNN trains end-to-end — so the network learns how much to lean on the biology. Because gradient flows to the neuron threshold through the surrogate, this is fully differentiable.

The ablation: --modulation none zeroes the bridge, giving homogeneous LIF neurons — the control that tests whether the biological coupling actually helps. Nothing about the cell model is needed at inference: the computed matrix P and the learned gains are baked into the checkpoint.


Usage

The checkpoint records its own tokenizer, model shape, and modulation matrix, so loading rebuilds the exact architecture automatically.

pip install torch numpy scipy transformers huggingface_hub
import torch, sys
from huggingface_hub import snapshot_download

repo = snapshot_download("Quazim0t0/SpikeWhale-SNN-216M")
sys.path.insert(0, f"{repo}/src")

from snn.text import SpikeVocab
from snn.model import SpikingCharLM
from snn.cell_modulation import NeuronModulation

ckpt = torch.load(f"{repo}/snn_stream_program.pth", map_location="cpu", weights_only=False)
vocab = SpikeVocab((ckpt["vocab"] or {}).get("tokenizer_json"))
m = ckpt["modulation"]
mod = NeuronModulation(m["P"] if m["enabled"] else None, m["hidden"],
                       base_threshold=m["base_threshold"], base_beta=m["base_beta"],
                       enabled=m["enabled"])
cfg = ckpt["config"]
model = SpikingCharLM(cfg["vocab_size"], cfg["emb_dim"], cfg["hidden"], mod,
                      readout_decay=cfg["readout_decay"], num_layers=cfg["num_layers"],
                      tie_embeddings=cfg["tie_embeddings"], **cfg.get("family", {}))
model.load_state_dict(ckpt["model_state_dict"]); model.eval()

ids = model.generate(vocab.encode("The ").tolist(), 80, temperature=0.7,
                     device=torch.device("cpu"), top_k=10)
print(vocab.decode(ids))

Or, from the bundled scripts:

PYTHONPATH=src python src/scripts/generate_snn.py \
    --checkpoint snn_stream_program.pth --prompt "The " --length 120

⚠️ This is a recurrent net: one LIF step per token, so CPU sampling is slow. A GPU is strongly recommended for both training and long generations.


Reproduce it from scratch

This repo is a complete, self-contained training package — clone it and run the same project end-to-end. See START_HERE.md and SNN.md for the full write-up.

pip install -r requirements.txt
pip install -e . --no-deps          # makes `snn` and `ml` importable

# train the ~216M spiking LM on the streamed FineWeb-Edu + FineMath blend,
# neurons bridged to the frozen virtual-cell program model:
python src/scripts/train_snn_stream.py --modulation program --data-dir data

# ablation — homogeneous neurons, no bridge (needs no cell data):
python src/scripts/train_snn_stream.py --modulation none --no-family

The datasets stream on the fly (nothing large is stored locally). Watch the firing rate printed each step: deep spiking stacks can go silent in the upper layers — the input projection of the deeper layers is scaled (spike_input_gain) to keep spikes propagating through all 4 layers.

What's in the repo

src/snn/     lif.py  model.py  cell_modulation.py  family.py  text.py  muon.py
src/ml/      the frozen virtual-cell gene-program model (read-only, for the bridge)
src/scripts/ train_snn_stream.py  train_snn.py  generate_snn.py  train_programs.py  …
spike_tokenizer.py  special_tokens.py  tokenizer.json   # the SpikeWhale tokenizer
data/  models/                                          # the cell-bridge bundle + frozen program model
research/                                               # analysis / health-check scripts
snn_stream_program.pth  latest.txt                      # the trained weights (this card)

Limitations & honest caveats

  • The gene-program bridge is experimental. It is a working, ablatable mechanism, not a validated claim about biology or a guaranteed quality win.
  • Recurrent SNN inference is slow (one time-step per token) and this model is large for an SNN; expect slow CPU generation.
  • Trained on English FineWeb-Edu + FineMath only; it inherits those corpora's biases and is a research artifact, not a production/aligned assistant.
  • Reported bits/token is the LM cross-entropy only, so it stays comparable across runs even when auxiliary family losses are active.

License

Apache-2.0.

Citation

@software{byrne_spikewhale_snn_2026,
  title  = {SpikeWhale-SNN-216M: a from-scratch spiking neural network language
            model bridged to a virtual-cell gene-program model},
  author = {Byrne, Dean (Quazim0t0)},
  year   = {2026},
  url    = {https://huggingface.co/Quazim0t0/SpikeWhale-SNN-216M}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Datasets used to train Quazim0t0/SpikeWhale-SNN-216M

Collection including Quazim0t0/SpikeWhale-SNN-216M