OLMo-3-7B-Think Sparse Autoencoders

This repository contains sparse autoencoders (SAEs) trained on post-layer residual stream activations from allenai/Olmo-3-7B-Think.

Each folder is packaged for SAELens with cfg.json and sae_weights.safetensors. The original training checkpoints included optimizer state; this release contains inference-only weights.

Code

Training and collection code is available at daMS2005/olmo-sae-builder.

Model Details

  • Base model: allenai/Olmo-3-7B-Think
  • Activation type: post-layer residual stream
  • Activation width: 4096
  • Hook convention: layer L uses outputs.hidden_states[L + 1] from Hugging Face Transformers
  • Collection context length: mixed-length chunks up to 8192 tokens
  • Training data source: Dolma v1_6-sample plus Dolci-Instruct-SFT text
  • Architecture: tied TopK sparse autoencoder
  • Decoder/encoder directions are exported in SAELens tensor format

Data Collection

Activations were collected with the local collect_data.py pipeline:

  • Downloaded all 103 Dolma v1_6-sample JSONL gzip shards from allenai/dolma.
  • Downloaded all 15 Dolci-Instruct-SFT train parquet shards from allenai/Dolci-Instruct-SFT.
  • Streamed local files from disk row by row; the full text dataset was not loaded into RAM or VRAM.
  • Tokenized text on CPU, inserted EOS separators, and packed mixed-length chunks.
  • Used length buckets (512, 2048), (3500, 5000), and (7000, 8192) tokens.
  • Ran allenai/Olmo-3-7B-Think in bfloat16 with torch.no_grad(), eval(), and use_cache=False.
  • Captured olmo.model.layers[L] outputs for layers 7, 15, 23, and 31.
  • This is equivalent to using outputs.hidden_states[L + 1] from Hugging Face Transformers.
  • Removed padding positions with the attention mask before saving activations.
  • Saved activation tensors as float16 chunks with shape [num_real_tokens, 4096].
  • The release training runs used 380 activation chunks per layer, about 20.0M token activations.

Recommended SAEs

Use these unless you specifically want to compare raw versus layernorm variants:

  • layer_07/width_131k/topk_64_raw
  • layer_15/width_163k/topk_64_raw
  • layer_23/width_163k/topk_128_raw
  • layer_31/width_163k/topk_128_raw

Layer 31 SAE

layer_31/width_163k/topk_128_raw is the final-layer SAE in this release.

  • Layer: 31
  • Latents: 163840
  • TopK: 128
  • Input normalization: none (raw activations)
  • Batch size: 256
  • Learning rate: 3e-4
  • Final training mean normalized MSE: 0.326502

Validation after export:

  • SAE.from_pretrained(...) loaded successfully through SAELens
  • encode() and decode() ran on CPU and CUDA
  • Exported SAELens tensors matched the original training checkpoint with max absolute difference 0.0
  • A small real-activation smoke test on 256 layer-31 vectors produced finite reconstructions

All Included SAEs

SAE id layer width k norm final normalized MSE canonical
layer_07/width_131k/topk_64_raw 7 131072 64 none 0.274970 yes
layer_07/width_163k/topk_64_layernorm 7 163840 64 layer_norm 0.277127
layer_15/width_163k/topk_64_raw 15 163840 64 none 0.303301 yes
layer_15/width_163k/topk_64_layernorm 15 163840 64 layer_norm 0.304524
layer_23/width_163k/topk_64_raw 23 163840 64 none 0.343949
layer_23/width_163k/topk_128_raw 23 163840 128 none 0.329776 yes
layer_31/width_163k/topk_128_raw 31 163840 128 none 0.326502 yes

The layernorm variants use SAELens runtime normalize_activations="layer_norm". The raw variants reconstruct raw activation vectors directly.

SAE Sections

layer_07/width_131k/topk_64_raw

Layer 7 raw-activation TopK64 SAE; best layer 7 run by final loss.

  • Status: recommended canonical SAE
  • Layer: 7
  • SAE width: 131072 latents
  • Input width: 4096
  • TopK: 64
  • Activation normalization: none
  • Final training mean normalized MSE: 0.274970
  • Source checkpoint: dams2005/olmo-3-7b-think-sae-layer-07/raw_prebias_131k_topk64/olmo_sae_layer07.pt

Raw residual stream activations; no runtime input normalization.

layer_07/width_163k/topk_64_layernorm

Layer 7 TopK64 SAE trained with runtime layernorm preprocessing.

  • Status: comparison variant
  • Layer: 7
  • SAE width: 163840 latents
  • Input width: 4096
  • TopK: 64
  • Activation normalization: layer_norm
  • Final training mean normalized MSE: 0.277127
  • Source checkpoint: dams2005/olmo-3-7b-think-sae-layer-07/olmo_sae_layer07.pt

