File size: 8,082 Bytes
cba8148
 
c683d55
cba8148
 
 
 
ebffa7b
cba8148
70fbe0e
cba8148
70fbe0e
cba8148
70fbe0e
00727eb
 
 
 
 
 
 
 
 
 
70fbe0e
00727eb
 
 
 
 
cba8148
 
 
 
00727eb
70fbe0e
cba8148
 
 
12ae9d2
00727eb
cba8148
 
 
00727eb
1fd2883
00727eb
1fd2883
 
 
e2575dd
 
 
 
 
326ebe9
 
e2575dd
 
326ebe9
e2575dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00727eb
70fbe0e
00727eb
70fbe0e
 
cba8148
70fbe0e
 
 
 
 
cba8148
70fbe0e
cba8148
70fbe0e
cba8148
00727eb
 
 
 
 
 
 
 
 
 
 
cba8148
00727eb
cba8148
 
 
00727eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cba8148
00727eb
 
 
 
242c1e5
 
 
 
df093ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
---
license: cc-by-4.0
library_name: clipscope
---

# CLIP-Scope

by [Louka Ewington-Pitsos](https://www.linkedin.com/in/louka-ewington-pitsos-2a92b21a0/?originalSubdomain=au) and [Ram Rattan Goyal](https://www.linkedin.com/in/ram-rattan-goyal/)

Heavily inspired by [google/gemma-scope](https://huggingface.co/google/gemma-scope) we are releasing a suite of 8 sparse autoencoders for [laion/CLIP-ViT-L-14-laion2B-s32B-b82K](https://huggingface.co/laion/CLIP-ViT-L-14-laion2B-s32B-b82K). 

![](./media/mse.png)


| Layer |   MSE   | Explained Variance | Dead Feature Proportion | Active Feature Proportion | Sort Eval Accuracy (CLS token) |
|-------|---------|--------------------|-------------------------|---------------------------|-------------------------------|
|   2   | 267.95  |        0.763       |        0.000912         |           0.001            |               -               |
|   5   | 354.46  |        0.665       |            0            |           0.0034           |               -               |
|   8   | 357.58  |        0.642       |            0            |           0.01074          |               -               |
|  11   | 321.23  |        0.674       |            0            |           0.0415           |            0.7334              |
|  14   | 319.64  |        0.689       |            0            |           0.07866          |            0.7427              |
|  17   | 261.20  |        0.731       |            0            |           0.1477           |            0.8689              |
|  20   | 278.06  |        0.706       |        0.0000763        |           0.2036           |            0.9149              |
|  22   | 299.96  |        0.684       |            0            |           0.1588           |            0.8641              |



![](./media/sort-eval.png)

Training logs are available [via wandb](https://wandb.ai/lewington/ViT-L-14-laion2B-s32B-b82K/workspace) and training code is available on [github](https://github.com/Lewington-pitsos/vitsae). The training process is heavily reliant on [AWS ECS](https://aws.amazon.com/ecs/) so the weights and biases logs may contain some strange artifacts when a spot instance is killed and the training is resumed by another instance. Some of the code is ripped directly from [Hugo Fry](https://github.com/HugoFry/mats_sae_training_for_ViTs).

### Vital Statistics:

- Number of tokens trained per autoencoder: 1.2 Billion
- Token type: all 257 image tokens (as opposed to just the CLS token)
- Number of unique images trained per autoencoder: 4.5 Million
- Training Dataset: [Laion-2b](https://huggingface.co/datasets/laion/laion2B-multi-joined-translated-to-en)
- SAE Architecture: topk with k=32
- Layer Location: always the residual stream
- Training Checkpoints: every ~100 million tokens
- Number of features per autoencoder: 65536 (expansion factor 16)

## Usage

First install our pypi package and PIL (pillow)

`pip install clipscope pillow`

Then

```python
import PIL
from clipscope import ConfiguredViT, TopKSAE

device='cpu'
filename_in_hf_repo = "22_resid/1200013184.pt"
sae = TopKSAE.from_pretrained(checkpoint=filename_in_hf_repo, device=device)

locations = [(22, 'resid')]
transformer = ConfiguredViT(locations, device=device)

input = PIL.Image.new("RGB", (224, 224), (0, 0, 0)) # black image for testing

activations = transformer.all_activations(input)[locations[0]] # (1, 257, 1024)
assert activations.shape == (1, 257, 1024)

activations = activations[:, 0] # just the cls token
# alternatively flatten the activations
# activations = activations.flatten(1)

print('activations shape', activations.shape)

output = sae.forward_verbose(activations)

print('output keys', output.keys())

print('latent shape', output['latent'].shape) # (1, 65536)
print('reconstruction shape', output['reconstruction'].shape) # (1, 1024)
```

## Formulae

We calculate MSE as `(batch - reconstruction).pow(2).sum(dim=-1).mean()` i.e. The MSE between the batch and the reconstruction, summed across features. We use batch norm to bring all activations into a similar range. We normalize the input before training using batch norm.

We calculate Explained Variance as 

```python
delta_variance = (batch - reconstruction).pow(2).sum(dim=-1)
activation_variance = (batch - batch.mean(dim=-1, keepdim=True)).pow(2).sum(dim=-1)
explained_variance = (1 - delta_variance / activation_variance).mean()
```

We calculate dead feature proportion as the proportion of features which have not activated in the last 10,000,000 samples.

## Subjective Interpretability  

To give an intuitive feel for the interpretability of these models we run 500,000 images from laion2b selected at random through the final trained SAE for each layer and record the latent activations on the CLS token for each. We then winnow down to the first 100 features which activate for at least 9 images. We cherry pick 3 of these and display them in a 3x3 grid for each layer. Below are the 3 grids for feature 22. Other grids are available for each layer in the `README.md` for that layer.

**feature 308**

![](./22_resid/examples/308_grid.png)

**feature 464**

![](./22_resid/examples/464_grid.png)

**feature 575**

![](./22_resid/examples/575_grid.png)

## Automated Sort EVALs

We performed automated [Sort Evals](https://transformer-circuits.pub/2024/august-update/index.html) following Anthropic except with the *formatted dataset examples* for both features being replaced by a 3x3 grid of the top 9 activating images and the *formatted query example* being replaced by an image activating for only one of those features but not included in the 3x3s. Our methodology was as follows:

1. pass 500,000 laion images from the training dataset through the SAE, record the first 2048 latent activations for each image (only 2048 to conserve space)
2. ignore all features which do not activate for at least 10 of those images
3. for the remaining features, if there are more than 100, select the first 100
4. if there are fewer than 100, the comparison will not be fair so we cannot proceed (this is why we do not have sort eval scores for layers 8, 5 and 2)
5. randomly select 500 pairs from among the 100 features, and perform a sort eval for each pair using gpt-4o
4. select 400 samples randomly from among these 500 sort evals 5 times, each time recording the accuracy (n correct / 400) for that subsample of 400
5. calculate the mean and standard deviation of these 5 accuracies.

The outcomes are plotted below. Active Feature Proportion is the proportion of features which activate for at least 10 images across the 500,000 image dataset for that layer. For the CLS token at layer 2 only 2/2048 features were "active" in this sense.

![](./media/sort-eval.png)

![](./media/active-feature-proportion.png)

## Token-wise MSE

All layers were trained across all 257 image patches. Below we provide plots demonstrating the reconstruction MSE for each token (other than the CLS token) as training progressed. It seems that throughout training the outer tokens are easier to reconstruct than those in the middle, presumably because these tokens capture more important information (i.e. foreground objects) and are therefore more information rich.

![](./media/layer_22_training_outputs.png)
![](./media/layer_22_individually_scaled.png)

## References

We draw heavily from prior Visual Sparse Autoencoder research work by [Hugo Fry](https://www.lesswrong.com/posts/bCtbuWraqYTDtuARg/towards-multimodal-interpretability-learning-sparse-2) and [Gytis Daujotas](https://www.lesswrong.com/posts/iYFuZo9BMvr6GgMs5/case-study-interpreting-manipulating-and-controlling-clip). We also rely on Autointerpretability research from [Anthropic Circuits Updates](https://transformer-circuits.pub/2024/august-update/index.html), and take the TopKSAE architecture and training methodology from [Scaling and Evaluating Sparse Autoencoders](https://cdn.openai.com/papers/sparse-autoencoders.pdf). We base all our training and inference on data from the [LAION project](https://laion.ai/laion-400-open-dataset/).