taufeeque commited on
Commit
7f9376c
β€’
1 Parent(s): 676f3c4

Add streamlit webapp files

Browse files
Code_Browser.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web App for the Codebook Features project."""
2
+
3
+ import glob
4
+ import os
5
+
6
+ import streamlit as st
7
+
8
+ import code_search_utils
9
+ import webapp_utils
10
+
11
+ DEPLOY_MODE = True
12
+
13
+
14
+ webapp_utils.load_widget_state()
15
+
16
+ st.set_page_config(
17
+ page_title="Codebook Features",
18
+ page_icon="πŸ“š",
19
+ )
20
+
21
+ st.title("Codebook Features")
22
+
23
+ base_cache_dir = "cache/"
24
+ dirs = glob.glob(base_cache_dir + "models/*/")
25
+ model_name_options = [d.split("/")[-2].split("_")[:-2] for d in dirs]
26
+ model_name_options = ["_".join(m) for m in model_name_options]
27
+ model_name_options = sorted(set(model_name_options))
28
+
29
+ model_name = st.selectbox(
30
+ "Model",
31
+ model_name_options,
32
+ key=webapp_utils.persist("model_name"),
33
+ )
34
+
35
+ model = model_name.split("_")[0].split("#")[0]
36
+ model_layers = {
37
+ "pythia-410m-deduped": 24,
38
+ "pythia-70m-deduped": 6,
39
+ "gpt2": 12,
40
+ "TinyStories-1Layer-21M": 1,
41
+ }
42
+ model_heads = {
43
+ "pythia-410m-deduped": 16,
44
+ "pythia-70m-deduped": 8,
45
+ "gpt2": 12,
46
+ "TinyStories-1Layer-21M": 16,
47
+ }
48
+ ccb = model_name.split("_")[1]
49
+ ccb = "_ccb" if ccb == "ccb" else ""
50
+ cb_at = "_".join(model_name.split("_")[2:])
51
+ seq_len = 512 if "tinystories" in model_name.lower() else 1024
52
+ st.session_state["seq_len"] = seq_len
53
+
54
+ codes_cache_path = base_cache_dir + f"models/{model_name}_*"
55
+ dirs = glob.glob(codes_cache_path)
56
+ dirs.sort(key=os.path.getmtime)
57
+
58
+ # session states
59
+ is_attn = "attn" in cb_at
60
+ num_layers = model_layers[model]
61
+ num_heads = model_heads[model]
62
+ codes_cache_path = dirs[-1] + "/"
63
+
64
+ model_info = code_search_utils.parse_model_info(codes_cache_path)
65
+ num_codes = model_info.num_codes
66
+ dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/"
67
+
68
+ (
69
+ tokens_str,
70
+ tokens_text,
71
+ token_byte_pos,
72
+ cb_acts,
73
+ act_count_ft_tkns,
74
+ metrics,
75
+ ) = webapp_utils.load_code_search_cache(codes_cache_path, dataset_cache_path)
76
+ metric_keys = ["eval_loss", "eval_accuracy", "eval_dead_code_fraction"]
77
+ metrics = {k: v for k, v in metrics.items() if k.split("/")[0] in metric_keys}
78
+
79
+ st.session_state["model_name_id"] = model_name
80
+ st.session_state["cb_acts"] = cb_acts
81
+ st.session_state["tokens_text"] = tokens_text
82
+ st.session_state["tokens_str"] = tokens_str
83
+ st.session_state["act_count_ft_tkns"] = act_count_ft_tkns
84
+
85
+ st.session_state["num_codes"] = num_codes
86
+ st.session_state["ccb"] = ccb
87
+ st.session_state["cb_at"] = cb_at
88
+ st.session_state["is_attn"] = is_attn
89
+
90
+ st.markdown("## Metrics")
91
+ # hide metrics by default
92
+ if st.checkbox("Show Model Metrics"):
93
+ st.write(metrics)
94
+
95
+ st.markdown("## Demo Codes")
96
+ demo_file_path = codes_cache_path + "demo_codes.txt"
97
+
98
+ if st.checkbox("Show Demo Codes"):
99
+ try:
100
+ with open(demo_file_path, "r") as f:
101
+ demo_codes = f.readlines()
102
+ except FileNotFoundError:
103
+ demo_codes = []
104
+
105
+ code_desc, code_regex = "", ""
106
+ demo_codes = [code.strip() for code in demo_codes if code.strip()]
107
+
108
+ num_cols = 6 if is_attn else 5
109
+ cols = st.columns([1] * (num_cols - 1) + [2])
110
+ # st.markdown(button_height_style, unsafe_allow_html=True)
111
+ cols[0].markdown("Search", help="Button to see token activations for the code.")
112
+ cols[1].write("Code")
113
+ cols[2].write("Layer")
114
+ if is_attn:
115
+ cols[3].write("Head")
116
+ cols[-2].markdown(
117
+ "Num Acts",
118
+ help="Number of tokens that the code activates on in the acts dataset.",
119
+ )
120
+ cols[-1].markdown("Description", help="Interpreted description of the code.")
121
+
122
+ if len(demo_codes) == 0:
123
+ st.markdown(
124
+ f"""
125
+ <div style="font-size: 1.3rem; color: red;">
126
+ No demo codes found in file {demo_file_path}
127
+ </div>
128
+ """,
129
+ unsafe_allow_html=True,
130
+ )
131
+ skip = True
132
+ for code_txt in demo_codes:
133
+ if code_txt.startswith("##"):
134
+ skip = True
135
+ continue
136
+ if code_txt.startswith("#"):
137
+ code_desc, code_regex = code_txt[1:].split(":")
138
+ code_desc, code_regex = code_desc.strip(), code_regex.strip()
139
+ skip = False
140
+ continue
141
+ if skip:
142
+ continue
143
+ code_info = code_search_utils.get_code_info_pr_from_str(code_txt, code_regex)
144
+ comp_info = f"layer{code_info.layer}_{f'head{code_info.head}' if code_info.head is not None else ''}"
145
+ button_key = (
146
+ f"demo_search_code{code_info.code}_layer{code_info.layer}_desc-{code_info.description}"
147
+ + (f"head{code_info.head}" if code_info.head is not None else "")
148
+ )
149
+ cols = st.columns([1] * (num_cols - 1) + [2])
150
+ button_clicked = cols[0].button(
151
+ "πŸ”",
152
+ key=button_key,
153
+ )
154
+ if button_clicked:
155
+ webapp_utils.set_ct_acts(
156
+ code_info.code, code_info.layer, code_info.head, None, is_attn
157
+ )
158
+ cols[1].write(code_info.code)
159
+ cols[2].write(str(code_info.layer))
160
+ if is_attn:
161
+ cols[3].write(str(code_info.head))
162
+ cols[-2].write(str(act_count_ft_tkns[comp_info][code_info.code]))
163
+ cols[-1].write(code_desc)
164
+ skip = True
165
+
166
+
167
+ st.markdown("## Code Search")
168
+
169
+ regex_pattern = st.text_input(
170
+ "Enter a regex pattern",
171
+ help="Wrap code token in the first group. E.g. New (York)",
172
+ key="regex_pattern",
173
+ )
174
+ # topk = st.slider("Top K", 1, 20, 10)
175
+ prec_col, sort_col = st.columns(2)
176
+ prec_threshold = prec_col.slider(
177
+ "Precision Threshold",
178
+ 0.0,
179
+ 1.0,
180
+ 0.9,
181
+ help="Shows codes with precision on the regex pattern above the threshold.",
182
+ )
183
+ sort_by_options = ["Precision", "Recall", "Num Acts"]
184
+ sort_by_name = sort_col.radio(
185
+ "Sort By",
186
+ sort_by_options,
187
+ index=0,
188
+ horizontal=True,
189
+ help="Sorts the codes by the selected metric.",
190
+ )
191
+ sort_by = sort_by_options.index(sort_by_name)
192
+
193
+
194
+ @st.cache_data(ttl=3600)
195
+ def get_codebook_wise_codes_for_regex(regex_pattern, prec_threshold, ccb, model_name):
196
+ """Get codebook wise codes for a given regex pattern."""
197
+ assert model_name is not None # required for loading from correct cache data
198
+ return code_search_utils.get_codes_from_pattern(
199
+ regex_pattern,
200
+ tokens_text,
201
+ token_byte_pos,
202
+ cb_acts,
203
+ act_count_ft_tkns,
204
+ ccb=ccb,
205
+ topk=8,
206
+ prec_threshold=prec_threshold,
207
+ )
208
+
209
+
210
+ if regex_pattern:
211
+ codebook_wise_codes, re_token_matches = get_codebook_wise_codes_for_regex(
212
+ regex_pattern,
213
+ prec_threshold,
214
+ ccb,
215
+ model_name,
216
+ )
217
+ st.markdown(f"Found :green[{re_token_matches}] matches")
218
+ num_search_cols = 7 if is_attn else 6
219
+ non_deploy_offset = 0
220
+ if not DEPLOY_MODE:
221
+ non_deploy_offset = 1
222
+ num_search_cols += non_deploy_offset
223
+
224
+ cols = st.columns(num_search_cols)
225
+
226
+ # st.markdown(button_height_style, unsafe_allow_html=True)
227
+
228
+ cols[0].markdown("Search", help="Button to see token activations for the code.")
229
+ cols[1].write("Layer")
230
+ if is_attn:
231
+ cols[2].write("Head")
232
+ cols[-4 - non_deploy_offset].write("Code")
233
+ cols[-3 - non_deploy_offset].write("Precision")
234
+ cols[-2 - non_deploy_offset].write("Recall")
235
+ cols[-1 - non_deploy_offset].markdown(
236
+ "Num Acts",
237
+ help="Number of tokens that the code activates on in the acts dataset.",
238
+ )
239
+ if not DEPLOY_MODE:
240
+ cols[-1].markdown(
241
+ "Save to Demos",
242
+ help="Button to save the code to demos along with the regex pattern.",
243
+ )
244
+ all_codes = codebook_wise_codes.items()
245
+ all_codes = [
246
+ (cb_name, code_pr_info)
247
+ for cb_name, code_pr_infos in all_codes
248
+ for code_pr_info in code_pr_infos
249
+ ]
250
+ all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True)
251
+ for cb_name, (code, prec, rec, code_acts) in all_codes:
252
+ layer_head = cb_name.split("_")
253
+ layer = layer_head[0][5:]
254
+ head = layer_head[1][4:] if len(layer_head) > 1 else None
255
+ button_key = f"search_code{code}_layer{layer}" + (
256
+ f"head{head}" if head is not None else ""
257
+ )
258
+ cols = st.columns(num_search_cols)
259
+ extra_args = {
260
+ "prec": prec,
261
+ "recall": rec,
262
+ "num_acts": code_acts,
263
+ "regex": regex_pattern,
264
+ }
265
+ button_clicked = cols[0].button("πŸ”", key=button_key)
266
+ if button_clicked:
267
+ webapp_utils.set_ct_acts(code, layer, head, extra_args, is_attn)
268
+ cols[1].write(layer)
269
+ if is_attn:
270
+ cols[2].write(head)
271
+ cols[-4 - non_deploy_offset].write(code)
272
+ cols[-3 - non_deploy_offset].write(f"{prec*100:.2f}%")
273
+ cols[-2 - non_deploy_offset].write(f"{rec*100:.2f}%")
274
+ cols[-1 - non_deploy_offset].write(str(code_acts))
275
+ if not DEPLOY_MODE:
276
+ webapp_utils.add_save_code_button(
277
+ demo_file_path,
278
+ num_acts=code_acts,
279
+ save_regex=True,
280
+ prec=prec,
281
+ recall=rec,
282
+ button_st_container=cols[-1],
283
+ button_key_suffix=f"_code{code}_layer{layer}_head{head}",
284
+ )
285
+
286
+ if len(all_codes) == 0:
287
+ st.markdown(
288
+ f"""
289
+ <div style="font-size: 1.0rem; color: red;">
290
+ No codes found for pattern {regex_pattern} at precision threshold: {prec_threshold}
291
+ </div>
292
+ """,
293
+ unsafe_allow_html=True,
294
+ )
295
+
296
+
297
+ st.markdown("## Code Token Activations")
298
+
299
+ filter_codes = st.checkbox("Filter Codes", key="filter_codes")
300
+ act_range, layer_code_acts = None, None
301
+ if filter_codes:
302
+ act_range = st.slider(
303
+ "Num Acts",
304
+ 0,
305
+ 10_000,
306
+ (100, 10_000),
307
+ key="ct_act_range",
308
+ help="Filter codes by the number of tokens they activate on.",
309
+ )
310
+
311
+ cols = st.columns(5 if is_attn else 4)
312
+ layer = cols[0].number_input("Layer", 0, num_layers - 1, 0, key="ct_act_layer")
313
+ if is_attn:
314
+ head = cols[1].number_input("Head", 0, num_heads - 1, 0, key="ct_act_head")
315
+ else:
316
+ head = None
317
+
318
+ def_code = st.session_state.get("ct_act_code", 0)
319
+ if filter_codes:
320
+ layer_code_acts = act_count_ft_tkns[
321
+ f"layer{layer}{'_head'+str(head) if head is not None else ''}"
322
+ ]
323
+ def_code = webapp_utils.find_next_code(def_code, layer_code_acts, act_range)
324
+ if "ct_act_code" in st.session_state:
325
+ st.session_state["ct_act_code"] = def_code
326
+
327
+ code = cols[-3].number_input(
328
+ "Code",
329
+ 0,
330
+ num_codes - 1,
331
+ def_code,
332
+ key="ct_act_code",
333
+ )
334
+ num_examples = cols[-2].number_input(
335
+ "Max Results",
336
+ -1,
337
+ 1000, # setting to 1000 for efficiency purposes even though it can be more than 1000.
338
+ 100,
339
+ help="Number of examples to show in the results. Set to -1 to show all examples.",
340
+ )
341
+ ctx_size = cols[-1].number_input(
342
+ "Context Size",
343
+ 1,
344
+ 10,
345
+ 5,
346
+ help="Number of tokens to show before and after the code token.",
347
+ )
348
+
349
+ acts, acts_count = webapp_utils.get_code_acts(
350
+ model_name,
351
+ tokens_str,
352
+ code,
353
+ layer,
354
+ head,
355
+ ctx_size,
356
+ num_examples,
357
+ )
358
+
359
+ st.write(
360
+ f"Token Activations for Layer {layer}{f' Head {head}' if head is not None else ''} Code {code} | "
361
+ f"Activates on {acts_count[0]} tokens on the acts dataset",
362
+ )
363
+
364
+ if not DEPLOY_MODE:
365
+ webapp_utils.add_save_code_button(
366
+ demo_file_path,
367
+ acts_count[0],
368
+ save_regex=False,
369
+ button_text=True,
370
+ button_key_suffix="_token_acts",
371
+ )
372
+
373
+ st.markdown(webapp_utils.escape_markdown(acts), unsafe_allow_html=True)
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: gray
5
  colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.25.0
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
 
5
  colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.25.0
8
+ app_file: Code_Browser.py
9
  pinned: false
10
  license: mit
11
  ---
code_search_utils.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Functions to help with searching codes using regex."""
2
+
3
+ import pickle
4
+ import re
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ import utils
13
+
14
+
15
+ def load_dataset_cache(cache_base_path):
16
+ """Load cache files required for dataset from `cache_base_path`."""
17
+ tokens_str = np.load(cache_base_path + "tokens_str.npy")
18
+ tokens_text = np.load(cache_base_path + "tokens_text.npy")
19
+ token_byte_pos = np.load(cache_base_path + "token_byte_pos.npy")
20
+ return tokens_str, tokens_text, token_byte_pos
21
+
22
+
23
+ def load_code_search_cache(cache_base_path):
24
+ """Load cache files required for code search from `cache_base_path`."""
25
+ metrics = np.load(cache_base_path + "metrics.npy", allow_pickle=True).item()
26
+ with open(cache_base_path + "cb_acts.pkl", "rb") as f:
27
+ cb_acts = pickle.load(f)
28
+ with open(cache_base_path + "act_count_ft_tkns.pkl", "rb") as f:
29
+ act_count_ft_tkns = pickle.load(f)
30
+
31
+ return cb_acts, act_count_ft_tkns, metrics
32
+
33
+
34
+ def search_re(re_pattern, tokens_text):
35
+ """Get list of (example_id, token_pos) where re_pattern matches in tokens_text."""
36
+ # TODO: ensure that parantheses are not escaped
37
+ if re_pattern.find("(") == -1:
38
+ re_pattern = f"({re_pattern})"
39
+ return [
40
+ (i, finditer.span(1)[0])
41
+ for i, text in enumerate(tokens_text)
42
+ for finditer in re.finditer(re_pattern, text)
43
+ if finditer.span(1)[0] != finditer.span(1)[1]
44
+ ]
45
+
46
+
47
+ def byte_id_to_token_pos_id(example_byte_id, token_byte_pos):
48
+ """Get (example_id, token_pos_id) for given (example_id, byte_id)."""
49
+ example_id, byte_id = example_byte_id
50
+ index = np.searchsorted(token_byte_pos[example_id], byte_id, side="right")
51
+ return (example_id, index)
52
+
53
+
54
+ def get_code_pr(token_pos_ids, codebook_acts, cb_act_counts=None):
55
+ """Get codes, prec, recall for given token_pos_ids and codebook_acts."""
56
+ codes = np.array(
57
+ [
58
+ codebook_acts[example_id][token_pos_id]
59
+ for example_id, token_pos_id in token_pos_ids
60
+ ]
61
+ )
62
+ codes, counts = np.unique(codes, return_counts=True)
63
+ recall = counts / len(token_pos_ids)
64
+ idx = recall > 0.01
65
+ codes, counts, recall = codes[idx], counts[idx], recall[idx]
66
+ if cb_act_counts is not None:
67
+ code_acts = np.array([cb_act_counts[code] for code in codes])
68
+ prec = counts / code_acts
69
+ sort_idx = np.argsort(prec)[::-1]
70
+ else:
71
+ code_acts = np.zeros_like(codes)
72
+ prec = np.zeros_like(codes)
73
+ sort_idx = np.argsort(recall)[::-1]
74
+ codes, prec, recall = codes[sort_idx], prec[sort_idx], recall[sort_idx]
75
+ code_acts = code_acts[sort_idx]
76
+ return codes, prec, recall, code_acts
77
+
78
+
79
+ def get_neuron_pr(
80
+ token_pos_ids, recall, neuron_acts_by_ex, neuron_sorted_acts, topk=10
81
+ ):
82
+ """Get codes, prec, recall for given token_pos_ids and codebook_acts."""
83
+ # check if neuron_acts_by_ex is a torch tensor
84
+ if isinstance(neuron_acts_by_ex, torch.Tensor):
85
+ re_neuron_acts = torch.stack(
86
+ [
87
+ neuron_acts_by_ex[example_id, token_pos_id]
88
+ for example_id, token_pos_id in token_pos_ids
89
+ ],
90
+ dim=-1,
91
+ ) # (layers, 2, dim_size, matches)
92
+ re_neuron_acts = torch.sort(re_neuron_acts, dim=-1).values
93
+ else:
94
+ re_neuron_acts = np.stack(
95
+ [
96
+ neuron_acts_by_ex[example_id, token_pos_id]
97
+ for example_id, token_pos_id in token_pos_ids
98
+ ],
99
+ axis=-1,
100
+ ) # (layers, 2, dim_size, matches)
101
+ re_neuron_acts.sort(axis=-1)
102
+ re_neuron_acts = torch.from_numpy(re_neuron_acts)
103
+ # re_neuron_acts = re_neuron_acts[:, :, :, -int(recall * re_neuron_acts.shape[-1]) :]
104
+ print("Examples for recall", recall, ":", int(recall * re_neuron_acts.shape[-1]))
105
+ act_thresh = re_neuron_acts[:, :, :, -int(recall * re_neuron_acts.shape[-1])]
106
+ # binary search act_thresh in neuron_sorted_acts
107
+ assert neuron_sorted_acts.shape[:-1] == act_thresh.shape
108
+ prec_den = torch.searchsorted(neuron_sorted_acts, act_thresh.unsqueeze(-1))
109
+ prec_den = prec_den.squeeze(-1)
110
+ prec_den = neuron_sorted_acts.shape[-1] - prec_den
111
+ prec = int(recall * re_neuron_acts.shape[-1]) / prec_den
112
+ assert (
113
+ prec.shape == re_neuron_acts.shape[:-1]
114
+ ), f"{prec.shape} != {re_neuron_acts.shape[:-1]}"
115
+
116
+ best_neuron_idx = np.unravel_index(prec.argmax(), prec.shape)
117
+ best_prec = prec[best_neuron_idx]
118
+ print("max prec:", best_prec)
119
+ best_neuron_act_thresh = act_thresh[best_neuron_idx].item()
120
+ best_neuron_acts = neuron_acts_by_ex[
121
+ :, :, best_neuron_idx[0], best_neuron_idx[1], best_neuron_idx[2]
122
+ ]
123
+ best_neuron_acts = best_neuron_acts >= best_neuron_act_thresh
124
+ best_neuron_acts = np.stack(np.where(best_neuron_acts), axis=-1)
125
+
126
+ return best_prec, best_neuron_acts, best_neuron_idx
127
+
128
+
129
+ def convert_to_adv_name(name, cb_at, ccb=""):
130
+ """Convert layer0_head0 to layer0_attn_preproj_ccb0."""
131
+ if ccb:
132
+ layer, head = name.split("_")
133
+ return layer + f"_{cb_at}_ccb" + head[4:]
134
+ else:
135
+ return layer + "_" + cb_at
136
+
137
+
138
+ def convert_to_base_name(name, ccb=""):
139
+ """Convert layer0_attn_preproj_ccb0 to layer0_head0."""
140
+ split_name = name.split("_")
141
+ layer, head = split_name[0], split_name[-1][3:]
142
+ if "ccb" in name:
143
+ return layer + "_head" + head
144
+ else:
145
+ return layer
146
+
147
+
148
+ def get_layer_head_from_base_name(name):
149
+ """Convert layer0_head0 to 0, 0."""
150
+ split_name = name.split("_")
151
+ layer = int(split_name[0][5:])
152
+ head = None
153
+ if len(split_name) > 1:
154
+ head = int(split_name[-1][4:])
155
+ return layer, head
156
+
157
+
158
+ def get_layer_head_from_adv_name(name):
159
+ """Convert layer0_attn_preproj_ccb0 to 0, 0."""
160
+ base_name = convert_to_base_name(name)
161
+ layer, head = get_layer_head_from_base_name(base_name)
162
+ return layer, head
163
+
164
+
165
+ def get_codes_from_pattern(
166
+ re_pattern,
167
+ tokens_text,
168
+ token_byte_pos,
169
+ cb_acts,
170
+ act_count_ft_tkns,
171
+ ccb="",
172
+ topk=5,
173
+ prec_threshold=0.5,
174
+ ):
175
+ """Fetch codes from a given regex pattern."""
176
+ byte_ids = search_re(re_pattern, tokens_text)
177
+ token_pos_ids = [
178
+ byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
179
+ ]
180
+ token_pos_ids = np.unique(token_pos_ids, axis=0)
181
+ re_token_matches = len(token_pos_ids)
182
+ codebook_wise_codes = {}
183
+ for cb_name, cb in tqdm(cb_acts.items()):
184
+ base_cb_name = convert_to_base_name(cb_name, ccb=ccb)
185
+ codes, prec, recall, code_acts = get_code_pr(
186
+ token_pos_ids,
187
+ cb,
188
+ cb_act_counts=act_count_ft_tkns[base_cb_name],
189
+ )
190
+ idx = np.arange(min(topk, len(codes)))
191
+ idx = idx[prec[:topk] > prec_threshold]
192
+ codes, prec, recall = codes[idx], prec[idx], recall[idx]
193
+ code_acts = code_acts[idx]
194
+ codes_pr = list(zip(codes, prec, recall, code_acts))
195
+ codebook_wise_codes[base_cb_name] = codes_pr
196
+ return codebook_wise_codes, re_token_matches
197
+
198
+
199
+ def get_neurons_from_pattern(
200
+ re_pattern,
201
+ tokens_text,
202
+ token_byte_pos,
203
+ neuron_acts_by_ex,
204
+ neuron_sorted_acts,
205
+ recall_threshold,
206
+ ):
207
+ """Fetch the best neuron (with act thresh given by recall) from a given regex pattern."""
208
+ byte_ids = search_re(re_pattern, tokens_text)
209
+ token_pos_ids = [
210
+ byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
211
+ ]
212
+ token_pos_ids = np.unique(token_pos_ids, axis=0)
213
+ re_token_matches = len(token_pos_ids)
214
+ best_prec, best_neuron_acts, best_neuron_idx = get_neuron_pr(
215
+ token_pos_ids,
216
+ recall_threshold,
217
+ neuron_acts_by_ex,
218
+ neuron_sorted_acts,
219
+ )
220
+ return best_prec, best_neuron_acts, best_neuron_idx, re_token_matches
221
+
222
+
223
+ def compare_codes_with_neurons(
224
+ best_codes_info,
225
+ tokens_text,
226
+ token_byte_pos,
227
+ neuron_acts_by_ex,
228
+ neuron_sorted_acts,
229
+ ):
230
+ """Compare codes with neurons."""
231
+ assert isinstance(neuron_acts_by_ex, np.ndarray)
232
+ (
233
+ all_best_prec,
234
+ all_best_neuron_acts,
235
+ all_best_neuron_idxs,
236
+ all_re_token_matches,
237
+ ) = zip(
238
+ *[
239
+ get_neurons_from_pattern(
240
+ code_info.re_pattern,
241
+ tokens_text,
242
+ token_byte_pos,
243
+ neuron_acts_by_ex,
244
+ neuron_sorted_acts,
245
+ code_info.recall,
246
+ )
247
+ for code_info in tqdm(range(len(best_codes_info)))
248
+ ],
249
+ strict=True,
250
+ )
251
+ code_best_precs = np.array(
252
+ [code_info.prec for code_info in range(len(best_codes_info))]
253
+ )
254
+ codes_better_than_neurons = code_best_precs > np.array(all_best_prec)
255
+ return codes_better_than_neurons.mean()
256
+
257
+
258
+ def get_code_info_pr_from_str(code_txt, regex):
259
+ """Extract code info fields from string."""
260
+ code_txt = code_txt.strip()
261
+ code_txt = code_txt.split(", ")
262
+ code_txt = dict(txt.split(": ") for txt in code_txt)
263
+ return utils.CodeInfo(**code_txt)
264
+
265
+
266
+ @dataclass
267
+ class ModelInfoForWebapp:
268
+ """Model info for webapp."""
269
+
270
+ model_name: str
271
+ pretrained_path: str
272
+ dataset_name: str
273
+ num_codes: int
274
+ cb_at: str
275
+ ccb: str
276
+ n_layers: int
277
+ n_heads: Optional[int] = None
278
+ seed: int = 42
279
+ max_samples: int = 2000
280
+
281
+ def __post_init__(self):
282
+ """Convert to correct types."""
283
+ self.num_codes = int(self.num_codes)
284
+ self.n_layers = int(self.n_layers)
285
+ if self.n_heads == "None":
286
+ self.n_heads = None
287
+ elif self.n_heads is not None:
288
+ self.n_heads = int(self.n_heads)
289
+ self.seed = int(self.seed)
290
+ self.max_samples = int(self.max_samples)
291
+
292
+
293
+ def parse_model_info(path):
294
+ """Parse model info from path."""
295
+ with open(path + "info.txt", "r") as f:
296
+ lines = f.readlines()
297
+ lines = dict(line.strip().split(": ") for line in lines)
298
+ return ModelInfoForWebapp(**lines)
299
+ return ModelInfoForWebapp(**lines)
pages/Concept_Code.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web app page for showing codes for different examples in the dataset."""
2
+
3
+
4
+ import streamlit as st
5
+ from streamlit_extras.switch_page_button import switch_page
6
+
7
+ import code_search_utils
8
+ import webapp_utils
9
+
10
+ webapp_utils.load_widget_state()
11
+
12
+ if "cb_acts" not in st.session_state:
13
+ switch_page("Code_Browser")
14
+
15
+ total_examples = 2000
16
+ prec_threshold = 0.01
17
+
18
+ model_name = st.session_state["model_name_id"]
19
+ seq_len = st.session_state["seq_len"]
20
+ tokens_text = st.session_state["tokens_text"]
21
+ tokens_str = st.session_state["tokens_str"]
22
+ cb_acts = st.session_state["cb_acts"]
23
+ act_count_ft_tkns = st.session_state["act_count_ft_tkns"]
24
+ ccb = st.session_state["ccb"]
25
+
26
+
27
+ def get_example_concept_codes(example_id):
28
+ """Get concept codes for the given example id."""
29
+ token_pos_ids = [(example_id, i) for i in range(seq_len)]
30
+ all_codes = []
31
+ for cb_name, cb in cb_acts.items():
32
+ base_cb_name = code_search_utils.convert_to_base_name(cb_name, ccb=ccb)
33
+ codes, prec, rec, code_acts = code_search_utils.get_code_pr(
34
+ token_pos_ids,
35
+ cb,
36
+ act_count_ft_tkns[base_cb_name],
37
+ )
38
+ prec_sat_idx = prec >= prec_threshold
39
+ codes, prec, rec, code_acts = (
40
+ codes[prec_sat_idx],
41
+ prec[prec_sat_idx],
42
+ rec[prec_sat_idx],
43
+ code_acts[prec_sat_idx],
44
+ )
45
+ rec_sat_idx = rec >= recall_threshold
46
+ codes, prec, rec, code_acts = (
47
+ codes[rec_sat_idx],
48
+ prec[rec_sat_idx],
49
+ rec[rec_sat_idx],
50
+ code_acts[rec_sat_idx],
51
+ )
52
+ codes_pr = list(zip(codes, prec, rec, code_acts))
53
+ all_codes.append((cb_name, codes_pr))
54
+ return all_codes
55
+
56
+
57
+ def find_next_example(example_id):
58
+ """Find the example after `example_id` that has concept codes."""
59
+ initial_example_id = example_id
60
+ example_id += 1
61
+ while example_id != initial_example_id:
62
+ all_codes = get_example_concept_codes(example_id)
63
+ codes_found = sum([len(code_pr_infos) for _, code_pr_infos in all_codes])
64
+ if codes_found > 0:
65
+ st.session_state["example_id"] = example_id
66
+ return
67
+ example_id = (example_id + 1) % total_examples
68
+ st.error(
69
+ f"No examples found at the specified recall threshold: {recall_threshold}.",
70
+ icon="🚨",
71
+ )
72
+
73
+
74
+ def redirect_to_main_with_code(code, layer, head):
75
+ """Redirect to main page with the given code."""
76
+ st.session_state["ct_act_code"] = code
77
+ st.session_state["ct_act_layer"] = layer
78
+ if st.session_state["is_attn"]:
79
+ st.session_state["ct_act_head"] = head
80
+ switch_page("Code Browser")
81
+
82
+
83
+ def show_examples_for_concept_code(code, layer, head, code_act_ratio=0.3):
84
+ """Show examples that the code activates on."""
85
+ ex_acts, _ = webapp_utils.get_code_acts(
86
+ model_name,
87
+ tokens_str,
88
+ code,
89
+ layer,
90
+ head,
91
+ ctx_size=5,
92
+ return_example_list=True,
93
+ )
94
+ filt_ex_acts = []
95
+ for act_str, num_acts in ex_acts:
96
+ if num_acts > seq_len * code_act_ratio:
97
+ filt_ex_acts.append(act_str)
98
+ st.markdown("#### Examples for Code")
99
+ st.markdown(
100
+ webapp_utils.escape_markdown("".join(filt_ex_acts)), unsafe_allow_html=True
101
+ )
102
+
103
+
104
+ is_attn = st.session_state["is_attn"]
105
+
106
+ st.markdown("## Concept Code")
107
+ concept_code_description = (
108
+ "Concept codes are codes that activate a lot on only a particular set of examples that share a concept. "
109
+ "Hence such codes can be thought to correspond to more higher-level concepts or features and "
110
+ "can activate on most tokens that belong in an example text. This interface provides a way to search for such "
111
+ "codes by going through different examples using Example ID."
112
+ )
113
+ st.write(concept_code_description)
114
+
115
+ # ex_col, p_col, r_col, trunc_col, sort_col = st.columns([1, 2, 2, 1, 1])
116
+ ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1])
117
+ example_id = ex_col.number_input(
118
+ "Example ID",
119
+ 0,
120
+ total_examples - 1,
121
+ 0,
122
+ key="example_id",
123
+ )
124
+ # prec_threshold = p_col.slider(
125
+ # "Precision Threshold",
126
+ # 0.0,
127
+ # 1.0,
128
+ # 0.02,
129
+ # key="prec",
130
+ # help="Precision Threshold controls the specificity of the codes for the given example.",
131
+ # )
132
+ recall_threshold = r_col.slider(
133
+ "Recall Threshold",
134
+ 0.0,
135
+ 1.0,
136
+ 0.3,
137
+ key="recall",
138
+ help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.",
139
+ )
140
+ example_truncation = trunc_col.number_input(
141
+ "Max Output Chars", 0, 10240, 1024, key="max_chars"
142
+ )
143
+ sort_by_options = ["Precision", "Recall", "Num Acts"]
144
+ sort_by_name = sort_col.radio(
145
+ "Sort By",
146
+ sort_by_options,
147
+ index=0,
148
+ horizontal=True,
149
+ help="Sorts the codes by the selected metric.",
150
+ )
151
+ sort_by = sort_by_options.index(sort_by_name)
152
+
153
+
154
+ button = st.button(
155
+ "Find Next Example",
156
+ key="find_next_example",
157
+ on_click=find_next_example,
158
+ args=(example_id,),
159
+ help="Find an example which has codes above the recall threshold.",
160
+ )
161
+ # if button:
162
+ # find_next_example(st.session_state["example_id"])
163
+
164
+
165
+ st.markdown("### Example Text")
166
+ trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else ""
167
+ st.write(tokens_text[example_id][:example_truncation] + trunc_suffix)
168
+
169
+ cols = st.columns(7 if is_attn else 6)
170
+ cols[0].markdown("Search", help="Button to see token activations for the code.")
171
+ cols[1].write("Layer")
172
+ if is_attn:
173
+ cols[2].write("Head")
174
+ cols[-4].write("Code")
175
+ cols[-3].write("Precision")
176
+ cols[-2].write("Recall")
177
+ cols[-1].markdown(
178
+ "Num Acts",
179
+ help="Number of tokens that the code activates on in the acts dataset.",
180
+ )
181
+
182
+ all_codes = get_example_concept_codes(example_id)
183
+ all_codes = [
184
+ (cb_name, code_pr_info)
185
+ for cb_name, code_pr_infos in all_codes
186
+ for code_pr_info in code_pr_infos
187
+ ]
188
+ all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True)
189
+
190
+ for cb_name, (code, p, r, acts) in all_codes:
191
+ cols = st.columns(7 if is_attn else 6)
192
+ code_button = cols[0].button(
193
+ "πŸ”",
194
+ key=f"ex-code-{code}-{cb_name}",
195
+ )
196
+ layer, head = code_search_utils.get_layer_head_from_adv_name(cb_name)
197
+ cols[1].write(str(layer))
198
+ if is_attn:
199
+ cols[2].write(str(head))
200
+
201
+ cols[-4].write(code)
202
+ cols[-3].write(f"{p*100:.2f}%")
203
+ cols[-2].write(f"{r*100:.2f}%")
204
+ cols[-1].write(str(acts))
205
+
206
+ if code_button:
207
+ show_examples_for_concept_code(
208
+ code,
209
+ layer,
210
+ head,
211
+ code_act_ratio=recall_threshold,
212
+ )
213
+ if len(all_codes) == 0:
214
+ st.markdown(
215
+ f"<div style='text-align:center'>No codes found at recall threshold: {recall_threshold}</div>",
216
+ unsafe_allow_html=True,
217
+ )
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ torch>=2.0.0
3
+ tqdm
4
+ termcolor
5
+ streamlit_extras
utils.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Util functions for codebook features."""
2
+ import re
3
+ import typing
4
+ from dataclasses import dataclass
5
+ from functools import partial
6
+ from typing import Optional
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from termcolor import colored
12
+ from tqdm import tqdm
13
+
14
+
15
+ @dataclass
16
+ class CodeInfo:
17
+ """Dataclass for codebook info."""
18
+
19
+ code: int
20
+ layer: int
21
+ head: Optional[int]
22
+ cb_at: Optional[str] = None
23
+
24
+ # for patching interventions
25
+ pos: Optional[int] = None
26
+ code_pos: Optional[int] = -1
27
+
28
+ # for description & regex-based interpretation
29
+ description: Optional[str] = None
30
+ regex: Optional[str] = None
31
+ prec: Optional[float] = None
32
+ recall: Optional[float] = None
33
+ num_acts: Optional[int] = None
34
+
35
+ def __post_init__(self):
36
+ """Convert to appropriate types."""
37
+ self.code = int(self.code)
38
+ self.layer = int(self.layer)
39
+ if self.head:
40
+ self.head = int(self.head)
41
+ if self.pos:
42
+ self.pos = int(self.pos)
43
+ if self.code_pos:
44
+ self.code_pos = int(self.code_pos)
45
+ if self.prec:
46
+ self.prec = float(self.prec)
47
+ assert 0 <= self.prec <= 1
48
+ if self.recall:
49
+ self.recall = float(self.recall)
50
+ assert 0 <= self.recall <= 1
51
+ if self.num_acts:
52
+ self.num_acts = int(self.num_acts)
53
+
54
+ def check_description_info(self):
55
+ """Check if the regex info is present."""
56
+ assert self.num_acts is not None and self.description is not None
57
+ if self.regex is not None:
58
+ assert self.prec is not None and self.recall is not None
59
+
60
+ def check_patch_info(self):
61
+ """Check if the patch info is present."""
62
+ # TODO: pos can be none for patching
63
+ assert self.pos is not None and self.code_pos is not None
64
+
65
+ def __repr__(self):
66
+ """Return the string representation."""
67
+ repr = f"CodeInfo(code={self.code}, layer={self.layer}, head={self.head}, cb_at={self.cb_at}"
68
+ if self.pos is not None or self.code_pos is not None:
69
+ repr += f", pos={self.pos}, code_pos={self.code_pos}"
70
+ if self.description is not None:
71
+ repr += f", description={self.description}"
72
+ if self.regex is not None:
73
+ repr += f", regex={self.regex}, prec={self.prec}, recall={self.recall}"
74
+ if self.num_acts is not None:
75
+ repr += f", num_acts={self.num_acts}"
76
+ repr += ")"
77
+ return repr
78
+
79
+
80
+ def logits_to_pred(logits, tokenizer, k=5):
81
+ """Convert logits to top-k predictions."""
82
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
83
+ probs = sorted_logits.softmax(dim=-1)
84
+ topk_preds = [tokenizer.convert_ids_to_tokens(e) for e in sorted_indices[:, -1, :k]]
85
+ topk_preds = [
86
+ tokenizer.convert_tokens_to_string([e]) for batch in topk_preds for e in batch
87
+ ]
88
+ return [(topk_preds[i], probs[:, -1, i].item()) for i in range(len(topk_preds))]
89
+
90
+
91
+ def patch_codebook_ids(
92
+ corrupted_codebook_ids, hook, pos, cache, cache_pos=None, code_idx=None
93
+ ):
94
+ """Patch codebook ids with cached ids."""
95
+ if cache_pos is None:
96
+ cache_pos = pos
97
+ if code_idx is None:
98
+ corrupted_codebook_ids[:, pos] = cache[hook.name][:, cache_pos]
99
+ else:
100
+ for code_id in range(32):
101
+ if code_id in code_idx:
102
+ corrupted_codebook_ids[:, pos, code_id] = cache[hook.name][
103
+ :, cache_pos, code_id
104
+ ]
105
+ else:
106
+ corrupted_codebook_ids[:, pos, code_id] = -1
107
+
108
+ return corrupted_codebook_ids
109
+
110
+
111
+ def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
112
+ """Calculate the average logit difference between the answer and the other token."""
113
+ # Only the final logits are relevant for the answer
114
+ final_logits = logits[:, -1, :]
115
+ answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
116
+ answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
117
+ if per_prompt:
118
+ return answer_logit_diff
119
+ else:
120
+ return answer_logit_diff.mean()
121
+
122
+
123
+ def normalize_patched_logit_diff(
124
+ patched_logit_diff,
125
+ base_average_logit_diff,
126
+ corrupted_average_logit_diff,
127
+ ):
128
+ """Normalize the patched logit difference."""
129
+ # Subtract corrupted logit diff to measure the improvement,
130
+ # divide by the total improvement from clean to corrupted to normalise
131
+ # 0 means zero change, negative means actively made worse,
132
+ # 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
133
+ return (patched_logit_diff - corrupted_average_logit_diff) / (
134
+ base_average_logit_diff - corrupted_average_logit_diff
135
+ )
136
+
137
+
138
+ def features_to_tokens(cb_key, cb_acts, num_codes, code=None):
139
+ """Return the set of token ids each codebook feature activates on."""
140
+ codebook_ids = cb_acts[cb_key]
141
+
142
+ if code is None:
143
+ features_tokens = [[] for _ in range(num_codes)]
144
+ for i in tqdm(range(codebook_ids.shape[0])):
145
+ for j in range(codebook_ids.shape[1]):
146
+ for k in range(codebook_ids.shape[2]):
147
+ features_tokens[codebook_ids[i, j, k]].append((i, j))
148
+ else:
149
+ idx0, idx1, _ = np.where(codebook_ids == code)
150
+ features_tokens = list(zip(idx0, idx1))
151
+
152
+ return features_tokens
153
+
154
+
155
+ def color_str(s: str, color: str, html: bool):
156
+ """Color the string for html or terminal."""
157
+ if html:
158
+ return f"<span style='color:{color}'>{s}</span>"
159
+ else:
160
+ return colored(s, color)
161
+
162
+
163
+ def color_tokens_red_automata(tokens, red_idx, html=False):
164
+ """Separate states with a dash and color red the tokens in red_idx."""
165
+ ret_string = ""
166
+ itr_over_red_idx = 0
167
+ tokens_enumerate = enumerate(tokens)
168
+ if tokens[0] == "<|endoftext|>":
169
+ next(tokens_enumerate)
170
+ if red_idx[0] == 0:
171
+ itr_over_red_idx += 1
172
+ for i, c in tokens_enumerate:
173
+ if i % 2 == 1:
174
+ ret_string += "-"
175
+ if itr_over_red_idx < len(red_idx) and i == red_idx[itr_over_red_idx]:
176
+ ret_string += color_str(c, "red", html)
177
+ itr_over_red_idx += 1
178
+ else:
179
+ ret_string += c
180
+ return ret_string
181
+
182
+
183
+ def color_tokens_red(tokens, red_idx, n=3, html=False):
184
+ """Color red the tokens in red_idx."""
185
+ ret_string = ""
186
+ last_colored_token_idx = -1
187
+ for i in red_idx:
188
+ c_str = tokens[i]
189
+ if i <= last_colored_token_idx + 2 * n + 1:
190
+ ret_string += "".join(tokens[last_colored_token_idx + 1 : i])
191
+ else:
192
+ ret_string += "".join(
193
+ tokens[last_colored_token_idx + 1 : last_colored_token_idx + n + 1]
194
+ )
195
+ ret_string += " ... "
196
+ ret_string += "".join(tokens[i - n : i])
197
+ ret_string += color_str(c_str, "red", html)
198
+ last_colored_token_idx = i
199
+ ret_string += "".join(
200
+ tokens[
201
+ last_colored_token_idx + 1 : min(last_colored_token_idx + n, len(tokens))
202
+ ]
203
+ )
204
+ return ret_string
205
+
206
+
207
+ def prepare_example_print(
208
+ example_id,
209
+ example_tokens,
210
+ tokens_to_color_red,
211
+ html,
212
+ color_red_fn=color_tokens_red,
213
+ ):
214
+ """Format example to print."""
215
+ example_output = color_str(example_id, "green", html)
216
+ example_output += (
217
+ ": "
218
+ + color_red_fn(example_tokens, tokens_to_color_red, html=html)
219
+ + ("<br>" if html else "\n")
220
+ )
221
+ return example_output
222
+
223
+
224
+ def tkn_print(
225
+ ll,
226
+ tokens,
227
+ separate_states,
228
+ n=3,
229
+ max_examples=100,
230
+ randomize=False,
231
+ html=False,
232
+ return_example_list=False,
233
+ ):
234
+ """Format and prints the tokens in ll."""
235
+ if randomize:
236
+ raise NotImplementedError("Randomize not yet implemented.")
237
+ indices = range(len(ll))
238
+ print_output = [] if return_example_list else ""
239
+ curr_ex = ll[0][0]
240
+ total_examples = 0
241
+ tokens_to_color_red = []
242
+ color_red_fn = (
243
+ color_tokens_red_automata if separate_states else partial(color_tokens_red, n=n)
244
+ )
245
+ for idx in indices:
246
+ if total_examples > max_examples:
247
+ break
248
+ i, j = ll[idx]
249
+
250
+ if i != curr_ex and curr_ex >= 0:
251
+ curr_ex_output = prepare_example_print(
252
+ curr_ex,
253
+ tokens[curr_ex],
254
+ tokens_to_color_red,
255
+ html,
256
+ color_red_fn,
257
+ )
258
+ total_examples += 1
259
+ if return_example_list:
260
+ print_output.append((curr_ex_output, len(tokens_to_color_red)))
261
+ else:
262
+ print_output += curr_ex_output
263
+ curr_ex = i
264
+ tokens_to_color_red = []
265
+ tokens_to_color_red.append(j)
266
+ curr_ex_output = prepare_example_print(
267
+ curr_ex,
268
+ tokens[curr_ex],
269
+ tokens_to_color_red,
270
+ html,
271
+ color_red_fn,
272
+ )
273
+ if return_example_list:
274
+ print_output.append((curr_ex_output, len(tokens_to_color_red)))
275
+ else:
276
+ print_output += curr_ex_output
277
+ asterisk_str = "********************************************"
278
+ print_output += color_str(asterisk_str, "green", html)
279
+ total_examples += 1
280
+
281
+ return print_output
282
+
283
+
284
+ def print_ft_tkns(
285
+ ft_tkns,
286
+ tokens,
287
+ separate_states=False,
288
+ n=3,
289
+ start=0,
290
+ stop=1000,
291
+ indices=None,
292
+ max_examples=100,
293
+ freq_filter=None,
294
+ randomize=False,
295
+ html=False,
296
+ return_example_list=False,
297
+ ):
298
+ """Print the tokens for the codebook features."""
299
+ indices = list(range(start, stop)) if indices is None else indices
300
+ num_tokens = len(tokens) * len(tokens[0])
301
+ codes, token_act_freqs, token_acts = [], [], []
302
+ for i in indices:
303
+ tkns = ft_tkns[i]
304
+ freq = (len(tkns), 100 * len(tkns) / num_tokens)
305
+ if freq_filter is not None and freq[1] > freq_filter:
306
+ continue
307
+ codes.append(i)
308
+ token_act_freqs.append(freq)
309
+ if len(tkns) > 0:
310
+ tkn_acts = tkn_print(
311
+ tkns,
312
+ tokens,
313
+ separate_states,
314
+ n=n,
315
+ max_examples=max_examples,
316
+ randomize=randomize,
317
+ html=html,
318
+ return_example_list=return_example_list,
319
+ )
320
+ token_acts.append(tkn_acts)
321
+ else:
322
+ token_acts.append("")
323
+ return codes, token_act_freqs, token_acts
324
+
325
+
326
+ def patch_in_codes(run_cb_ids, hook, pos, code, code_pos=None):
327
+ """Patch in the `code` at `run_cb_ids`."""
328
+ pos = slice(None) if pos is None else pos
329
+ code_pos = slice(None) if code_pos is None else code_pos
330
+
331
+ if code_pos == "append":
332
+ assert pos == slice(None)
333
+ run_cb_ids = F.pad(run_cb_ids, (0, 1), mode="constant", value=code)
334
+ if isinstance(pos, typing.Iterable) or isinstance(pos, typing.Iterable):
335
+ for p in pos:
336
+ run_cb_ids[:, p, code_pos] = code
337
+ else:
338
+ run_cb_ids[:, pos, code_pos] = code
339
+ return run_cb_ids
340
+
341
+
342
+ def get_cb_layer_name(cb_at, layer_idx, head_idx=None):
343
+ """Get the layer name used to store hooks/cache."""
344
+ if head_idx is None:
345
+ return f"blocks.{layer_idx}.{cb_at}.codebook_layer.hook_codebook_ids"
346
+ else:
347
+ return f"blocks.{layer_idx}.{cb_at}.codebook_layer.codebook.{head_idx}.hook_codebook_ids"
348
+
349
+
350
+ def get_cb_layer_names(layer, patch_types, n_heads):
351
+ """Get the layer names used to store hooks/cache."""
352
+ layer_names = []
353
+ attn_added, mlp_added = False, False
354
+ if "attn_out" in patch_types:
355
+ attn_added = True
356
+ for head in range(n_heads):
357
+ layer_names.append(
358
+ f"blocks.{layer}.attn.codebook_layer.codebook.{head}.hook_codebook_ids"
359
+ )
360
+ if "mlp_out" in patch_types:
361
+ mlp_added = True
362
+ layer_names.append(f"blocks.{layer}.mlp.codebook_layer.hook_codebook_ids")
363
+
364
+ for patch_type in patch_types:
365
+ # match patch_type of the pattern attn_\d_head_\d
366
+ attn_head = re.match(r"attn_(\d)_head_(\d)", patch_type)
367
+ if (not attn_added) and attn_head and attn_head[1] == str(layer):
368
+ layer_names.append(
369
+ f"blocks.{layer}.attn.codebook_layer.codebook.{attn_head[2]}.hook_codebook_ids"
370
+ )
371
+ mlp = re.match(r"mlp_(\d)", patch_type)
372
+ if (not mlp_added) and mlp and mlp[1] == str(layer):
373
+ layer_names.append(f"blocks.{layer}.mlp.codebook_layer.hook_codebook_ids")
374
+
375
+ return layer_names
376
+
377
+
378
+ def cb_layer_name_to_info(layer_name):
379
+ """Get the layer info from the layer name."""
380
+ layer_name_split = layer_name.split(".")
381
+ layer_idx = int(layer_name_split[1])
382
+ cb_at = layer_name_split[2]
383
+ if cb_at == "mlp":
384
+ head_idx = None
385
+ else:
386
+ head_idx = int(layer_name_split[5])
387
+ return cb_at, layer_idx, head_idx
388
+
389
+
390
+ def get_hooks(code, cb_at, layer_idx, head_idx=None, pos=None):
391
+ """Get the hooks for the codebook features."""
392
+ hook_fns = [
393
+ partial(patch_in_codes, pos=pos, code=code[i]) for i in range(len(code))
394
+ ]
395
+ return [
396
+ (get_cb_layer_name(cb_at[i], layer_idx[i], head_idx[i]), hook_fns[i])
397
+ for i in range(len(code))
398
+ ]
399
+
400
+
401
+ def run_with_codes(
402
+ input, cb_model, code, cb_at, layer_idx, head_idx=None, pos=None, prepend_bos=True
403
+ ):
404
+ """Run the model with the codebook features patched in."""
405
+ hook_fns = [
406
+ partial(patch_in_codes, pos=pos, code=code[i]) for i in range(len(code))
407
+ ]
408
+ cb_model.reset_codebook_metrics()
409
+ cb_model.reset_hook_kwargs()
410
+ fwd_hooks = [
411
+ (get_cb_layer_name(cb_at[i], layer_idx[i], head_idx[i]), hook_fns[i])
412
+ for i in range(len(cb_at))
413
+ ]
414
+ with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model:
415
+ patched_logits, patched_cache = hooked_model.run_with_cache(
416
+ input, prepend_bos=prepend_bos
417
+ )
418
+ return patched_logits, patched_cache
419
+
420
+
421
+ def in_hook_list(list_of_arg_tuples, layer, head=None):
422
+ """Check if the component specified by `layer` and `head` is in the `list_of_arg_tuples`."""
423
+ # if head is not provided, then checks in MLP
424
+ for arg_tuple in list_of_arg_tuples:
425
+ if head is None:
426
+ if arg_tuple.cb_at == "mlp" and arg_tuple.layer == layer:
427
+ return True
428
+ else:
429
+ if (
430
+ arg_tuple.cb_at == "attn"
431
+ and arg_tuple.layer == layer
432
+ and arg_tuple.head == head
433
+ ):
434
+ return True
435
+ return False
436
+
437
+
438
+ # def generate_with_codes(input, code, cb_at, layer_idx, head_idx=None, pos=None, disable_other_comps=False):
439
+ def generate_with_codes(
440
+ input,
441
+ cb_model,
442
+ list_of_code_infos=(),
443
+ disable_other_comps=False,
444
+ automata=None,
445
+ generate_kwargs=None,
446
+ ):
447
+ """Model's generation with the codebook features patched in."""
448
+ if generate_kwargs is None:
449
+ generate_kwargs = {}
450
+ hook_fns = [
451
+ partial(patch_in_codes, pos=tupl.pos, code=tupl.code)
452
+ for tupl in list_of_code_infos
453
+ ]
454
+ fwd_hooks = [
455
+ (get_cb_layer_name(tupl.cb_at, tupl.layer, tupl.head), hook_fns[i])
456
+ for i, tupl in enumerate(list_of_code_infos)
457
+ ]
458
+ cb_model.reset_hook_kwargs()
459
+ if disable_other_comps:
460
+ for layer, cb in cb_model.all_codebooks.items():
461
+ for head_idx, head in enumerate(cb[0].codebook):
462
+ if not in_hook_list(list_of_code_infos, layer, head_idx):
463
+ head.set_hook_kwargs(
464
+ disable_topk=1, disable_for_tkns=[-1], keep_k_codes=False
465
+ )
466
+ if not in_hook_list(list_of_code_infos, layer):
467
+ cb[1].set_hook_kwargs(
468
+ disable_topk=1, disable_for_tkns=[-1], keep_k_codes=False
469
+ )
470
+ with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model:
471
+ gen = hooked_model.generate(input, **generate_kwargs)
472
+ return automata.seq_to_traj(gen)[0] if automata is not None else gen
473
+
474
+
475
+ def kl_div(logits1, logits2, pos=-1, reduction="batchmean"):
476
+ """Calculate the KL divergence between the logits at `pos`."""
477
+ logits1_last, logits2_last = logits1[:, pos, :], logits2[:, pos, :]
478
+ # calculate kl divergence between clean and mod logits last
479
+ return F.kl_div(
480
+ F.log_softmax(logits1_last, dim=-1),
481
+ F.log_softmax(logits2_last, dim=-1),
482
+ log_target=True,
483
+ reduction=reduction,
484
+ )
485
+
486
+
487
+ def JSD(logits1, logits2, pos=-1, reduction="batchmean"):
488
+ """Compute the Jensen-Shannon divergence between two distributions."""
489
+ if len(logits1.shape) == 3:
490
+ logits1, logits2 = logits1[:, pos, :], logits2[:, pos, :]
491
+
492
+ probs1 = F.softmax(logits1, dim=-1)
493
+ probs2 = F.softmax(logits2, dim=-1)
494
+
495
+ total_m = (0.5 * (probs1 + probs2)).log()
496
+
497
+ loss = 0.0
498
+ loss += F.kl_div(
499
+ total_m,
500
+ F.log_softmax(logits1, dim=-1),
501
+ log_target=True,
502
+ reduction=reduction,
503
+ )
504
+ loss += F.kl_div(
505
+ total_m,
506
+ F.log_softmax(logits2, dim=-1),
507
+ log_target=True,
508
+ reduction=reduction,
509
+ )
510
+ return 0.5 * loss
511
+
512
+
513
+ def residual_stream_patching_hook(resid_pre, hook, cache, position: int):
514
+ """Patch in the codebook features at `position` from `cache`."""
515
+ clean_resid_pre = cache[hook.name]
516
+ resid_pre[:, position, :] = clean_resid_pre[:, position, :]
517
+ return resid_pre
518
+
519
+
520
+ def find_code_changes(cache1, cache2, pos=None):
521
+ """Find the codebook codes that are different between the two caches."""
522
+ for k in cache1.keys():
523
+ if "codebook" in k:
524
+ c1 = cache1[k][0, pos]
525
+ c2 = cache2[k][0, pos]
526
+ if not torch.all(c1 == c2):
527
+ print(cb_layer_name_to_info(k), c1.tolist(), c2.tolist())
528
+ print(cb_layer_name_to_info(k), c1.tolist(), c2.tolist())
529
+
530
+
531
+ def common_codes_in_cache(cache_codes, threshold=0.0):
532
+ """Get the common code in the cache."""
533
+ codes, counts = torch.unique(cache_codes, return_counts=True, sorted=True)
534
+ counts = counts.float() * 100
535
+ counts /= cache_codes.shape[1]
536
+ counts, indices = torch.sort(counts, descending=True)
537
+ codes = codes[indices]
538
+ indices = counts > threshold
539
+ codes, counts = codes[indices], counts[indices]
540
+ return codes, counts
541
+
542
+
543
+ def parse_code_info_string(
544
+ info_str: str, cb_at="attn", pos=None, code_pos=-1
545
+ ) -> CodeInfo:
546
+ """Parse the code info string.
547
+
548
+ The format of the `info_str` is:
549
+ `code: 0, layer: 0, head: 0, occ_freq: 0.0, train_act_freq: 0.0`.
550
+ """
551
+ code, layer, head, occ_freq, train_act_freq = info_str.split(", ")
552
+ code = int(code.split(": ")[1])
553
+ layer = int(layer.split(": ")[1])
554
+ head = int(head.split(": ")[1]) if head else None
555
+ occ_freq = float(occ_freq.split(": ")[1])
556
+ train_act_freq = float(train_act_freq.split(": ")[1])
557
+ return CodeInfo(code, layer, head, pos=pos, code_pos=code_pos, cb_at=cb_at)
558
+
559
+
560
+ def parse_concept_codes_string(info_str: str, pos=None, code_append=False):
561
+ """Parse the concept codes string."""
562
+ code_info_strs = info_str.strip().split("\n")
563
+ concept_codes = []
564
+ layer, head = None, None
565
+ code_pos = "append" if code_append else -1
566
+ for code_info_str in code_info_strs:
567
+ concept_codes.append(
568
+ parse_code_info_string(code_info_str, pos=pos, code_pos=code_pos)
569
+ )
570
+ if code_append:
571
+ continue
572
+ if layer == concept_codes[-1].layer and head == concept_codes[-1].head:
573
+ code_pos -= 1
574
+ else:
575
+ code_pos = -1
576
+ concept_codes[-1].code_pos = code_pos
577
+ layer, head = concept_codes[-1].layer, concept_codes[-1].head
578
+ return concept_codes
webapp_utils.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for running webapp using streamlit."""
2
+
3
+
4
+ import streamlit as st
5
+ from streamlit.components.v1 import html
6
+
7
+ import code_search_utils
8
+ import utils
9
+
10
+ _PERSIST_STATE_KEY = f"{__name__}_PERSIST"
11
+ TOTAL_SAVE_BUTTONS = 0
12
+
13
+
14
+ def persist(key: str) -> str:
15
+ """Mark widget state as persistent."""
16
+ if _PERSIST_STATE_KEY not in st.session_state:
17
+ st.session_state[_PERSIST_STATE_KEY] = set()
18
+
19
+ st.session_state[_PERSIST_STATE_KEY].add(key)
20
+
21
+ return key
22
+
23
+
24
+ def load_widget_state():
25
+ """Load persistent widget state."""
26
+ if _PERSIST_STATE_KEY in st.session_state:
27
+ st.session_state.update(
28
+ {
29
+ key: value
30
+ for key, value in st.session_state.items()
31
+ if key in st.session_state[_PERSIST_STATE_KEY]
32
+ }
33
+ )
34
+
35
+
36
+ @st.cache_resource
37
+ def load_dataset_cache(dataset_cache_path):
38
+ """Load cache files required for dataset from `cache_path`."""
39
+ return code_search_utils.load_dataset_cache(dataset_cache_path)
40
+
41
+
42
+ @st.cache_resource
43
+ def load_code_search_cache(codes_cache_path, dataset_cache_path):
44
+ """Load cache files required for code search from `codes_cache_path`."""
45
+ (
46
+ tokens_str,
47
+ tokens_text,
48
+ token_byte_pos,
49
+ ) = load_dataset_cache(dataset_cache_path)
50
+ (
51
+ cb_acts,
52
+ act_count_ft_tkns,
53
+ metrics,
54
+ ) = code_search_utils.load_code_search_cache(codes_cache_path)
55
+ return tokens_str, tokens_text, token_byte_pos, cb_acts, act_count_ft_tkns, metrics
56
+
57
+
58
+ @st.cache_data(max_entries=100)
59
+ def load_ft_tkns(model_id, layer, head=None, code=None):
60
+ """Load the code-to-token map for a codebook."""
61
+ # model_id required to not mix cache_data for different models
62
+ assert model_id is not None
63
+ cb_at = st.session_state["cb_at"]
64
+ ccb = st.session_state["ccb"]
65
+ cb_acts = st.session_state["cb_acts"]
66
+ if head is not None:
67
+ cb_name = f"layer{layer}_{cb_at}{ccb}{head}"
68
+ else:
69
+ cb_name = f"layer{layer}_{cb_at}"
70
+ return utils.features_to_tokens(
71
+ cb_name,
72
+ cb_acts,
73
+ num_codes=st.session_state["num_codes"],
74
+ code=code,
75
+ )
76
+
77
+
78
+ def get_code_acts(
79
+ model_id,
80
+ tokens_str,
81
+ code,
82
+ layer,
83
+ head=None,
84
+ ctx_size=5,
85
+ num_examples=100,
86
+ return_example_list=False,
87
+ ):
88
+ """Get the token activations for a given code."""
89
+ ft_tkns = load_ft_tkns(model_id, layer, head, code)
90
+ ft_tkns = [ft_tkns]
91
+ _, freqs, acts = utils.print_ft_tkns(
92
+ ft_tkns,
93
+ tokens=tokens_str,
94
+ indices=[0],
95
+ html=True,
96
+ n=ctx_size,
97
+ max_examples=num_examples,
98
+ return_example_list=return_example_list,
99
+ )
100
+ return acts[0], freqs[0]
101
+
102
+
103
+ def set_ct_acts(code, layer, head=None, extra_args=None, is_attn=False):
104
+ """Set the code and layer for the token activations."""
105
+ # convert to int
106
+ code, layer, head = int(code), int(layer), int(head) if head is not None else None
107
+ st.session_state["ct_act_code"] = code
108
+ st.session_state["ct_act_layer"] = layer
109
+ if is_attn:
110
+ st.session_state["ct_act_head"] = head
111
+ st.session_state["filter_codes"] = False
112
+
113
+ my_html = """
114
+ <script>
115
+ document.location.href = "#code-token-activations";
116
+ </script>
117
+ """
118
+ html(my_html, height=0, width=0, scrolling=False)
119
+
120
+
121
+ def find_next_code(code, layer_code_acts, act_range=None):
122
+ """Find the next code that has activations in the given range."""
123
+ if act_range is None:
124
+ return code
125
+ for code_iter, code_act_count in enumerate(layer_code_acts[code:]):
126
+ if code_act_count >= act_range[0] and code_act_count <= act_range[1]:
127
+ code += code_iter
128
+ break
129
+ return code
130
+
131
+
132
+ def escape_markdown(text):
133
+ """Escapes markdown special characters."""
134
+ MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
135
+ for char in MD_SPECIAL_CHARS:
136
+ text = text.replace(char, "\\" + char)
137
+ return text
138
+
139
+
140
+ def add_code_to_demo_file(code_info: utils.CodeInfo, file_path: str):
141
+ """Add code to demo file."""
142
+ # TODO: add check for duplicate code and return False if found
143
+ # TODO: convert saved codes to databases instead of txt files?
144
+ code_info.check_description_info()
145
+ with open(file_path, "a") as f:
146
+ f.write("\n")
147
+ f.write(f"# {code_info.description}:")
148
+ if code_info.regex:
149
+ f.write(f" {code_info.regex}")
150
+ f.write("\n")
151
+ f.write(f"layer: {code_info.layer}")
152
+ f.write(f", head: {code_info.head}" if code_info.head is not None else "")
153
+ f.write(f", code: {code_info.code}")
154
+ if code_info.regex:
155
+ f.write(f", prec: {code_info.prec:.4f}, recall: {code_info.recall:.4f}")
156
+ f.write(f", num_acts: {code_info.num_acts}\n")
157
+ return True
158
+
159
+
160
+ def add_save_code_button(
161
+ demo_file_path: str,
162
+ num_acts: int,
163
+ save_regex: bool = False,
164
+ prec: float = None,
165
+ recall: float = None,
166
+ button_st_container=st,
167
+ button_text: bool = False,
168
+ button_key_suffix: str = "",
169
+ ):
170
+ """Add a button on streamlit to save code to demo codes file."""
171
+ save_button = button_st_container.button(
172
+ "πŸ’Ύ" + (" Save Code to Demos" if button_text else ""),
173
+ key=f"save_code_button{button_key_suffix}",
174
+ help="Save code to demo codes file",
175
+ )
176
+ if save_button:
177
+ description = st.text_input(
178
+ "Write a description for the code",
179
+ key="save_code_desc",
180
+ )
181
+ if not description:
182
+ return
183
+
184
+ description = st.session_state.get("save_code_desc", None)
185
+ if description:
186
+ layer = st.session_state["ct_act_layer"]
187
+ is_attn = st.session_state["is_attn"]
188
+ if is_attn:
189
+ head = st.session_state["ct_act_head"]
190
+ else:
191
+ head = None
192
+
193
+ code = st.session_state["ct_act_code"]
194
+ code_info = utils.CodeInfo(
195
+ layer=layer,
196
+ head=head,
197
+ code=code,
198
+ description=description,
199
+ num_acts=num_acts,
200
+ )
201
+
202
+ if save_regex:
203
+ code_info.regex = st.session_state["regex_pattern"]
204
+ code_info.prec = prec
205
+ code_info.recall = recall
206
+
207
+ saved = add_code_to_demo_file(code_info, demo_file_path)
208
+ if saved:
209
+ st.success("Code saved!", icon="πŸŽ‰")
210
+ st.success("Code saved!", icon="πŸŽ‰")
webapp_utils_full_ft_tkns_for_ts.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for running webapp using streamlit."""
2
+
3
+
4
+ import streamlit as st
5
+ from streamlit.components.v1 import html
6
+
7
+ import code_search_utils
8
+ import utils
9
+
10
+ _PERSIST_STATE_KEY = f"{__name__}_PERSIST"
11
+ TOTAL_SAVE_BUTTONS = 0
12
+
13
+
14
+ def persist(key: str) -> str:
15
+ """Mark widget state as persistent."""
16
+ if _PERSIST_STATE_KEY not in st.session_state:
17
+ st.session_state[_PERSIST_STATE_KEY] = set()
18
+
19
+ st.session_state[_PERSIST_STATE_KEY].add(key)
20
+
21
+ return key
22
+
23
+
24
+ def load_widget_state():
25
+ """Load persistent widget state."""
26
+ if _PERSIST_STATE_KEY in st.session_state:
27
+ st.session_state.update(
28
+ {
29
+ key: value
30
+ for key, value in st.session_state.items()
31
+ if key in st.session_state[_PERSIST_STATE_KEY]
32
+ }
33
+ )
34
+
35
+
36
+ @st.cache_resource
37
+ def load_dataset_cache(dataset_cache_path):
38
+ """Load cache files required for dataset from `cache_path`."""
39
+ return code_search_utils.load_dataset_cache(dataset_cache_path)
40
+
41
+
42
+ @st.cache_resource
43
+ def load_code_search_cache(codes_cache_path, dataset_cache_path):
44
+ """Load cache files required for code search from `codes_cache_path`."""
45
+ (
46
+ tokens_str,
47
+ tokens_text,
48
+ token_byte_pos,
49
+ ) = load_dataset_cache(dataset_cache_path)
50
+ (
51
+ cb_acts,
52
+ act_count_ft_tkns,
53
+ metrics,
54
+ ) = code_search_utils.load_code_search_cache(codes_cache_path)
55
+ return tokens_str, tokens_text, token_byte_pos, cb_acts, act_count_ft_tkns, metrics
56
+
57
+
58
+ @st.cache_data(max_entries=100)
59
+ def load_ft_tkns(model_id, layer, head=None, code=None):
60
+ """Load the code-to-token map for a codebook."""
61
+ # model_id required to not mix cache_data for different models
62
+ assert model_id is not None
63
+ cb_at = st.session_state["cb_at"]
64
+ ccb = st.session_state["ccb"]
65
+ cb_acts = st.session_state["cb_acts"]
66
+ if head is not None:
67
+ cb_name = f"layer{layer}_{cb_at}{ccb}{head}"
68
+ else:
69
+ cb_name = f"layer{layer}_{cb_at}"
70
+ return utils.features_to_tokens(
71
+ cb_name,
72
+ cb_acts,
73
+ num_codes=st.session_state["num_codes"],
74
+ code=code,
75
+ )
76
+
77
+
78
+ def get_code_acts(
79
+ model_id,
80
+ tokens_str,
81
+ code,
82
+ layer,
83
+ head=None,
84
+ ctx_size=5,
85
+ num_examples=100,
86
+ return_example_list=False,
87
+ ):
88
+ """Get the token activations for a given code."""
89
+ code_to_pass = None if "tinystories" in model_id.lower() else code
90
+ ft_tkns = load_ft_tkns(model_id, layer, head, code_to_pass)
91
+ if code_to_pass is not None:
92
+ ft_tkns = [ft_tkns]
93
+ else:
94
+ ft_tkns = ft_tkns[code : code + 1]
95
+ _, freqs, acts = utils.print_ft_tkns(
96
+ ft_tkns,
97
+ tokens=tokens_str,
98
+ indices=[0],
99
+ html=True,
100
+ n=ctx_size,
101
+ max_examples=num_examples,
102
+ return_example_list=return_example_list,
103
+ )
104
+ return acts[0], freqs[0]
105
+
106
+
107
+ def set_ct_acts(code, layer, head=None, extra_args=None, is_attn=False):
108
+ """Set the code and layer for the token activations."""
109
+ # convert to int
110
+ code, layer, head = int(code), int(layer), int(head) if head is not None else None
111
+ st.session_state["ct_act_code"] = code
112
+ st.session_state["ct_act_layer"] = layer
113
+ if is_attn:
114
+ st.session_state["ct_act_head"] = head
115
+ st.session_state["filter_codes"] = False
116
+
117
+ info_txt = (
118
+ f"layer: {layer},{f' head: {head},' if head is not None else ''} code: {code}"
119
+ )
120
+ if extra_args:
121
+ for k, v in extra_args.items():
122
+ info_txt += f", {k}: {v}"
123
+ my_html = f"""
124
+ <script>
125
+ async function myF() {{
126
+ await new Promise(r => setTimeout(r, 10));
127
+ const textarea = document.createElement("textarea");
128
+ textarea.textContent = "{info_txt}";
129
+ document.body.appendChild(textarea);
130
+ textarea.select();
131
+ document.execCommand("copy");
132
+ document.body.removeChild(textarea);
133
+ }}
134
+ myF();
135
+ window.location.hash = "code-token-activations";
136
+ console.log(window.location.hash)
137
+ </script>
138
+ """
139
+ html(my_html, height=0, width=0, scrolling=False)
140
+
141
+
142
+ def find_next_code(code, layer_code_acts, act_range=None):
143
+ """Find the next code that has activations in the given range."""
144
+ # code = st.session_state["ct_act_code"]
145
+ if act_range is None:
146
+ return code
147
+ for code_iter, code_act_count in enumerate(layer_code_acts[code:]):
148
+ if code_act_count >= act_range[0] and code_act_count <= act_range[1]:
149
+ code += code_iter
150
+ # st.session_state["ct_act_code"] = code
151
+ break
152
+ return code
153
+
154
+
155
+ def escape_markdown(text):
156
+ """Escapes markdown special characters."""
157
+ MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
158
+ for char in MD_SPECIAL_CHARS:
159
+ text = text.replace(char, "\\" + char)
160
+ return text
161
+
162
+
163
+ def add_code_to_demo_file(code_info: utils.CodeInfo, file_path: str):
164
+ """Add code to demo file."""
165
+ # TODO: add check for duplicate code and return False if found
166
+ # TODO: convert saved codes to databases instead of txt files?
167
+ code_info.check_description_info()
168
+ with open(file_path, "a") as f:
169
+ f.write("\n")
170
+ f.write(f"# {code_info.description}:")
171
+ if code_info.regex:
172
+ f.write(f" {code_info.regex}")
173
+ f.write("\n")
174
+ f.write(f"layer: {code_info.layer}")
175
+ f.write(f", head: {code_info.head}" if code_info.head is not None else "")
176
+ f.write(f", code: {code_info.code}")
177
+ if code_info.regex:
178
+ f.write(f", prec: {code_info.prec:.4f}, recall: {code_info.recall:.4f}")
179
+ f.write(f", num_acts: {code_info.num_acts}\n")
180
+ return True
181
+
182
+
183
+ def add_save_code_button(
184
+ demo_file_path: str,
185
+ num_acts: int,
186
+ save_regex: bool = False,
187
+ prec: float = None,
188
+ recall: float = None,
189
+ button_st_container=st,
190
+ button_text: bool = False,
191
+ button_key_suffix: str = "",
192
+ ):
193
+ """Add a button on streamlit to save code to demo codes file."""
194
+ save_button = button_st_container.button(
195
+ "πŸ’Ύ" + (" Save Code to Demos" if button_text else ""),
196
+ key=f"save_code_button{button_key_suffix}",
197
+ help="Save code to demo codes file",
198
+ )
199
+ if save_button:
200
+ description = st.text_input(
201
+ "Write a description for the code",
202
+ key="save_code_desc",
203
+ )
204
+ if not description:
205
+ return
206
+
207
+ description = st.session_state.get("save_code_desc", None)
208
+ if description:
209
+ layer = st.session_state["ct_act_layer"]
210
+ is_attn = st.session_state["is_attn"]
211
+ if is_attn:
212
+ head = st.session_state["ct_act_head"]
213
+ else:
214
+ head = None
215
+
216
+ code = st.session_state["ct_act_code"]
217
+ code_info = utils.CodeInfo(
218
+ layer=layer,
219
+ head=head,
220
+ code=code,
221
+ description=description,
222
+ num_acts=num_acts,
223
+ )
224
+
225
+ if save_regex:
226
+ code_info.regex = st.session_state["regex_pattern"]
227
+ code_info.prec = prec
228
+ code_info.recall = recall
229
+
230
+ saved = add_code_to_demo_file(code_info, demo_file_path)
231
+ if saved:
232
+ st.success("Code saved!", icon="πŸŽ‰")
233
+ st.success("Code saved!", icon="πŸŽ‰")