SAEs for google/gemma-4-E4B-it (layers 7 / 21 / 35)

The official base-model SAEs for Gemma-4-E4B (decoderesearch/gemma-4-saes) fail catastrophically on the instruction-tuned model's activations β€” held-out variance explained is negative at every layer (βˆ’1008 / βˆ’268 / βˆ’69 on it-pile). This repo contains those same three SAEs warm-started from the base weights and continue-trained on gemma-4-E4B-it activations, which recovers them to VE β‰ˆ 0.81–0.87 at the original sparsity (L0 β‰ˆ k = 100).

Matryoshka batch-topk SAEs (k=100, d_in=2560, d_sae=65536, widths [2048, 16384, 65536]), saved in sae_lens format (JumpReLU inference form: W_enc [2560, 65536], b_enc, W_dec [65536, 2560], b_dec, per-feature threshold [65536]), fp32, apply_b_dec_to_input: true, normalize_activations: none.

Results

Held-out evaluation: ~262k tokens each of pile-uncopyrighted and ultrachat_200k rendered with the Gemma chat template, at dataset offsets disjoint from the training stream, context 1024, BOS prepended, special-token positions excluded. VE = 1 βˆ’ SSE / total variance; L0 = mean active features per token.

layer before: base SAE on it_pile / it_chat VE after: it_pile VE (L0) after: it_chat VE (L0)
7 βˆ’1008.4 / βˆ’941.8 0.8328 (99.3) 0.8674 (100.8)
21 βˆ’268.1 / βˆ’195.9 0.8229 (101.9) 0.8215 (102.9)
35 βˆ’69.0 / βˆ’59.1 0.8145 (93.7) 0.8451 (107.1)

The "after" numbers were round-trip verified: the pushed repo was re-downloaded, loaded with SAE.load_from_disk, and re-evaluated on the held-out sets, matching the local result JSONs.

Usage

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sae_lens import SAE
from huggingface_hub import snapshot_download

LAYER = 21  # 7, 21, or 35
path = snapshot_download("atawfeek/gemma-4-e4b-it-saes")
sae = SAE.load_from_disk(f"{path}/btk-mat-layer-{LAYER}-k-100").to("cuda")

name = "google/gemma-4-E4B-it"
tok = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name, torch_dtype=torch.bfloat16, device_map="cuda")

# Hook point: forward OUTPUT of model.model.language_model.layers[LAYER]
acts = []
hook = lambda m, i, o: acts.append((o[0] if isinstance(o, tuple) else o).detach())
handle = model.model.language_model.layers[LAYER].register_forward_hook(hook)

text = "..."  # any reasonably long natural text (the SAEs were trained at context 1024)
enc = tok(text, return_tensors="pt", truncation=True, max_length=1024).to("cuda")  # BOS is prepended
with torch.no_grad():
    model(**enc)
handle.remove()

x = acts[0][0].float()  # [seq, 2560] β€” SAE runs in fp32 on raw (unnormalized) activations
# Convention: EXCLUDE special-token positions (BOS, chat-control tokens) before the SAE
keep = ~torch.isin(enc.input_ids[0], torch.tensor(tok.all_special_ids, device=x.device))
x = x[keep]

feats = sae.encode(x)
recon = sae.decode(feats)
ve = 1 - (x - recon).pow(2).sum() / (x - x.mean(0)).pow(2).sum()
l0 = (feats > 0).float().sum(-1).mean()
print(f"VE {ve.item():.3f}, L0 {l0.item():.1f}")  # expect VE β‰ˆ 0.8+ , L0 β‰ˆ 100 on natural text

On a few KB of ordinary English text this snippet should report VE in the ~0.8 range and L0 near 100 for each layer. Evaluating on very short or atypical text will give noisier numbers.

