SAEs for google/gemma-4-E4B-it (layers 7 / 21 / 35)
The official base-model SAEs for Gemma-4-E4B
(decoderesearch/gemma-4-saes) fail
catastrophically on the instruction-tuned model's activations β held-out variance explained
is negative at every layer (β1008 / β268 / β69 on it-pile). This repo contains those same
three SAEs warm-started from the base weights and continue-trained on gemma-4-E4B-it
activations, which recovers them to VE β 0.81β0.87 at the original sparsity (L0 β k = 100).
Matryoshka batch-topk SAEs (k=100, d_in=2560, d_sae=65536, widths [2048, 16384, 65536]),
saved in sae_lens format (JumpReLU inference form: W_enc [2560, 65536], b_enc,
W_dec [65536, 2560], b_dec, per-feature threshold [65536]), fp32,
apply_b_dec_to_input: true, normalize_activations: none.
Results
Held-out evaluation: ~262k tokens each of pile-uncopyrighted and ultrachat_200k rendered with the Gemma chat template, at dataset offsets disjoint from the training stream, context 1024, BOS prepended, special-token positions excluded. VE = 1 β SSE / total variance; L0 = mean active features per token.
| layer | before: base SAE on it_pile / it_chat VE | after: it_pile VE (L0) | after: it_chat VE (L0) |
|---|---|---|---|
| 7 | β1008.4 / β941.8 | 0.8328 (99.3) | 0.8674 (100.8) |
| 21 | β268.1 / β195.9 | 0.8229 (101.9) | 0.8215 (102.9) |
| 35 | β69.0 / β59.1 | 0.8145 (93.7) | 0.8451 (107.1) |
The "after" numbers were round-trip verified: the pushed repo was re-downloaded, loaded with
SAE.load_from_disk, and re-evaluated on the held-out sets, matching the local result JSONs.
Usage
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sae_lens import SAE
from huggingface_hub import snapshot_download
LAYER = 21 # 7, 21, or 35
path = snapshot_download("atawfeek/gemma-4-e4b-it-saes")
sae = SAE.load_from_disk(f"{path}/btk-mat-layer-{LAYER}-k-100").to("cuda")
name = "google/gemma-4-E4B-it"
tok = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name, torch_dtype=torch.bfloat16, device_map="cuda")
# Hook point: forward OUTPUT of model.model.language_model.layers[LAYER]
acts = []
hook = lambda m, i, o: acts.append((o[0] if isinstance(o, tuple) else o).detach())
handle = model.model.language_model.layers[LAYER].register_forward_hook(hook)
text = "..." # any reasonably long natural text (the SAEs were trained at context 1024)
enc = tok(text, return_tensors="pt", truncation=True, max_length=1024).to("cuda") # BOS is prepended
with torch.no_grad():
model(**enc)
handle.remove()
x = acts[0][0].float() # [seq, 2560] β SAE runs in fp32 on raw (unnormalized) activations
# Convention: EXCLUDE special-token positions (BOS, chat-control tokens) before the SAE
keep = ~torch.isin(enc.input_ids[0], torch.tensor(tok.all_special_ids, device=x.device))
x = x[keep]
feats = sae.encode(x)
recon = sae.decode(feats)
ve = 1 - (x - recon).pow(2).sum() / (x - x.mean(0)).pow(2).sum()
l0 = (feats > 0).float().sum(-1).mean()
print(f"VE {ve.item():.3f}, L0 {l0.item():.1f}") # expect VE β 0.8+ , L0 β 100 on natural text
On a few KB of ordinary English text this snippet should report VE in the ~0.8 range and L0 near 100 for each layer. Evaluating on very short or atypical text will give noisier numbers.
Training details
- Warm start: initialized from
decoderesearch/gemma-4-saesgemma-4-e4b/btk-mat-layer-{7,21,35}-k-100(matryoshka batch-topk SAEs trained on 1B pile tokens against basegoogle/gemma-4-E4B). Therunner_cfg.jsonshipped in each layer directory is that original base-model training config, preserved as the recipe record; the continue-training metadata lives in eachcfg.json. - Data: 50%
monology/pile-uncopyrighted+ 50%HuggingFaceH4/ultrachat_200krendered with the Gemma chat template, E4B-it activations at context 1024, BOS prepended, special-token positions excluded; ~8.4M-token corpus. - Budget: 3 epochs β 25M tokens seen per layer (plus one polish epoch for L21, below).
- Objective: same matryoshka batch-topk objective as the source SAEs (k=100, widths [2048, 16384, 65536], decoder-norm rescaling, threshold EMA 0.01), Adam lr 1e-4 cosine, batch 4096 tokens. Dead-latent resampling at epoch boundaries at 0.2Γ median encoder norm.
- The L21 incident (honest note): L21's first run resampled 7,343 dead latents at the
epoch-1β2 boundary at full median encoder norm. The new latents flooded the batch-topk
selection, spiked the threshold EMA from 1.64 to 5.76, and crashed held-out VE from
0.8176/0.8243 (step 4000) to 0.7349/0.7383; it recovered only to 0.7833/0.7804 by the end
of the cosine schedule β below the 0.8 release gate. The fix was (1) switching resampling
to the standard 0.2Γ norm (validated on L35, whose epoch-boundary resamples then caused no
VE dip), and (2) one polish epoch for L21 from the saved checkpoint (lr 5e-5 cosine, no
resampling), which reached the released 0.8229/0.8215. Pre-polish numbers are preserved in
the training logs (
L21_result_prepolish.json). - Dead latents (known blind spots): training-time dead-feature tracking at save reports
678 (L7), 3,329 (L21), 67 (L35) of 65,536. The shipped
sparsity.safetensors(logββ firing frequency on the ~524k-token held-out eval, floored at β10) shows a larger never-fired-on-eval tail: 4,342 / 15,080 / 2,738 latents at the floor β the difference is ultra-rare latents firing less than ~once per 500k tokens. L21's higher dead fraction is residue of the resampling incident.
Evaluation-methodology caveat (affects how "before" numbers are read)
The base SAEs were previously reported healthy on base-model activations (VE 0.87β0.89, L0 β 88β94). That measurement was made on 7 short clean texts (~50 tokens, BOS dropped, fp32 in-memory). On real 1024-context pile data (262k held-out tokens, special tokens excluded) the same SAEs measure plain VE 0.194 / 0.752 / 0.790 (L0 120β145) β but this gap is almost entirely an outlier-token effect: the top 0.1% of tokens by SSE carry 88% / 59% / 56% of total squared error, and excluding just those positions restores VE to 0.901 / 0.899 / 0.906. Notably these outlier tokens are not high-norm β dropping the top 0.1% by activation norm changes nothing β so the effect is invisible to norm-based outlier screens. This is a finding about evaluation methodology (short-text evals can overstate robustness, and plain VE on long contexts can understate the SSE-robust ceiling), not a criticism of the original release: on its own terms the base release performs at a ~0.90 sse-robust ceiling on its training distribution.
The adapted SAEs in this repo show no such pathology: their top-0.1%-by-SSE tokens carry only ~0.25β0.40% of SSE, and SSE-robust VE coincides with plain VE to within Β±0.005 at every layer. The it-side failure of the base SAEs is uniform (VE stays β β1008 / β268 / β69 under every exclusion scheme), i.e. a genuine distribution shift rather than an outlier artifact.
License & lineage
- Weights lineage: warm-started from
decoderesearch/gemma-4-saes(credit to the decode research team β these adaptations exist because their base SAEs were good enough to recover with a ~25M-token continue-train). That repository does not declare an explicit license. - Model lineage: trained on activations of
google/gemma-4-E4B-it, which is released under Apache-2.0. - This repo is released under Apache-2.0, matching the upstream model.
- Training data:
monology/pile-uncopyrightedandHuggingFaceH4/ultrachat_200k.
Limitations
- Feature explanations / auto-interp not yet generated β Neuronpedia submission pending; dashboards will be linked here once accepted.
- Downstream-CE delta: not yet measured (splice-in reconstruction effect on model loss). Held-out VE / L0 are the quality record to date.
- Dead latents: see counts above; L21 in particular has an elevated dead fraction.
- Trained on stock
google/gemma-4-E4B-itonly β no claim of transfer to derivative fine-tunes. - Single seed, single training run per layer; layers 7, 21, 35 only.
- Evaluated on pile + ultrachat-style chat data; other domains (code-heavy, multilingual, long-context > 1024) unmeasured.
Citation
@misc{tawfeek2026gemma4itsaes,
title = {Warm-started sparse autoencoders for gemma-4-E4B-it (layers 7/21/35)},
author = {Tawfeek, Andrew R.},
year = {2026},
howpublished = {\url{https://huggingface.co/atawfeek/gemma-4-e4b-it-saes}},
note = {Warm-start adaptation of decoderesearch/gemma-4-saes to instruction-tuned activations}
}
Feature dashboards will live on Neuronpedia once the upload is accepted.