CLIP-ViT-L-scope / README.md
lewington's picture
Update README.md
c683d55 verified
metadata
license: cc-by-4.0
library_name: clipscope

CLIP-Scope

by Louka Ewington-Pitsos and Ram Rattan Goyal

Heavily inspired by google/gemma-scope we are releasing a suite of 8 sparse autoencoders for laion/CLIP-ViT-L-14-laion2B-s32B-b82K.

Layer MSE Explained Variance Dead Feature Proportion Active Feature Proportion Sort Eval Accuracy (CLS token)
2 267.95 0.763 0.000912 0.001 -
5 354.46 0.665 0 0.0034 -
8 357.58 0.642 0 0.01074 -
11 321.23 0.674 0 0.0415 0.7334
14 319.64 0.689 0 0.07866 0.7427
17 261.20 0.731 0 0.1477 0.8689
20 278.06 0.706 0.0000763 0.2036 0.9149
22 299.96 0.684 0 0.1588 0.8641

Training logs are available via wandb and training code is available on github. The training process is heavily reliant on AWS ECS so the weights and biases logs may contain some strange artifacts when a spot instance is killed and the training is resumed by another instance. Some of the code is ripped directly from Hugo Fry.

Vital Statistics:

  • Number of tokens trained per autoencoder: 1.2 Billion
  • Token type: all 257 image tokens (as opposed to just the CLS token)
  • Number of unique images trained per autoencoder: 4.5 Million
  • Training Dataset: Laion-2b
  • SAE Architecture: topk with k=32
  • Layer Location: always the residual stream
  • Training Checkpoints: every ~100 million tokens
  • Number of features per autoencoder: 65536 (expansion factor 16)

Usage

First install our pypi package and PIL (pillow)

pip install clipscope pillow

Then

import PIL
from clipscope import ConfiguredViT, TopKSAE

device='cpu'
filename_in_hf_repo = "22_resid/1200013184.pt"
sae = TopKSAE.from_pretrained(checkpoint=filename_in_hf_repo, device=device)

locations = [(22, 'resid')]
transformer = ConfiguredViT(locations, device=device)

input = PIL.Image.new("RGB", (224, 224), (0, 0, 0)) # black image for testing

activations = transformer.all_activations(input)[locations[0]] # (1, 257, 1024)
assert activations.shape == (1, 257, 1024)

activations = activations[:, 0] # just the cls token
# alternatively flatten the activations
# activations = activations.flatten(1)

print('activations shape', activations.shape)

output = sae.forward_verbose(activations)

print('output keys', output.keys())

print('latent shape', output['latent'].shape) # (1, 65536)
print('reconstruction shape', output['reconstruction'].shape) # (1, 1024)

Formulae

We calculate MSE as (batch - reconstruction).pow(2).sum(dim=-1).mean() i.e. The MSE between the batch and the reconstruction, summed across features. We use batch norm to bring all activations into a similar range. We normalize the input before training using batch norm.

We calculate Explained Variance as

delta_variance = (batch - reconstruction).pow(2).sum(dim=-1)
activation_variance = (batch - batch.mean(dim=-1, keepdim=True)).pow(2).sum(dim=-1)
explained_variance = (1 - delta_variance / activation_variance).mean()

We calculate dead feature proportion as the proportion of features which have not activated in the last 10,000,000 samples.

Subjective Interpretability

To give an intuitive feel for the interpretability of these models we run 500,000 images from laion2b selected at random through the final trained SAE for each layer and record the latent activations on the CLS token for each. We then winnow down to the first 100 features which activate for at least 9 images. We cherry pick 3 of these and display them in a 3x3 grid for each layer. Below are the 3 grids for feature 22. Other grids are available for each layer in the README.md for that layer.

feature 308

feature 464

feature 575

Automated Sort EVALs

We performed automated Sort Evals following Anthropic except with the formatted dataset examples for both features being replaced by a 3x3 grid of the top 9 activating images and the formatted query example being replaced by an image activating for only one of those features but not included in the 3x3s. Our methodology was as follows:

  1. pass 500,000 laion images from the training dataset through the SAE, record the first 2048 latent activations for each image (only 2048 to conserve space)
  2. ignore all features which do not activate for at least 10 of those images
  3. for the remaining features, if there are more than 100, select the first 100
  4. if there are fewer than 100, the comparison will not be fair so we cannot proceed (this is why we do not have sort eval scores for layers 8, 5 and 2)
  5. randomly select 500 pairs from among the 100 features, and perform a sort eval for each pair using gpt-4o
  6. select 400 samples randomly from among these 500 sort evals 5 times, each time recording the accuracy (n correct / 400) for that subsample of 400
  7. calculate the mean and standard deviation of these 5 accuracies.

The outcomes are plotted below. Active Feature Proportion is the proportion of features which activate for at least 10 images across the 500,000 image dataset for that layer. For the CLS token at layer 2 only 2/2048 features were "active" in this sense.

Token-wise MSE

All layers were trained across all 257 image patches. Below we provide plots demonstrating the reconstruction MSE for each token (other than the CLS token) as training progressed. It seems that throughout training the outer tokens are easier to reconstruct than those in the middle, presumably because these tokens capture more important information (i.e. foreground objects) and are therefore more information rich.

References

We draw heavily from prior Visual Sparse Autoencoder research work by Hugo Fry and Gytis Daujotas. We also rely on Autointerpretability research from Anthropic Circuits Updates, and take the TopKSAE architecture and training methodology from Scaling and Evaluating Sparse Autoencoders. We base all our training and inference on data from the LAION project.