Gemma 2b - IT - Residual Stream SAEs
This SAE is a follow-up to my other Gemma-2b SAEs trained on the based model.
These SAEs were trained with SAE Lens and the library version is stored in the cfg.json.
All training hyperparameters are specified in cfg.json.
They are loadable using SAE via a few methods. The preferred method is to use the following:
import torch
from transformer_lens import HookedTransformer
from sae_lens import SAE, ActivationsStore
torch.set_grad_enabled(False)
model = HookedTransformer.from_pretrained("gemma-2b-it")
sae, cfg, sparsity = SAE.from_pretrained(
"gemma-2b-it-res-jb", # to see the list of available releases, go to: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
"blocks.12.hook_resid_post" # change this to another specific SAE ID in the release if desired.
)
# For loading activations or tokens from the training dataset.
activation_store = ActivationsStore.from_sae(
model=model,
sae=sae,
streaming=True,
# fairly conservative parameters here so can use same for larger
# models without running out of memory.
store_batch_size_prompts=8,
train_batch_size_tokens=4096,
n_batches_in_buffer=4,
device=device,
)
SAEs
Resid Post 12
Stats:
- 16384 Features (expansion factor 8) achieving a CE Loss score of
- CE Loss score of 98.13%.
- Mean L0 58 (in practice L0 is log normal distributed and is heavily right tailed).
- Dead Features: Less than 500 dead features.
Notes:
- This SAE was trained on open-web-text tokenized.
- The sparsity json didn't have enough samples in it so I wouldn't trust it.