taufeeque commited on
Commit
63b5bc1
β€’
1 Parent(s): b2a4148

Update code

Browse files
Files changed (7) hide show
  1. .gitignore +3 -0
  2. Code_Browser.py +180 -140
  3. README.md +2 -2
  4. code_search_utils.py +201 -97
  5. pages/Concept_Code.py +5 -17
  6. utils.py +187 -232
  7. webapp_utils.py +21 -9
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ hgf_webapp/
3
+ .vscode/
Code_Browser.py CHANGED
@@ -1,15 +1,38 @@
1
  """Web App for the Codebook Features project."""
2
 
 
3
  import glob
4
  import os
5
 
6
  import streamlit as st
7
 
8
  import code_search_utils
 
9
  import webapp_utils
10
 
11
- DEPLOY_MODE = True
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  webapp_utils.load_widget_state()
15
 
@@ -20,14 +43,17 @@ st.set_page_config(
20
 
21
  st.title("Codebook Features")
22
 
 
 
23
  pretty_model_names = {
24
  "TinyStories-1Layer-21M#100ksteps_vcb_mlp": "TinyStories-1L-21M-MLP",
25
- "TinyStories-1Layer-21M_ccb_attn_preproj": "TinyStories-1L-21M-Attn",
26
- "TinyStories-33M_ccb_attn_preproj": "TinyStories-4L-33M-Attn",
 
27
  }
28
  orig_model_name = {v: k for k, v in pretty_model_names.items()}
29
 
30
- base_cache_dir = "cache/"
31
  dirs = glob.glob(base_cache_dir + "models/*/")
32
  model_name_options = [d.split("/")[-2].split("_")[:-2] for d in dirs]
33
  model_name_options = ["_".join(m) for m in model_name_options]
@@ -41,25 +67,23 @@ p_model_name = st.selectbox(
41
  key=webapp_utils.persist("model_name"),
42
  )
43
  model_name = orig_model_name.get(p_model_name, p_model_name)
44
- model = model_name.split("_")[0].split("#")[0]
45
- ccb = model_name.split("_")[1]
46
- ccb = "_ccb" if ccb == "ccb" else ""
47
- cb_at = "_".join(model_name.split("_")[2:])
48
- seq_len = 512 if "tinystories" in model_name.lower() else 1024
49
- st.session_state["seq_len"] = seq_len
50
 
51
  codes_cache_path = base_cache_dir + f"models/{model_name}_*"
52
  dirs = glob.glob(codes_cache_path)
53
  dirs.sort(key=os.path.getmtime)
54
 
55
  # session states
56
- is_attn = "attn" in cb_at
57
  codes_cache_path = dirs[-1] + "/"
58
 
59
- model_info = code_search_utils.parse_model_info(codes_cache_path)
60
  num_codes = model_info.num_codes
61
  num_layers = model_info.n_layers
62
  num_heads = model_info.n_heads
 
 
 
 
63
  dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/"
64
 
65
  (
@@ -70,9 +94,12 @@ dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/"
70
  act_count_ft_tkns,
71
  metrics,
72
  ) = webapp_utils.load_code_search_cache(codes_cache_path, dataset_cache_path)
 
73
  metric_keys = ["eval_loss", "eval_accuracy", "eval_dead_code_fraction"]
74
  metrics = {k: v for k, v in metrics.items() if k.split("/")[0] in metric_keys}
75
 
 
 
76
  st.session_state["model_name_id"] = model_name
77
  st.session_state["cb_acts"] = cb_acts
78
  st.session_state["tokens_text"] = tokens_text
@@ -80,11 +107,13 @@ st.session_state["tokens_str"] = tokens_str
80
  st.session_state["act_count_ft_tkns"] = act_count_ft_tkns
81
 
82
  st.session_state["num_codes"] = num_codes
83
- st.session_state["ccb"] = ccb
84
  st.session_state["cb_at"] = cb_at
85
  st.session_state["is_attn"] = is_attn
 
86
 
87
- if not DEPLOY_MODE:
 
88
  st.markdown("## Metrics")
89
  # hide metrics by default
90
  if st.checkbox("Show Model Metrics"):
@@ -93,7 +122,7 @@ if not DEPLOY_MODE:
93
  st.markdown("## Demo Codes")
94
  demo_codes_desc = (
95
  "This section contains codes that we've found to be interpretable along "
96
- "with a description of the feature we think they are capturing."
97
  "Click on the πŸ” search button for a code to see the tokens that code activates on."
98
  )
99
  st.write(demo_codes_desc)
@@ -144,7 +173,7 @@ if st.checkbox("Show Demo Codes"):
144
  continue
145
  if skip:
146
  continue
147
- code_info = code_search_utils.get_code_info_pr_from_str(code_txt, code_regex)
148
  comp_info = f"layer{code_info.layer}_{f'head{code_info.head}' if code_info.head is not None else ''}"
149
  button_key = (
150
  f"demo_search_code{code_info.code}_layer{code_info.layer}_desc-{code_info.description}"
@@ -167,150 +196,160 @@ if st.checkbox("Show Demo Codes"):
167
  cols[-1].write(code_desc)
168
  skip = True
169
 
 
170
 
171
  st.markdown("## Code Search")
172
-
173
- regex_pattern = st.text_input(
174
- "Enter a regex pattern",
175
- help="Wrap code token in the first group. E.g. New (York)",
176
- key="regex_pattern",
 
 
177
  )
178
- # topk = st.slider("Top K", 1, 20, 10)
179
- prec_col, sort_col = st.columns(2)
180
- prec_threshold = prec_col.slider(
181
- "Precision Threshold",
182
- 0.0,
183
- 1.0,
184
- 0.9,
185
- help="Shows codes with precision on the regex pattern above the threshold.",
186
- )
187
- sort_by_options = ["Precision", "Recall", "Num Acts"]
188
- sort_by_name = sort_col.radio(
189
- "Sort By",
190
- sort_by_options,
191
- index=0,
192
- horizontal=True,
193
- help="Sorts the codes by the selected metric.",
194
- )
195
- sort_by = sort_by_options.index(sort_by_name)
196
-
197
-
198
- @st.cache_data(ttl=3600)
199
- def get_codebook_wise_codes_for_regex(regex_pattern, prec_threshold, ccb, model_name):
200
- """Get codebook wise codes for a given regex pattern."""
201
- assert model_name is not None # required for loading from correct cache data
202
- return code_search_utils.get_codes_from_pattern(
203
- regex_pattern,
204
- tokens_text,
205
- token_byte_pos,
206
- cb_acts,
207
- act_count_ft_tkns,
208
- ccb=ccb,
209
- topk=8,
210
- prec_threshold=prec_threshold,
211
- )
212
 
213
-
214
- if regex_pattern:
215
- codebook_wise_codes, re_token_matches = get_codebook_wise_codes_for_regex(
216
- regex_pattern,
217
- prec_threshold,
218
- ccb,
219
- model_name,
220
  )
221
- st.markdown(
222
- f"Found <span style='color:green;'>{re_token_matches}</span> matches",
223
- unsafe_allow_html=True,
 
 
 
 
 
224
  )
225
- num_search_cols = 7 if is_attn else 6
226
- non_deploy_offset = 0
227
- if not DEPLOY_MODE:
228
- non_deploy_offset = 1
229
- num_search_cols += non_deploy_offset
230
-
231
- cols = st.columns(num_search_cols)
232
-
233
- # st.markdown(button_height_style, unsafe_allow_html=True)
234
-
235
- cols[0].markdown("Search", help="Button to see token activations for the code.")
236
- cols[1].write("Layer")
237
- if is_attn:
238
- cols[2].write("Head")
239
- cols[-4 - non_deploy_offset].write("Code")
240
- cols[-3 - non_deploy_offset].write("Precision")
241
- cols[-2 - non_deploy_offset].write("Recall")
242
- cols[-1 - non_deploy_offset].markdown(
243
- "Num Acts",
244
- help="Number of tokens that the code activates on in the acts dataset.",
245
  )
246
- if not DEPLOY_MODE:
247
- cols[-1].markdown(
248
- "Save to Demos",
249
- help="Button to save the code to demos along with the regex pattern.",
250
- )
251
- all_codes = codebook_wise_codes.items()
252
- all_codes = [
253
- (cb_name, code_pr_info)
254
- for cb_name, code_pr_infos in all_codes
255
- for code_pr_info in code_pr_infos
256
- ]
257
- all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True)
258
- for cb_name, (code, prec, rec, code_acts) in all_codes:
259
- layer_head = cb_name.split("_")
260
- layer = layer_head[0][5:]
261
- head = layer_head[1][4:] if len(layer_head) > 1 else None
262
- button_key = f"search_code{code}_layer{layer}" + (
263
- f"head{head}" if head is not None else ""
264
  )
265
- cols = st.columns(num_search_cols)
266
- extra_args = {
267
- "prec": prec,
268
- "recall": rec,
269
- "num_acts": code_acts,
270
- "regex": regex_pattern,
271
- }
272
- button_clicked = cols[0].button("πŸ”", key=button_key)
273
- if button_clicked:
274
- webapp_utils.set_ct_acts(code, layer, head, extra_args, is_attn)
275
- cols[1].write(layer)
276
- if is_attn:
277
- cols[2].write(head)
278
- cols[-4 - non_deploy_offset].write(code)
279
- cols[-3 - non_deploy_offset].write(f"{prec*100:.2f}%")
280
- cols[-2 - non_deploy_offset].write(f"{rec*100:.2f}%")
281
- cols[-1 - non_deploy_offset].write(str(code_acts))
282
- if not DEPLOY_MODE:
283
- webapp_utils.add_save_code_button(
284
- demo_file_path,
285
- num_acts=code_acts,
286
- save_regex=True,
287
- prec=prec,
288
- recall=rec,
289
- button_st_container=cols[-1],
290
- button_key_suffix=f"_code{code}_layer{layer}_head{head}",
291
- )
292
 
293
- if len(all_codes) == 0:
 
 
 
 
 
 
294
  st.markdown(
295
- f"""
296
- <div style="font-size: 1.0rem; color: red;">
297
- No codes found for pattern {regex_pattern} at precision threshold: {prec_threshold}
298
- </div>
299
- """,
300
  unsafe_allow_html=True,
301
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
 
303
 
304
  st.markdown("## Code Token Activations")
305
 
306
- filter_codes = st.checkbox("Show filters", key="filter_codes")
307
  act_range, layer_code_acts = None, None
308
  if filter_codes:
309
  act_range = st.slider(
310
- "Num Acts",
311
  0,
312
  10_000,
313
- (100, 10_000),
314
  key="ct_act_range",
315
  help="Filter codes by the number of tokens they activate on.",
316
  )
@@ -361,6 +400,7 @@ acts, acts_count = webapp_utils.get_code_acts(
361
  head,
362
  ctx_size,
363
  num_examples,
 
364
  )
365
 
366
  st.write(
@@ -368,7 +408,7 @@ st.write(
368
  f"Activates on {acts_count[0]} tokens on the acts dataset",
369
  )
370
 
371
- if not DEPLOY_MODE:
372
  webapp_utils.add_save_code_button(
373
  demo_file_path,
374
  acts_count[0],
 
1
  """Web App for the Codebook Features project."""
2
 
3
+ import argparse
4
  import glob
5
  import os
6
 
7
  import streamlit as st
8
 
9
  import code_search_utils
10
+ import utils
11
  import webapp_utils
12
 
13
+ # --- Parse command line arguments ---
14
 
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ "--deploy",
18
+ default=True,
19
+ help="Deploy mode.",
20
+ )
21
+ parser.add_argument(
22
+ "--cache_dir",
23
+ type=str,
24
+ default="cache/",
25
+ help="Path to directory containing cache for codebook models.",
26
+ )
27
+ try:
28
+ args = parser.parse_args()
29
+ except SystemExit as e:
30
+ # This exception will be raised if --help or invalid command line arguments
31
+ # are used. Currently streamlit prevents the program from exiting normally
32
+ # so we have to do a hard exit.
33
+ os._exit(e.code if isinstance(e.code, int) else 1)
34
+
35
+ deploy = args.deploy
36
 
37
  webapp_utils.load_widget_state()
38
 
 
43
 
44
  st.title("Codebook Features")
45
 
46
+ # --- Load model info and cache ---
47
+
48
  pretty_model_names = {
49
  "TinyStories-1Layer-21M#100ksteps_vcb_mlp": "TinyStories-1L-21M-MLP",
50
+ "TinyStories-1Layer-21M_ccb_attn_preproj": "TinyStories 1 Layer Attention Codebook",
51
+ "TinyStories-33M_ccb_attn_preproj": "TinyStories 4 Layer Attention Codebook",
52
+ "TinyStories-1Layer-21M_vcb_mlp": "TinyStories 1 Layer MLP Codebook",
53
  }
54
  orig_model_name = {v: k for k, v in pretty_model_names.items()}
55
 
56
+ base_cache_dir = args.cache_dir
57
  dirs = glob.glob(base_cache_dir + "models/*/")
58
  model_name_options = [d.split("/")[-2].split("_")[:-2] for d in dirs]
59
  model_name_options = ["_".join(m) for m in model_name_options]
 
67
  key=webapp_utils.persist("model_name"),
68
  )
69
  model_name = orig_model_name.get(p_model_name, p_model_name)
70
+ is_fsm = "FSM" in p_model_name
 
 
 
 
 
71
 
72
  codes_cache_path = base_cache_dir + f"models/{model_name}_*"
73
  dirs = glob.glob(codes_cache_path)
74
  dirs.sort(key=os.path.getmtime)
75
 
76
  # session states
 
77
  codes_cache_path = dirs[-1] + "/"
78
 
79
+ model_info = utils.ModelInfoForWebapp.load(codes_cache_path)
80
  num_codes = model_info.num_codes
81
  num_layers = model_info.n_layers
82
  num_heads = model_info.n_heads
83
+ cb_at = model_info.cb_at
84
+ gcb = model_info.gcb
85
+ gcb = "_gcb" if gcb else ""
86
+ is_attn = "attn" in cb_at
87
  dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/"
88
 
89
  (
 
94
  act_count_ft_tkns,
95
  metrics,
96
  ) = webapp_utils.load_code_search_cache(codes_cache_path, dataset_cache_path)
97
+ seq_len = len(tokens_str[0])
98
  metric_keys = ["eval_loss", "eval_accuracy", "eval_dead_code_fraction"]
99
  metrics = {k: v for k, v in metrics.items() if k.split("/")[0] in metric_keys}
100
 
101
+ # --- Set the session states ---
102
+
103
  st.session_state["model_name_id"] = model_name
104
  st.session_state["cb_acts"] = cb_acts
105
  st.session_state["tokens_text"] = tokens_text
 
107
  st.session_state["act_count_ft_tkns"] = act_count_ft_tkns
108
 
109
  st.session_state["num_codes"] = num_codes
110
+ st.session_state["gcb"] = gcb
111
  st.session_state["cb_at"] = cb_at
112
  st.session_state["is_attn"] = is_attn
113
+ st.session_state["seq_len"] = seq_len
114
 
115
+
116
+ if not deploy:
117
  st.markdown("## Metrics")
118
  # hide metrics by default
119
  if st.checkbox("Show Model Metrics"):
 
122
  st.markdown("## Demo Codes")
123
  demo_codes_desc = (
124
  "This section contains codes that we've found to be interpretable along "
125
+ "with a description of the feature we think they are capturing. "
126
  "Click on the πŸ” search button for a code to see the tokens that code activates on."
127
  )
128
  st.write(demo_codes_desc)
 
173
  continue
174
  if skip:
175
  continue
176
+ code_info = utils.CodeInfo.from_str(code_txt, regex=code_regex)
177
  comp_info = f"layer{code_info.layer}_{f'head{code_info.head}' if code_info.head is not None else ''}"
178
  button_key = (
179
  f"demo_search_code{code_info.code}_layer{code_info.layer}_desc-{code_info.description}"
 
196
  cols[-1].write(code_desc)
197
  skip = True
198
 
199
+ # --- Code Search ---
200
 
201
  st.markdown("## Code Search")
202
+ code_search_desc = (
203
+ "If you want to find whether the codebooks model has captured a relevant features from the data,"
204
+ " you can specify a regex pattern for your feature and find whether any code activating on the regex pattern"
205
+ " exists. The first group in the regex pattern is the token that the code activates on. If the group contains"
206
+ " multiple tokens, we search for codes that will activate on the first token in the group followed by the"
207
+ " subsequent tokens in the group. For example, the search term 'New (York)' will try to find codes that"
208
+ " activate on the bigram feature 'New York' at the York token."
209
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ if st.checkbox("Search with Regex"):
212
+ st.write(code_search_desc)
213
+ regex_pattern = st.text_input(
214
+ "Enter a regex pattern",
215
+ help="Wrap code token in the first group. E.g. New (York)",
216
+ key="regex_pattern",
 
217
  )
218
+ # topk = st.slider("Top K", 1, 20, 10)
219
+ prec_col, sort_col = st.columns(2)
220
+ prec_threshold = prec_col.slider(
221
+ "Precision Threshold",
222
+ 0.0,
223
+ 1.0,
224
+ 0.9,
225
+ help="Shows codes with precision on the regex pattern above the threshold.",
226
  )
227
+ sort_by_options = ["Precision", "Recall", "Num Acts"]
228
+ sort_by_name = sort_col.radio(
229
+ "Sort By",
230
+ sort_by_options,
231
+ index=0,
232
+ horizontal=True,
233
+ help="Sorts the codes by the selected metric.",
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  )
235
+ sort_by = sort_by_options.index(sort_by_name)
236
+
237
+ @st.cache_data(ttl=3600)
238
+ def get_codebook_wise_codes_for_regex(
239
+ regex_pattern, prec_threshold, gcb, model_name
240
+ ):
241
+ """Get codebook wise codes for a given regex pattern."""
242
+ assert model_name is not None # required for loading from correct cache data
243
+ return code_search_utils.get_codes_from_pattern(
244
+ regex_pattern,
245
+ tokens_text,
246
+ token_byte_pos,
247
+ cb_acts,
248
+ act_count_ft_tkns,
249
+ gcb=gcb,
250
+ topk=8,
251
+ prec_threshold=prec_threshold,
 
252
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ if regex_pattern:
255
+ codebook_wise_codes, re_token_matches = get_codebook_wise_codes_for_regex(
256
+ regex_pattern,
257
+ prec_threshold,
258
+ gcb,
259
+ model_name,
260
+ )
261
  st.markdown(
262
+ f"Found <span style='color:green;'>{re_token_matches}</span> matches",
 
 
 
 
263
  unsafe_allow_html=True,
264
  )
265
+ num_search_cols = 7 if is_attn else 6
266
+ non_deploy_offset = 0
267
+ if not deploy:
268
+ non_deploy_offset = 1
269
+ num_search_cols += non_deploy_offset
270
+
271
+ cols = st.columns(num_search_cols)
272
+
273
+ cols[0].markdown("Search", help="Button to see token activations for the code.")
274
+ cols[1].write("Layer")
275
+ if is_attn:
276
+ cols[2].write("Head")
277
+ cols[-4 - non_deploy_offset].write("Code")
278
+ cols[-3 - non_deploy_offset].write("Precision")
279
+ cols[-2 - non_deploy_offset].write("Recall")
280
+ cols[-1 - non_deploy_offset].markdown(
281
+ "Num Acts",
282
+ help="Number of tokens that the code activates on in the acts dataset.",
283
+ )
284
+ if not deploy:
285
+ cols[-1].markdown(
286
+ "Save to Demos",
287
+ help="Button to save the code to demos along with the regex pattern.",
288
+ )
289
+ all_codes = codebook_wise_codes.items()
290
+ all_codes = [
291
+ (cb_name, code_pr_info)
292
+ for cb_name, code_pr_infos in all_codes
293
+ for code_pr_info in code_pr_infos
294
+ ]
295
+ all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True)
296
+ for cb_name, (code, prec, rec, code_acts) in all_codes:
297
+ layer_head = cb_name.split("_")
298
+ layer = layer_head[0][5:]
299
+ head = layer_head[1][4:] if len(layer_head) > 1 else None
300
+ button_key = f"search_code{code}_layer{layer}" + (
301
+ f"head{head}" if head is not None else ""
302
+ )
303
+ cols = st.columns(num_search_cols)
304
+ extra_args = {
305
+ "prec": prec,
306
+ "recall": rec,
307
+ "num_acts": code_acts,
308
+ "regex": regex_pattern,
309
+ }
310
+ button_clicked = cols[0].button("πŸ”", key=button_key)
311
+ if button_clicked:
312
+ webapp_utils.set_ct_acts(code, layer, head, extra_args, is_attn)
313
+ cols[1].write(layer)
314
+ if is_attn:
315
+ cols[2].write(head)
316
+ cols[-4 - non_deploy_offset].write(code)
317
+ cols[-3 - non_deploy_offset].write(f"{prec*100:.2f}%")
318
+ cols[-2 - non_deploy_offset].write(f"{rec*100:.2f}%")
319
+ cols[-1 - non_deploy_offset].write(str(code_acts))
320
+ if not deploy:
321
+ webapp_utils.add_save_code_button(
322
+ demo_file_path,
323
+ num_acts=code_acts,
324
+ save_regex=True,
325
+ prec=prec,
326
+ recall=rec,
327
+ button_st_container=cols[-1],
328
+ button_key_suffix=f"_code{code}_layer{layer}_head{head}",
329
+ )
330
+
331
+ if len(all_codes) == 0:
332
+ st.markdown(
333
+ f"""
334
+ <div style="font-size: 1.0rem; color: red;">
335
+ No codes found for pattern {regex_pattern} at precision threshold: {prec_threshold}
336
+ </div>
337
+ """,
338
+ unsafe_allow_html=True,
339
+ )
340
 
341
+ # --- Display Code Token Activations ---
342
 
343
  st.markdown("## Code Token Activations")
344
 
345
+ filter_codes = st.checkbox("Show filters", key="filter_codes", value=True)
346
  act_range, layer_code_acts = None, None
347
  if filter_codes:
348
  act_range = st.slider(
349
+ "Minimum number of activations",
350
  0,
351
  10_000,
352
+ 100,
353
  key="ct_act_range",
354
  help="Filter codes by the number of tokens they activate on.",
355
  )
 
400
  head,
401
  ctx_size,
402
  num_examples,
403
+ is_fsm=is_fsm,
404
  )
405
 
406
  st.write(
 
408
  f"Activates on {acts_count[0]} tokens on the acts dataset",
409
  )
410
 
411
+ if not deploy:
412
  webapp_utils.add_save_code_button(
413
  demo_file_path,
414
  acts_count[0],
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Codebook Features
3
- emoji: πŸ‘€
4
  colorFrom: gray
5
- colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.25.0
8
  app_file: Code_Browser.py
 
1
  ---
2
  title: Codebook Features
3
+ emoji: πŸ“š
4
  colorFrom: gray
5
+ colorTo: blue
6
  sdk: streamlit
7
  sdk_version: 1.25.0
8
  app_file: Code_Browser.py
code_search_utils.py CHANGED
@@ -2,15 +2,11 @@
2
 
3
  import pickle
4
  import re
5
- from dataclasses import dataclass
6
- from typing import Optional
7
 
8
  import numpy as np
9
  import torch
10
  from tqdm import tqdm
11
 
12
- import utils
13
-
14
 
15
  def load_dataset_cache(cache_base_path):
16
  """Load cache files required for dataset from `cache_base_path`."""
@@ -31,28 +27,73 @@ def load_code_search_cache(cache_base_path):
31
  return cb_acts, act_count_ft_tkns, metrics
32
 
33
 
34
- def search_re(re_pattern, tokens_text):
35
- """Get list of (example_id, token_pos) where re_pattern matches in tokens_text."""
36
- # TODO: ensure that parantheses are not escaped
 
 
 
 
 
 
 
 
 
 
 
 
37
  if re_pattern.find("(") == -1:
38
  re_pattern = f"({re_pattern})"
39
- return [
40
  (i, finditer.span(1)[0])
41
  for i, text in enumerate(tokens_text)
42
  for finditer in re.finditer(re_pattern, text)
43
  if finditer.span(1)[0] != finditer.span(1)[1]
44
  ]
 
 
 
45
 
46
 
47
  def byte_id_to_token_pos_id(example_byte_id, token_byte_pos):
48
- """Get (example_id, token_pos_id) for given (example_id, byte_id)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  example_id, byte_id = example_byte_id
50
  index = np.searchsorted(token_byte_pos[example_id], byte_id, side="right")
51
  return (example_id, index)
52
 
53
 
54
- def get_code_pr(token_pos_ids, codebook_acts, cb_act_counts=None):
55
- """Get codes, prec, recall for given token_pos_ids and codebook_acts."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  codes = np.array(
57
  [
58
  codebook_acts[example_id][token_pos_id]
@@ -76,46 +117,64 @@ def get_code_pr(token_pos_ids, codebook_acts, cb_act_counts=None):
76
  return codes, prec, recall, code_acts
77
 
78
 
79
- def get_neuron_pr(
80
- token_pos_ids, recall, neuron_acts_by_ex, neuron_sorted_acts, topk=10
81
  ):
82
- """Get codes, prec, recall for given token_pos_ids and codebook_acts."""
83
- # check if neuron_acts_by_ex is a torch tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  if isinstance(neuron_acts_by_ex, torch.Tensor):
85
- re_neuron_acts = torch.stack(
86
  [
87
  neuron_acts_by_ex[example_id, token_pos_id]
88
  for example_id, token_pos_id in token_pos_ids
89
  ],
90
  dim=-1,
91
  ) # (layers, 2, dim_size, matches)
92
- re_neuron_acts = torch.sort(re_neuron_acts, dim=-1).values
93
  else:
94
- re_neuron_acts = np.stack(
95
  [
96
  neuron_acts_by_ex[example_id, token_pos_id]
97
  for example_id, token_pos_id in token_pos_ids
98
  ],
99
  axis=-1,
100
  ) # (layers, 2, dim_size, matches)
101
- re_neuron_acts.sort(axis=-1)
102
- re_neuron_acts = torch.from_numpy(re_neuron_acts)
103
- # re_neuron_acts = re_neuron_acts[:, :, :, -int(recall * re_neuron_acts.shape[-1]) :]
104
- print("Examples for recall", recall, ":", int(recall * re_neuron_acts.shape[-1]))
105
- act_thresh = re_neuron_acts[:, :, :, -int(recall * re_neuron_acts.shape[-1])]
106
- # binary search act_thresh in neuron_sorted_acts
107
  assert neuron_sorted_acts.shape[:-1] == act_thresh.shape
108
  prec_den = torch.searchsorted(neuron_sorted_acts, act_thresh.unsqueeze(-1))
109
  prec_den = prec_den.squeeze(-1)
110
  prec_den = neuron_sorted_acts.shape[-1] - prec_den
111
- prec = int(recall * re_neuron_acts.shape[-1]) / prec_den
112
  assert (
113
- prec.shape == re_neuron_acts.shape[:-1]
114
- ), f"{prec.shape} != {re_neuron_acts.shape[:-1]}"
115
 
116
  best_neuron_idx = np.unravel_index(prec.argmax(), prec.shape)
117
  best_prec = prec[best_neuron_idx]
118
- print("max prec:", best_prec)
119
  best_neuron_act_thresh = act_thresh[best_neuron_idx].item()
120
  best_neuron_acts = neuron_acts_by_ex[
121
  :, :, best_neuron_idx[0], best_neuron_idx[1], best_neuron_idx[2]
@@ -126,20 +185,20 @@ def get_neuron_pr(
126
  return best_prec, best_neuron_acts, best_neuron_idx
127
 
128
 
129
- def convert_to_adv_name(name, cb_at, ccb=""):
130
- """Convert layer0_head0 to layer0_attn_preproj_ccb0."""
131
- if ccb:
132
  layer, head = name.split("_")
133
- return layer + f"_{cb_at}_ccb" + head[4:]
134
  else:
135
  return layer + "_" + cb_at
136
 
137
 
138
- def convert_to_base_name(name, ccb=""):
139
- """Convert layer0_attn_preproj_ccb0 to layer0_head0."""
140
  split_name = name.split("_")
141
  layer, head = split_name[0], split_name[-1][3:]
142
- if "ccb" in name:
143
  return layer + "_head" + head
144
  else:
145
  return layer
@@ -156,7 +215,7 @@ def get_layer_head_from_base_name(name):
156
 
157
 
158
  def get_layer_head_from_adv_name(name):
159
- """Convert layer0_attn_preproj_ccb0 to 0, 0."""
160
  base_name = convert_to_base_name(name)
161
  layer, head = get_layer_head_from_base_name(base_name)
162
  return layer, head
@@ -168,12 +227,39 @@ def get_codes_from_pattern(
168
  token_byte_pos,
169
  cb_acts,
170
  act_count_ft_tkns,
171
- ccb="",
172
  topk=5,
173
  prec_threshold=0.5,
 
174
  ):
175
- """Fetch codes from a given regex pattern."""
176
- byte_ids = search_re(re_pattern, tokens_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  token_pos_ids = [
178
  byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
179
  ]
@@ -181,8 +267,8 @@ def get_codes_from_pattern(
181
  re_token_matches = len(token_pos_ids)
182
  codebook_wise_codes = {}
183
  for cb_name, cb in tqdm(cb_acts.items()):
184
- base_cb_name = convert_to_base_name(cb_name, ccb=ccb)
185
- codes, prec, recall, code_acts = get_code_pr(
186
  token_pos_ids,
187
  cb,
188
  cb_act_counts=act_count_ft_tkns[base_cb_name],
@@ -203,15 +289,49 @@ def get_neurons_from_pattern(
203
  neuron_acts_by_ex,
204
  neuron_sorted_acts,
205
  recall_threshold,
 
206
  ):
207
- """Fetch the best neuron (with act thresh given by recall) from a given regex pattern."""
208
- byte_ids = search_re(re_pattern, tokens_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  token_pos_ids = [
210
  byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
211
  ]
212
  token_pos_ids = np.unique(token_pos_ids, axis=0)
213
  re_token_matches = len(token_pos_ids)
214
- best_prec, best_neuron_acts, best_neuron_idx = get_neuron_pr(
215
  token_pos_ids,
216
  recall_threshold,
217
  neuron_acts_by_ex,
@@ -226,74 +346,58 @@ def compare_codes_with_neurons(
226
  token_byte_pos,
227
  neuron_acts_by_ex,
228
  neuron_sorted_acts,
 
229
  ):
230
- """Compare codes with neurons."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  assert isinstance(neuron_acts_by_ex, np.ndarray)
232
  (
233
- all_best_prec,
234
  all_best_neuron_acts,
235
  all_best_neuron_idxs,
236
  all_re_token_matches,
237
  ) = zip(
238
  *[
239
  get_neurons_from_pattern(
240
- code_info.re_pattern,
241
  tokens_text,
242
  token_byte_pos,
243
  neuron_acts_by_ex,
244
  neuron_sorted_acts,
245
  code_info.recall,
 
246
  )
247
- for code_info in tqdm(range(len(best_codes_info)))
248
  ],
249
  strict=True,
250
  )
251
- code_best_precs = np.array(
252
- [code_info.prec for code_info in range(len(best_codes_info))]
253
- )
254
- codes_better_than_neurons = code_best_precs > np.array(all_best_prec)
255
- return codes_better_than_neurons.mean()
256
-
257
-
258
- def get_code_info_pr_from_str(code_txt, regex):
259
- """Extract code info fields from string."""
260
- code_txt = code_txt.strip()
261
- code_txt = code_txt.split(", ")
262
- code_txt = dict(txt.split(": ") for txt in code_txt)
263
- return utils.CodeInfo(**code_txt)
264
-
265
-
266
- @dataclass
267
- class ModelInfoForWebapp:
268
- """Model info for webapp."""
269
-
270
- model_name: str
271
- pretrained_path: str
272
- dataset_name: str
273
- num_codes: int
274
- cb_at: str
275
- ccb: str
276
- n_layers: int
277
- n_heads: Optional[int] = None
278
- seed: int = 42
279
- max_samples: int = 2000
280
-
281
- def __post_init__(self):
282
- """Convert to correct types."""
283
- self.num_codes = int(self.num_codes)
284
- self.n_layers = int(self.n_layers)
285
- if self.n_heads == "None":
286
- self.n_heads = None
287
- elif self.n_heads is not None:
288
- self.n_heads = int(self.n_heads)
289
- self.seed = int(self.seed)
290
- self.max_samples = int(self.max_samples)
291
-
292
-
293
- def parse_model_info(path):
294
- """Parse model info from path."""
295
- with open(path + "info.txt", "r") as f:
296
- lines = f.readlines()
297
- lines = dict(line.strip().split(": ") for line in lines)
298
- return ModelInfoForWebapp(**lines)
299
- return ModelInfoForWebapp(**lines)
 
2
 
3
  import pickle
4
  import re
 
 
5
 
6
  import numpy as np
7
  import torch
8
  from tqdm import tqdm
9
 
 
 
10
 
11
  def load_dataset_cache(cache_base_path):
12
  """Load cache files required for dataset from `cache_base_path`."""
 
27
  return cb_acts, act_count_ft_tkns, metrics
28
 
29
 
30
+ def search_re(re_pattern, tokens_text, at_odd_even=-1):
31
+ """Get list of (example_id, token_pos) where re_pattern matches in tokens_text.
32
+
33
+ Args:
34
+ re_pattern: regex pattern to search for.
35
+ tokens_text: list of example texts.
36
+ at_odd_even: to limit matches to odd or even positions only.
37
+ -1 (default): to not limit matches.
38
+ 0: to limit matches to odd positions only.
39
+ 1: to limit matches to even positions only.
40
+ This is useful for the TokFSM dataset when searching for states
41
+ since the first token of states are always at even positions.
42
+ """
43
+ # TODO: ensure that parentheses are not escaped
44
+ assert at_odd_even in [-1, 0, 1], f"Invalid at_odd_even: {at_odd_even}"
45
  if re_pattern.find("(") == -1:
46
  re_pattern = f"({re_pattern})"
47
+ res = [
48
  (i, finditer.span(1)[0])
49
  for i, text in enumerate(tokens_text)
50
  for finditer in re.finditer(re_pattern, text)
51
  if finditer.span(1)[0] != finditer.span(1)[1]
52
  ]
53
+ if at_odd_even != -1:
54
+ res = [r for r in res if r[1] % 2 == at_odd_even]
55
+ return res
56
 
57
 
58
  def byte_id_to_token_pos_id(example_byte_id, token_byte_pos):
59
+ """Convert byte position (or character position in a text) to its token position.
60
+
61
+ Used to convert the searched regex span to its token position.
62
+
63
+ Args:
64
+ example_byte_id: tuple of (example_id, byte_id) where byte_id is a
65
+ character's position in the text.
66
+ token_byte_pos: numpy array of shape (num_examples, seq_len) where
67
+ `token_byte_pos[example_id][token_pos]` is the byte position of
68
+ the token at `token_pos` in the example with `example_id`.
69
+
70
+ Returns:
71
+ (example_id, token_pos_id) tuple.
72
+ """
73
  example_id, byte_id = example_byte_id
74
  index = np.searchsorted(token_byte_pos[example_id], byte_id, side="right")
75
  return (example_id, index)
76
 
77
 
78
+ def get_code_precision_and_recall(token_pos_ids, codebook_acts, cb_act_counts=None):
79
+ """Search for the codes that activate on the given `token_pos_ids`.
80
+
81
+ Args:
82
+ token_pos_ids: list of (example_id, token_pos_id) tuples.
83
+ codebook_acts: numpy array of activations of a codebook on a dataset with
84
+ shape (num_examples, seq_len, k_codebook).
85
+ cb_act_counts: array of shape (num_codes,) where `cb_act_counts[cb_name][code]`
86
+ is the number of times the code `code` is activated in the dataset.
87
+
88
+ Returns:
89
+ codes: numpy array of code ids sorted by their precision on the given `token_pos_ids`.
90
+ prec: numpy array where `prec[i]` is the precision of the code
91
+ `codes[i]` for the given `token_pos_ids`.
92
+ recall: numpy array where `recall[i]` is the recall of the code
93
+ `codes[i]` for the given `token_pos_ids`.
94
+ code_acts: numpy array where `code_acts[i]` is the number of times
95
+ the code `codes[i]` is activated in the dataset.
96
+ """
97
  codes = np.array(
98
  [
99
  codebook_acts[example_id][token_pos_id]
 
117
  return codes, prec, recall, code_acts
118
 
119
 
120
+ def get_neuron_precision_and_recall(
121
+ token_pos_ids, recall, neuron_acts_by_ex, neuron_sorted_acts
122
  ):
123
+ """Get the neurons with the highest precision and recall for the given `token_pos_ids`.
124
+
125
+ Args:
126
+ token_pos_ids: list of token (example_id, token_pos_id) tuples from a dataset over which
127
+ the neurons with the highest precision and recall are to be found.
128
+ recall: recall threshold for the neurons (this determines their activation threshold).
129
+ neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons
130
+ on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).
131
+ The third dimension is 2 because we consider neurons from both: attention and mlp.
132
+ neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons
133
+ on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).
134
+ This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two
135
+ dimensions to the last dimensions and then sorting the last dimension.
136
+
137
+ Returns:
138
+ best_prec: highest precision amongst all the neurons for the given `token_pos_ids`.
139
+ best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids`
140
+ based on the threshold determined by the `recall` argument.
141
+ best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number,
142
+ `is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp,
143
+ and `neuron_id` is the neuron's index in the layer.
144
+ """
145
  if isinstance(neuron_acts_by_ex, torch.Tensor):
146
+ neuron_acts_on_pattern = torch.stack(
147
  [
148
  neuron_acts_by_ex[example_id, token_pos_id]
149
  for example_id, token_pos_id in token_pos_ids
150
  ],
151
  dim=-1,
152
  ) # (layers, 2, dim_size, matches)
153
+ neuron_acts_on_pattern = torch.sort(neuron_acts_on_pattern, dim=-1).values
154
  else:
155
+ neuron_acts_on_pattern = np.stack(
156
  [
157
  neuron_acts_by_ex[example_id, token_pos_id]
158
  for example_id, token_pos_id in token_pos_ids
159
  ],
160
  axis=-1,
161
  ) # (layers, 2, dim_size, matches)
162
+ neuron_acts_on_pattern.sort(axis=-1)
163
+ neuron_acts_on_pattern = torch.from_numpy(neuron_acts_on_pattern)
164
+ act_thresh = neuron_acts_on_pattern[
165
+ :, :, :, -int(recall * neuron_acts_on_pattern.shape[-1])
166
+ ]
 
167
  assert neuron_sorted_acts.shape[:-1] == act_thresh.shape
168
  prec_den = torch.searchsorted(neuron_sorted_acts, act_thresh.unsqueeze(-1))
169
  prec_den = prec_den.squeeze(-1)
170
  prec_den = neuron_sorted_acts.shape[-1] - prec_den
171
+ prec = int(recall * neuron_acts_on_pattern.shape[-1]) / prec_den
172
  assert (
173
+ prec.shape == neuron_acts_on_pattern.shape[:-1]
174
+ ), f"{prec.shape} != {neuron_acts_on_pattern.shape[:-1]}"
175
 
176
  best_neuron_idx = np.unravel_index(prec.argmax(), prec.shape)
177
  best_prec = prec[best_neuron_idx]
 
178
  best_neuron_act_thresh = act_thresh[best_neuron_idx].item()
179
  best_neuron_acts = neuron_acts_by_ex[
180
  :, :, best_neuron_idx[0], best_neuron_idx[1], best_neuron_idx[2]
 
185
  return best_prec, best_neuron_acts, best_neuron_idx
186
 
187
 
188
+ def convert_to_adv_name(name, cb_at, gcb=""):
189
+ """Convert layer0_head0 to layer0_attn_preproj_gcb0."""
190
+ if gcb:
191
  layer, head = name.split("_")
192
+ return layer + f"_{cb_at}_gcb" + head[4:]
193
  else:
194
  return layer + "_" + cb_at
195
 
196
 
197
+ def convert_to_base_name(name, gcb=""):
198
+ """Convert layer0_attn_preproj_gcb0 to layer0_head0."""
199
  split_name = name.split("_")
200
  layer, head = split_name[0], split_name[-1][3:]
201
+ if "gcb" in name:
202
  return layer + "_head" + head
203
  else:
204
  return layer
 
215
 
216
 
217
  def get_layer_head_from_adv_name(name):
218
+ """Convert layer0_attn_preproj_gcb0 to 0, 0."""
219
  base_name = convert_to_base_name(name)
220
  layer, head = get_layer_head_from_base_name(base_name)
221
  return layer, head
 
227
  token_byte_pos,
228
  cb_acts,
229
  act_count_ft_tkns,
230
+ gcb="",
231
  topk=5,
232
  prec_threshold=0.5,
233
+ at_odd_even=-1,
234
  ):
235
+ """Fetch codes that activate on a given regex pattern.
236
+
237
+ Retrieves at most `top_k` codes that activate with precision above `prec_threshold`.
238
+
239
+ Args:
240
+ re_pattern: regex pattern to search for.
241
+ tokens_text: list of example texts of a dataset.
242
+ token_byte_pos: numpy array of shape (num_examples, seq_len) where
243
+ `token_byte_pos[example_id][token_pos]` is the byte position of
244
+ the token at `token_pos` in the example with `example_id`.
245
+ cb_acts: dict of codebook activations.
246
+ act_count_ft_tkns: dict over all codebooks of number of token activations on the dataset
247
+ gcb: "_gcb" for grouped codebooks and "" for non-grouped codebooks.
248
+ topk: maximum number of codes to return per codebook.
249
+ prec_threshold: minimum precision required for a code to be returned.
250
+ at_odd_even: to limit matches to odd or even positions only.
251
+ -1 (default): to not limit matches.
252
+ 0: to limit matches to odd positions only.
253
+ 1: to limit matches to even positions only.
254
+ This is useful for the TokFSM dataset when searching for states
255
+ since the first token of states are always at even positions.
256
+
257
+ Returns:
258
+ codebook_wise_codes: dict of codebook name to list of
259
+ (code, prec, recall, code_acts) tuples.
260
+ re_token_matches: number of tokens that match the regex pattern.
261
+ """
262
+ byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even)
263
  token_pos_ids = [
264
  byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
265
  ]
 
267
  re_token_matches = len(token_pos_ids)
268
  codebook_wise_codes = {}
269
  for cb_name, cb in tqdm(cb_acts.items()):
270
+ base_cb_name = convert_to_base_name(cb_name, gcb=gcb)
271
+ codes, prec, recall, code_acts = get_code_precision_and_recall(
272
  token_pos_ids,
273
  cb,
274
  cb_act_counts=act_count_ft_tkns[base_cb_name],
 
289
  neuron_acts_by_ex,
290
  neuron_sorted_acts,
291
  recall_threshold,
292
+ at_odd_even=-1,
293
  ):
294
+ """Fetch the highest precision neurons that activate on a given regex pattern.
295
+
296
+ The activation threshold for the neurons is determined by the `recall_threshold`.
297
+
298
+ Args:
299
+ re_pattern: regex pattern to search for.
300
+ tokens_text: list of example texts of a dataset.
301
+ token_byte_pos: numpy array of shape (num_examples, seq_len) where
302
+ `token_byte_pos[example_id][token_pos]` is the byte position of
303
+ the token at `token_pos` in the example with `example_id`.
304
+ neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons
305
+ on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).
306
+ The third dimension is 2 because we consider neurons from both: attention and mlp.
307
+ neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons
308
+ on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).
309
+ This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two
310
+ dimensions to the last dimensions and then sorting the last dimension.
311
+ recall_threshold: recall threshold for the neurons (this determines their activation threshold).
312
+ at_odd_even: to limit matches to odd or even positions only.
313
+ -1 (default): to not limit matches.
314
+ 0: to limit matches to odd positions only.
315
+ 1: to limit matches to even positions only.
316
+ This is useful for the TokFSM dataset when searching for states
317
+ since the first token of states are always at even positions.
318
+
319
+ Returns:
320
+ best_prec: highest precision amongst all the neurons for the given `token_pos_ids`.
321
+ best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids`
322
+ based on the threshold determined by the `recall` argument.
323
+ best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number,
324
+ `is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp,
325
+ and `neuron_id` is the neuron's index in the layer.
326
+ re_token_matches: number of tokens that match the regex pattern.
327
+ """
328
+ byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even)
329
  token_pos_ids = [
330
  byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
331
  ]
332
  token_pos_ids = np.unique(token_pos_ids, axis=0)
333
  re_token_matches = len(token_pos_ids)
334
+ best_prec, best_neuron_acts, best_neuron_idx = get_neuron_precision_and_recall(
335
  token_pos_ids,
336
  recall_threshold,
337
  neuron_acts_by_ex,
 
346
  token_byte_pos,
347
  neuron_acts_by_ex,
348
  neuron_sorted_acts,
349
+ at_odd_even=-1,
350
  ):
351
+ """Compare codes with the highest precision neurons on the regex pattern of the code.
352
+
353
+ Args:
354
+ best_codes_info: list of CodeInfo objects.
355
+ tokens_text: list of example texts of a dataset.
356
+ token_byte_pos: numpy array of shape (num_examples, seq_len) where
357
+ `token_byte_pos[example_id][token_pos]` is the byte position of
358
+ the token at `token_pos` in the example with `example_id`.
359
+ neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons
360
+ on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).
361
+ The third dimension is 2 because we consider neurons from both: attention and mlp.
362
+ neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons
363
+ on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).
364
+ This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two
365
+ dimensions to the last dimensions and then sorting the last dimension.
366
+ at_odd_even: to limit matches to odd or even positions only.
367
+ -1 (default): to not limit matches.
368
+ 0: to limit matches to odd positions only.
369
+ 1: to limit matches to even positions only.
370
+ This is useful for the TokFSM dataset when searching for states
371
+ since the first token of states are always at even positions.
372
+
373
+ Returns:
374
+ codes_better_than_neurons: fraction of codes that have higher precision than the highest
375
+ precision neuron on the regex pattern of the code.
376
+ code_best_precs: is an array of the precision of each code in `best_codes_info`.
377
+ all_best_prec: is an array of the highest precision neurons on the regex pattern.
378
+ """
379
  assert isinstance(neuron_acts_by_ex, np.ndarray)
380
  (
381
+ neuron_best_prec,
382
  all_best_neuron_acts,
383
  all_best_neuron_idxs,
384
  all_re_token_matches,
385
  ) = zip(
386
  *[
387
  get_neurons_from_pattern(
388
+ code_info.regex,
389
  tokens_text,
390
  token_byte_pos,
391
  neuron_acts_by_ex,
392
  neuron_sorted_acts,
393
  code_info.recall,
394
+ at_odd_even=at_odd_even,
395
  )
396
+ for code_info in tqdm(best_codes_info)
397
  ],
398
  strict=True,
399
  )
400
+ neuron_best_prec = np.array(neuron_best_prec)
401
+ code_best_precs = np.array([code_info.prec for code_info in best_codes_info])
402
+ codes_better_than_neurons = code_best_precs > neuron_best_prec
403
+ return codes_better_than_neurons.mean(), code_best_precs, neuron_best_prec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/Concept_Code.py CHANGED
@@ -21,7 +21,7 @@ tokens_text = st.session_state["tokens_text"]
21
  tokens_str = st.session_state["tokens_str"]
22
  cb_acts = st.session_state["cb_acts"]
23
  act_count_ft_tkns = st.session_state["act_count_ft_tkns"]
24
- ccb = st.session_state["ccb"]
25
 
26
 
27
  def get_example_concept_codes(example_id):
@@ -29,8 +29,8 @@ def get_example_concept_codes(example_id):
29
  token_pos_ids = [(example_id, i) for i in range(seq_len)]
30
  all_codes = []
31
  for cb_name, cb in cb_acts.items():
32
- base_cb_name = code_search_utils.convert_to_base_name(cb_name, ccb=ccb)
33
- codes, prec, rec, code_acts = code_search_utils.get_code_pr(
34
  token_pos_ids,
35
  cb,
36
  act_count_ft_tkns[base_cb_name],
@@ -112,7 +112,6 @@ concept_code_description = (
112
  )
113
  st.write(concept_code_description)
114
 
115
- # ex_col, p_col, r_col, trunc_col, sort_col = st.columns([1, 2, 2, 1, 1])
116
  ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1])
117
  example_id = ex_col.number_input(
118
  "Example ID",
@@ -121,14 +120,6 @@ example_id = ex_col.number_input(
121
  0,
122
  key="example_id",
123
  )
124
- # prec_threshold = p_col.slider(
125
- # "Precision Threshold",
126
- # 0.0,
127
- # 1.0,
128
- # 0.02,
129
- # key="prec",
130
- # help="Precision Threshold controls the specificity of the codes for the given example.",
131
- # )
132
  recall_threshold = r_col.slider(
133
  "Recall Threshold",
134
  0.0,
@@ -138,13 +129,13 @@ recall_threshold = r_col.slider(
138
  help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.",
139
  )
140
  example_truncation = trunc_col.number_input(
141
- "Max Output Chars", 0, 10240, 1024, key="max_chars"
142
  )
143
  sort_by_options = ["Precision", "Recall", "Num Acts"]
144
  sort_by_name = sort_col.radio(
145
  "Sort By",
146
  sort_by_options,
147
- index=0,
148
  horizontal=True,
149
  help="Sorts the codes by the selected metric.",
150
  )
@@ -158,9 +149,6 @@ button = st.button(
158
  args=(example_id,),
159
  help="Find an example which has codes above the recall threshold.",
160
  )
161
- # if button:
162
- # find_next_example(st.session_state["example_id"])
163
-
164
 
165
  st.markdown("### Example Text")
166
  trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else ""
 
21
  tokens_str = st.session_state["tokens_str"]
22
  cb_acts = st.session_state["cb_acts"]
23
  act_count_ft_tkns = st.session_state["act_count_ft_tkns"]
24
+ gcb = st.session_state["gcb"]
25
 
26
 
27
  def get_example_concept_codes(example_id):
 
29
  token_pos_ids = [(example_id, i) for i in range(seq_len)]
30
  all_codes = []
31
  for cb_name, cb in cb_acts.items():
32
+ base_cb_name = code_search_utils.convert_to_base_name(cb_name, gcb=gcb)
33
+ codes, prec, rec, code_acts = code_search_utils.get_code_precision_and_recall(
34
  token_pos_ids,
35
  cb,
36
  act_count_ft_tkns[base_cb_name],
 
112
  )
113
  st.write(concept_code_description)
114
 
 
115
  ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1])
116
  example_id = ex_col.number_input(
117
  "Example ID",
 
120
  0,
121
  key="example_id",
122
  )
 
 
 
 
 
 
 
 
123
  recall_threshold = r_col.slider(
124
  "Recall Threshold",
125
  0.0,
 
129
  help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.",
130
  )
131
  example_truncation = trunc_col.number_input(
132
+ "Max Output Chars", 0, 102400, 1024, key="max_chars"
133
  )
134
  sort_by_options = ["Precision", "Recall", "Num Acts"]
135
  sort_by_name = sort_col.radio(
136
  "Sort By",
137
  sort_by_options,
138
+ index=1,
139
  horizontal=True,
140
  help="Sorts the codes by the selected metric.",
141
  )
 
149
  args=(example_id,),
150
  help="Find an example which has codes above the recall threshold.",
151
  )
 
 
 
152
 
153
  st.markdown("### Example Text")
154
  trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else ""
utils.py CHANGED
@@ -1,4 +1,6 @@
1
  """Util functions for codebook features."""
 
 
2
  import re
3
  import typing
4
  from dataclasses import dataclass
@@ -57,11 +59,6 @@ class CodeInfo:
57
  if self.regex is not None:
58
  assert self.prec is not None and self.recall is not None
59
 
60
- def check_patch_info(self):
61
- """Check if the patch info is present."""
62
- # TODO: pos can be none for patching
63
- assert self.pos is not None and self.code_pos is not None
64
-
65
  def __repr__(self):
66
  """Return the string representation."""
67
  repr = f"CodeInfo(code={self.code}, layer={self.layer}, head={self.head}, cb_at={self.cb_at}"
@@ -76,6 +73,57 @@ class CodeInfo:
76
  repr += ")"
77
  return repr
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  def logits_to_pred(logits, tokenizer, k=5):
81
  """Convert logits to top-k predictions."""
@@ -88,53 +136,6 @@ def logits_to_pred(logits, tokenizer, k=5):
88
  return [(topk_preds[i], probs[:, -1, i].item()) for i in range(len(topk_preds))]
89
 
90
 
91
- def patch_codebook_ids(
92
- corrupted_codebook_ids, hook, pos, cache, cache_pos=None, code_idx=None
93
- ):
94
- """Patch codebook ids with cached ids."""
95
- if cache_pos is None:
96
- cache_pos = pos
97
- if code_idx is None:
98
- corrupted_codebook_ids[:, pos] = cache[hook.name][:, cache_pos]
99
- else:
100
- for code_id in range(32):
101
- if code_id in code_idx:
102
- corrupted_codebook_ids[:, pos, code_id] = cache[hook.name][
103
- :, cache_pos, code_id
104
- ]
105
- else:
106
- corrupted_codebook_ids[:, pos, code_id] = -1
107
-
108
- return corrupted_codebook_ids
109
-
110
-
111
- def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
112
- """Calculate the average logit difference between the answer and the other token."""
113
- # Only the final logits are relevant for the answer
114
- final_logits = logits[:, -1, :]
115
- answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
116
- answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
117
- if per_prompt:
118
- return answer_logit_diff
119
- else:
120
- return answer_logit_diff.mean()
121
-
122
-
123
- def normalize_patched_logit_diff(
124
- patched_logit_diff,
125
- base_average_logit_diff,
126
- corrupted_average_logit_diff,
127
- ):
128
- """Normalize the patched logit difference."""
129
- # Subtract corrupted logit diff to measure the improvement,
130
- # divide by the total improvement from clean to corrupted to normalise
131
- # 0 means zero change, negative means actively made worse,
132
- # 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
133
- return (patched_logit_diff - corrupted_average_logit_diff) / (
134
- base_average_logit_diff - corrupted_average_logit_diff
135
- )
136
-
137
-
138
  def features_to_tokens(cb_key, cb_acts, num_codes, code=None):
139
  """Return the set of token ids each codebook feature activates on."""
140
  codebook_ids = cb_acts[cb_key]
@@ -154,7 +155,6 @@ def features_to_tokens(cb_key, cb_acts, num_codes, code=None):
154
 
155
  def color_str(s: str, html: bool, color: Optional[str] = None):
156
  """Color the string for html or terminal."""
157
-
158
  if html:
159
  color = "DeepSkyBlue" if color is None else color
160
  return f"<span style='color:{color}'>{s}</span>"
@@ -163,7 +163,7 @@ def color_str(s: str, html: bool, color: Optional[str] = None):
163
  return colored(s, color)
164
 
165
 
166
- def color_tokens_automata(tokens, color_idx, html=False):
167
  """Separate states with a dash and color red the tokens in color_idx."""
168
  ret_string = ""
169
  itr_over_color_idx = 0
@@ -224,31 +224,48 @@ def prepare_example_print(
224
  return example_output
225
 
226
 
227
- def tkn_print(
228
- ll,
229
  tokens,
230
- separate_states,
231
  n=3,
232
  max_examples=100,
233
  randomize=False,
234
  html=False,
235
  return_example_list=False,
236
  ):
237
- """Format and prints the tokens in ll."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  if randomize:
239
  raise NotImplementedError("Randomize not yet implemented.")
240
- indices = range(len(ll))
241
  print_output = [] if return_example_list else ""
242
- curr_ex = ll[0][0]
243
  total_examples = 0
244
  tokens_to_color = []
245
- color_fn = color_tokens_automata if separate_states else partial(color_tokens, n=n)
246
  for idx in indices:
247
  if total_examples > max_examples:
248
  break
249
- i, j = ll[idx]
250
 
251
  if i != curr_ex and curr_ex >= 0:
 
252
  curr_ex_output = prepare_example_print(
253
  curr_ex,
254
  tokens[curr_ex],
@@ -275,17 +292,16 @@ def tkn_print(
275
  print_output.append((curr_ex_output, len(tokens_to_color)))
276
  else:
277
  print_output += curr_ex_output
278
- asterisk_str = "********************************************"
279
- print_output += color_str(asterisk_str, html, "green")
280
  total_examples += 1
281
 
282
  return print_output
283
 
284
 
285
- def print_ft_tkns(
286
  ft_tkns,
287
  tokens,
288
- separate_states=False,
289
  n=3,
290
  start=0,
291
  stop=1000,
@@ -301,17 +317,17 @@ def print_ft_tkns(
301
  num_tokens = len(tokens) * len(tokens[0])
302
  codes, token_act_freqs, token_acts = [], [], []
303
  for i in indices:
304
- tkns = ft_tkns[i]
305
- freq = (len(tkns), 100 * len(tkns) / num_tokens)
306
  if freq_filter is not None and freq[1] > freq_filter:
307
  continue
308
  codes.append(i)
309
  token_act_freqs.append(freq)
310
- if len(tkns) > 0:
311
- tkn_acts = tkn_print(
312
- tkns,
313
  tokens,
314
- separate_states,
315
  n=n,
316
  max_examples=max_examples,
317
  randomize=randomize,
@@ -340,149 +356,59 @@ def patch_in_codes(run_cb_ids, hook, pos, code, code_pos=None):
340
  return run_cb_ids
341
 
342
 
343
- def get_cb_layer_name(cb_at, layer_idx, head_idx=None):
344
  """Get the layer name used to store hooks/cache."""
345
- if head_idx is None:
346
- return f"blocks.{layer_idx}.{cb_at}.codebook_layer.hook_codebook_ids"
347
- else:
348
- return f"blocks.{layer_idx}.{cb_at}.codebook_layer.codebook.{head_idx}.hook_codebook_ids"
349
-
350
-
351
- def get_cb_layer_names(layer, patch_types, n_heads):
352
- """Get the layer names used to store hooks/cache."""
353
- layer_names = []
354
- attn_added, mlp_added = False, False
355
- if "attn_out" in patch_types:
356
- attn_added = True
357
- for head in range(n_heads):
358
- layer_names.append(
359
- f"blocks.{layer}.attn.codebook_layer.codebook.{head}.hook_codebook_ids"
360
- )
361
- if "mlp_out" in patch_types:
362
- mlp_added = True
363
- layer_names.append(f"blocks.{layer}.mlp.codebook_layer.hook_codebook_ids")
364
-
365
- for patch_type in patch_types:
366
- # match patch_type of the pattern attn_\d_head_\d
367
- attn_head = re.match(r"attn_(\d)_head_(\d)", patch_type)
368
- if (not attn_added) and attn_head and attn_head[1] == str(layer):
369
- layer_names.append(
370
- f"blocks.{layer}.attn.codebook_layer.codebook.{attn_head[2]}.hook_codebook_ids"
371
- )
372
- mlp = re.match(r"mlp_(\d)", patch_type)
373
- if (not mlp_added) and mlp and mlp[1] == str(layer):
374
- layer_names.append(f"blocks.{layer}.mlp.codebook_layer.hook_codebook_ids")
375
-
376
- return layer_names
377
-
378
-
379
- def cb_layer_name_to_info(layer_name):
380
- """Get the layer info from the layer name."""
381
- layer_name_split = layer_name.split(".")
382
- layer_idx = int(layer_name_split[1])
383
- cb_at = layer_name_split[2]
384
- if cb_at == "mlp":
385
- head_idx = None
386
  else:
387
- head_idx = int(layer_name_split[5])
388
- return cb_at, layer_idx, head_idx
389
-
390
-
391
- def get_hooks(code, cb_at, layer_idx, head_idx=None, pos=None):
392
- """Get the hooks for the codebook features."""
393
- hook_fns = [
394
- partial(patch_in_codes, pos=pos, code=code[i]) for i in range(len(code))
395
- ]
396
- return [
397
- (get_cb_layer_name(cb_at[i], layer_idx[i], head_idx[i]), hook_fns[i])
398
- for i in range(len(code))
399
- ]
400
-
401
-
402
- def run_with_codes(
403
- input, cb_model, code, cb_at, layer_idx, head_idx=None, pos=None, prepend_bos=True
404
- ):
405
- """Run the model with the codebook features patched in."""
406
- hook_fns = [
407
- partial(patch_in_codes, pos=pos, code=code[i]) for i in range(len(code))
408
- ]
409
- cb_model.reset_codebook_metrics()
410
- cb_model.reset_hook_kwargs()
411
- fwd_hooks = [
412
- (get_cb_layer_name(cb_at[i], layer_idx[i], head_idx[i]), hook_fns[i])
413
- for i in range(len(cb_at))
414
- ]
415
- with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model:
416
- patched_logits, patched_cache = hooked_model.run_with_cache(
417
- input, prepend_bos=prepend_bos
418
- )
419
- return patched_logits, patched_cache
420
-
421
-
422
- def in_hook_list(list_of_arg_tuples, layer, head=None):
423
- """Check if the component specified by `layer` and `head` is in the `list_of_arg_tuples`."""
424
- # if head is not provided, then checks in MLP
425
- for arg_tuple in list_of_arg_tuples:
426
- if head is None:
427
- if arg_tuple.cb_at == "mlp" and arg_tuple.layer == layer:
428
- return True
429
- else:
430
- if (
431
- arg_tuple.cb_at == "attn"
432
- and arg_tuple.layer == layer
433
- and arg_tuple.head == head
434
- ):
435
- return True
436
- return False
437
 
438
 
439
- # def generate_with_codes(input, code, cb_at, layer_idx, head_idx=None, pos=None, disable_other_comps=False):
440
- def generate_with_codes(
441
  input,
442
  cb_model,
 
 
443
  list_of_code_infos=(),
444
- disable_other_comps=False,
445
- automata=None,
446
- generate_kwargs=None,
447
  ):
448
- """Model's generation with the codebook features patched in."""
449
- if generate_kwargs is None:
450
- generate_kwargs = {}
 
 
 
 
451
  hook_fns = [
452
- partial(patch_in_codes, pos=tupl.pos, code=tupl.code)
453
  for tupl in list_of_code_infos
454
  ]
455
  fwd_hooks = [
456
- (get_cb_layer_name(tupl.cb_at, tupl.layer, tupl.head), hook_fns[i])
457
  for i, tupl in enumerate(list_of_code_infos)
458
  ]
459
  cb_model.reset_hook_kwargs()
460
- if disable_other_comps:
461
- for layer, cb in cb_model.all_codebooks.items():
462
- for head_idx, head in enumerate(cb[0].codebook):
463
- if not in_hook_list(list_of_code_infos, layer, head_idx):
464
- head.set_hook_kwargs(
465
- disable_topk=1, disable_for_tkns=[-1], keep_k_codes=False
466
- )
467
- if not in_hook_list(list_of_code_infos, layer):
468
- cb[1].set_hook_kwargs(
469
- disable_topk=1, disable_for_tkns=[-1], keep_k_codes=False
470
- )
471
  with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model:
472
- gen = hooked_model.generate(input, **generate_kwargs)
473
- return automata.seq_to_traj(gen)[0] if automata is not None else gen
474
 
475
 
476
- def kl_div(logits1, logits2, pos=-1, reduction="batchmean"):
477
- """Calculate the KL divergence between the logits at `pos`."""
478
- logits1_last, logits2_last = logits1[:, pos, :], logits2[:, pos, :]
479
- # calculate kl divergence between clean and mod logits last
480
- return F.kl_div(
481
- F.log_softmax(logits1_last, dim=-1),
482
- F.log_softmax(logits2_last, dim=-1),
483
- log_target=True,
484
- reduction=reduction,
 
 
 
 
 
485
  )
 
486
 
487
 
488
  def JSD(logits1, logits2, pos=-1, reduction="batchmean"):
@@ -511,11 +437,27 @@ def JSD(logits1, logits2, pos=-1, reduction="batchmean"):
511
  return 0.5 * loss
512
 
513
 
514
- def residual_stream_patching_hook(resid_pre, hook, cache, position: int):
515
- """Patch in the codebook features at `position` from `cache`."""
516
- clean_resid_pre = cache[hook.name]
517
- resid_pre[:, position, :] = clean_resid_pre[:, position, :]
518
- return resid_pre
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
 
520
 
521
  def find_code_changes(cache1, cache2, pos=None):
@@ -525,8 +467,8 @@ def find_code_changes(cache1, cache2, pos=None):
525
  c1 = cache1[k][0, pos]
526
  c2 = cache2[k][0, pos]
527
  if not torch.all(c1 == c2):
528
- print(cb_layer_name_to_info(k), c1.tolist(), c2.tolist())
529
- print(cb_layer_name_to_info(k), c1.tolist(), c2.tolist())
530
 
531
 
532
  def common_codes_in_cache(cache_codes, threshold=0.0):
@@ -541,39 +483,52 @@ def common_codes_in_cache(cache_codes, threshold=0.0):
541
  return codes, counts
542
 
543
 
544
- def parse_code_info_string(
545
- info_str: str, cb_at="attn", pos=None, code_pos=-1
546
- ) -> CodeInfo:
547
- """Parse the code info string.
548
-
549
- The format of the `info_str` is:
550
- `code: 0, layer: 0, head: 0, occ_freq: 0.0, train_act_freq: 0.0`.
551
- """
552
- code, layer, head, occ_freq, train_act_freq = info_str.split(", ")
553
- code = int(code.split(": ")[1])
554
- layer = int(layer.split(": ")[1])
555
- head = int(head.split(": ")[1]) if head else None
556
- occ_freq = float(occ_freq.split(": ")[1])
557
- train_act_freq = float(train_act_freq.split(": ")[1])
558
- return CodeInfo(code, layer, head, pos=pos, code_pos=code_pos, cb_at=cb_at)
559
-
560
-
561
- def parse_concept_codes_string(info_str: str, pos=None, code_append=False):
562
- """Parse the concept codes string."""
563
  code_info_strs = info_str.strip().split("\n")
564
- concept_codes = []
 
565
  layer, head = None, None
566
- code_pos = "append" if code_append else -1
 
 
 
567
  for code_info_str in code_info_strs:
568
- concept_codes.append(
569
- parse_code_info_string(code_info_str, pos=pos, code_pos=code_pos)
 
 
 
 
 
570
  )
571
- if code_append:
572
  continue
573
- if layer == concept_codes[-1].layer and head == concept_codes[-1].head:
574
- code_pos -= 1
575
  else:
576
  code_pos = -1
577
- concept_codes[-1].code_pos = code_pos
578
- layer, head = concept_codes[-1].layer, concept_codes[-1].head
579
- return concept_codes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Util functions for codebook features."""
2
+
3
+ import pathlib
4
  import re
5
  import typing
6
  from dataclasses import dataclass
 
59
  if self.regex is not None:
60
  assert self.prec is not None and self.recall is not None
61
 
 
 
 
 
 
62
  def __repr__(self):
63
  """Return the string representation."""
64
  repr = f"CodeInfo(code={self.code}, layer={self.layer}, head={self.head}, cb_at={self.cb_at}"
 
73
  repr += ")"
74
  return repr
75
 
76
+ @classmethod
77
+ def from_str(cls, code_txt, *args, **kwargs):
78
+ """Extract code info fields from string."""
79
+ code_txt = code_txt.strip().lower()
80
+ code_txt = code_txt.split(", ")
81
+ code_txt = dict(txt.split(": ") for txt in code_txt)
82
+ return cls(*args, **code_txt, **kwargs)
83
+
84
+
85
+ @dataclass
86
+ class ModelInfoForWebapp:
87
+ """Model info for webapp."""
88
+
89
+ model_name: str
90
+ pretrained_path: str
91
+ dataset_name: str
92
+ num_codes: int
93
+ cb_at: str
94
+ gcb: str
95
+ n_layers: int
96
+ n_heads: Optional[int] = None
97
+ seed: int = 42
98
+ max_samples: int = 2000
99
+
100
+ def __post_init__(self):
101
+ """Convert to correct types."""
102
+ self.num_codes = int(self.num_codes)
103
+ self.n_layers = int(self.n_layers)
104
+ if self.n_heads == "None":
105
+ self.n_heads = None
106
+ elif self.n_heads is not None:
107
+ self.n_heads = int(self.n_heads)
108
+ self.seed = int(self.seed)
109
+ self.max_samples = int(self.max_samples)
110
+
111
+ @classmethod
112
+ def load(cls, path):
113
+ """Parse model info from path."""
114
+ path = pathlib.Path(path)
115
+ with open(path / "info.txt", "r") as f:
116
+ lines = f.readlines()
117
+ lines = dict(line.strip().split(": ") for line in lines)
118
+ return cls(**lines)
119
+
120
+ def save(self, path):
121
+ """Save model info to path."""
122
+ path = pathlib.Path(path)
123
+ with open(path / "info.txt", "w") as f:
124
+ for k, v in self.__dict__.items():
125
+ f.write(f"{k}: {v}\n")
126
+
127
 
128
  def logits_to_pred(logits, tokenizer, k=5):
129
  """Convert logits to top-k predictions."""
 
136
  return [(topk_preds[i], probs[:, -1, i].item()) for i in range(len(topk_preds))]
137
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def features_to_tokens(cb_key, cb_acts, num_codes, code=None):
140
  """Return the set of token ids each codebook feature activates on."""
141
  codebook_ids = cb_acts[cb_key]
 
155
 
156
  def color_str(s: str, html: bool, color: Optional[str] = None):
157
  """Color the string for html or terminal."""
 
158
  if html:
159
  color = "DeepSkyBlue" if color is None else color
160
  return f"<span style='color:{color}'>{s}</span>"
 
163
  return colored(s, color)
164
 
165
 
166
+ def color_tokens_tokfsm(tokens, color_idx, html=False):
167
  """Separate states with a dash and color red the tokens in color_idx."""
168
  ret_string = ""
169
  itr_over_color_idx = 0
 
224
  return example_output
225
 
226
 
227
+ def print_token_activations_of_code(
228
+ code_act_by_pos,
229
  tokens,
230
+ is_fsm=False,
231
  n=3,
232
  max_examples=100,
233
  randomize=False,
234
  html=False,
235
  return_example_list=False,
236
  ):
237
+ """Print the context with the tokens that a code activates on.
238
+
239
+ Args:
240
+ code_act_by_pos: list of (example_id, token_pos_id) tuples specifying
241
+ the token positions that a code activates on in a dataset.
242
+ tokens: list of tokens of a dataset.
243
+ is_fsm: whether the dataset is the TokFSM dataset.
244
+ n: context to print around each side of a token that the code activates on.
245
+ max_examples: maximum number of examples to print.
246
+ randomize: whether to randomize the order of examples.
247
+ html: Format the printing style for html or terminal.
248
+ return_example_list: whether to return the printed string by examples or as a single string.
249
+
250
+ Returns:
251
+ string of all examples formatted if `return_example_list` is False otherwise
252
+ list of (example_string, num_tokens_colored) tuples for each example.
253
+ """
254
  if randomize:
255
  raise NotImplementedError("Randomize not yet implemented.")
256
+ indices = range(len(code_act_by_pos))
257
  print_output = [] if return_example_list else ""
258
+ curr_ex = code_act_by_pos[0][0]
259
  total_examples = 0
260
  tokens_to_color = []
261
+ color_fn = color_tokens_tokfsm if is_fsm else partial(color_tokens, n=n)
262
  for idx in indices:
263
  if total_examples > max_examples:
264
  break
265
+ i, j = code_act_by_pos[idx]
266
 
267
  if i != curr_ex and curr_ex >= 0:
268
+ # got new example so print the previous one
269
  curr_ex_output = prepare_example_print(
270
  curr_ex,
271
  tokens[curr_ex],
 
292
  print_output.append((curr_ex_output, len(tokens_to_color)))
293
  else:
294
  print_output += curr_ex_output
295
+ print_output += color_str("*" * 50, html, "green")
 
296
  total_examples += 1
297
 
298
  return print_output
299
 
300
 
301
+ def print_token_activations_of_codes(
302
  ft_tkns,
303
  tokens,
304
+ is_fsm=False,
305
  n=3,
306
  start=0,
307
  stop=1000,
 
317
  num_tokens = len(tokens) * len(tokens[0])
318
  codes, token_act_freqs, token_acts = [], [], []
319
  for i in indices:
320
+ tkns_of_code = ft_tkns[i]
321
+ freq = (len(tkns_of_code), 100 * len(tkns_of_code) / num_tokens)
322
  if freq_filter is not None and freq[1] > freq_filter:
323
  continue
324
  codes.append(i)
325
  token_act_freqs.append(freq)
326
+ if len(tkns_of_code) > 0:
327
+ tkn_acts = print_token_activations_of_code(
328
+ tkns_of_code,
329
  tokens,
330
+ is_fsm,
331
  n=n,
332
  max_examples=max_examples,
333
  randomize=randomize,
 
356
  return run_cb_ids
357
 
358
 
359
+ def get_cb_hook_key(cb_at: str, layer_idx: int, gcb_idx: Optional[int] = None):
360
  """Get the layer name used to store hooks/cache."""
361
+ comp_name = "attn" if "attn" in cb_at else "mlp"
362
+ if gcb_idx is None:
363
+ return f"blocks.{layer_idx}.{comp_name}.codebook_layer.hook_codebook_ids"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  else:
365
+ return f"blocks.{layer_idx}.{comp_name}.codebook_layer.codebook.{gcb_idx}.hook_codebook_ids"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
 
368
+ def run_model_fn_with_codes(
 
369
  input,
370
  cb_model,
371
+ fn_name,
372
+ fn_kwargs=None,
373
  list_of_code_infos=(),
 
 
 
374
  ):
375
+ """Run the `cb_model`'s `fn_name` method while activating the codes in `list_of_code_infos`.
376
+
377
+ Common use case includes running the `run_with_cache` method while activating the codes.
378
+ For running the `generate` method, use `generate_with_codes` instead.
379
+ """
380
+ if fn_kwargs is None:
381
+ fn_kwargs = {}
382
  hook_fns = [
383
+ partial(patch_in_codes, pos=tupl.pos, code=tupl.code, code_pos=tupl.code_pos)
384
  for tupl in list_of_code_infos
385
  ]
386
  fwd_hooks = [
387
+ (get_cb_hook_key(tupl.cb_at, tupl.layer, tupl.head), hook_fns[i])
388
  for i, tupl in enumerate(list_of_code_infos)
389
  ]
390
  cb_model.reset_hook_kwargs()
 
 
 
 
 
 
 
 
 
 
 
391
  with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model:
392
+ ret = hooked_model.__getattribute__(fn_name)(input, **fn_kwargs)
393
+ return ret
394
 
395
 
396
+ def generate_with_codes(
397
+ input,
398
+ cb_model,
399
+ list_of_code_infos=(),
400
+ tokfsm=None,
401
+ generate_kwargs=None,
402
+ ):
403
+ """Sample from the language model while activating the codes in `list_of_code_infos`."""
404
+ gen = run_model_fn_with_codes(
405
+ input,
406
+ cb_model,
407
+ "generate",
408
+ generate_kwargs,
409
+ list_of_code_infos,
410
  )
411
+ return tokfsm.seq_to_traj(gen) if tokfsm is not None else gen
412
 
413
 
414
  def JSD(logits1, logits2, pos=-1, reduction="batchmean"):
 
437
  return 0.5 * loss
438
 
439
 
440
+ def cb_hook_key_to_info(layer_hook_key: str):
441
+ """Get the layer info from the codebook layer hook key.
442
+
443
+ Args:
444
+ layer_hook_key: the hook key of the codebook layer.
445
+ E.g. `blocks.3.attn.codebook_layer.hook_codebook_ids`
446
+
447
+ Returns:
448
+ comp_name: the name of the component codebook is appied at.
449
+ layer_idx: the layer index.
450
+ gcb_idx: the codebook index if the codebook layer is grouped, otherwise None.
451
+ """
452
+ layer_search = re.search(r"blocks\.(\d+)\.(\w+)\.", layer_hook_key)
453
+ assert layer_search is not None
454
+ layer_idx, comp_name = int(layer_search.group(1)), layer_search.group(2)
455
+ gcb_idx_search = re.search(r"codebook\.(\d+)", layer_hook_key)
456
+ if gcb_idx_search is not None:
457
+ gcb_idx = int(gcb_idx_search.group(1))
458
+ else:
459
+ gcb_idx = None
460
+ return comp_name, layer_idx, gcb_idx
461
 
462
 
463
  def find_code_changes(cache1, cache2, pos=None):
 
467
  c1 = cache1[k][0, pos]
468
  c2 = cache2[k][0, pos]
469
  if not torch.all(c1 == c2):
470
+ print(cb_hook_key_to_info(k), c1.tolist(), c2.tolist())
471
+ print(cb_hook_key_to_info(k), c1.tolist(), c2.tolist())
472
 
473
 
474
  def common_codes_in_cache(cache_codes, threshold=0.0):
 
483
  return codes, counts
484
 
485
 
486
+ def parse_topic_codes_string(
487
+ info_str: str,
488
+ pos: Optional[int] = None,
489
+ code_append: Optional[bool] = False,
490
+ **code_info_kwargs,
491
+ ):
492
+ """Parse the topic codes string."""
 
 
 
 
 
 
 
 
 
 
 
 
493
  code_info_strs = info_str.strip().split("\n")
494
+ code_info_strs = [e.strip() for e in code_info_strs if e]
495
+ topic_codes = []
496
  layer, head = None, None
497
+ if code_append is None:
498
+ code_pos = None
499
+ else:
500
+ code_pos = "append" if code_append else -1
501
  for code_info_str in code_info_strs:
502
+ topic_codes.append(
503
+ CodeInfo.from_str(
504
+ code_info_str,
505
+ pos=pos,
506
+ code_pos=code_pos,
507
+ **code_info_kwargs,
508
+ )
509
  )
510
+ if code_append is None or code_append:
511
  continue
512
+ if layer == topic_codes[-1].layer and head == topic_codes[-1].head:
513
+ code_pos -= 1 # type: ignore
514
  else:
515
  code_pos = -1
516
+ topic_codes[-1].code_pos = code_pos
517
+ layer, head = topic_codes[-1].layer, topic_codes[-1].head
518
+ return topic_codes
519
+
520
+
521
+ def find_similar_codes(cb_model, code_info, n=8):
522
+ """Find the `n` most similar codes to the given code using cosine similarity.
523
+
524
+ Useful for finding related codes for interpretability.
525
+ """
526
+ codebook = cb_model.get_codebook(code_info)
527
+ device = codebook.weight.device
528
+ code = codebook(torch.tensor(code_info.code).to(device))
529
+ code = code.to(device)
530
+ logits = torch.matmul(code, codebook.weight.T)
531
+ _, indices = torch.topk(logits, n)
532
+ assert indices[0] == code_info.code
533
+ assert torch.allclose(logits[indices[0]], torch.tensor(1.0))
534
+ return indices[1:], logits[indices[1:]].tolist()
webapp_utils.py CHANGED
@@ -1,6 +1,9 @@
1
  """Utility functions for running webapp using streamlit."""
2
 
3
 
 
 
 
4
  import streamlit as st
5
  from streamlit.components.v1 import html
6
 
@@ -61,10 +64,10 @@ def load_ft_tkns(model_id, layer, head=None, code=None):
61
  # model_id required to not mix cache_data for different models
62
  assert model_id is not None
63
  cb_at = st.session_state["cb_at"]
64
- ccb = st.session_state["ccb"]
65
  cb_acts = st.session_state["cb_acts"]
66
  if head is not None:
67
- cb_name = f"layer{layer}_{cb_at}{ccb}{head}"
68
  else:
69
  cb_name = f"layer{layer}_{cb_at}"
70
  return utils.features_to_tokens(
@@ -84,11 +87,12 @@ def get_code_acts(
84
  ctx_size=5,
85
  num_examples=100,
86
  return_example_list=False,
 
87
  ):
88
  """Get the token activations for a given code."""
89
  ft_tkns = load_ft_tkns(model_id, layer, head, code)
90
  ft_tkns = [ft_tkns]
91
- _, freqs, acts = utils.print_ft_tkns(
92
  ft_tkns,
93
  tokens=tokens_str,
94
  indices=[0],
@@ -96,6 +100,7 @@ def get_code_acts(
96
  n=ctx_size,
97
  max_examples=num_examples,
98
  return_example_list=return_example_list,
 
99
  )
100
  return acts[0], freqs[0]
101
 
@@ -122,8 +127,16 @@ def find_next_code(code, layer_code_acts, act_range=None):
122
  """Find the next code that has activations in the given range."""
123
  if act_range is None:
124
  return code
 
 
 
 
 
 
 
 
125
  for code_iter, code_act_count in enumerate(layer_code_acts[code:]):
126
- if code_act_count >= act_range[0] and code_act_count <= act_range[1]:
127
  code += code_iter
128
  break
129
  return code
@@ -161,8 +174,8 @@ def add_save_code_button(
161
  demo_file_path: str,
162
  num_acts: int,
163
  save_regex: bool = False,
164
- prec: float = None,
165
- recall: float = None,
166
  button_st_container=st,
167
  button_text: bool = False,
168
  button_key_suffix: str = "",
@@ -176,12 +189,12 @@ def add_save_code_button(
176
  if save_button:
177
  description = st.text_input(
178
  "Write a description for the code",
179
- key="save_code_desc",
180
  )
181
  if not description:
182
  return
183
 
184
- description = st.session_state.get("save_code_desc", None)
185
  if description:
186
  layer = st.session_state["ct_act_layer"]
187
  is_attn = st.session_state["is_attn"]
@@ -207,4 +220,3 @@ def add_save_code_button(
207
  saved = add_code_to_demo_file(code_info, demo_file_path)
208
  if saved:
209
  st.success("Code saved!", icon="πŸŽ‰")
210
- st.success("Code saved!", icon="πŸŽ‰")
 
1
  """Utility functions for running webapp using streamlit."""
2
 
3
 
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
  import streamlit as st
8
  from streamlit.components.v1 import html
9
 
 
64
  # model_id required to not mix cache_data for different models
65
  assert model_id is not None
66
  cb_at = st.session_state["cb_at"]
67
+ gcb = st.session_state["gcb"]
68
  cb_acts = st.session_state["cb_acts"]
69
  if head is not None:
70
+ cb_name = f"layer{layer}_{cb_at}{gcb}{head}"
71
  else:
72
  cb_name = f"layer{layer}_{cb_at}"
73
  return utils.features_to_tokens(
 
87
  ctx_size=5,
88
  num_examples=100,
89
  return_example_list=False,
90
+ is_fsm=False,
91
  ):
92
  """Get the token activations for a given code."""
93
  ft_tkns = load_ft_tkns(model_id, layer, head, code)
94
  ft_tkns = [ft_tkns]
95
+ _, freqs, acts = utils.print_token_activations_of_codes(
96
  ft_tkns,
97
  tokens=tokens_str,
98
  indices=[0],
 
100
  n=ctx_size,
101
  max_examples=num_examples,
102
  return_example_list=return_example_list,
103
+ is_fsm=is_fsm,
104
  )
105
  return acts[0], freqs[0]
106
 
 
127
  """Find the next code that has activations in the given range."""
128
  if act_range is None:
129
  return code
130
+ min_act, max_act = 0, np.inf
131
+ if isinstance(act_range, tuple):
132
+ if len(act_range) == 2:
133
+ min_act, max_act = act_range
134
+ else:
135
+ min_act = act_range[0]
136
+ elif isinstance(act_range, int):
137
+ min_act = act_range
138
  for code_iter, code_act_count in enumerate(layer_code_acts[code:]):
139
+ if code_act_count >= min_act and code_act_count <= max_act:
140
  code += code_iter
141
  break
142
  return code
 
174
  demo_file_path: str,
175
  num_acts: int,
176
  save_regex: bool = False,
177
+ prec: Optional[float] = None,
178
+ recall: Optional[float] = None,
179
  button_st_container=st,
180
  button_text: bool = False,
181
  button_key_suffix: str = "",
 
189
  if save_button:
190
  description = st.text_input(
191
  "Write a description for the code",
192
+ key=f"save_code_desc{button_key_suffix}",
193
  )
194
  if not description:
195
  return
196
 
197
+ description = st.session_state.get(f"save_code_desc{button_key_suffix}", None)
198
  if description:
199
  layer = st.session_state["ct_act_layer"]
200
  is_attn = st.session_state["is_attn"]
 
220
  saved = add_code_to_demo_file(code_info, demo_file_path)
221
  if saved:
222
  st.success("Code saved!", icon="πŸŽ‰")