Uses SAELens runtime layer normalization before encoding and reverses it after decoding.

layer_15/width_163k/topk_64_raw

Layer 15 raw-activation TopK64 SAE; best layer 15 run by final loss.

  • Status: recommended canonical SAE
  • Layer: 15
  • SAE width: 163840 latents
  • Input width: 4096
  • TopK: 64
  • Activation normalization: none
  • Final training mean normalized MSE: 0.303301
  • Source checkpoint: dams2005/olmo-3-7b-think-sae-layer-15/raw_prebias_163k_topk64/olmo_sae_layer15.pt

Raw residual stream activations; no runtime input normalization.

layer_15/width_163k/topk_64_layernorm

Layer 15 TopK64 SAE trained with runtime layernorm preprocessing.

  • Status: comparison variant
  • Layer: 15
  • SAE width: 163840 latents
  • Input width: 4096
  • TopK: 64
  • Activation normalization: layer_norm
  • Final training mean normalized MSE: 0.304524
  • Source checkpoint: dams2005/olmo-3-7b-think-sae-layer-15/olmo_sae_layer15.pt

Uses SAELens runtime layer normalization before encoding and reverses it after decoding.

layer_23/width_163k/topk_64_raw

Layer 23 raw-activation TopK64 SAE; kept for same-k comparisons.

  • Status: comparison variant
  • Layer: 23
  • SAE width: 163840 latents
  • Input width: 4096
  • TopK: 64
  • Activation normalization: none
  • Final training mean normalized MSE: 0.343949
  • Source checkpoint: dams2005/olmo-3-7b-think-sae-layer-23/raw_prebias_163k_topk64/olmo_sae_layer23.pt

Raw residual stream activations; no runtime input normalization.

layer_23/width_163k/topk_128_raw

Layer 23 raw-activation TopK128 SAE; best layer 23 run by final loss.

  • Status: recommended canonical SAE
  • Layer: 23
  • SAE width: 163840 latents
  • Input width: 4096
  • TopK: 128
  • Activation normalization: none
  • Final training mean normalized MSE: 0.329776
  • Source checkpoint: dams2005/olmo-3-7b-think-sae-layer-23/raw_prebias_163k_topk128/olmo_sae_layer23.pt

Raw residual stream activations; no runtime input normalization.

layer_31/width_163k/topk_128_raw

Layer 31 raw-activation TopK128 SAE.

  • Status: recommended canonical SAE
  • Layer: 31
  • SAE width: 163840 latents
  • Input width: 4096
  • TopK: 128
  • Activation normalization: none
  • Final training mean normalized MSE: 0.326502
  • Source checkpoint: dams2005/olmo-3-7b-think-sae-layer-31/raw_prebias_163k_topk128/olmo_sae_layer31.pt

Raw residual stream activations; no runtime input normalization.

SAELens Usage

from sae_lens import SAE

repo_id = "dams2005/olmo-3-7b-think-saes"
sae_id = "layer_31/width_163k/topk_128_raw"
sae = SAE.from_pretrained(repo_id, sae_id, device="cuda", dtype="float32")

Hugging Face Transformers Usage

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sae_lens import SAE

model_id = "allenai/Olmo-3-7B-Think"
layer = 31
sae_id = "layer_31/width_163k/topk_128_raw"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    dtype=torch.bfloat16,
    device_map="auto",
)
sae = SAE.from_pretrained("dams2005/olmo-3-7b-think-saes", sae_id, device="cuda")

tokens = tokenizer("The Eiffel Tower is in", return_tensors="pt").to(model.device)
with torch.no_grad():
    outputs = model(**tokens, output_hidden_states=True)
    acts = outputs.hidden_states[layer + 1].to("cuda").float()
    feature_acts = sae.encode(acts)
    recons = sae.decode(feature_acts)

hidden_states[0] is the embedding output, so layer L activations are hidden_states[L + 1].

Intended Use

These SAEs are intended for mechanistic interpretability research: feature inspection, activation analysis, and reconstruction experiments over OLMo 3 residual stream activations. They are not language models, classifiers, or safety filters.

Limitations

  • The SAEs were trained for one epoch over the collected activation set.
  • Reconstruction quality varies by layer; later-layer activations were harder to reconstruct.
  • Features are not labeled or curated in this release.
  • The layernorm variants change the SAE preprocessing basis at runtime; prefer the raw canonical variants for direct raw-residual analysis.

Files

  • manifest.json: machine-readable list of included SAEs and training metrics
  • saelens_pretrained_saes.yaml: registry-style snippet for possible SAELens registration
  • */cfg.json: SAELens config for each SAE
  • */sae_weights.safetensors: SAELens weights for each SAE
  • */training.log: training log for the source run
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for dams2005/olmo-3-7b-think-saes

Finetuned
(12)
this model