Training details

  • Warm start: initialized from decoderesearch/gemma-4-saes gemma-4-e4b/btk-mat-layer-{7,21,35}-k-100 (matryoshka batch-topk SAEs trained on 1B pile tokens against base google/gemma-4-E4B). The runner_cfg.json shipped in each layer directory is that original base-model training config, preserved as the recipe record; the continue-training metadata lives in each cfg.json.
  • Data: 50% monology/pile-uncopyrighted + 50% HuggingFaceH4/ultrachat_200k rendered with the Gemma chat template, E4B-it activations at context 1024, BOS prepended, special-token positions excluded; ~8.4M-token corpus.
  • Budget: 3 epochs β‰ˆ 25M tokens seen per layer (plus one polish epoch for L21, below).
  • Objective: same matryoshka batch-topk objective as the source SAEs (k=100, widths [2048, 16384, 65536], decoder-norm rescaling, threshold EMA 0.01), Adam lr 1e-4 cosine, batch 4096 tokens. Dead-latent resampling at epoch boundaries at 0.2Γ— median encoder norm.
  • The L21 incident (honest note): L21's first run resampled 7,343 dead latents at the epoch-1β†’2 boundary at full median encoder norm. The new latents flooded the batch-topk selection, spiked the threshold EMA from 1.64 to 5.76, and crashed held-out VE from 0.8176/0.8243 (step 4000) to 0.7349/0.7383; it recovered only to 0.7833/0.7804 by the end of the cosine schedule β€” below the 0.8 release gate. The fix was (1) switching resampling to the standard 0.2Γ— norm (validated on L35, whose epoch-boundary resamples then caused no VE dip), and (2) one polish epoch for L21 from the saved checkpoint (lr 5e-5 cosine, no resampling), which reached the released 0.8229/0.8215. Pre-polish numbers are preserved in the training logs (L21_result_prepolish.json).
  • Dead latents (known blind spots): training-time dead-feature tracking at save reports 678 (L7), 3,329 (L21), 67 (L35) of 65,536. The shipped sparsity.safetensors (log₁₀ firing frequency on the ~524k-token held-out eval, floored at βˆ’10) shows a larger never-fired-on-eval tail: 4,342 / 15,080 / 2,738 latents at the floor β€” the difference is ultra-rare latents firing less than ~once per 500k tokens. L21's higher dead fraction is residue of the resampling incident.

Evaluation-methodology caveat (affects how "before" numbers are read)

The base SAEs were previously reported healthy on base-model activations (VE 0.87–0.89, L0 β‰ˆ 88–94). That measurement was made on 7 short clean texts (~50 tokens, BOS dropped, fp32 in-memory). On real 1024-context pile data (262k held-out tokens, special tokens excluded) the same SAEs measure plain VE 0.194 / 0.752 / 0.790 (L0 120–145) β€” but this gap is almost entirely an outlier-token effect: the top 0.1% of tokens by SSE carry 88% / 59% / 56% of total squared error, and excluding just those positions restores VE to 0.901 / 0.899 / 0.906. Notably these outlier tokens are not high-norm β€” dropping the top 0.1% by activation norm changes nothing β€” so the effect is invisible to norm-based outlier screens. This is a finding about evaluation methodology (short-text evals can overstate robustness, and plain VE on long contexts can understate the SSE-robust ceiling), not a criticism of the original release: on its own terms the base release performs at a ~0.90 sse-robust ceiling on its training distribution.

The adapted SAEs in this repo show no such pathology: their top-0.1%-by-SSE tokens carry only ~0.25–0.40% of SSE, and SSE-robust VE coincides with plain VE to within Β±0.005 at every layer. The it-side failure of the base SAEs is uniform (VE stays β‰ˆ βˆ’1008 / βˆ’268 / βˆ’69 under every exclusion scheme), i.e. a genuine distribution shift rather than an outlier artifact.

License & lineage

  • Weights lineage: warm-started from decoderesearch/gemma-4-saes (credit to the decode research team β€” these adaptations exist because their base SAEs were good enough to recover with a ~25M-token continue-train). That repository does not declare an explicit license.
  • Model lineage: trained on activations of google/gemma-4-E4B-it, which is released under Apache-2.0.
  • This repo is released under Apache-2.0, matching the upstream model.
  • Training data: monology/pile-uncopyrighted and HuggingFaceH4/ultrachat_200k.

Limitations

  • Feature explanations / auto-interp not yet generated β€” Neuronpedia submission pending; dashboards will be linked here once accepted.
  • Downstream-CE delta: not yet measured (splice-in reconstruction effect on model loss). Held-out VE / L0 are the quality record to date.
  • Dead latents: see counts above; L21 in particular has an elevated dead fraction.
  • Trained on stock google/gemma-4-E4B-it only β€” no claim of transfer to derivative fine-tunes.
  • Single seed, single training run per layer; layers 7, 21, 35 only.
  • Evaluated on pile + ultrachat-style chat data; other domains (code-heavy, multilingual, long-context > 1024) unmeasured.

Citation

@misc{tawfeek2026gemma4itsaes,
  title        = {Warm-started sparse autoencoders for gemma-4-E4B-it (layers 7/21/35)},
  author       = {Tawfeek, Andrew R.},
  year         = {2026},
  howpublished = {\url{https://huggingface.co/atawfeek/gemma-4-e4b-it-saes}},
  note         = {Warm-start adaptation of decoderesearch/gemma-4-saes to instruction-tuned activations}
}

Feature dashboards will live on Neuronpedia once the upload is accepted.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for atawfeek/gemma-4-e4b-it-saes

Finetuned
(222)
this model