File size: 8,082 Bytes
cba8148 c683d55 cba8148 ebffa7b cba8148 70fbe0e cba8148 70fbe0e cba8148 70fbe0e 00727eb 70fbe0e 00727eb cba8148 00727eb 70fbe0e cba8148 12ae9d2 00727eb cba8148 00727eb 1fd2883 00727eb 1fd2883 e2575dd 326ebe9 e2575dd 326ebe9 e2575dd 00727eb 70fbe0e 00727eb 70fbe0e cba8148 70fbe0e cba8148 70fbe0e cba8148 70fbe0e cba8148 00727eb cba8148 00727eb cba8148 00727eb cba8148 00727eb 242c1e5 df093ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
---
license: cc-by-4.0
library_name: clipscope
---
# CLIP-Scope
by [Louka Ewington-Pitsos](https://www.linkedin.com/in/louka-ewington-pitsos-2a92b21a0/?originalSubdomain=au) and [Ram Rattan Goyal](https://www.linkedin.com/in/ram-rattan-goyal/)
Heavily inspired by [google/gemma-scope](https://huggingface.co/google/gemma-scope) we are releasing a suite of 8 sparse autoencoders for [laion/CLIP-ViT-L-14-laion2B-s32B-b82K](https://huggingface.co/laion/CLIP-ViT-L-14-laion2B-s32B-b82K).
![](./media/mse.png)
| Layer | MSE | Explained Variance | Dead Feature Proportion | Active Feature Proportion | Sort Eval Accuracy (CLS token) |
|-------|---------|--------------------|-------------------------|---------------------------|-------------------------------|
| 2 | 267.95 | 0.763 | 0.000912 | 0.001 | - |
| 5 | 354.46 | 0.665 | 0 | 0.0034 | - |
| 8 | 357.58 | 0.642 | 0 | 0.01074 | - |
| 11 | 321.23 | 0.674 | 0 | 0.0415 | 0.7334 |
| 14 | 319.64 | 0.689 | 0 | 0.07866 | 0.7427 |
| 17 | 261.20 | 0.731 | 0 | 0.1477 | 0.8689 |
| 20 | 278.06 | 0.706 | 0.0000763 | 0.2036 | 0.9149 |
| 22 | 299.96 | 0.684 | 0 | 0.1588 | 0.8641 |
![](./media/sort-eval.png)
Training logs are available [via wandb](https://wandb.ai/lewington/ViT-L-14-laion2B-s32B-b82K/workspace) and training code is available on [github](https://github.com/Lewington-pitsos/vitsae). The training process is heavily reliant on [AWS ECS](https://aws.amazon.com/ecs/) so the weights and biases logs may contain some strange artifacts when a spot instance is killed and the training is resumed by another instance. Some of the code is ripped directly from [Hugo Fry](https://github.com/HugoFry/mats_sae_training_for_ViTs).
### Vital Statistics:
- Number of tokens trained per autoencoder: 1.2 Billion
- Token type: all 257 image tokens (as opposed to just the CLS token)
- Number of unique images trained per autoencoder: 4.5 Million
- Training Dataset: [Laion-2b](https://huggingface.co/datasets/laion/laion2B-multi-joined-translated-to-en)
- SAE Architecture: topk with k=32
- Layer Location: always the residual stream
- Training Checkpoints: every ~100 million tokens
- Number of features per autoencoder: 65536 (expansion factor 16)
## Usage
First install our pypi package and PIL (pillow)
`pip install clipscope pillow`
Then
```python
import PIL
from clipscope import ConfiguredViT, TopKSAE
device='cpu'
filename_in_hf_repo = "22_resid/1200013184.pt"
sae = TopKSAE.from_pretrained(checkpoint=filename_in_hf_repo, device=device)
locations = [(22, 'resid')]
transformer = ConfiguredViT(locations, device=device)
input = PIL.Image.new("RGB", (224, 224), (0, 0, 0)) # black image for testing
activations = transformer.all_activations(input)[locations[0]] # (1, 257, 1024)
assert activations.shape == (1, 257, 1024)
activations = activations[:, 0] # just the cls token
# alternatively flatten the activations
# activations = activations.flatten(1)
print('activations shape', activations.shape)
output = sae.forward_verbose(activations)
print('output keys', output.keys())
print('latent shape', output['latent'].shape) # (1, 65536)
print('reconstruction shape', output['reconstruction'].shape) # (1, 1024)
```
## Formulae
We calculate MSE as `(batch - reconstruction).pow(2).sum(dim=-1).mean()` i.e. The MSE between the batch and the reconstruction, summed across features. We use batch norm to bring all activations into a similar range. We normalize the input before training using batch norm.
We calculate Explained Variance as
```python
delta_variance = (batch - reconstruction).pow(2).sum(dim=-1)
activation_variance = (batch - batch.mean(dim=-1, keepdim=True)).pow(2).sum(dim=-1)
explained_variance = (1 - delta_variance / activation_variance).mean()
```
We calculate dead feature proportion as the proportion of features which have not activated in the last 10,000,000 samples.
## Subjective Interpretability
To give an intuitive feel for the interpretability of these models we run 500,000 images from laion2b selected at random through the final trained SAE for each layer and record the latent activations on the CLS token for each. We then winnow down to the first 100 features which activate for at least 9 images. We cherry pick 3 of these and display them in a 3x3 grid for each layer. Below are the 3 grids for feature 22. Other grids are available for each layer in the `README.md` for that layer.
**feature 308**
![](./22_resid/examples/308_grid.png)
**feature 464**
![](./22_resid/examples/464_grid.png)
**feature 575**
![](./22_resid/examples/575_grid.png)
## Automated Sort EVALs
We performed automated [Sort Evals](https://transformer-circuits.pub/2024/august-update/index.html) following Anthropic except with the *formatted dataset examples* for both features being replaced by a 3x3 grid of the top 9 activating images and the *formatted query example* being replaced by an image activating for only one of those features but not included in the 3x3s. Our methodology was as follows:
1. pass 500,000 laion images from the training dataset through the SAE, record the first 2048 latent activations for each image (only 2048 to conserve space)
2. ignore all features which do not activate for at least 10 of those images
3. for the remaining features, if there are more than 100, select the first 100
4. if there are fewer than 100, the comparison will not be fair so we cannot proceed (this is why we do not have sort eval scores for layers 8, 5 and 2)
5. randomly select 500 pairs from among the 100 features, and perform a sort eval for each pair using gpt-4o
4. select 400 samples randomly from among these 500 sort evals 5 times, each time recording the accuracy (n correct / 400) for that subsample of 400
5. calculate the mean and standard deviation of these 5 accuracies.
The outcomes are plotted below. Active Feature Proportion is the proportion of features which activate for at least 10 images across the 500,000 image dataset for that layer. For the CLS token at layer 2 only 2/2048 features were "active" in this sense.
![](./media/sort-eval.png)
![](./media/active-feature-proportion.png)
## Token-wise MSE
All layers were trained across all 257 image patches. Below we provide plots demonstrating the reconstruction MSE for each token (other than the CLS token) as training progressed. It seems that throughout training the outer tokens are easier to reconstruct than those in the middle, presumably because these tokens capture more important information (i.e. foreground objects) and are therefore more information rich.
![](./media/layer_22_training_outputs.png)
![](./media/layer_22_individually_scaled.png)
## References
We draw heavily from prior Visual Sparse Autoencoder research work by [Hugo Fry](https://www.lesswrong.com/posts/bCtbuWraqYTDtuARg/towards-multimodal-interpretability-learning-sparse-2) and [Gytis Daujotas](https://www.lesswrong.com/posts/iYFuZo9BMvr6GgMs5/case-study-interpreting-manipulating-and-controlling-clip). We also rely on Autointerpretability research from [Anthropic Circuits Updates](https://transformer-circuits.pub/2024/august-update/index.html), and take the TopKSAE architecture and training methodology from [Scaling and Evaluating Sparse Autoencoders](https://cdn.openai.com/papers/sparse-autoencoders.pdf). We base all our training and inference on data from the [LAION project](https://laion.ai/laion-400-open-dataset/). |