π¬ Sparse Autoencoders for Ministral-3-3B-Instruct-2512
This repository contains high-dimensional Sparse Autoencoders (SAEs) trained on the residual stream of mistralai/Ministral-3-3B-Instruct-2512.
These SAEs are designed for Mechanistic Interpretability, decomposing the language model's dense internal representations into sparse, interpretable features. Given the training dataset mixture, these features are particularly well-suited for investigating mathematical reasoning (Chain-of-Thought), alignment/RLHF, and safety/refusal mechanisms.
π Architecture & Configuration
| Property | Value | Description |
|---|---|---|
| Target Model | mistralai/Ministral-3-3B-Instruct-2512 |
Base LLM |
| Target Layers | [19, 21] | Residual stream layers |
| Expansion Factor | 16x | Ratio of SAE features to model dimensions |
| Dictionary Size ($d_{sae}$) | 49,152 | Total number of features |
| Input Dimension ($d_{model}$) | 3,072 | Residual stream width |
π Training Corpus & Context
The SAEs were trained on a specialized heterogeneous token mixture (~600M tokens) optimized to surface features related to reasoning and safety:
- NuminaMath-CoT (40%): Focuses on Chain-of-Thought and advanced mathematical reasoning.
- FineWeb (20%): High-quality web text for general world knowledge and linguistics.
- HH-RLHF (20%): Anthropic's Helpful and Harmless data, extracting alignment-related features.
- Alpaca (10%): General instruction-following and tasks.
- Safety/Jailbreak Mix (30.0%): Combined JailbreakBench, HarmBench, and AdvBench prompts to intentionally isolate refusal vectors and malicious intent representations.
βοΈ Training Hyperparameters
- Learning Rate: 0.0002 (Cosine Annealing with Warm Restarts)
- Sparsity Penalty ($Ξ»$): 0.0008
- Batch Size: 2048 tokens
- Optimization: Adam (Ξ²1=0.9, Ξ²2=0.999)
- Resampling: Dead neurons resampled every 25000 steps (threshold: 25000).
- Precision & Hardware: Trained in
bfloat16on an NVIDIA RTX A6000.
π» Usage with sae-lens
These SAEs are perfectly formatted for direct use with sae-lens. Since there are multiple layers, simply specify the layer string via the sae_id argument:
from sae_lens import SAE
# Load the SAE for layer 19
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="analist/ministral-3b-sae",
sae_id="layer_19",
device="cuda"
)
# sae.encode(activations)
π» Usage with PyTorch
If you prefer not to use sae-lens, you can load the safetensors weights directly:
import torch
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
layer = 19
model_path = hf_hub_download(repo_id="analist/ministral-3b-sae", filename=f"layer_{layer}/sae_weights.safetensors")
# Initialize your custom SAE class
# sae = SparseAutoencoder(d_model=3072, d_sae=49152, dtype=torch.bfloat16, layer_idx=layer)
# sae.load_state_dict(load_file(model_path), strict=False)
# sae.eval()
β οΈ Limitations
- Layer Specificity: Each SAE is strictly tied to a specific layer. Applying it to other layers or models will produce meaningless noise.
- Feature Completeness: While optimized for safety and reasoning, the dictionary may not capture 100% of the model's concepts, and some polysemanticity may remain.
- Downloads last month
- 36
Model tree for analist/ministral-3b-sae
Base model
mistralai/Ministral-3-3B-Base-2512