Pythia-410m Bilinear MLP Transcoders

This repository contains bilinear transcoder models trained to approximate the MLP layers of EleutherAI/pythia-410m.

Overview

Transcoders are auxiliary models that learn to approximate the behavior of transformer components (in this case, MLPs) using simpler architectures. These bilinear transcoders use a Hadamard neural network architecture to approximate each of the 24 MLP layers in Pythia-410m.

Model Architecture

  • Base Model: EleutherAI/pythia-410m (24 layers)
  • Transcoder Type: Bilinear (Hadamard Neural Network)
  • Architecture: output = W_left @ (x βŠ™ (W_right @ x)) + bias
    • Input dimension: 1024 (d_model)
    • Hidden dimension: 4096 (4x expansion)
    • Output dimension: 1024 (d_model)
  • Training: 3000 batches, batch size 512, Muon optimizer (lr=0.02)
  • Dataset: monology/pile-uncopyrighted

Performance Summary

All 24 layers achieve >82% variance explained, with most layers >93%:

Layer Final FVU Variance Explained Notes
0 0.0075 99.2% Best performance
1-2 0.167-0.174 82.6-83.2% Hardest to approximate
3-22 0.037-0.066 93.4-96.3% Consistent performance
23 0.0259 97.4% Second-best

Average across all layers: 93.4% variance explained (FVU = 0.0657)

Repository Structure

.
β”œβ”€β”€ layer_0/
β”‚   β”œβ”€β”€ transcoder_weights_l0_bilinear_muon_3000b.pt
β”‚   └── config.yaml
β”œβ”€β”€ layer_1/
β”‚   β”œβ”€β”€ transcoder_weights_l1_bilinear_muon_3000b.pt
β”‚   └── config.yaml
...
β”œβ”€β”€ layer_23/
β”‚   β”œβ”€β”€ transcoder_weights_l23_bilinear_muon_3000b.pt
β”‚   └── config.yaml
β”œβ”€β”€ figures/
β”‚   β”œβ”€β”€ all_layers_comparison.png
β”‚   β”œβ”€β”€ training_curves_overlaid_layers_0_5.png
β”‚   β”œβ”€β”€ training_curves_overlaid_layers_6_11.png
β”‚   β”œβ”€β”€ training_curves_overlaid_layers_12_17.png
β”‚   └── training_curves_overlaid_layers_18_23.png
└── README.md

Usage

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load base model
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m")

# Load transcoder for layer 3
layer_idx = 3
checkpoint = torch.load(f"layer_{layer_idx}/transcoder_weights_l{layer_idx}_bilinear_muon_3000b.pt")

# Extract configuration
config = checkpoint['config']
print(f"Input dim: {config.n_inputs}")
print(f"Hidden dim: {config.n_hidden}")
print(f"Output dim: {config.n_outputs}")

# Reconstruct model (example - you'll need the Bilinear class)
class Bilinear(torch.nn.Module):
    def __init__(self, n_inputs, n_hidden, n_outputs, bias=True):
        super().__init__()
        self.W_left = torch.nn.Linear(n_hidden, n_outputs, bias=bias)
        self.W_right = torch.nn.Linear(n_inputs, n_hidden, bias=False)

    def forward(self, x):
        right = self.W_right(x)
        hadamard = x.unsqueeze(-1) * right.unsqueeze(-2)
        return self.W_left(hadamard.sum(dim=-2))

transcoder = Bilinear(config.n_inputs, config.n_hidden, config.n_outputs, config.bias)
transcoder.load_state_dict(checkpoint['model_state_dict'])
transcoder.eval()

# Use transcoder to approximate MLP
with torch.no_grad():
    # Get MLP input from layer 3
    inputs = tokenizer("Hello world", return_tensors="pt")
    outputs = model(**inputs, output_hidden_states=True)
    mlp_input = outputs.hidden_states[layer_idx]  # Before MLP

    # Approximate MLP output with transcoder
    transcoded_output = transcoder(mlp_input)

Training Details

  • Optimizer: Muon (momentum-based optimizer)
  • Learning Rate: 0.02 (hardcoded for Muon)
  • Batch Size: 512
  • Total Batches: 3000 per layer
  • Training Time: ~75 minutes per layer on A100
  • Normalization: Per-batch z-score normalization

Checkpoint Contents

Each checkpoint (.pt file) contains:

  • model_state_dict: Model weights
  • optimizer_state_dict: Optimizer state
  • config: Configuration object with dimensions
  • mse_losses: List of MSE losses per batch
  • variance_explained: List of variance explained per batch
  • fvu_values: List of FVU values per batch
  • layer_idx: Layer index (0-23)
  • d_model: Model dimension (1024)

Key Findings

  1. Layer 0 is dramatically easier to approximate (99.2% VE) - nearly perfect reconstruction
  2. Layers 1-2 are hardest (~83% VE) - contain complex transformations
  3. Middle layers (3-22) are remarkably consistent (93-96% VE) - homogeneous structure
  4. Final layer is highly learnable (97.4% VE)

This suggests that input and output layers have more structured patterns, while early layers (1-2) perform more complex transformations that are difficult for bilinear models to capture.

Citation

If you use these transcoders in your research, please cite:

@misc{pythia410m-bilinear-transcoders,
  title={Bilinear MLP Transcoders for Pythia-410m},
  author={[Your Name]},
  year={2025},
  publisher={Hugging Face},
  url={https://huggingface.co/[your-username]/pythia-410m-bilinear-transcoders}
}

License

MIT License

Acknowledgments

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support