Pythia-410m Bilinear MLP Transcoders
This repository contains bilinear transcoder models trained to approximate the MLP layers of EleutherAI/pythia-410m.
Overview
Transcoders are auxiliary models that learn to approximate the behavior of transformer components (in this case, MLPs) using simpler architectures. These bilinear transcoders use a Hadamard neural network architecture to approximate each of the 24 MLP layers in Pythia-410m.
Model Architecture
- Base Model: EleutherAI/pythia-410m (24 layers)
- Transcoder Type: Bilinear (Hadamard Neural Network)
- Architecture:
output = W_left @ (x β (W_right @ x)) + bias- Input dimension: 1024 (d_model)
- Hidden dimension: 4096 (4x expansion)
- Output dimension: 1024 (d_model)
- Training: 3000 batches, batch size 512, Muon optimizer (lr=0.02)
- Dataset: monology/pile-uncopyrighted
Performance Summary
All 24 layers achieve >82% variance explained, with most layers >93%:
| Layer | Final FVU | Variance Explained | Notes |
|---|---|---|---|
| 0 | 0.0075 | 99.2% | Best performance |
| 1-2 | 0.167-0.174 | 82.6-83.2% | Hardest to approximate |
| 3-22 | 0.037-0.066 | 93.4-96.3% | Consistent performance |
| 23 | 0.0259 | 97.4% | Second-best |
Average across all layers: 93.4% variance explained (FVU = 0.0657)
Repository Structure
.
βββ layer_0/
β βββ transcoder_weights_l0_bilinear_muon_3000b.pt
β βββ config.yaml
βββ layer_1/
β βββ transcoder_weights_l1_bilinear_muon_3000b.pt
β βββ config.yaml
...
βββ layer_23/
β βββ transcoder_weights_l23_bilinear_muon_3000b.pt
β βββ config.yaml
βββ figures/
β βββ all_layers_comparison.png
β βββ training_curves_overlaid_layers_0_5.png
β βββ training_curves_overlaid_layers_6_11.png
β βββ training_curves_overlaid_layers_12_17.png
β βββ training_curves_overlaid_layers_18_23.png
βββ README.md
Usage
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load base model
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m")
# Load transcoder for layer 3
layer_idx = 3
checkpoint = torch.load(f"layer_{layer_idx}/transcoder_weights_l{layer_idx}_bilinear_muon_3000b.pt")
# Extract configuration
config = checkpoint['config']
print(f"Input dim: {config.n_inputs}")
print(f"Hidden dim: {config.n_hidden}")
print(f"Output dim: {config.n_outputs}")
# Reconstruct model (example - you'll need the Bilinear class)
class Bilinear(torch.nn.Module):
def __init__(self, n_inputs, n_hidden, n_outputs, bias=True):
super().__init__()
self.W_left = torch.nn.Linear(n_hidden, n_outputs, bias=bias)
self.W_right = torch.nn.Linear(n_inputs, n_hidden, bias=False)
def forward(self, x):
right = self.W_right(x)
hadamard = x.unsqueeze(-1) * right.unsqueeze(-2)
return self.W_left(hadamard.sum(dim=-2))
transcoder = Bilinear(config.n_inputs, config.n_hidden, config.n_outputs, config.bias)
transcoder.load_state_dict(checkpoint['model_state_dict'])
transcoder.eval()
# Use transcoder to approximate MLP
with torch.no_grad():
# Get MLP input from layer 3
inputs = tokenizer("Hello world", return_tensors="pt")
outputs = model(**inputs, output_hidden_states=True)
mlp_input = outputs.hidden_states[layer_idx] # Before MLP
# Approximate MLP output with transcoder
transcoded_output = transcoder(mlp_input)
Training Details
- Optimizer: Muon (momentum-based optimizer)
- Learning Rate: 0.02 (hardcoded for Muon)
- Batch Size: 512
- Total Batches: 3000 per layer
- Training Time: ~75 minutes per layer on A100
- Normalization: Per-batch z-score normalization
Checkpoint Contents
Each checkpoint (.pt file) contains:
model_state_dict: Model weightsoptimizer_state_dict: Optimizer stateconfig: Configuration object with dimensionsmse_losses: List of MSE losses per batchvariance_explained: List of variance explained per batchfvu_values: List of FVU values per batchlayer_idx: Layer index (0-23)d_model: Model dimension (1024)
Key Findings
- Layer 0 is dramatically easier to approximate (99.2% VE) - nearly perfect reconstruction
- Layers 1-2 are hardest (~83% VE) - contain complex transformations
- Middle layers (3-22) are remarkably consistent (93-96% VE) - homogeneous structure
- Final layer is highly learnable (97.4% VE)
This suggests that input and output layers have more structured patterns, while early layers (1-2) perform more complex transformations that are difficult for bilinear models to capture.
Citation
If you use these transcoders in your research, please cite:
@misc{pythia410m-bilinear-transcoders,
title={Bilinear MLP Transcoders for Pythia-410m},
author={[Your Name]},
year={2025},
publisher={Hugging Face},
url={https://huggingface.co/[your-username]/pythia-410m-bilinear-transcoders}
}
License
MIT License
Acknowledgments
- Base model: EleutherAI/pythia-410m
- Training dataset: monology/pile-uncopyrighted