anonymous-author-129 commited on
Commit
5bc51c9
·
verified ·
1 Parent(s): 9524885

Upload 31 files

Browse files
Files changed (31) hide show
  1. SAE/__init__.py +1 -0
  2. SAE/config.json +23 -0
  3. SAE/dataset_iterator.py +53 -0
  4. SAE/sae.py +215 -0
  5. SAE/sae_utils.py +46 -0
  6. SDLens/__init__.py +1 -0
  7. SDLens/hooked_scheduler.py +40 -0
  8. SDLens/hooked_sd_pipeline.py +321 -0
  9. app.ipynb +0 -0
  10. app.py +405 -0
  11. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  12. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  13. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  14. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  15. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  16. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  17. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  18. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  19. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  20. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  21. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  22. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  23. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  24. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  25. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  26. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  27. requirements.txt +11 -0
  28. scripts/collect_latents_dataset.py +96 -0
  29. scripts/train_sae.py +308 -0
  30. utils/__init__.py +1 -0
  31. utils/hooks.py +46 -0
SAE/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sae import SparseAutoencoder
SAE/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sae_configs": [
3
+ {
4
+ "d_model": 1280,
5
+ "n_dirs": 5120,
6
+ "k": 20
7
+ },
8
+ {
9
+ "d_model": 1280,
10
+ "n_dirs": 640,
11
+ "k": 20
12
+ }
13
+ ],
14
+ "bs": 4096,
15
+ "log_interval": 500,
16
+ "save_interval": 5000,
17
+
18
+ "paths_to_latents": [
19
+ "PASS YOUR PATHS HERE. Example /home/username/latents/<timestamp>. It should contain tar archives with latents."
20
+ ],
21
+ "save_path_base": "<Your SAE save path>",
22
+ "block_name": "unet.down_blocks.2.attentions.1"
23
+ }
SAE/dataset_iterator.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import webdataset as wds
2
+ import os
3
+ import torch
4
+
5
+ class ActivationsDataloader:
6
+ def __init__(self, paths_to_datasets, block_name, batch_size, output_or_diff='diff', num_in_buffer=50):
7
+ assert output_or_diff in ['diff', 'output'], "Provide 'output' or 'diff'"
8
+
9
+ self.dataset = wds.WebDataset(
10
+ [os.path.join(path_to_dataset, f"{block_name}.tar")
11
+ for path_to_dataset in paths_to_datasets]
12
+ ).decode("torch")
13
+ self.iter = iter(self.dataset)
14
+ self.buffer = None
15
+ self.pointer = 0
16
+ self.num_in_buffer = num_in_buffer
17
+ self.output_or_diff = output_or_diff
18
+ self.batch_size = batch_size
19
+ self.one_size = None
20
+
21
+ def renew_buffer(self, to_retrieve):
22
+ to_merge = []
23
+ if self.buffer is not None and self.buffer.shape[0] > self.pointer:
24
+ to_merge = [self.buffer[self.pointer:].clone()]
25
+ del self.buffer
26
+ for _ in range(to_retrieve):
27
+ sample = next(self.iter)
28
+ latents = sample['output.pth'] if self.output_or_diff == 'output' else sample['diff.pth']
29
+ latents = latents.permute((0, 1, 3, 4, 2))
30
+ latents = latents.reshape((-1, latents.shape[-1]))
31
+ to_merge.append(latents.to('cuda'))
32
+ self.one_size = latents.shape[0]
33
+ self.buffer = torch.cat(to_merge, dim=0)
34
+ shuffled_indices = torch.randperm(self.buffer.shape[0])
35
+ self.buffer = self.buffer[shuffled_indices]
36
+ self.pointer = 0
37
+
38
+ def iterate(self):
39
+ while True:
40
+ if self.buffer == None or self.buffer.shape[0] - self.pointer < self.num_in_buffer * self.one_size * 4 // 5:
41
+ try:
42
+ to_retrieve = self.num_in_buffer if self.buffer is None else self.num_in_buffer // 5
43
+ self.renew_buffer(to_retrieve)
44
+ except StopIteration:
45
+ break
46
+
47
+ batch = self.buffer[self.pointer: self.pointer + self.batch_size]
48
+ self.pointer += self.batch_size
49
+
50
+ assert batch.shape[0] == self.batch_size
51
+ yield batch
52
+
53
+
SAE/sae.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from
3
+ https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/model.py
4
+ '''
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import os
9
+ import json
10
+ import spaces
11
+ import logging
12
+
13
+ class SparseAutoencoder(nn.Module):
14
+ """
15
+ Top-K Autoencoder with sparse kernels. Implements:
16
+
17
+ latents = relu(topk(encoder(x - pre_bias) + latent_bias))
18
+ recons = decoder(latents) + pre_bias
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ n_dirs_local: int,
24
+ d_model: int,
25
+ k: int,
26
+ auxk: int | None,
27
+ dead_steps_threshold: int,
28
+ ):
29
+ super().__init__()
30
+ self.n_dirs_local = n_dirs_local
31
+ self.d_model = d_model
32
+ self.k = k
33
+ self.auxk = auxk
34
+ self.dead_steps_threshold = dead_steps_threshold
35
+
36
+ self.encoder = nn.Linear(d_model, n_dirs_local, bias=False)
37
+ self.decoder = nn.Linear(n_dirs_local, d_model, bias=False)
38
+
39
+ self.pre_bias = nn.Parameter(torch.zeros(d_model))
40
+ self.latent_bias = nn.Parameter(torch.zeros(n_dirs_local))
41
+
42
+ self.stats_last_nonzero: torch.Tensor
43
+ self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
44
+
45
+ ## initialization
46
+
47
+ # "tied" init
48
+ self.decoder.weight.data = self.encoder.weight.data.T.clone()
49
+
50
+ # store decoder in column major layout for kernel
51
+ self.decoder.weight.data = self.decoder.weight.data.T.contiguous().T
52
+
53
+ unit_norm_decoder_(self)
54
+
55
+ def auxk_mask_fn(self, x):
56
+ dead_mask = self.stats_last_nonzero > dead_steps_threshold
57
+ x.data *= dead_mask # inplace to save memory
58
+ return x
59
+
60
+ def save_to_disk(self, path: str):
61
+ PATH_TO_CFG = 'config.json'
62
+ PATH_TO_WEIGHTS = 'state_dict.pth'
63
+
64
+ cfg = {
65
+ "n_dirs_local": self.n_dirs_local,
66
+ "d_model": self.d_model,
67
+ "k": self.k,
68
+ "auxk": self.auxk,
69
+ "dead_steps_threshold": self.dead_steps_threshold,
70
+ }
71
+
72
+ os.makedirs(path, exist_ok=True)
73
+
74
+ with open(os.path.join(path, PATH_TO_CFG), 'w') as f:
75
+ json.dump(cfg, f)
76
+
77
+
78
+ torch.save({
79
+ "state_dict": self.state_dict(),
80
+ }, os.path.join(path, PATH_TO_WEIGHTS))
81
+
82
+
83
+ @classmethod
84
+ def load_from_disk(cls, path: str):
85
+ PATH_TO_CFG = 'config.json'
86
+ PATH_TO_WEIGHTS = 'state_dict.pth'
87
+
88
+ with open(os.path.join(path, PATH_TO_CFG), 'r') as f:
89
+ cfg = json.load(f)
90
+
91
+ ae = cls(
92
+ n_dirs_local=cfg["n_dirs_local"],
93
+ d_model=cfg["d_model"],
94
+ k=cfg["k"],
95
+ auxk=cfg["auxk"],
96
+ dead_steps_threshold=cfg["dead_steps_threshold"],
97
+ )
98
+
99
+ state_dict = torch.load(os.path.join(path, PATH_TO_WEIGHTS))["state_dict"]
100
+ ae.load_state_dict(state_dict)
101
+
102
+ return ae
103
+
104
+ @property
105
+ def n_dirs(self):
106
+ return self.n_dirs_local
107
+
108
+ def encode(self, x):
109
+ x = x.to('cuda') - self.pre_bias
110
+ latents_pre_act = self.encoder(x) + self.latent_bias
111
+
112
+ vals, inds = torch.topk(
113
+ latents_pre_act,
114
+ k=self.k,
115
+ dim=-1
116
+ )
117
+
118
+ latents = torch.zeros_like(latents_pre_act)
119
+ latents.scatter_(-1, inds, torch.relu(vals))
120
+
121
+ return latents
122
+
123
+ def forward(self, x):
124
+ x = x - self.pre_bias
125
+ latents_pre_act = self.encoder(x) + self.latent_bias
126
+ vals, inds = torch.topk(
127
+ latents_pre_act,
128
+ k=self.k,
129
+ dim=-1
130
+ )
131
+
132
+ ## set num nonzero stat ##
133
+ tmp = torch.zeros_like(self.stats_last_nonzero)
134
+ tmp.scatter_add_(
135
+ 0,
136
+ inds.reshape(-1),
137
+ (vals > 1e-3).to(tmp.dtype).reshape(-1),
138
+ )
139
+ self.stats_last_nonzero *= 1 - tmp.clamp(max=1)
140
+ self.stats_last_nonzero += 1
141
+ ## end stats ##
142
+
143
+ ## auxk
144
+ if self.auxk is not None: # for auxk
145
+ # IMPORTANT: has to go after stats update!
146
+ # WARN: auxk_mask_fn can mutate latents_pre_act!
147
+ auxk_vals, auxk_inds = torch.topk(
148
+ self.auxk_mask_fn(latents_pre_act),
149
+ k=self.auxk,
150
+ dim=-1
151
+ )
152
+ else:
153
+ auxk_inds = None
154
+ auxk_vals = None
155
+
156
+ ## end auxk
157
+
158
+ vals = torch.relu(vals)
159
+ if auxk_vals is not None:
160
+ auxk_vals = torch.relu(auxk_vals)
161
+
162
+
163
+ rows, cols = latents_pre_act.size()
164
+ row_indices = torch.arange(rows).unsqueeze(1).expand(-1, self.k).reshape(-1)
165
+ vals = vals.reshape(-1)
166
+ inds = inds.reshape(-1)
167
+
168
+ indices = torch.stack([row_indices.to(inds.device), inds])
169
+
170
+ sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
171
+
172
+ recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
173
+
174
+
175
+ return recons, {
176
+ "inds": inds,
177
+ "vals": vals,
178
+ "auxk_inds": auxk_inds,
179
+ "auxk_vals": auxk_vals,
180
+ }
181
+
182
+ def decode_sparse(self, inds, vals):
183
+ rows, cols = inds.shape[0], self.n_dirs
184
+
185
+ row_indices = torch.arange(rows).unsqueeze(1).expand(-1, inds.shape[1]).reshape(-1)
186
+ vals = vals.reshape(-1)
187
+ inds = inds.reshape(-1)
188
+
189
+ indices = torch.stack([row_indices.to(inds.device), inds])
190
+
191
+ sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
192
+
193
+ recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
194
+ return recons
195
+
196
+ @property
197
+ def device(self):
198
+ return next(self.parameters()).device
199
+
200
+
201
+ def unit_norm_decoder_(autoencoder: SparseAutoencoder) -> None:
202
+ """
203
+ Unit normalize the decoder weights of an autoencoder.
204
+ """
205
+ autoencoder.decoder.weight.data /= autoencoder.decoder.weight.data.norm(dim=0)
206
+
207
+
208
+ def unit_norm_decoder_grad_adjustment_(autoencoder) -> None:
209
+ """project out gradient information parallel to the dictionary vectors - assumes that the decoder is already unit normed"""
210
+
211
+ assert autoencoder.decoder.weight.grad is not None
212
+
213
+ autoencoder.decoder.weight.grad +=\
214
+ torch.einsum("bn,bn->n", autoencoder.decoder.weight.data, autoencoder.decoder.weight.grad) *\
215
+ autoencoder.decoder.weight.data * -1
SAE/sae_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass, field
3
+
4
+ @dataclass
5
+ class SAETrainingConfig:
6
+ d_model: int
7
+ n_dirs: int
8
+ k: int
9
+ block_name: str
10
+ bs: int
11
+ save_path_base: str
12
+ auxk: int = 256
13
+ lr: float = 1e-4
14
+ eps: float = 6.25e-10
15
+ dead_toks_threshold: int = 10_000_000
16
+ auxk_coef: float = 1/32
17
+
18
+ @property
19
+ def sae_name(self):
20
+ return f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}'
21
+
22
+ @property
23
+ def save_path(self):
24
+ return os.path.join(save_path_base, f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}')
25
+
26
+ @dataclass
27
+ class Config:
28
+ saes: list[SAETrainingConfig]
29
+ paths_to_latents: list[str]
30
+ log_interval: int
31
+ save_interval: int
32
+ bs: int
33
+ block_name: str
34
+ wandb_project: str = 'sdxl_sae_train'
35
+ wandb_name: str = 'multiple_sae'
36
+
37
+ def __init__(self, cfg_json):
38
+ self.saes = [SAETrainingConfig(**sae_cfg, block_name=cfg_json['block_name'], bs=cfg_json['bs'], save_path_base=cfg_json['save_path_base'])
39
+ for sae_cfg in cfg_json['sae_configs']]
40
+
41
+ self.save_path_base = cfg_json['save_path_base']
42
+ self.paths_to_latents = cfg_json['paths_to_latents']
43
+ self.log_interval = cfg_json['log_interval']
44
+ self.save_interval = cfg_json['save_interval']
45
+ self.bs = cfg_json['bs']
46
+ self.block_name = cfg_json['block_name']
SDLens/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .hooked_sd_pipeline import HookedIFPipeline, HookedStableDiffusionXLPipeline
SDLens/hooked_scheduler.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DDPMScheduler
2
+ import torch
3
+
4
+ class HookedNoiseScheduler:
5
+ scheduler: DDPMScheduler
6
+ pre_hooks: list
7
+ post_hooks: list
8
+
9
+ def __init__(self, scheduler):
10
+ object.__setattr__(self, 'scheduler', scheduler)
11
+ object.__setattr__(self, 'pre_hooks', [])
12
+ object.__setattr__(self, 'post_hooks', [])
13
+
14
+ def step(
15
+ self,
16
+ model_output, timestep, sample, generator, return_dict
17
+ ):
18
+ assert return_dict == False, "return_dict == True is not implemented"
19
+ for hook in self.pre_hooks:
20
+ hook_output = hook(model_output, timestep, sample, generator)
21
+ if hook_output is not None:
22
+ model_output, timestep, sample, generator = hook_output
23
+
24
+ (pred_prev_sample, ) = self.scheduler.step(model_output, timestep, sample, generator, return_dict)
25
+
26
+ for hook in self.post_hooks:
27
+ hook_output = hook(pred_prev_sample)
28
+ if hook_output is not None:
29
+ pred_prev_sample = hook_output
30
+
31
+ return (pred_prev_sample, )
32
+
33
+ def __getattr__(self, name):
34
+ return getattr(self.scheduler, name)
35
+
36
+ def __setattr__(self, name, value):
37
+ if name in {'scheduler', 'pre_hooks', 'post_hooks'}:
38
+ object.__setattr__(self, name, value)
39
+ else:
40
+ setattr(self.scheduler, name, value)
SDLens/hooked_sd_pipeline.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ from diffusers import StableDiffusionXLPipeline, IFPipeline
3
+ from typing import List, Dict, Callable, Union
4
+ import torch
5
+ from .hooked_scheduler import HookedNoiseScheduler
6
+ import spaces
7
+
8
+ def retrieve(io):
9
+ if isinstance(io, tuple):
10
+ if len(io) == 1:
11
+ return io[0]
12
+ else:
13
+ raise ValueError("A tuple should have length of 1")
14
+ elif isinstance(io, torch.Tensor):
15
+ return io
16
+ else:
17
+ raise ValueError("Input/Output must be a tensor, or 1-element tuple")
18
+
19
+
20
+ class HookedDiffusionAbstractPipeline:
21
+ parent_cls = None
22
+ pipe = None
23
+
24
+ def __init__(self, pipe: parent_cls, use_hooked_scheduler: bool = False):
25
+ if use_hooked_scheduler:
26
+ pipe.scheduler = HookedNoiseScheduler(pipe.scheduler)
27
+ self.__dict__['pipe'] = pipe
28
+ self.use_hooked_scheduler = use_hooked_scheduler
29
+
30
+ @classmethod
31
+ def from_pretrained(cls, *args, **kwargs):
32
+ return cls(cls.parent_cls.from_pretrained(*args, **kwargs))
33
+
34
+ def run_with_hooks(self,
35
+ *args,
36
+ position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
37
+ **kwargs
38
+ ):
39
+ '''
40
+ Run the pipeline with hooks at specified positions.
41
+ Returns the final output.
42
+
43
+ Args:
44
+ *args: Arguments to pass to the pipeline.
45
+ position_hook_dict: A dictionary mapping positions to hooks.
46
+ The keys are positions in the pipeline where the hooks should be registered.
47
+ The values are either a single hook or a list of hooks to be registered at the specified position.
48
+ Each hook should be a callable that takes three arguments: (module, input, output).
49
+ **kwargs: Keyword arguments to pass to the pipeline.
50
+ '''
51
+ hooks = []
52
+ for position, hook in position_hook_dict.items():
53
+ if isinstance(hook, list):
54
+ for h in hook:
55
+ hooks.append(self._register_general_hook(position, h))
56
+ else:
57
+ hooks.append(self._register_general_hook(position, hook))
58
+
59
+ hooks = [hook for hook in hooks if hook is not None]
60
+
61
+ try:
62
+ output = self.pipe(*args, **kwargs)
63
+ finally:
64
+ for hook in hooks:
65
+ hook.remove()
66
+ if self.use_hooked_scheduler:
67
+ self.pipe.scheduler.pre_hooks = []
68
+ self.pipe.scheduler.post_hooks = []
69
+
70
+ return output
71
+
72
+
73
+ def run_with_cache(self,
74
+ *args,
75
+ positions_to_cache: List[str],
76
+ save_input: bool = False,
77
+ save_output: bool = True,
78
+ **kwargs
79
+ ):
80
+ '''
81
+ Run the pipeline with caching at specified positions.
82
+
83
+ This method allows you to cache the intermediate inputs and/or outputs of the pipeline
84
+ at certain positions. The final output of the pipeline and a dictionary of cached values
85
+ are returned.
86
+
87
+ Args:
88
+ *args: Arguments to pass to the pipeline.
89
+ positions_to_cache (List[str]): A list of positions in the pipeline where intermediate
90
+ inputs/outputs should be cached.
91
+ save_input (bool, optional): If True, caches the input at each specified position.
92
+ Defaults to False.
93
+ save_output (bool, optional): If True, caches the output at each specified position.
94
+ Defaults to True.
95
+ **kwargs: Keyword arguments to pass to the pipeline.
96
+
97
+ Returns:
98
+ final_output: The final output of the pipeline after execution.
99
+ cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions
100
+ and values are dictionaries containing the cached 'input' and/or 'output' at each position,
101
+ depending on the flags `save_input` and `save_output`.
102
+ '''
103
+ cache_input, cache_output = dict() if save_input else None, dict() if save_output else None
104
+ hooks = [
105
+ self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache
106
+ ]
107
+ hooks = [hook for hook in hooks if hook is not None]
108
+ output = self.pipe(*args, **kwargs)
109
+ for hook in hooks:
110
+ hook.remove()
111
+ if self.use_hooked_scheduler:
112
+ self.pipe.scheduler.pre_hooks = []
113
+ self.pipe.scheduler.post_hooks = []
114
+
115
+ cache_dict = {}
116
+ if save_input:
117
+ for position, block in cache_input.items():
118
+ cache_input[position] = torch.stack(block, dim=1)
119
+ cache_dict['input'] = cache_input
120
+
121
+ if save_output:
122
+ for position, block in cache_output.items():
123
+ cache_output[position] = torch.stack(block, dim=1)
124
+ cache_dict['output'] = cache_output
125
+ return output, cache_dict
126
+
127
+
128
+ def run_with_hooks_and_cache(self,
129
+ *args,
130
+ position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
131
+ positions_to_cache: List[str] = [],
132
+ save_input: bool = False,
133
+ save_output: bool = True,
134
+ **kwargs
135
+ ):
136
+ '''
137
+ Run the pipeline with hooks and caching at specified positions.
138
+
139
+ This method allows you to register hooks at certain positions in the pipeline and
140
+ cache intermediate inputs and/or outputs at specified positions. Hooks can be used
141
+ for inspecting or modifying the pipeline's execution, and caching stores intermediate
142
+ values for later inspection or use.
143
+
144
+ Args:
145
+ *args: Arguments to pass to the pipeline.
146
+ position_hook_dict Dict[str, Union[Callable, List[Callable]]]:
147
+ A dictionary where the keys are the positions in the pipeline, and the values
148
+ are hooks (either a single hook or a list of hooks) to be registered at those positions.
149
+ Each hook should be a callable that accepts three arguments: (module, input, output).
150
+ positions_to_cache (List[str], optional): A list of positions in the pipeline where
151
+ intermediate inputs/outputs should be cached. Defaults to an empty list.
152
+ save_input (bool, optional): If True, caches the input at each specified position.
153
+ Defaults to False.
154
+ save_output (bool, optional): If True, caches the output at each specified position.
155
+ Defaults to True.
156
+ **kwargs: Additional keyword arguments to pass to the pipeline.
157
+
158
+ Returns:
159
+ final_output: The final output of the pipeline after execution.
160
+ cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions
161
+ and values are dictionaries containing the cached 'input' and/or 'output' at each position,
162
+ depending on the flags `save_input` and `save_output`.
163
+ '''
164
+ cache_input, cache_output = dict() if save_input else None, dict() if save_output else None
165
+ hooks = [
166
+ self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache
167
+ ]
168
+
169
+ for position, hook in position_hook_dict.items():
170
+ if isinstance(hook, list):
171
+ for h in hook:
172
+ hooks.append(self._register_general_hook(position, h))
173
+ else:
174
+ hooks.append(self._register_general_hook(position, hook))
175
+
176
+ hooks = [hook for hook in hooks if hook is not None]
177
+ output = self.pipe(*args, **kwargs)
178
+ for hook in hooks:
179
+ hook.remove()
180
+ if self.use_hooked_scheduler:
181
+ self.pipe.scheduler.pre_hooks = []
182
+ self.pipe.scheduler.post_hooks = []
183
+
184
+ cache_dict = {}
185
+ if save_input:
186
+ for position, block in cache_input.items():
187
+ cache_input[position] = torch.stack(block, dim=1)
188
+ cache_dict['input'] = cache_input
189
+
190
+ if save_output:
191
+ for position, block in cache_output.items():
192
+ cache_output[position] = torch.stack(block, dim=1)
193
+ cache_dict['output'] = cache_output
194
+
195
+ return output, cache_dict
196
+
197
+
198
+ def _locate_block(self, position: str):
199
+ '''
200
+ Locate the block at the specified position in the pipeline.
201
+ '''
202
+ block = self.pipe
203
+ for step in position.split('.'):
204
+ if step.isdigit():
205
+ step = int(step)
206
+ block = block[step]
207
+ else:
208
+ block = getattr(block, step)
209
+ return block
210
+
211
+
212
+ def _register_cache_hook(self, position: str, cache_input: Dict, cache_output: Dict):
213
+
214
+ if position.endswith('$self_attention') or position.endswith('$cross_attention'):
215
+ return self._register_cache_attention_hook(position, cache_output)
216
+
217
+ if position == 'noise':
218
+ def hook(model_output, timestep, sample, generator):
219
+ if position not in cache_output:
220
+ cache_output[position] = []
221
+ cache_output[position].append(sample)
222
+
223
+ if self.use_hooked_scheduler:
224
+ self.pipe.scheduler.post_hooks.append(hook)
225
+ else:
226
+ raise ValueError('Cannot cache noise without using hooked scheduler')
227
+ return
228
+
229
+ block = self._locate_block(position)
230
+
231
+ def hook(module, input, kwargs, output):
232
+ if cache_input is not None:
233
+ if position not in cache_input:
234
+ cache_input[position] = []
235
+ cache_input[position].append(retrieve(input))
236
+
237
+ if cache_output is not None:
238
+ if position not in cache_output:
239
+ cache_output[position] = []
240
+ cache_output[position].append(retrieve(output))
241
+
242
+ return block.register_forward_hook(hook, with_kwargs=True)
243
+
244
+ def _register_cache_attention_hook(self, position, cache):
245
+ attn_block = self._locate_block(position.split('$')[0])
246
+ if position.endswith('$self_attention'):
247
+ attn_block = attn_block.attn1
248
+ elif position.endswith('$cross_attention'):
249
+ attn_block = attn_block.attn2
250
+ else:
251
+ raise ValueError('Wrong attention type')
252
+
253
+ def hook(module, args, kwargs, output):
254
+ hidden_states = args[0]
255
+ encoder_hidden_states = kwargs['encoder_hidden_states']
256
+ attention_mask = kwargs['attention_mask']
257
+ batch_size, sequence_length, _ = hidden_states.shape
258
+ attention_mask = attn_block.prepare_attention_mask(attention_mask, sequence_length, batch_size)
259
+ query = attn_block.to_q(hidden_states)
260
+
261
+
262
+ if encoder_hidden_states is None:
263
+ encoder_hidden_states = hidden_states
264
+ elif attn_block.norm_cross is not None:
265
+ encoder_hidden_states = attn_block.norm_cross(encoder_hidden_states)
266
+
267
+ key = attn_block.to_k(encoder_hidden_states)
268
+ value = attn_block.to_v(encoder_hidden_states)
269
+
270
+ query = attn_block.head_to_batch_dim(query)
271
+ key = attn_block.head_to_batch_dim(key)
272
+ value = attn_block.head_to_batch_dim(value)
273
+
274
+ attention_probs = attn_block.get_attention_scores(query, key, attention_mask)
275
+ attention_probs = attention_probs.view(
276
+ batch_size,
277
+ attention_probs.shape[0] // batch_size,
278
+ attention_probs.shape[1],
279
+ attention_probs.shape[2]
280
+ )
281
+ if position not in cache:
282
+ cache[position] = []
283
+ cache[position].append(attention_probs)
284
+
285
+ return attn_block.register_forward_hook(hook, with_kwargs=True)
286
+
287
+ def _register_general_hook(self, position, hook):
288
+ if position == 'scheduler_pre':
289
+ if not self.use_hooked_scheduler:
290
+ raise ValueError('Cannot register hooks on scheduler without using hooked scheduler')
291
+ self.pipe.scheduler.pre_hooks.append(hook)
292
+ return
293
+ elif position == 'scheduler_post':
294
+ if not self.use_hooked_scheduler:
295
+ raise ValueError('Cannot register hooks on scheduler without using hooked scheduler')
296
+ self.pipe.scheduler.post_hooks.append(hook)
297
+ return
298
+
299
+ block = self._locate_block(position)
300
+ return block.register_forward_hook(hook)
301
+
302
+ def to(self, *args, **kwargs):
303
+ self.pipe = self.pipe.to(*args, **kwargs)
304
+ return self
305
+
306
+ def __getattr__(self, name):
307
+ return getattr(self.pipe, name)
308
+
309
+ def __setattr__(self, name, value):
310
+ return setattr(self.pipe, name, value)
311
+
312
+ def __call__(self, *args, **kwargs):
313
+ return self.pipe(*args, **kwargs)
314
+
315
+
316
+ class HookedStableDiffusionXLPipeline(HookedDiffusionAbstractPipeline):
317
+ parent_cls = StableDiffusionXLPipeline
318
+
319
+
320
+ class HookedIFPipeline(HookedDiffusionAbstractPipeline):
321
+ parent_cls = IFPipeline
app.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ from PIL import Image
5
+ from SDLens import HookedStableDiffusionXLPipeline
6
+ from SAE import SparseAutoencoder
7
+ from utils import add_feature_on_area
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.colors import ListedColormap
11
+ from utils import add_feature_on_area, replace_with_feature
12
+ import threading
13
+ import spaces
14
+
15
+ code_to_block = {
16
+ "down.2.1": "unet.down_blocks.2.attentions.1",
17
+ "mid.0": "unet.mid_block.attentions.0",
18
+ "up.0.1": "unet.up_blocks.0.attentions.1",
19
+ "up.0.0": "unet.up_blocks.0.attentions.0"
20
+ }
21
+ lock = threading.Lock()
22
+
23
+
24
+ def process_cache(cache, saes_dict):
25
+
26
+ top_features_dict = {}
27
+ sparse_maps_dict = {}
28
+
29
+ for code in code_to_block.keys():
30
+ block = code_to_block[code]
31
+ sae = saes_dict[code]
32
+
33
+ diff = cache["output"][block] - cache["input"][block]
34
+ diff = diff.permute(0, 1, 3, 4, 2).squeeze(0).squeeze(0)
35
+ with torch.no_grad():
36
+ sparse_maps = sae.encode(diff)
37
+ averages = torch.mean(sparse_maps, dim=(0, 1))
38
+
39
+ top_features = torch.topk(averages, 10).indices
40
+
41
+ top_features_dict[code] = top_features.cpu().tolist()
42
+ sparse_maps_dict[code] = sparse_maps.cpu().numpy()
43
+
44
+ return top_features_dict, sparse_maps_dict
45
+
46
+
47
+ def plot_image_heatmap(cache, block_select, radio):
48
+ code = block_select.split()[0]
49
+ feature = int(radio)
50
+ block = code_to_block[code]
51
+
52
+ heatmap = cache["heatmaps"][code][:, :, feature]
53
+ heatmap = np.kron(heatmap, np.ones((32, 32)))
54
+ image = cache["image"].convert("RGBA")
55
+
56
+ jet = plt.cm.jet
57
+ cmap = jet(np.arange(jet.N))
58
+ cmap[:1, -1] = 0
59
+ cmap[1:, -1] = 0.6
60
+ cmap = ListedColormap(cmap)
61
+ heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
62
+ heatmap_rgba = cmap(heatmap)
63
+ heatmap_image = Image.fromarray((heatmap_rgba * 255).astype(np.uint8))
64
+ heatmap_with_transparency = Image.alpha_composite(image, heatmap_image)
65
+
66
+ return heatmap_with_transparency
67
+
68
+
69
+ def create_prompt_part(pipe, saes_dict, demo):
70
+ @spaces.GPU
71
+ def image_gen(prompt):
72
+ lock.acquire()
73
+ try:
74
+ images, cache = pipe.run_with_cache(
75
+ prompt,
76
+ positions_to_cache=list(code_to_block.values()),
77
+ num_inference_steps=1,
78
+ generator=torch.Generator(device="cpu").manual_seed(42),
79
+ guidance_scale=0.0,
80
+ save_input=True,
81
+ save_output=True
82
+ )
83
+ finally:
84
+ lock.release()
85
+
86
+ top_features_dict, top_sparse_maps_dict = process_cache(cache, saes_dict)
87
+ return images.images[0], {
88
+ "image": images.images[0],
89
+ "heatmaps": top_sparse_maps_dict,
90
+ "features": top_features_dict
91
+ }
92
+
93
+ def update_radio(cache, block_select):
94
+ code = block_select.split()[0]
95
+ return gr.update(choices=cache["features"][code])
96
+
97
+ def update_img(cache, block_select, radio):
98
+ new_img = plot_image_heatmap(cache, block_select, radio)
99
+ return new_img
100
+
101
+ with gr.Tab("Explore", elem_classes="tabs") as explore_tab:
102
+ cache = gr.State(value={
103
+ "image": None,
104
+ "heatmaps": None,
105
+ "features": []
106
+ })
107
+ with gr.Row():
108
+ with gr.Column(scale=7):
109
+ with gr.Row(equal_height=True):
110
+ prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A cinematic shot of a professor sloth wearing a tuxedo at a BBQ party and eathing a dish with peas.")
111
+ button = gr.Button("Generate", elem_classes="generate_button1")
112
+
113
+ with gr.Row():
114
+ image = gr.Image(width=512, height=512, image_mode="RGB", label="Generated image")
115
+
116
+ with gr.Column(scale=4):
117
+ block_select = gr.Dropdown(
118
+ choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
119
+ value="down.2.1 (composition)",
120
+ label="Select block",
121
+ elem_id="block_select",
122
+ interactive=True
123
+ )
124
+ radio = gr.Radio(choices=[], label="Select a feature", interactive=True)
125
+
126
+ button.click(image_gen, [prompt_field], outputs=[image, cache])
127
+ cache.change(update_radio, [cache, block_select], outputs=[radio])
128
+ block_select.select(update_radio, [cache, block_select], outputs=[radio])
129
+ radio.select(update_img, [cache, block_select, radio], outputs=[image])
130
+ demo.load(image_gen, [prompt_field], outputs=[image, cache])
131
+
132
+ return explore_tab
133
+
134
+ def downsample_mask(image, factor):
135
+ downsampled = image.reshape(
136
+ (image.shape[0] // factor, factor,
137
+ image.shape[1] // factor, factor)
138
+ )
139
+ downsampled = downsampled.mean(axis=(1, 3))
140
+ return downsampled
141
+
142
+ def create_intervene_part(pipe: HookedStableDiffusionXLPipeline, saes_dict, means_dict, demo):
143
+ @spaces.GPU
144
+ def image_gen(prompt, num_steps):
145
+ lock.acquire()
146
+ try:
147
+ images = pipe.run_with_hooks(
148
+ prompt,
149
+ position_hook_dict={},
150
+ num_inference_steps=num_steps,
151
+ generator=torch.Generator(device="cpu").manual_seed(42),
152
+ guidance_scale=0.0
153
+ )
154
+ finally:
155
+ lock.release()
156
+ return images.images[0]
157
+
158
+ @spaces.GPU
159
+ def image_mod(prompt, block_str, brush_index, strength, num_steps, input_image):
160
+ block = block_str.split(" ")[0]
161
+
162
+ mask = (input_image["layers"][0] > 0)[:, :, -1].astype(float)
163
+ mask = downsample_mask(mask, 32)
164
+ mask = torch.tensor(mask, dtype=torch.float32, device="cuda")
165
+
166
+ if mask.sum() == 0:
167
+ gr.Info("No mask selected, please draw on the input image")
168
+
169
+ def hook(module, input, output):
170
+ return add_feature_on_area(
171
+ saes_dict[block],
172
+ brush_index,
173
+ mask * means_dict[block][brush_index] * strength,
174
+ module,
175
+ input,
176
+ output
177
+ )
178
+
179
+ lock.acquire()
180
+ try:
181
+ image = pipe.run_with_hooks(
182
+ prompt,
183
+ position_hook_dict={code_to_block[block]: hook},
184
+ num_inference_steps=num_steps,
185
+ generator=torch.Generator(device="cpu").manual_seed(42),
186
+ guidance_scale=0.0
187
+ ).images[0]
188
+ finally:
189
+ lock.release()
190
+ return image
191
+
192
+ @spaces.GPU
193
+ def feature_icon(block_str, brush_index):
194
+ block = block_str.split(" ")[0]
195
+ if block in ["mid.0", "up.0.0"]:
196
+ gr.Info("Note that Feature Icon works best with down.2.1 and up.0.1 blocks but feel free to explore", duration=3)
197
+
198
+ def hook(module, input, output):
199
+ return replace_with_feature(
200
+ saes_dict[block],
201
+ brush_index,
202
+ means_dict[block][brush_index] * saes_dict[block].k,
203
+ module,
204
+ input,
205
+ output
206
+ )
207
+
208
+ lock.acquire()
209
+ try:
210
+ image = pipe.run_with_hooks(
211
+ "",
212
+ position_hook_dict={code_to_block[block]: hook},
213
+ num_inference_steps=1,
214
+ generator=torch.Generator(device="cpu").manual_seed(42),
215
+ guidance_scale=0.0
216
+ ).images[0]
217
+ finally:
218
+ lock.release()
219
+ return image
220
+
221
+ with gr.Tab("Paint!", elem_classes="tabs") as intervene_tab:
222
+ image_state = gr.State(value=None)
223
+ with gr.Row():
224
+ with gr.Column(scale=3):
225
+ # Generation column
226
+ with gr.Row():
227
+ # prompt and num_steps
228
+ prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A dog plays with a ball, cartoon", elem_id="prompt_input")
229
+ num_steps = gr.Number(value=1, label="Number of steps", minimum=1, maximum=4, elem_id="num_steps", precision=0)
230
+ with gr.Row():
231
+ #Generate button
232
+ button_generate = gr.Button("Generate", elem_id="generate_button")
233
+ with gr.Column(scale=3):
234
+ # Intervention column
235
+ with gr.Row():
236
+ # dropdowns and number inputs
237
+ with gr.Column(scale=7):
238
+ with gr.Row():
239
+ block_select = gr.Dropdown(
240
+ choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
241
+ value="down.2.1 (composition)",
242
+ label="Select block",
243
+ elem_id="block_select"
244
+ )
245
+ brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=5119, elem_id="brush_index", precision=0)
246
+ with gr.Row():
247
+ button_icon = gr.Button('Feature Icon', elem_id="feature_icon_button")
248
+ with gr.Column(scale=3):
249
+ with gr.Row():
250
+ strength = gr.Number(value=10, label="Strength", minimum=-40, maximum=40, elem_id="strength", precision=2)
251
+ with gr.Row():
252
+ button = gr.Button('Apply', elem_id="apply_button")
253
+
254
+ with gr.Row():
255
+ with gr.Column():
256
+ # Input image
257
+ i_image = gr.Sketchpad(
258
+ height=610,
259
+ layers=False, transforms=[], placeholder="Generate and paint!",
260
+ brush=gr.Brush(default_size=64, color_mode="fixed", colors=['black']),
261
+ container=False,
262
+ canvas_size=(512, 512),
263
+ label="Input Image")
264
+ clear_button = gr.Button("Clear")
265
+ clear_button.click(lambda x: x, [image_state], [i_image])
266
+ # Output image
267
+ o_image = gr.Image(width=512, height=512, label="Output Image")
268
+
269
+ # Set up the click events
270
+ button_generate.click(image_gen, inputs=[prompt_field, num_steps], outputs=[image_state])
271
+ image_state.change(lambda x: x, [image_state], [i_image])
272
+ button.click(image_mod,
273
+ inputs=[prompt_field, block_select, brush_index, strength, num_steps, i_image],
274
+ outputs=o_image)
275
+ button_icon.click(feature_icon, inputs=[block_select, brush_index], outputs=o_image)
276
+ demo.load(image_gen, [prompt_field, num_steps], outputs=[image_state])
277
+
278
+
279
+ return intervene_tab
280
+
281
+
282
+ def create_top_images_part(demo):
283
+ def update_top_images(block_select, brush_index):
284
+ block = block_select.split(" ")[0]
285
+ url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{brush_index}.jpg"
286
+ return url
287
+
288
+ with gr.Tab("Top Images", elem_classes="tabs") as top_images_tab:
289
+ with gr.Row():
290
+ block_select = gr.Dropdown(
291
+ choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
292
+ value="down.2.1 (composition)",
293
+ label="Select block"
294
+ )
295
+ brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=5119, precision=0)
296
+ with gr.Row():
297
+ image = gr.Image(width=600, height=600, label="Top Images")
298
+
299
+ block_select.select(update_top_images, [block_select, brush_index], outputs=[image])
300
+ brush_index.change(update_top_images, [block_select, brush_index], outputs=[image])
301
+ demo.load(update_top_images, [block_select, brush_index], outputs=[image])
302
+ return top_images_tab
303
+
304
+
305
+ def create_intro_part():
306
+ with gr.Tab("Instructions", elem_classes="tabs") as intro_tab:
307
+ gr.Markdown(
308
+ '''# Unpacking SDXL Turbo with Sparse Autoencoders
309
+ ## Demo Overview
310
+ This demo showcases the use of Sparse Autoencoders (SAEs) to understand the features learned by the Stable Diffusion XL Turbo model.
311
+
312
+ ## How to Use
313
+ ### Explore
314
+ * Enter a prompt in the text box and click on the "Generate" button to generate an image.
315
+ * You can observe the active features in different blocks plot on top of the generated image.
316
+ ### Top Images
317
+ * For each feature, you can view the top images that activate the feature the most.
318
+ ### Paint!
319
+ * Generate an image using the prompt.
320
+ * Paint on the generated image to apply interventions.
321
+ * Use the "Feature Icon" button to understand how the selected brush functions.
322
+
323
+ ### Remarks
324
+ * Not all brushes mix well with all images. Experiment with different brushes and strengths.
325
+ * Feature Icon works best with `down.2.1 (composition)` and `up.0.1 (style)` blocks.
326
+ * This demo is provided for research purposes only. We do not take responsibility for the content generated by the demo.
327
+
328
+ ### Interesting features to try
329
+ To get started, try the following features:
330
+ - down.2.1 (composition): 2301 (evil) 3747 (image frame) 4998 (cartoon)
331
+ - up.0.1 (style): 4977 (tiger stripes) 90 (fur) 2615 (twilight blur)
332
+ '''
333
+ )
334
+
335
+ return intro_tab
336
+
337
+
338
+ def create_demo(pipe, saes_dict, means_dict):
339
+ custom_css = """
340
+ .tabs button {
341
+ font-size: 20px !important; /* Adjust font size for tab text */
342
+ padding: 10px !important; /* Adjust padding to make the tabs bigger */
343
+ font-weight: bold !important; /* Adjust font weight to make the text bold */
344
+ }
345
+ .generate_button1 {
346
+ max-width: 160px !important;
347
+ margin-top: 20px !important;
348
+ margin-bottom: 20px !important;
349
+ }
350
+ """
351
+
352
+ with gr.Blocks(css=custom_css) as demo:
353
+ with create_intro_part():
354
+ pass
355
+ with create_prompt_part(pipe, saes_dict, demo):
356
+ pass
357
+ with create_top_images_part(demo):
358
+ pass
359
+ with create_intervene_part(pipe, saes_dict, means_dict, demo):
360
+ pass
361
+
362
+ return demo
363
+
364
+
365
+ if __name__ == "__main__":
366
+ import os
367
+ import gradio as gr
368
+ import torch
369
+ from SDLens import HookedStableDiffusionXLPipeline
370
+ from SAE import SparseAutoencoder
371
+
372
+ dtype=torch.float32
373
+ pipe = HookedStableDiffusionXLPipeline.from_pretrained(
374
+ 'stabilityai/sdxl-turbo',
375
+ torch_dtype=dtype,
376
+ variant=("fp16" if dtype==torch.float16 else None)
377
+ )
378
+ pipe.set_progress_bar_config(disable=True)
379
+ pipe.to('cuda')
380
+
381
+ path_to_checkpoints = './checkpoints/'
382
+
383
+ code_to_block = {
384
+ "down.2.1": "unet.down_blocks.2.attentions.1",
385
+ "mid.0": "unet.mid_block.attentions.0",
386
+ "up.0.1": "unet.up_blocks.0.attentions.1",
387
+ "up.0.0": "unet.up_blocks.0.attentions.0"
388
+ }
389
+
390
+ saes_dict = {}
391
+ means_dict = {}
392
+
393
+ for code, block in code_to_block.items():
394
+ sae = SparseAutoencoder.load_from_disk(
395
+ os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final"),
396
+ )
397
+ means = torch.load(
398
+ os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final", "mean.pt"),
399
+ weights_only=True
400
+ )
401
+ saes_dict[code] = sae.to('cuda', dtype=dtype)
402
+ means_dict[code] = means.to('cuda', dtype=dtype)
403
+
404
+ demo = create_demo(pipe, saes_dict, means_dict)
405
+ demo.launch()
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9197875b722b020a2fafb7f5b8a96cf958610231df48f1db0c24134b02550ef
3
+ size 130
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d0d19fbcaa4ce8ca8d80b5572c58e30118395c1dbf60dc3cb4b70a75ece91cb
3
+ size 133
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7408c703cd359453a778da0995c7a4b956fc7ecd09dcbcf8854542450ce698b0
3
+ size 130
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d92206402c52e067c4f905e43908634044f056c13191ef0a43a3add82503697f
3
+ size 130
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:589671ca17b36e9d5f5286721f8bcba2f6317de289239b65a8a44deab2c5e65b
3
+ size 133
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57890256dbe4a4a802b0c359805c1cc3cf64856c416a93e07f05d4f72e19a793
3
+ size 130
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a70298fd4c98ed142b762e7729450df5422759b8c37ac7b6e43d0b607daa986
3
+ size 130
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c013be6cc9bfd78d6f01a6c1c8c54b53204473ebc8f385cf7be30c0253dbd373
3
+ size 133
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67e44758d99f68f3fb406207ebe6ccd980cef13ca3c52e76d760607d3be2774b
3
+ size 130
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a725d48ebbd3f44eaccbfbf426cf304007e4bdd7d390d710a3b0d36ed06702d
3
+ size 130
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b14db32636802bd9fd43b5cc8f7948234b6d2083e3c7fb7a8668594862f45796
3
+ size 133
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d508933b4f1334a81757f72bd451ebdeaf1c3d48cf75c9063ce26204bcc5f71f
3
+ size 130
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.29.2
2
+ gradio==5.6.0
3
+ --extra-index-url https://download.pytorch.org/whl/cu113
4
+ torch
5
+ numpy
6
+ matplotlib
7
+ pillow
8
+ wandb
9
+ einops
10
+ transformers
11
+ accelerate
scripts/collect_latents_dataset.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import io
4
+ import tarfile
5
+ import torch
6
+ import webdataset as wds
7
+ import numpy as np
8
+
9
+ from tqdm import tqdm
10
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
11
+ from SDLens.hooked_sd_pipeline import HookedStableDiffusionXLPipeline
12
+
13
+ import datetime
14
+ from datasets import load_dataset
15
+ from torch.utils.data import DataLoader
16
+ import diffusers
17
+ import fire
18
+
19
+ def main(save_path, start_at=0, finish_at=30000, dataset_batch_size=50):
20
+ blocks_to_save = [
21
+ 'unet.down_blocks.2.attentions.1',
22
+ 'unet.mid_block.attentions.0',
23
+ 'unet.up_blocks.0.attentions.0',
24
+ 'unet.up_blocks.0.attentions.1',
25
+ ]
26
+
27
+ # Initialization
28
+ dataset = load_dataset("guangyil/laion-coco-aesthetic", split="train", columns=["caption"], streaming=True).shuffle(seed=42)
29
+ pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo')
30
+ pipe.to('cuda')
31
+ pipe.set_progress_bar_config(disable=True)
32
+ dataloader = DataLoader(dataset, batch_size=dataset_batch_size)
33
+
34
+ ct = datetime.datetime.now()
35
+ save_path = os.path.join(save_path, str(ct))
36
+ # Collecting dataset
37
+ os.makedirs(save_path, exist_ok=True)
38
+
39
+ writers = {
40
+ block: wds.TarWriter(f'{save_path}/{block}.tar') for block in blocks_to_save
41
+ }
42
+
43
+ writers.update({'images': wds.TarWriter(f'{save_path}/images.tar')})
44
+
45
+ def to_kwargs(kwargs_to_save):
46
+ kwargs = kwargs_to_save.copy()
47
+ seed = kwargs['seed']
48
+ del kwargs['seed']
49
+ kwargs['generator'] = torch.Generator(device="cpu").manual_seed(num_document)
50
+ return kwargs
51
+
52
+ dataloader_iter = iter(dataloader)
53
+ for num_document, batch in tqdm(enumerate(dataloader)):
54
+ if num_document < start_at:
55
+ continue
56
+
57
+ if num_document >= finish_at:
58
+ break
59
+
60
+ kwargs_to_save = {
61
+ 'prompt': batch['caption'],
62
+ 'positions_to_cache': blocks_to_save,
63
+ 'save_input': True,
64
+ 'save_output': True,
65
+ 'num_inference_steps': 1,
66
+ 'guidance_scale': 0.0,
67
+ 'seed': num_document,
68
+ 'output_type': 'pil'
69
+ }
70
+
71
+ kwargs = to_kwargs(kwargs_to_save)
72
+
73
+ output, cache = pipe.run_with_cache(
74
+ **kwargs
75
+ )
76
+
77
+ blocks = cache['input'].keys()
78
+ for block in blocks:
79
+ sample = {
80
+ "__key__": f"sample_{num_document}",
81
+ "output.pth": cache['output'][block],
82
+ "diff.pth": cache['output'][block] - cache['input'][block],
83
+ "gen_args.json": kwargs_to_save
84
+ }
85
+
86
+ writers[block].write(sample)
87
+ writers['images'].write({
88
+ "__key__": f"sample_{num_document}",
89
+ "images.npy": np.stack(output.images)
90
+ })
91
+
92
+ for block, writer in writers.items():
93
+ writer.close()
94
+
95
+ if __name__ == '__main__':
96
+ fire.Fire(main)
scripts/train_sae.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from
3
+ https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/train.py
4
+ '''
5
+
6
+
7
+ import os
8
+ import sys
9
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
10
+ from typing import Callable, Iterable, Iterator
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.distributed import ReduceOp
17
+ from SAE.dataset_iterator import ActivationsDataloader
18
+ from SAE.sae import SparseAutoencoder, unit_norm_decoder_, unit_norm_decoder_grad_adjustment_
19
+ from SAE.sae_utils import SAETrainingConfig, Config
20
+
21
+ from types import SimpleNamespace
22
+ from typing import Optional, List
23
+ import json
24
+
25
+ import tqdm
26
+
27
+ def weighted_average(points: torch.Tensor, weights: torch.Tensor):
28
+ weights = weights / weights.sum()
29
+ return (points * weights.view(-1, 1)).sum(dim=0)
30
+
31
+
32
+ @torch.no_grad()
33
+ def geometric_median_objective(
34
+ median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
35
+ ) -> torch.Tensor:
36
+
37
+ norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
38
+
39
+ return (norms * weights).sum()
40
+
41
+
42
+ def compute_geometric_median(
43
+ points: torch.Tensor,
44
+ weights: Optional[torch.Tensor] = None,
45
+ eps: float = 1e-6,
46
+ maxiter: int = 100,
47
+ ftol: float = 1e-20,
48
+ do_log: bool = False,
49
+ ):
50
+ """
51
+ :param points: ``torch.Tensor`` of shape ``(n, d)``
52
+ :param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
53
+ :param eps: Smallest allowed value of denominator, to avoid divide by zero.
54
+ Equivalently, this is a smoothing parameter. Default 1e-6.
55
+ :param maxiter: Maximum number of Weiszfeld iterations. Default 100
56
+ :param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
57
+ :param do_log: If true will return a log of function values encountered through the course of the algorithm
58
+ :return: SimpleNamespace object with fields
59
+ - `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
60
+ - `termination`: string explaining how the algorithm terminated.
61
+ - `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
62
+ """
63
+ with torch.no_grad():
64
+
65
+ if weights is None:
66
+ weights = torch.ones((points.shape[0],), device=points.device)
67
+ # initialize median estimate at mean
68
+ new_weights = weights
69
+ median = weighted_average(points, weights)
70
+ objective_value = geometric_median_objective(median, points, weights)
71
+ if do_log:
72
+ logs = [objective_value]
73
+ else:
74
+ logs = None
75
+
76
+ # Weiszfeld iterations
77
+ early_termination = False
78
+ pbar = tqdm.tqdm(range(maxiter))
79
+ for _ in pbar:
80
+ prev_obj_value = objective_value
81
+
82
+ norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
83
+ new_weights = weights / torch.clamp(norms, min=eps)
84
+ median = weighted_average(points, new_weights)
85
+ objective_value = geometric_median_objective(median, points, weights)
86
+
87
+ if logs is not None:
88
+ logs.append(objective_value)
89
+ if abs(prev_obj_value - objective_value) <= ftol * objective_value:
90
+ early_termination = True
91
+ break
92
+
93
+ pbar.set_description(f"Objective value: {objective_value:.4f}")
94
+
95
+ median = weighted_average(points, new_weights) # allow autodiff to track it
96
+ return SimpleNamespace(
97
+ median=median,
98
+ new_weights=new_weights,
99
+ termination=(
100
+ "function value converged within tolerance"
101
+ if early_termination
102
+ else "maximum iterations reached"
103
+ ),
104
+ logs=logs,
105
+ )
106
+
107
+ def maybe_transpose(x):
108
+ return x.T if not x.is_contiguous() and x.T.is_contiguous() else x
109
+
110
+ import wandb
111
+
112
+ RANK = 0
113
+
114
+ class Logger:
115
+ def __init__(self, sae_name, **kws):
116
+ self.vals = {}
117
+ self.enabled = (RANK == 0) and not kws.pop("dummy", False)
118
+ self.sae_name = sae_name
119
+
120
+ def logkv(self, k, v):
121
+ if self.enabled:
122
+ self.vals[f'{self.sae_name}/{k}'] = v.detach() if isinstance(v, torch.Tensor) else v
123
+ return v
124
+
125
+ def dumpkvs(self, step):
126
+ if self.enabled:
127
+ wandb.log(self.vals, step=step)
128
+ self.vals = {}
129
+
130
+
131
+ class FeaturesStats:
132
+ def __init__(self, dim, logger):
133
+ self.dim = dim
134
+ self.logger = logger
135
+ self.reinit()
136
+
137
+ def reinit(self):
138
+ self.n_activated = torch.zeros(self.dim, dtype=torch.long, device="cuda")
139
+ self.n = 0
140
+
141
+ def update(self, inds):
142
+ self.n += inds.shape[0]
143
+ inds = inds.flatten().detach()
144
+ self.n_activated.scatter_add_(0, inds, torch.ones_like(inds))
145
+
146
+ def log(self):
147
+ self.logger.logkv('activated', (self.n_activated / self.n + 1e-9).log10().cpu().numpy())
148
+
149
+ def training_loop_(
150
+ aes,
151
+ train_acts_iter,
152
+ loss_fn,
153
+ log_interval,
154
+ save_interval,
155
+ loggers,
156
+ sae_cfgs,
157
+ ):
158
+ sae_packs = []
159
+ for ae, cfg, logger in zip(aes, sae_cfgs, loggers):
160
+ pbar = tqdm.tqdm(unit=" steps", desc="Training Loss: ")
161
+ fstats = FeaturesStats(ae.n_dirs, logger)
162
+ opt = torch.optim.Adam(ae.parameters(), lr=cfg.lr, eps=cfg.eps, fused=True)
163
+ sae_packs.append((ae, cfg, logger, pbar, fstats, opt))
164
+
165
+ for i, flat_acts_train_batch in enumerate(train_acts_iter):
166
+ flat_acts_train_batch = flat_acts_train_batch.cuda()
167
+
168
+ for ae, cfg, logger, pbar, fstats, opt in sae_packs:
169
+ recons, info = ae(flat_acts_train_batch)
170
+ loss = loss_fn(ae, cfg, flat_acts_train_batch, recons, info, logger)
171
+
172
+ fstats.update(info['inds'])
173
+
174
+ bs = flat_acts_train_batch.shape[0]
175
+ logger.logkv('not-activated 1e4', (ae.stats_last_nonzero > 1e4 / bs).mean(dtype=float).item())
176
+ logger.logkv('not-activated 1e6', (ae.stats_last_nonzero > 1e6 / bs).mean(dtype=float).item())
177
+ logger.logkv('not-activated 1e7', (ae.stats_last_nonzero > 1e7 / bs).mean(dtype=float).item())
178
+
179
+ logger.logkv('explained variance', explained_variance(recons, flat_acts_train_batch))
180
+ logger.logkv('l2_div', (torch.linalg.norm(recons, dim=1) / torch.linalg.norm(flat_acts_train_batch, dim=1)).mean())
181
+
182
+ if (i + 1) % log_interval == 0:
183
+ fstats.log()
184
+ fstats.reinit()
185
+
186
+ if (i + 1) % save_interval == 0:
187
+ ae.save_to_disk(f"{cfg.save_path}/{i + 1}")
188
+
189
+ loss.backward()
190
+
191
+ unit_norm_decoder_(ae)
192
+ unit_norm_decoder_grad_adjustment_(ae)
193
+
194
+ opt.step()
195
+ opt.zero_grad()
196
+ logger.dumpkvs(i)
197
+
198
+ pbar.set_description(f"Training Loss {loss.item():.4f}")
199
+ pbar.update(1)
200
+
201
+
202
+ for ae, cfg, logger, pbar, fstats, opt in sae_packs:
203
+ pbar.close()
204
+ ae.save_to_disk(f"{cfg.save_path}/final")
205
+
206
+
207
+ def init_from_data_(ae, stats_acts_sample):
208
+ ae.pre_bias.data = (
209
+ compute_geometric_median(stats_acts_sample[:32768].float().cpu()).median.cuda().float()
210
+ )
211
+
212
+
213
+ def mse(recons, x):
214
+ # return ((recons - x) ** 2).sum(dim=-1).mean()
215
+ return ((recons - x) ** 2).mean()
216
+
217
+ def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
218
+ # only used for auxk
219
+ xs_mu = xs.mean(dim=0)
220
+
221
+ loss = mse(recon, xs) / mse(
222
+ xs_mu[None, :].broadcast_to(xs.shape), xs
223
+ )
224
+
225
+ return loss
226
+
227
+ def explained_variance(recons, x):
228
+ # Compute the variance of the difference
229
+ diff = x - recons
230
+ diff_var = torch.var(diff, dim=0, unbiased=False)
231
+
232
+ # Compute the variance of the original tensor
233
+ x_var = torch.var(x, dim=0, unbiased=False)
234
+
235
+ # Avoid division by zero
236
+ explained_var = 1 - diff_var / (x_var + 1e-8)
237
+
238
+ return explained_var.mean()
239
+
240
+
241
+ def main():
242
+ cfg = Config(json.load(open('SAE/config.json')))
243
+
244
+ dataloader = ActivationsDataloader(cfg.paths_to_latents, cfg.block_name, cfg.bs)
245
+
246
+ acts_iter = dataloader.iterate()
247
+ stats_acts_sample = torch.cat([
248
+ next(acts_iter).cpu() for _ in range(10)
249
+ ], dim=0)
250
+
251
+ aes = [
252
+ SparseAutoencoder(
253
+ n_dirs_local=sae.n_dirs,
254
+ d_model=sae.d_model,
255
+ k=sae.k,
256
+ auxk=sae.auxk,
257
+ dead_steps_threshold=sae.dead_toks_threshold // cfg.bs,
258
+ ).cuda()
259
+ for sae in cfg.saes
260
+ ]
261
+
262
+ for ae in aes:
263
+ init_from_data_(ae, stats_acts_sample)
264
+
265
+ mse_scale = (
266
+ 1 / ((stats_acts_sample.float().mean(dim=0) - stats_acts_sample.float()) ** 2).mean()
267
+ )
268
+ mse_scale = mse_scale.item()
269
+ del stats_acts_sample
270
+
271
+ wandb.init(
272
+ project=cfg.wandb_project,
273
+ name=cfg.wandb_name,
274
+ )
275
+
276
+ loggers = [Logger(
277
+ sae_name=cfg_sae.sae_name,
278
+ dummy=False,
279
+ ) for cfg_sae in cfg.saes]
280
+
281
+ training_loop_(
282
+ aes,
283
+ acts_iter,
284
+ lambda ae, cfg_sae, flat_acts_train_batch, recons, info, logger: (
285
+ # MSE
286
+ logger.logkv("train_recons", mse_scale * mse(recons, flat_acts_train_batch))
287
+ # AuxK
288
+ + logger.logkv(
289
+ "train_maxk_recons",
290
+ cfg_sae.auxk_coef
291
+ * normalized_mse(
292
+ ae.decode_sparse(
293
+ info["auxk_inds"],
294
+ info["auxk_vals"],
295
+ ),
296
+ flat_acts_train_batch - recons.detach() + ae.pre_bias.detach(),
297
+ ).nan_to_num(0),
298
+ )
299
+ ),
300
+ sae_cfgs = cfg.saes,
301
+ loggers=loggers,
302
+ log_interval=cfg.log_interval,
303
+ save_interval=cfg.save_interval,
304
+ )
305
+
306
+
307
+ if __name__ == "__main__":
308
+ main()
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .hooks import *
utils/hooks.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+
4
+ @torch.no_grad()
5
+ def add_feature(sae, feature_idx, value, module, input, output):
6
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
7
+ activated = sae.encode(diff)
8
+ mask = torch.zeros_like(activated, device=diff.device)
9
+ mask[..., feature_idx] = value
10
+ to_add = mask @ sae.decoder.weight.T
11
+ return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
12
+
13
+
14
+ @torch.no_grad()
15
+ def add_feature_on_area(sae, feature_idx, activation_map, module, input, output):
16
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
17
+ activated = sae.encode(diff)
18
+ mask = torch.zeros_like(activated, device=diff.device)
19
+ if len(activation_map) == 2:
20
+ activation_map = activation_map.unsqueeze(0)
21
+ mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
22
+ to_add = mask @ sae.decoder.weight.T
23
+ return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
24
+
25
+
26
+ @torch.no_grad()
27
+ def replace_with_feature(sae, feature_idx, value, module, input, output):
28
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
29
+ activated = sae.encode(diff)
30
+ mask = torch.zeros_like(activated, device=diff.device)
31
+ mask[..., feature_idx] = value
32
+ to_add = mask @ sae.decoder.weight.T
33
+ return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
34
+
35
+
36
+ @torch.no_grad()
37
+ def reconstruct_sae_hook(sae, module, input, output):
38
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
39
+ activated = sae.encode(diff)
40
+ reconstructed = sae.decoder(activated) + sae.pre_bias
41
+ return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
42
+
43
+
44
+ @torch.no_grad()
45
+ def ablate_block(module, input, output):
46
+ return input