CLIP-ViT-L-scope / README.md
lewington's picture
remove example pt file
e2575dd
|
raw
history blame
4.36 kB
metadata
license: cc-by-4.0

CLIP-Scope

by Louka Ewington-Pitsos and Ram ____

Heavily inspired by google/gemma-scope we are releasing a suite of 8 sparse autoencoders for laion/CLIP-ViT-L-14-laion2B-s32B-b82K.

Layer MSE Explained Variance Dead Feature Proportion
2 267.95 0.763 0.000912
5 354.46 0.665 0
8 357.58 0.642 0
11 321.23 0.674 0
14 319.64 0.689 0
17 261.20 0.731 0
20 278.06 0.706 0.0000763
22 299.96 0.684 0

Training logs are available via wandb and training code is available on github. The training process is heavily reliant on AWS ECS so may contain some strange artifacts when a spot instance is killed and the training is resumed by another instance. Some of the code is ripped directly from Hugo Fry.

Vital Statistics:

  • Number of tokens trained per autoencoder: 1.2 Billion
  • Token type: all 257 image tokens (as opposed to just the cls token)
  • Number of unique images trained per autoencoder: 4.5 Million
  • Training Dataset: Laion-2b
  • SAE Architecture: topk with k=32
  • Layer Location: always the residual stream
  • Training Checkpoints: every ~25 million tokens
  • Number of features: 65536

Usage

import PIL
from clipscope import ConfiguredViT, TopKSAE

device='cpu'
filename_in_hf_repo = "725159424.pt"
sae = TopKSAE.from_pretrained(repo_id="lewington/CLIP-ViT-L-scope", filename=filename_in_hf_repo, device=device)

transformer_name='laion/CLIP-ViT-L-14-laion2B-s32B-b82K'
locations = [(22, 'resid')]
transformer = ConfiguredViT(locations, transformer_name, device=device)

input = PIL.Image.new("RGB", (224, 224), (0, 0, 0)) # black image for testing

activations = transformer.all_activations(input)[locations[0]] # (1, 257, 1024)
assert activations.shape == (1, 257, 1024)

activations = activations[:, 0] # just the cls token
# alternatively flatten the activations
# activations = activations.flatten(1)

print('activations shape', activations.shape)

output = sae.forward_verbose(activations)

print('output keys', output.keys())

print('latent shape', output['latent'].shape) # (1, 65536)
print('reconstruction shape', output['reconstruction'].shape) # (1, 1024)

Error Formulae

We calculate MSE as (batch - reconstruction).pow(2).sum(dim=-1).mean() i.e. The MSE between the batch and the un-normalized reconstruction, summed across features. We use batch norm to bring all activations into a similar range.

We calculate Explained Variance as

delta_variance = (batch - reconstruction).pow(2).sum(dim=-1)
activation_variance = (batch - batch.mean(dim=-1, keepdim=True)).pow(2).sum(dim=-1)
explained_variance = (1 - delta_variance / activation_variance).mean()

We calculate dead feature proportion as the proportion of features which have not activated in the last 10,000,000 samples.

Subjective Interpretability

To give an intuitive feel for the interpretability of these models we run 500,000 images from laion2b selected at random through the final trained SAE for each layer and record the latent activations for each. We then winnow down to the first 100 features which activate for at least 9 images. We cherry pick 3 of these and display them in a 3x3 grid for each layer. We do this twice, one for the CLS token and once for token 137 (near the middle of the image). Below are the 6 grids for feature 22. Other grids are available for each layer.

Automated Sort EVALs

Token-wise MSE