Spaces:
Runtime error
Runtime error
Add streamlit webapp files
Browse files- Code_Browser.py +373 -0
- README.md +1 -1
- code_search_utils.py +299 -0
- pages/Concept_Code.py +217 -0
- requirements.txt +5 -0
- utils.py +578 -0
- webapp_utils.py +210 -0
- webapp_utils_full_ft_tkns_for_ts.py +233 -0
Code_Browser.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Web App for the Codebook Features project."""
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
|
6 |
+
import streamlit as st
|
7 |
+
|
8 |
+
import code_search_utils
|
9 |
+
import webapp_utils
|
10 |
+
|
11 |
+
DEPLOY_MODE = True
|
12 |
+
|
13 |
+
|
14 |
+
webapp_utils.load_widget_state()
|
15 |
+
|
16 |
+
st.set_page_config(
|
17 |
+
page_title="Codebook Features",
|
18 |
+
page_icon="π",
|
19 |
+
)
|
20 |
+
|
21 |
+
st.title("Codebook Features")
|
22 |
+
|
23 |
+
base_cache_dir = "cache/"
|
24 |
+
dirs = glob.glob(base_cache_dir + "models/*/")
|
25 |
+
model_name_options = [d.split("/")[-2].split("_")[:-2] for d in dirs]
|
26 |
+
model_name_options = ["_".join(m) for m in model_name_options]
|
27 |
+
model_name_options = sorted(set(model_name_options))
|
28 |
+
|
29 |
+
model_name = st.selectbox(
|
30 |
+
"Model",
|
31 |
+
model_name_options,
|
32 |
+
key=webapp_utils.persist("model_name"),
|
33 |
+
)
|
34 |
+
|
35 |
+
model = model_name.split("_")[0].split("#")[0]
|
36 |
+
model_layers = {
|
37 |
+
"pythia-410m-deduped": 24,
|
38 |
+
"pythia-70m-deduped": 6,
|
39 |
+
"gpt2": 12,
|
40 |
+
"TinyStories-1Layer-21M": 1,
|
41 |
+
}
|
42 |
+
model_heads = {
|
43 |
+
"pythia-410m-deduped": 16,
|
44 |
+
"pythia-70m-deduped": 8,
|
45 |
+
"gpt2": 12,
|
46 |
+
"TinyStories-1Layer-21M": 16,
|
47 |
+
}
|
48 |
+
ccb = model_name.split("_")[1]
|
49 |
+
ccb = "_ccb" if ccb == "ccb" else ""
|
50 |
+
cb_at = "_".join(model_name.split("_")[2:])
|
51 |
+
seq_len = 512 if "tinystories" in model_name.lower() else 1024
|
52 |
+
st.session_state["seq_len"] = seq_len
|
53 |
+
|
54 |
+
codes_cache_path = base_cache_dir + f"models/{model_name}_*"
|
55 |
+
dirs = glob.glob(codes_cache_path)
|
56 |
+
dirs.sort(key=os.path.getmtime)
|
57 |
+
|
58 |
+
# session states
|
59 |
+
is_attn = "attn" in cb_at
|
60 |
+
num_layers = model_layers[model]
|
61 |
+
num_heads = model_heads[model]
|
62 |
+
codes_cache_path = dirs[-1] + "/"
|
63 |
+
|
64 |
+
model_info = code_search_utils.parse_model_info(codes_cache_path)
|
65 |
+
num_codes = model_info.num_codes
|
66 |
+
dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/"
|
67 |
+
|
68 |
+
(
|
69 |
+
tokens_str,
|
70 |
+
tokens_text,
|
71 |
+
token_byte_pos,
|
72 |
+
cb_acts,
|
73 |
+
act_count_ft_tkns,
|
74 |
+
metrics,
|
75 |
+
) = webapp_utils.load_code_search_cache(codes_cache_path, dataset_cache_path)
|
76 |
+
metric_keys = ["eval_loss", "eval_accuracy", "eval_dead_code_fraction"]
|
77 |
+
metrics = {k: v for k, v in metrics.items() if k.split("/")[0] in metric_keys}
|
78 |
+
|
79 |
+
st.session_state["model_name_id"] = model_name
|
80 |
+
st.session_state["cb_acts"] = cb_acts
|
81 |
+
st.session_state["tokens_text"] = tokens_text
|
82 |
+
st.session_state["tokens_str"] = tokens_str
|
83 |
+
st.session_state["act_count_ft_tkns"] = act_count_ft_tkns
|
84 |
+
|
85 |
+
st.session_state["num_codes"] = num_codes
|
86 |
+
st.session_state["ccb"] = ccb
|
87 |
+
st.session_state["cb_at"] = cb_at
|
88 |
+
st.session_state["is_attn"] = is_attn
|
89 |
+
|
90 |
+
st.markdown("## Metrics")
|
91 |
+
# hide metrics by default
|
92 |
+
if st.checkbox("Show Model Metrics"):
|
93 |
+
st.write(metrics)
|
94 |
+
|
95 |
+
st.markdown("## Demo Codes")
|
96 |
+
demo_file_path = codes_cache_path + "demo_codes.txt"
|
97 |
+
|
98 |
+
if st.checkbox("Show Demo Codes"):
|
99 |
+
try:
|
100 |
+
with open(demo_file_path, "r") as f:
|
101 |
+
demo_codes = f.readlines()
|
102 |
+
except FileNotFoundError:
|
103 |
+
demo_codes = []
|
104 |
+
|
105 |
+
code_desc, code_regex = "", ""
|
106 |
+
demo_codes = [code.strip() for code in demo_codes if code.strip()]
|
107 |
+
|
108 |
+
num_cols = 6 if is_attn else 5
|
109 |
+
cols = st.columns([1] * (num_cols - 1) + [2])
|
110 |
+
# st.markdown(button_height_style, unsafe_allow_html=True)
|
111 |
+
cols[0].markdown("Search", help="Button to see token activations for the code.")
|
112 |
+
cols[1].write("Code")
|
113 |
+
cols[2].write("Layer")
|
114 |
+
if is_attn:
|
115 |
+
cols[3].write("Head")
|
116 |
+
cols[-2].markdown(
|
117 |
+
"Num Acts",
|
118 |
+
help="Number of tokens that the code activates on in the acts dataset.",
|
119 |
+
)
|
120 |
+
cols[-1].markdown("Description", help="Interpreted description of the code.")
|
121 |
+
|
122 |
+
if len(demo_codes) == 0:
|
123 |
+
st.markdown(
|
124 |
+
f"""
|
125 |
+
<div style="font-size: 1.3rem; color: red;">
|
126 |
+
No demo codes found in file {demo_file_path}
|
127 |
+
</div>
|
128 |
+
""",
|
129 |
+
unsafe_allow_html=True,
|
130 |
+
)
|
131 |
+
skip = True
|
132 |
+
for code_txt in demo_codes:
|
133 |
+
if code_txt.startswith("##"):
|
134 |
+
skip = True
|
135 |
+
continue
|
136 |
+
if code_txt.startswith("#"):
|
137 |
+
code_desc, code_regex = code_txt[1:].split(":")
|
138 |
+
code_desc, code_regex = code_desc.strip(), code_regex.strip()
|
139 |
+
skip = False
|
140 |
+
continue
|
141 |
+
if skip:
|
142 |
+
continue
|
143 |
+
code_info = code_search_utils.get_code_info_pr_from_str(code_txt, code_regex)
|
144 |
+
comp_info = f"layer{code_info.layer}_{f'head{code_info.head}' if code_info.head is not None else ''}"
|
145 |
+
button_key = (
|
146 |
+
f"demo_search_code{code_info.code}_layer{code_info.layer}_desc-{code_info.description}"
|
147 |
+
+ (f"head{code_info.head}" if code_info.head is not None else "")
|
148 |
+
)
|
149 |
+
cols = st.columns([1] * (num_cols - 1) + [2])
|
150 |
+
button_clicked = cols[0].button(
|
151 |
+
"π",
|
152 |
+
key=button_key,
|
153 |
+
)
|
154 |
+
if button_clicked:
|
155 |
+
webapp_utils.set_ct_acts(
|
156 |
+
code_info.code, code_info.layer, code_info.head, None, is_attn
|
157 |
+
)
|
158 |
+
cols[1].write(code_info.code)
|
159 |
+
cols[2].write(str(code_info.layer))
|
160 |
+
if is_attn:
|
161 |
+
cols[3].write(str(code_info.head))
|
162 |
+
cols[-2].write(str(act_count_ft_tkns[comp_info][code_info.code]))
|
163 |
+
cols[-1].write(code_desc)
|
164 |
+
skip = True
|
165 |
+
|
166 |
+
|
167 |
+
st.markdown("## Code Search")
|
168 |
+
|
169 |
+
regex_pattern = st.text_input(
|
170 |
+
"Enter a regex pattern",
|
171 |
+
help="Wrap code token in the first group. E.g. New (York)",
|
172 |
+
key="regex_pattern",
|
173 |
+
)
|
174 |
+
# topk = st.slider("Top K", 1, 20, 10)
|
175 |
+
prec_col, sort_col = st.columns(2)
|
176 |
+
prec_threshold = prec_col.slider(
|
177 |
+
"Precision Threshold",
|
178 |
+
0.0,
|
179 |
+
1.0,
|
180 |
+
0.9,
|
181 |
+
help="Shows codes with precision on the regex pattern above the threshold.",
|
182 |
+
)
|
183 |
+
sort_by_options = ["Precision", "Recall", "Num Acts"]
|
184 |
+
sort_by_name = sort_col.radio(
|
185 |
+
"Sort By",
|
186 |
+
sort_by_options,
|
187 |
+
index=0,
|
188 |
+
horizontal=True,
|
189 |
+
help="Sorts the codes by the selected metric.",
|
190 |
+
)
|
191 |
+
sort_by = sort_by_options.index(sort_by_name)
|
192 |
+
|
193 |
+
|
194 |
+
@st.cache_data(ttl=3600)
|
195 |
+
def get_codebook_wise_codes_for_regex(regex_pattern, prec_threshold, ccb, model_name):
|
196 |
+
"""Get codebook wise codes for a given regex pattern."""
|
197 |
+
assert model_name is not None # required for loading from correct cache data
|
198 |
+
return code_search_utils.get_codes_from_pattern(
|
199 |
+
regex_pattern,
|
200 |
+
tokens_text,
|
201 |
+
token_byte_pos,
|
202 |
+
cb_acts,
|
203 |
+
act_count_ft_tkns,
|
204 |
+
ccb=ccb,
|
205 |
+
topk=8,
|
206 |
+
prec_threshold=prec_threshold,
|
207 |
+
)
|
208 |
+
|
209 |
+
|
210 |
+
if regex_pattern:
|
211 |
+
codebook_wise_codes, re_token_matches = get_codebook_wise_codes_for_regex(
|
212 |
+
regex_pattern,
|
213 |
+
prec_threshold,
|
214 |
+
ccb,
|
215 |
+
model_name,
|
216 |
+
)
|
217 |
+
st.markdown(f"Found :green[{re_token_matches}] matches")
|
218 |
+
num_search_cols = 7 if is_attn else 6
|
219 |
+
non_deploy_offset = 0
|
220 |
+
if not DEPLOY_MODE:
|
221 |
+
non_deploy_offset = 1
|
222 |
+
num_search_cols += non_deploy_offset
|
223 |
+
|
224 |
+
cols = st.columns(num_search_cols)
|
225 |
+
|
226 |
+
# st.markdown(button_height_style, unsafe_allow_html=True)
|
227 |
+
|
228 |
+
cols[0].markdown("Search", help="Button to see token activations for the code.")
|
229 |
+
cols[1].write("Layer")
|
230 |
+
if is_attn:
|
231 |
+
cols[2].write("Head")
|
232 |
+
cols[-4 - non_deploy_offset].write("Code")
|
233 |
+
cols[-3 - non_deploy_offset].write("Precision")
|
234 |
+
cols[-2 - non_deploy_offset].write("Recall")
|
235 |
+
cols[-1 - non_deploy_offset].markdown(
|
236 |
+
"Num Acts",
|
237 |
+
help="Number of tokens that the code activates on in the acts dataset.",
|
238 |
+
)
|
239 |
+
if not DEPLOY_MODE:
|
240 |
+
cols[-1].markdown(
|
241 |
+
"Save to Demos",
|
242 |
+
help="Button to save the code to demos along with the regex pattern.",
|
243 |
+
)
|
244 |
+
all_codes = codebook_wise_codes.items()
|
245 |
+
all_codes = [
|
246 |
+
(cb_name, code_pr_info)
|
247 |
+
for cb_name, code_pr_infos in all_codes
|
248 |
+
for code_pr_info in code_pr_infos
|
249 |
+
]
|
250 |
+
all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True)
|
251 |
+
for cb_name, (code, prec, rec, code_acts) in all_codes:
|
252 |
+
layer_head = cb_name.split("_")
|
253 |
+
layer = layer_head[0][5:]
|
254 |
+
head = layer_head[1][4:] if len(layer_head) > 1 else None
|
255 |
+
button_key = f"search_code{code}_layer{layer}" + (
|
256 |
+
f"head{head}" if head is not None else ""
|
257 |
+
)
|
258 |
+
cols = st.columns(num_search_cols)
|
259 |
+
extra_args = {
|
260 |
+
"prec": prec,
|
261 |
+
"recall": rec,
|
262 |
+
"num_acts": code_acts,
|
263 |
+
"regex": regex_pattern,
|
264 |
+
}
|
265 |
+
button_clicked = cols[0].button("π", key=button_key)
|
266 |
+
if button_clicked:
|
267 |
+
webapp_utils.set_ct_acts(code, layer, head, extra_args, is_attn)
|
268 |
+
cols[1].write(layer)
|
269 |
+
if is_attn:
|
270 |
+
cols[2].write(head)
|
271 |
+
cols[-4 - non_deploy_offset].write(code)
|
272 |
+
cols[-3 - non_deploy_offset].write(f"{prec*100:.2f}%")
|
273 |
+
cols[-2 - non_deploy_offset].write(f"{rec*100:.2f}%")
|
274 |
+
cols[-1 - non_deploy_offset].write(str(code_acts))
|
275 |
+
if not DEPLOY_MODE:
|
276 |
+
webapp_utils.add_save_code_button(
|
277 |
+
demo_file_path,
|
278 |
+
num_acts=code_acts,
|
279 |
+
save_regex=True,
|
280 |
+
prec=prec,
|
281 |
+
recall=rec,
|
282 |
+
button_st_container=cols[-1],
|
283 |
+
button_key_suffix=f"_code{code}_layer{layer}_head{head}",
|
284 |
+
)
|
285 |
+
|
286 |
+
if len(all_codes) == 0:
|
287 |
+
st.markdown(
|
288 |
+
f"""
|
289 |
+
<div style="font-size: 1.0rem; color: red;">
|
290 |
+
No codes found for pattern {regex_pattern} at precision threshold: {prec_threshold}
|
291 |
+
</div>
|
292 |
+
""",
|
293 |
+
unsafe_allow_html=True,
|
294 |
+
)
|
295 |
+
|
296 |
+
|
297 |
+
st.markdown("## Code Token Activations")
|
298 |
+
|
299 |
+
filter_codes = st.checkbox("Filter Codes", key="filter_codes")
|
300 |
+
act_range, layer_code_acts = None, None
|
301 |
+
if filter_codes:
|
302 |
+
act_range = st.slider(
|
303 |
+
"Num Acts",
|
304 |
+
0,
|
305 |
+
10_000,
|
306 |
+
(100, 10_000),
|
307 |
+
key="ct_act_range",
|
308 |
+
help="Filter codes by the number of tokens they activate on.",
|
309 |
+
)
|
310 |
+
|
311 |
+
cols = st.columns(5 if is_attn else 4)
|
312 |
+
layer = cols[0].number_input("Layer", 0, num_layers - 1, 0, key="ct_act_layer")
|
313 |
+
if is_attn:
|
314 |
+
head = cols[1].number_input("Head", 0, num_heads - 1, 0, key="ct_act_head")
|
315 |
+
else:
|
316 |
+
head = None
|
317 |
+
|
318 |
+
def_code = st.session_state.get("ct_act_code", 0)
|
319 |
+
if filter_codes:
|
320 |
+
layer_code_acts = act_count_ft_tkns[
|
321 |
+
f"layer{layer}{'_head'+str(head) if head is not None else ''}"
|
322 |
+
]
|
323 |
+
def_code = webapp_utils.find_next_code(def_code, layer_code_acts, act_range)
|
324 |
+
if "ct_act_code" in st.session_state:
|
325 |
+
st.session_state["ct_act_code"] = def_code
|
326 |
+
|
327 |
+
code = cols[-3].number_input(
|
328 |
+
"Code",
|
329 |
+
0,
|
330 |
+
num_codes - 1,
|
331 |
+
def_code,
|
332 |
+
key="ct_act_code",
|
333 |
+
)
|
334 |
+
num_examples = cols[-2].number_input(
|
335 |
+
"Max Results",
|
336 |
+
-1,
|
337 |
+
1000, # setting to 1000 for efficiency purposes even though it can be more than 1000.
|
338 |
+
100,
|
339 |
+
help="Number of examples to show in the results. Set to -1 to show all examples.",
|
340 |
+
)
|
341 |
+
ctx_size = cols[-1].number_input(
|
342 |
+
"Context Size",
|
343 |
+
1,
|
344 |
+
10,
|
345 |
+
5,
|
346 |
+
help="Number of tokens to show before and after the code token.",
|
347 |
+
)
|
348 |
+
|
349 |
+
acts, acts_count = webapp_utils.get_code_acts(
|
350 |
+
model_name,
|
351 |
+
tokens_str,
|
352 |
+
code,
|
353 |
+
layer,
|
354 |
+
head,
|
355 |
+
ctx_size,
|
356 |
+
num_examples,
|
357 |
+
)
|
358 |
+
|
359 |
+
st.write(
|
360 |
+
f"Token Activations for Layer {layer}{f' Head {head}' if head is not None else ''} Code {code} | "
|
361 |
+
f"Activates on {acts_count[0]} tokens on the acts dataset",
|
362 |
+
)
|
363 |
+
|
364 |
+
if not DEPLOY_MODE:
|
365 |
+
webapp_utils.add_save_code_button(
|
366 |
+
demo_file_path,
|
367 |
+
acts_count[0],
|
368 |
+
save_regex=False,
|
369 |
+
button_text=True,
|
370 |
+
button_key_suffix="_token_acts",
|
371 |
+
)
|
372 |
+
|
373 |
+
st.markdown(webapp_utils.escape_markdown(acts), unsafe_allow_html=True)
|
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: gray
|
|
5 |
colorTo: green
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.25.0
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
|
|
5 |
colorTo: green
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.25.0
|
8 |
+
app_file: Code_Browser.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
code_search_utils.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Functions to help with searching codes using regex."""
|
2 |
+
|
3 |
+
import pickle
|
4 |
+
import re
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import utils
|
13 |
+
|
14 |
+
|
15 |
+
def load_dataset_cache(cache_base_path):
|
16 |
+
"""Load cache files required for dataset from `cache_base_path`."""
|
17 |
+
tokens_str = np.load(cache_base_path + "tokens_str.npy")
|
18 |
+
tokens_text = np.load(cache_base_path + "tokens_text.npy")
|
19 |
+
token_byte_pos = np.load(cache_base_path + "token_byte_pos.npy")
|
20 |
+
return tokens_str, tokens_text, token_byte_pos
|
21 |
+
|
22 |
+
|
23 |
+
def load_code_search_cache(cache_base_path):
|
24 |
+
"""Load cache files required for code search from `cache_base_path`."""
|
25 |
+
metrics = np.load(cache_base_path + "metrics.npy", allow_pickle=True).item()
|
26 |
+
with open(cache_base_path + "cb_acts.pkl", "rb") as f:
|
27 |
+
cb_acts = pickle.load(f)
|
28 |
+
with open(cache_base_path + "act_count_ft_tkns.pkl", "rb") as f:
|
29 |
+
act_count_ft_tkns = pickle.load(f)
|
30 |
+
|
31 |
+
return cb_acts, act_count_ft_tkns, metrics
|
32 |
+
|
33 |
+
|
34 |
+
def search_re(re_pattern, tokens_text):
|
35 |
+
"""Get list of (example_id, token_pos) where re_pattern matches in tokens_text."""
|
36 |
+
# TODO: ensure that parantheses are not escaped
|
37 |
+
if re_pattern.find("(") == -1:
|
38 |
+
re_pattern = f"({re_pattern})"
|
39 |
+
return [
|
40 |
+
(i, finditer.span(1)[0])
|
41 |
+
for i, text in enumerate(tokens_text)
|
42 |
+
for finditer in re.finditer(re_pattern, text)
|
43 |
+
if finditer.span(1)[0] != finditer.span(1)[1]
|
44 |
+
]
|
45 |
+
|
46 |
+
|
47 |
+
def byte_id_to_token_pos_id(example_byte_id, token_byte_pos):
|
48 |
+
"""Get (example_id, token_pos_id) for given (example_id, byte_id)."""
|
49 |
+
example_id, byte_id = example_byte_id
|
50 |
+
index = np.searchsorted(token_byte_pos[example_id], byte_id, side="right")
|
51 |
+
return (example_id, index)
|
52 |
+
|
53 |
+
|
54 |
+
def get_code_pr(token_pos_ids, codebook_acts, cb_act_counts=None):
|
55 |
+
"""Get codes, prec, recall for given token_pos_ids and codebook_acts."""
|
56 |
+
codes = np.array(
|
57 |
+
[
|
58 |
+
codebook_acts[example_id][token_pos_id]
|
59 |
+
for example_id, token_pos_id in token_pos_ids
|
60 |
+
]
|
61 |
+
)
|
62 |
+
codes, counts = np.unique(codes, return_counts=True)
|
63 |
+
recall = counts / len(token_pos_ids)
|
64 |
+
idx = recall > 0.01
|
65 |
+
codes, counts, recall = codes[idx], counts[idx], recall[idx]
|
66 |
+
if cb_act_counts is not None:
|
67 |
+
code_acts = np.array([cb_act_counts[code] for code in codes])
|
68 |
+
prec = counts / code_acts
|
69 |
+
sort_idx = np.argsort(prec)[::-1]
|
70 |
+
else:
|
71 |
+
code_acts = np.zeros_like(codes)
|
72 |
+
prec = np.zeros_like(codes)
|
73 |
+
sort_idx = np.argsort(recall)[::-1]
|
74 |
+
codes, prec, recall = codes[sort_idx], prec[sort_idx], recall[sort_idx]
|
75 |
+
code_acts = code_acts[sort_idx]
|
76 |
+
return codes, prec, recall, code_acts
|
77 |
+
|
78 |
+
|
79 |
+
def get_neuron_pr(
|
80 |
+
token_pos_ids, recall, neuron_acts_by_ex, neuron_sorted_acts, topk=10
|
81 |
+
):
|
82 |
+
"""Get codes, prec, recall for given token_pos_ids and codebook_acts."""
|
83 |
+
# check if neuron_acts_by_ex is a torch tensor
|
84 |
+
if isinstance(neuron_acts_by_ex, torch.Tensor):
|
85 |
+
re_neuron_acts = torch.stack(
|
86 |
+
[
|
87 |
+
neuron_acts_by_ex[example_id, token_pos_id]
|
88 |
+
for example_id, token_pos_id in token_pos_ids
|
89 |
+
],
|
90 |
+
dim=-1,
|
91 |
+
) # (layers, 2, dim_size, matches)
|
92 |
+
re_neuron_acts = torch.sort(re_neuron_acts, dim=-1).values
|
93 |
+
else:
|
94 |
+
re_neuron_acts = np.stack(
|
95 |
+
[
|
96 |
+
neuron_acts_by_ex[example_id, token_pos_id]
|
97 |
+
for example_id, token_pos_id in token_pos_ids
|
98 |
+
],
|
99 |
+
axis=-1,
|
100 |
+
) # (layers, 2, dim_size, matches)
|
101 |
+
re_neuron_acts.sort(axis=-1)
|
102 |
+
re_neuron_acts = torch.from_numpy(re_neuron_acts)
|
103 |
+
# re_neuron_acts = re_neuron_acts[:, :, :, -int(recall * re_neuron_acts.shape[-1]) :]
|
104 |
+
print("Examples for recall", recall, ":", int(recall * re_neuron_acts.shape[-1]))
|
105 |
+
act_thresh = re_neuron_acts[:, :, :, -int(recall * re_neuron_acts.shape[-1])]
|
106 |
+
# binary search act_thresh in neuron_sorted_acts
|
107 |
+
assert neuron_sorted_acts.shape[:-1] == act_thresh.shape
|
108 |
+
prec_den = torch.searchsorted(neuron_sorted_acts, act_thresh.unsqueeze(-1))
|
109 |
+
prec_den = prec_den.squeeze(-1)
|
110 |
+
prec_den = neuron_sorted_acts.shape[-1] - prec_den
|
111 |
+
prec = int(recall * re_neuron_acts.shape[-1]) / prec_den
|
112 |
+
assert (
|
113 |
+
prec.shape == re_neuron_acts.shape[:-1]
|
114 |
+
), f"{prec.shape} != {re_neuron_acts.shape[:-1]}"
|
115 |
+
|
116 |
+
best_neuron_idx = np.unravel_index(prec.argmax(), prec.shape)
|
117 |
+
best_prec = prec[best_neuron_idx]
|
118 |
+
print("max prec:", best_prec)
|
119 |
+
best_neuron_act_thresh = act_thresh[best_neuron_idx].item()
|
120 |
+
best_neuron_acts = neuron_acts_by_ex[
|
121 |
+
:, :, best_neuron_idx[0], best_neuron_idx[1], best_neuron_idx[2]
|
122 |
+
]
|
123 |
+
best_neuron_acts = best_neuron_acts >= best_neuron_act_thresh
|
124 |
+
best_neuron_acts = np.stack(np.where(best_neuron_acts), axis=-1)
|
125 |
+
|
126 |
+
return best_prec, best_neuron_acts, best_neuron_idx
|
127 |
+
|
128 |
+
|
129 |
+
def convert_to_adv_name(name, cb_at, ccb=""):
|
130 |
+
"""Convert layer0_head0 to layer0_attn_preproj_ccb0."""
|
131 |
+
if ccb:
|
132 |
+
layer, head = name.split("_")
|
133 |
+
return layer + f"_{cb_at}_ccb" + head[4:]
|
134 |
+
else:
|
135 |
+
return layer + "_" + cb_at
|
136 |
+
|
137 |
+
|
138 |
+
def convert_to_base_name(name, ccb=""):
|
139 |
+
"""Convert layer0_attn_preproj_ccb0 to layer0_head0."""
|
140 |
+
split_name = name.split("_")
|
141 |
+
layer, head = split_name[0], split_name[-1][3:]
|
142 |
+
if "ccb" in name:
|
143 |
+
return layer + "_head" + head
|
144 |
+
else:
|
145 |
+
return layer
|
146 |
+
|
147 |
+
|
148 |
+
def get_layer_head_from_base_name(name):
|
149 |
+
"""Convert layer0_head0 to 0, 0."""
|
150 |
+
split_name = name.split("_")
|
151 |
+
layer = int(split_name[0][5:])
|
152 |
+
head = None
|
153 |
+
if len(split_name) > 1:
|
154 |
+
head = int(split_name[-1][4:])
|
155 |
+
return layer, head
|
156 |
+
|
157 |
+
|
158 |
+
def get_layer_head_from_adv_name(name):
|
159 |
+
"""Convert layer0_attn_preproj_ccb0 to 0, 0."""
|
160 |
+
base_name = convert_to_base_name(name)
|
161 |
+
layer, head = get_layer_head_from_base_name(base_name)
|
162 |
+
return layer, head
|
163 |
+
|
164 |
+
|
165 |
+
def get_codes_from_pattern(
|
166 |
+
re_pattern,
|
167 |
+
tokens_text,
|
168 |
+
token_byte_pos,
|
169 |
+
cb_acts,
|
170 |
+
act_count_ft_tkns,
|
171 |
+
ccb="",
|
172 |
+
topk=5,
|
173 |
+
prec_threshold=0.5,
|
174 |
+
):
|
175 |
+
"""Fetch codes from a given regex pattern."""
|
176 |
+
byte_ids = search_re(re_pattern, tokens_text)
|
177 |
+
token_pos_ids = [
|
178 |
+
byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
|
179 |
+
]
|
180 |
+
token_pos_ids = np.unique(token_pos_ids, axis=0)
|
181 |
+
re_token_matches = len(token_pos_ids)
|
182 |
+
codebook_wise_codes = {}
|
183 |
+
for cb_name, cb in tqdm(cb_acts.items()):
|
184 |
+
base_cb_name = convert_to_base_name(cb_name, ccb=ccb)
|
185 |
+
codes, prec, recall, code_acts = get_code_pr(
|
186 |
+
token_pos_ids,
|
187 |
+
cb,
|
188 |
+
cb_act_counts=act_count_ft_tkns[base_cb_name],
|
189 |
+
)
|
190 |
+
idx = np.arange(min(topk, len(codes)))
|
191 |
+
idx = idx[prec[:topk] > prec_threshold]
|
192 |
+
codes, prec, recall = codes[idx], prec[idx], recall[idx]
|
193 |
+
code_acts = code_acts[idx]
|
194 |
+
codes_pr = list(zip(codes, prec, recall, code_acts))
|
195 |
+
codebook_wise_codes[base_cb_name] = codes_pr
|
196 |
+
return codebook_wise_codes, re_token_matches
|
197 |
+
|
198 |
+
|
199 |
+
def get_neurons_from_pattern(
|
200 |
+
re_pattern,
|
201 |
+
tokens_text,
|
202 |
+
token_byte_pos,
|
203 |
+
neuron_acts_by_ex,
|
204 |
+
neuron_sorted_acts,
|
205 |
+
recall_threshold,
|
206 |
+
):
|
207 |
+
"""Fetch the best neuron (with act thresh given by recall) from a given regex pattern."""
|
208 |
+
byte_ids = search_re(re_pattern, tokens_text)
|
209 |
+
token_pos_ids = [
|
210 |
+
byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
|
211 |
+
]
|
212 |
+
token_pos_ids = np.unique(token_pos_ids, axis=0)
|
213 |
+
re_token_matches = len(token_pos_ids)
|
214 |
+
best_prec, best_neuron_acts, best_neuron_idx = get_neuron_pr(
|
215 |
+
token_pos_ids,
|
216 |
+
recall_threshold,
|
217 |
+
neuron_acts_by_ex,
|
218 |
+
neuron_sorted_acts,
|
219 |
+
)
|
220 |
+
return best_prec, best_neuron_acts, best_neuron_idx, re_token_matches
|
221 |
+
|
222 |
+
|
223 |
+
def compare_codes_with_neurons(
|
224 |
+
best_codes_info,
|
225 |
+
tokens_text,
|
226 |
+
token_byte_pos,
|
227 |
+
neuron_acts_by_ex,
|
228 |
+
neuron_sorted_acts,
|
229 |
+
):
|
230 |
+
"""Compare codes with neurons."""
|
231 |
+
assert isinstance(neuron_acts_by_ex, np.ndarray)
|
232 |
+
(
|
233 |
+
all_best_prec,
|
234 |
+
all_best_neuron_acts,
|
235 |
+
all_best_neuron_idxs,
|
236 |
+
all_re_token_matches,
|
237 |
+
) = zip(
|
238 |
+
*[
|
239 |
+
get_neurons_from_pattern(
|
240 |
+
code_info.re_pattern,
|
241 |
+
tokens_text,
|
242 |
+
token_byte_pos,
|
243 |
+
neuron_acts_by_ex,
|
244 |
+
neuron_sorted_acts,
|
245 |
+
code_info.recall,
|
246 |
+
)
|
247 |
+
for code_info in tqdm(range(len(best_codes_info)))
|
248 |
+
],
|
249 |
+
strict=True,
|
250 |
+
)
|
251 |
+
code_best_precs = np.array(
|
252 |
+
[code_info.prec for code_info in range(len(best_codes_info))]
|
253 |
+
)
|
254 |
+
codes_better_than_neurons = code_best_precs > np.array(all_best_prec)
|
255 |
+
return codes_better_than_neurons.mean()
|
256 |
+
|
257 |
+
|
258 |
+
def get_code_info_pr_from_str(code_txt, regex):
|
259 |
+
"""Extract code info fields from string."""
|
260 |
+
code_txt = code_txt.strip()
|
261 |
+
code_txt = code_txt.split(", ")
|
262 |
+
code_txt = dict(txt.split(": ") for txt in code_txt)
|
263 |
+
return utils.CodeInfo(**code_txt)
|
264 |
+
|
265 |
+
|
266 |
+
@dataclass
|
267 |
+
class ModelInfoForWebapp:
|
268 |
+
"""Model info for webapp."""
|
269 |
+
|
270 |
+
model_name: str
|
271 |
+
pretrained_path: str
|
272 |
+
dataset_name: str
|
273 |
+
num_codes: int
|
274 |
+
cb_at: str
|
275 |
+
ccb: str
|
276 |
+
n_layers: int
|
277 |
+
n_heads: Optional[int] = None
|
278 |
+
seed: int = 42
|
279 |
+
max_samples: int = 2000
|
280 |
+
|
281 |
+
def __post_init__(self):
|
282 |
+
"""Convert to correct types."""
|
283 |
+
self.num_codes = int(self.num_codes)
|
284 |
+
self.n_layers = int(self.n_layers)
|
285 |
+
if self.n_heads == "None":
|
286 |
+
self.n_heads = None
|
287 |
+
elif self.n_heads is not None:
|
288 |
+
self.n_heads = int(self.n_heads)
|
289 |
+
self.seed = int(self.seed)
|
290 |
+
self.max_samples = int(self.max_samples)
|
291 |
+
|
292 |
+
|
293 |
+
def parse_model_info(path):
|
294 |
+
"""Parse model info from path."""
|
295 |
+
with open(path + "info.txt", "r") as f:
|
296 |
+
lines = f.readlines()
|
297 |
+
lines = dict(line.strip().split(": ") for line in lines)
|
298 |
+
return ModelInfoForWebapp(**lines)
|
299 |
+
return ModelInfoForWebapp(**lines)
|
pages/Concept_Code.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Web app page for showing codes for different examples in the dataset."""
|
2 |
+
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
from streamlit_extras.switch_page_button import switch_page
|
6 |
+
|
7 |
+
import code_search_utils
|
8 |
+
import webapp_utils
|
9 |
+
|
10 |
+
webapp_utils.load_widget_state()
|
11 |
+
|
12 |
+
if "cb_acts" not in st.session_state:
|
13 |
+
switch_page("Code_Browser")
|
14 |
+
|
15 |
+
total_examples = 2000
|
16 |
+
prec_threshold = 0.01
|
17 |
+
|
18 |
+
model_name = st.session_state["model_name_id"]
|
19 |
+
seq_len = st.session_state["seq_len"]
|
20 |
+
tokens_text = st.session_state["tokens_text"]
|
21 |
+
tokens_str = st.session_state["tokens_str"]
|
22 |
+
cb_acts = st.session_state["cb_acts"]
|
23 |
+
act_count_ft_tkns = st.session_state["act_count_ft_tkns"]
|
24 |
+
ccb = st.session_state["ccb"]
|
25 |
+
|
26 |
+
|
27 |
+
def get_example_concept_codes(example_id):
|
28 |
+
"""Get concept codes for the given example id."""
|
29 |
+
token_pos_ids = [(example_id, i) for i in range(seq_len)]
|
30 |
+
all_codes = []
|
31 |
+
for cb_name, cb in cb_acts.items():
|
32 |
+
base_cb_name = code_search_utils.convert_to_base_name(cb_name, ccb=ccb)
|
33 |
+
codes, prec, rec, code_acts = code_search_utils.get_code_pr(
|
34 |
+
token_pos_ids,
|
35 |
+
cb,
|
36 |
+
act_count_ft_tkns[base_cb_name],
|
37 |
+
)
|
38 |
+
prec_sat_idx = prec >= prec_threshold
|
39 |
+
codes, prec, rec, code_acts = (
|
40 |
+
codes[prec_sat_idx],
|
41 |
+
prec[prec_sat_idx],
|
42 |
+
rec[prec_sat_idx],
|
43 |
+
code_acts[prec_sat_idx],
|
44 |
+
)
|
45 |
+
rec_sat_idx = rec >= recall_threshold
|
46 |
+
codes, prec, rec, code_acts = (
|
47 |
+
codes[rec_sat_idx],
|
48 |
+
prec[rec_sat_idx],
|
49 |
+
rec[rec_sat_idx],
|
50 |
+
code_acts[rec_sat_idx],
|
51 |
+
)
|
52 |
+
codes_pr = list(zip(codes, prec, rec, code_acts))
|
53 |
+
all_codes.append((cb_name, codes_pr))
|
54 |
+
return all_codes
|
55 |
+
|
56 |
+
|
57 |
+
def find_next_example(example_id):
|
58 |
+
"""Find the example after `example_id` that has concept codes."""
|
59 |
+
initial_example_id = example_id
|
60 |
+
example_id += 1
|
61 |
+
while example_id != initial_example_id:
|
62 |
+
all_codes = get_example_concept_codes(example_id)
|
63 |
+
codes_found = sum([len(code_pr_infos) for _, code_pr_infos in all_codes])
|
64 |
+
if codes_found > 0:
|
65 |
+
st.session_state["example_id"] = example_id
|
66 |
+
return
|
67 |
+
example_id = (example_id + 1) % total_examples
|
68 |
+
st.error(
|
69 |
+
f"No examples found at the specified recall threshold: {recall_threshold}.",
|
70 |
+
icon="π¨",
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
def redirect_to_main_with_code(code, layer, head):
|
75 |
+
"""Redirect to main page with the given code."""
|
76 |
+
st.session_state["ct_act_code"] = code
|
77 |
+
st.session_state["ct_act_layer"] = layer
|
78 |
+
if st.session_state["is_attn"]:
|
79 |
+
st.session_state["ct_act_head"] = head
|
80 |
+
switch_page("Code Browser")
|
81 |
+
|
82 |
+
|
83 |
+
def show_examples_for_concept_code(code, layer, head, code_act_ratio=0.3):
|
84 |
+
"""Show examples that the code activates on."""
|
85 |
+
ex_acts, _ = webapp_utils.get_code_acts(
|
86 |
+
model_name,
|
87 |
+
tokens_str,
|
88 |
+
code,
|
89 |
+
layer,
|
90 |
+
head,
|
91 |
+
ctx_size=5,
|
92 |
+
return_example_list=True,
|
93 |
+
)
|
94 |
+
filt_ex_acts = []
|
95 |
+
for act_str, num_acts in ex_acts:
|
96 |
+
if num_acts > seq_len * code_act_ratio:
|
97 |
+
filt_ex_acts.append(act_str)
|
98 |
+
st.markdown("#### Examples for Code")
|
99 |
+
st.markdown(
|
100 |
+
webapp_utils.escape_markdown("".join(filt_ex_acts)), unsafe_allow_html=True
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
+
is_attn = st.session_state["is_attn"]
|
105 |
+
|
106 |
+
st.markdown("## Concept Code")
|
107 |
+
concept_code_description = (
|
108 |
+
"Concept codes are codes that activate a lot on only a particular set of examples that share a concept. "
|
109 |
+
"Hence such codes can be thought to correspond to more higher-level concepts or features and "
|
110 |
+
"can activate on most tokens that belong in an example text. This interface provides a way to search for such "
|
111 |
+
"codes by going through different examples using Example ID."
|
112 |
+
)
|
113 |
+
st.write(concept_code_description)
|
114 |
+
|
115 |
+
# ex_col, p_col, r_col, trunc_col, sort_col = st.columns([1, 2, 2, 1, 1])
|
116 |
+
ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1])
|
117 |
+
example_id = ex_col.number_input(
|
118 |
+
"Example ID",
|
119 |
+
0,
|
120 |
+
total_examples - 1,
|
121 |
+
0,
|
122 |
+
key="example_id",
|
123 |
+
)
|
124 |
+
# prec_threshold = p_col.slider(
|
125 |
+
# "Precision Threshold",
|
126 |
+
# 0.0,
|
127 |
+
# 1.0,
|
128 |
+
# 0.02,
|
129 |
+
# key="prec",
|
130 |
+
# help="Precision Threshold controls the specificity of the codes for the given example.",
|
131 |
+
# )
|
132 |
+
recall_threshold = r_col.slider(
|
133 |
+
"Recall Threshold",
|
134 |
+
0.0,
|
135 |
+
1.0,
|
136 |
+
0.3,
|
137 |
+
key="recall",
|
138 |
+
help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.",
|
139 |
+
)
|
140 |
+
example_truncation = trunc_col.number_input(
|
141 |
+
"Max Output Chars", 0, 10240, 1024, key="max_chars"
|
142 |
+
)
|
143 |
+
sort_by_options = ["Precision", "Recall", "Num Acts"]
|
144 |
+
sort_by_name = sort_col.radio(
|
145 |
+
"Sort By",
|
146 |
+
sort_by_options,
|
147 |
+
index=0,
|
148 |
+
horizontal=True,
|
149 |
+
help="Sorts the codes by the selected metric.",
|
150 |
+
)
|
151 |
+
sort_by = sort_by_options.index(sort_by_name)
|
152 |
+
|
153 |
+
|
154 |
+
button = st.button(
|
155 |
+
"Find Next Example",
|
156 |
+
key="find_next_example",
|
157 |
+
on_click=find_next_example,
|
158 |
+
args=(example_id,),
|
159 |
+
help="Find an example which has codes above the recall threshold.",
|
160 |
+
)
|
161 |
+
# if button:
|
162 |
+
# find_next_example(st.session_state["example_id"])
|
163 |
+
|
164 |
+
|
165 |
+
st.markdown("### Example Text")
|
166 |
+
trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else ""
|
167 |
+
st.write(tokens_text[example_id][:example_truncation] + trunc_suffix)
|
168 |
+
|
169 |
+
cols = st.columns(7 if is_attn else 6)
|
170 |
+
cols[0].markdown("Search", help="Button to see token activations for the code.")
|
171 |
+
cols[1].write("Layer")
|
172 |
+
if is_attn:
|
173 |
+
cols[2].write("Head")
|
174 |
+
cols[-4].write("Code")
|
175 |
+
cols[-3].write("Precision")
|
176 |
+
cols[-2].write("Recall")
|
177 |
+
cols[-1].markdown(
|
178 |
+
"Num Acts",
|
179 |
+
help="Number of tokens that the code activates on in the acts dataset.",
|
180 |
+
)
|
181 |
+
|
182 |
+
all_codes = get_example_concept_codes(example_id)
|
183 |
+
all_codes = [
|
184 |
+
(cb_name, code_pr_info)
|
185 |
+
for cb_name, code_pr_infos in all_codes
|
186 |
+
for code_pr_info in code_pr_infos
|
187 |
+
]
|
188 |
+
all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True)
|
189 |
+
|
190 |
+
for cb_name, (code, p, r, acts) in all_codes:
|
191 |
+
cols = st.columns(7 if is_attn else 6)
|
192 |
+
code_button = cols[0].button(
|
193 |
+
"π",
|
194 |
+
key=f"ex-code-{code}-{cb_name}",
|
195 |
+
)
|
196 |
+
layer, head = code_search_utils.get_layer_head_from_adv_name(cb_name)
|
197 |
+
cols[1].write(str(layer))
|
198 |
+
if is_attn:
|
199 |
+
cols[2].write(str(head))
|
200 |
+
|
201 |
+
cols[-4].write(code)
|
202 |
+
cols[-3].write(f"{p*100:.2f}%")
|
203 |
+
cols[-2].write(f"{r*100:.2f}%")
|
204 |
+
cols[-1].write(str(acts))
|
205 |
+
|
206 |
+
if code_button:
|
207 |
+
show_examples_for_concept_code(
|
208 |
+
code,
|
209 |
+
layer,
|
210 |
+
head,
|
211 |
+
code_act_ratio=recall_threshold,
|
212 |
+
)
|
213 |
+
if len(all_codes) == 0:
|
214 |
+
st.markdown(
|
215 |
+
f"<div style='text-align:center'>No codes found at recall threshold: {recall_threshold}</div>",
|
216 |
+
unsafe_allow_html=True,
|
217 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
torch>=2.0.0
|
3 |
+
tqdm
|
4 |
+
termcolor
|
5 |
+
streamlit_extras
|
utils.py
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Util functions for codebook features."""
|
2 |
+
import re
|
3 |
+
import typing
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from functools import partial
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from termcolor import colored
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class CodeInfo:
|
17 |
+
"""Dataclass for codebook info."""
|
18 |
+
|
19 |
+
code: int
|
20 |
+
layer: int
|
21 |
+
head: Optional[int]
|
22 |
+
cb_at: Optional[str] = None
|
23 |
+
|
24 |
+
# for patching interventions
|
25 |
+
pos: Optional[int] = None
|
26 |
+
code_pos: Optional[int] = -1
|
27 |
+
|
28 |
+
# for description & regex-based interpretation
|
29 |
+
description: Optional[str] = None
|
30 |
+
regex: Optional[str] = None
|
31 |
+
prec: Optional[float] = None
|
32 |
+
recall: Optional[float] = None
|
33 |
+
num_acts: Optional[int] = None
|
34 |
+
|
35 |
+
def __post_init__(self):
|
36 |
+
"""Convert to appropriate types."""
|
37 |
+
self.code = int(self.code)
|
38 |
+
self.layer = int(self.layer)
|
39 |
+
if self.head:
|
40 |
+
self.head = int(self.head)
|
41 |
+
if self.pos:
|
42 |
+
self.pos = int(self.pos)
|
43 |
+
if self.code_pos:
|
44 |
+
self.code_pos = int(self.code_pos)
|
45 |
+
if self.prec:
|
46 |
+
self.prec = float(self.prec)
|
47 |
+
assert 0 <= self.prec <= 1
|
48 |
+
if self.recall:
|
49 |
+
self.recall = float(self.recall)
|
50 |
+
assert 0 <= self.recall <= 1
|
51 |
+
if self.num_acts:
|
52 |
+
self.num_acts = int(self.num_acts)
|
53 |
+
|
54 |
+
def check_description_info(self):
|
55 |
+
"""Check if the regex info is present."""
|
56 |
+
assert self.num_acts is not None and self.description is not None
|
57 |
+
if self.regex is not None:
|
58 |
+
assert self.prec is not None and self.recall is not None
|
59 |
+
|
60 |
+
def check_patch_info(self):
|
61 |
+
"""Check if the patch info is present."""
|
62 |
+
# TODO: pos can be none for patching
|
63 |
+
assert self.pos is not None and self.code_pos is not None
|
64 |
+
|
65 |
+
def __repr__(self):
|
66 |
+
"""Return the string representation."""
|
67 |
+
repr = f"CodeInfo(code={self.code}, layer={self.layer}, head={self.head}, cb_at={self.cb_at}"
|
68 |
+
if self.pos is not None or self.code_pos is not None:
|
69 |
+
repr += f", pos={self.pos}, code_pos={self.code_pos}"
|
70 |
+
if self.description is not None:
|
71 |
+
repr += f", description={self.description}"
|
72 |
+
if self.regex is not None:
|
73 |
+
repr += f", regex={self.regex}, prec={self.prec}, recall={self.recall}"
|
74 |
+
if self.num_acts is not None:
|
75 |
+
repr += f", num_acts={self.num_acts}"
|
76 |
+
repr += ")"
|
77 |
+
return repr
|
78 |
+
|
79 |
+
|
80 |
+
def logits_to_pred(logits, tokenizer, k=5):
|
81 |
+
"""Convert logits to top-k predictions."""
|
82 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
83 |
+
probs = sorted_logits.softmax(dim=-1)
|
84 |
+
topk_preds = [tokenizer.convert_ids_to_tokens(e) for e in sorted_indices[:, -1, :k]]
|
85 |
+
topk_preds = [
|
86 |
+
tokenizer.convert_tokens_to_string([e]) for batch in topk_preds for e in batch
|
87 |
+
]
|
88 |
+
return [(topk_preds[i], probs[:, -1, i].item()) for i in range(len(topk_preds))]
|
89 |
+
|
90 |
+
|
91 |
+
def patch_codebook_ids(
|
92 |
+
corrupted_codebook_ids, hook, pos, cache, cache_pos=None, code_idx=None
|
93 |
+
):
|
94 |
+
"""Patch codebook ids with cached ids."""
|
95 |
+
if cache_pos is None:
|
96 |
+
cache_pos = pos
|
97 |
+
if code_idx is None:
|
98 |
+
corrupted_codebook_ids[:, pos] = cache[hook.name][:, cache_pos]
|
99 |
+
else:
|
100 |
+
for code_id in range(32):
|
101 |
+
if code_id in code_idx:
|
102 |
+
corrupted_codebook_ids[:, pos, code_id] = cache[hook.name][
|
103 |
+
:, cache_pos, code_id
|
104 |
+
]
|
105 |
+
else:
|
106 |
+
corrupted_codebook_ids[:, pos, code_id] = -1
|
107 |
+
|
108 |
+
return corrupted_codebook_ids
|
109 |
+
|
110 |
+
|
111 |
+
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
|
112 |
+
"""Calculate the average logit difference between the answer and the other token."""
|
113 |
+
# Only the final logits are relevant for the answer
|
114 |
+
final_logits = logits[:, -1, :]
|
115 |
+
answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
|
116 |
+
answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
|
117 |
+
if per_prompt:
|
118 |
+
return answer_logit_diff
|
119 |
+
else:
|
120 |
+
return answer_logit_diff.mean()
|
121 |
+
|
122 |
+
|
123 |
+
def normalize_patched_logit_diff(
|
124 |
+
patched_logit_diff,
|
125 |
+
base_average_logit_diff,
|
126 |
+
corrupted_average_logit_diff,
|
127 |
+
):
|
128 |
+
"""Normalize the patched logit difference."""
|
129 |
+
# Subtract corrupted logit diff to measure the improvement,
|
130 |
+
# divide by the total improvement from clean to corrupted to normalise
|
131 |
+
# 0 means zero change, negative means actively made worse,
|
132 |
+
# 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
|
133 |
+
return (patched_logit_diff - corrupted_average_logit_diff) / (
|
134 |
+
base_average_logit_diff - corrupted_average_logit_diff
|
135 |
+
)
|
136 |
+
|
137 |
+
|
138 |
+
def features_to_tokens(cb_key, cb_acts, num_codes, code=None):
|
139 |
+
"""Return the set of token ids each codebook feature activates on."""
|
140 |
+
codebook_ids = cb_acts[cb_key]
|
141 |
+
|
142 |
+
if code is None:
|
143 |
+
features_tokens = [[] for _ in range(num_codes)]
|
144 |
+
for i in tqdm(range(codebook_ids.shape[0])):
|
145 |
+
for j in range(codebook_ids.shape[1]):
|
146 |
+
for k in range(codebook_ids.shape[2]):
|
147 |
+
features_tokens[codebook_ids[i, j, k]].append((i, j))
|
148 |
+
else:
|
149 |
+
idx0, idx1, _ = np.where(codebook_ids == code)
|
150 |
+
features_tokens = list(zip(idx0, idx1))
|
151 |
+
|
152 |
+
return features_tokens
|
153 |
+
|
154 |
+
|
155 |
+
def color_str(s: str, color: str, html: bool):
|
156 |
+
"""Color the string for html or terminal."""
|
157 |
+
if html:
|
158 |
+
return f"<span style='color:{color}'>{s}</span>"
|
159 |
+
else:
|
160 |
+
return colored(s, color)
|
161 |
+
|
162 |
+
|
163 |
+
def color_tokens_red_automata(tokens, red_idx, html=False):
|
164 |
+
"""Separate states with a dash and color red the tokens in red_idx."""
|
165 |
+
ret_string = ""
|
166 |
+
itr_over_red_idx = 0
|
167 |
+
tokens_enumerate = enumerate(tokens)
|
168 |
+
if tokens[0] == "<|endoftext|>":
|
169 |
+
next(tokens_enumerate)
|
170 |
+
if red_idx[0] == 0:
|
171 |
+
itr_over_red_idx += 1
|
172 |
+
for i, c in tokens_enumerate:
|
173 |
+
if i % 2 == 1:
|
174 |
+
ret_string += "-"
|
175 |
+
if itr_over_red_idx < len(red_idx) and i == red_idx[itr_over_red_idx]:
|
176 |
+
ret_string += color_str(c, "red", html)
|
177 |
+
itr_over_red_idx += 1
|
178 |
+
else:
|
179 |
+
ret_string += c
|
180 |
+
return ret_string
|
181 |
+
|
182 |
+
|
183 |
+
def color_tokens_red(tokens, red_idx, n=3, html=False):
|
184 |
+
"""Color red the tokens in red_idx."""
|
185 |
+
ret_string = ""
|
186 |
+
last_colored_token_idx = -1
|
187 |
+
for i in red_idx:
|
188 |
+
c_str = tokens[i]
|
189 |
+
if i <= last_colored_token_idx + 2 * n + 1:
|
190 |
+
ret_string += "".join(tokens[last_colored_token_idx + 1 : i])
|
191 |
+
else:
|
192 |
+
ret_string += "".join(
|
193 |
+
tokens[last_colored_token_idx + 1 : last_colored_token_idx + n + 1]
|
194 |
+
)
|
195 |
+
ret_string += " ... "
|
196 |
+
ret_string += "".join(tokens[i - n : i])
|
197 |
+
ret_string += color_str(c_str, "red", html)
|
198 |
+
last_colored_token_idx = i
|
199 |
+
ret_string += "".join(
|
200 |
+
tokens[
|
201 |
+
last_colored_token_idx + 1 : min(last_colored_token_idx + n, len(tokens))
|
202 |
+
]
|
203 |
+
)
|
204 |
+
return ret_string
|
205 |
+
|
206 |
+
|
207 |
+
def prepare_example_print(
|
208 |
+
example_id,
|
209 |
+
example_tokens,
|
210 |
+
tokens_to_color_red,
|
211 |
+
html,
|
212 |
+
color_red_fn=color_tokens_red,
|
213 |
+
):
|
214 |
+
"""Format example to print."""
|
215 |
+
example_output = color_str(example_id, "green", html)
|
216 |
+
example_output += (
|
217 |
+
": "
|
218 |
+
+ color_red_fn(example_tokens, tokens_to_color_red, html=html)
|
219 |
+
+ ("<br>" if html else "\n")
|
220 |
+
)
|
221 |
+
return example_output
|
222 |
+
|
223 |
+
|
224 |
+
def tkn_print(
|
225 |
+
ll,
|
226 |
+
tokens,
|
227 |
+
separate_states,
|
228 |
+
n=3,
|
229 |
+
max_examples=100,
|
230 |
+
randomize=False,
|
231 |
+
html=False,
|
232 |
+
return_example_list=False,
|
233 |
+
):
|
234 |
+
"""Format and prints the tokens in ll."""
|
235 |
+
if randomize:
|
236 |
+
raise NotImplementedError("Randomize not yet implemented.")
|
237 |
+
indices = range(len(ll))
|
238 |
+
print_output = [] if return_example_list else ""
|
239 |
+
curr_ex = ll[0][0]
|
240 |
+
total_examples = 0
|
241 |
+
tokens_to_color_red = []
|
242 |
+
color_red_fn = (
|
243 |
+
color_tokens_red_automata if separate_states else partial(color_tokens_red, n=n)
|
244 |
+
)
|
245 |
+
for idx in indices:
|
246 |
+
if total_examples > max_examples:
|
247 |
+
break
|
248 |
+
i, j = ll[idx]
|
249 |
+
|
250 |
+
if i != curr_ex and curr_ex >= 0:
|
251 |
+
curr_ex_output = prepare_example_print(
|
252 |
+
curr_ex,
|
253 |
+
tokens[curr_ex],
|
254 |
+
tokens_to_color_red,
|
255 |
+
html,
|
256 |
+
color_red_fn,
|
257 |
+
)
|
258 |
+
total_examples += 1
|
259 |
+
if return_example_list:
|
260 |
+
print_output.append((curr_ex_output, len(tokens_to_color_red)))
|
261 |
+
else:
|
262 |
+
print_output += curr_ex_output
|
263 |
+
curr_ex = i
|
264 |
+
tokens_to_color_red = []
|
265 |
+
tokens_to_color_red.append(j)
|
266 |
+
curr_ex_output = prepare_example_print(
|
267 |
+
curr_ex,
|
268 |
+
tokens[curr_ex],
|
269 |
+
tokens_to_color_red,
|
270 |
+
html,
|
271 |
+
color_red_fn,
|
272 |
+
)
|
273 |
+
if return_example_list:
|
274 |
+
print_output.append((curr_ex_output, len(tokens_to_color_red)))
|
275 |
+
else:
|
276 |
+
print_output += curr_ex_output
|
277 |
+
asterisk_str = "********************************************"
|
278 |
+
print_output += color_str(asterisk_str, "green", html)
|
279 |
+
total_examples += 1
|
280 |
+
|
281 |
+
return print_output
|
282 |
+
|
283 |
+
|
284 |
+
def print_ft_tkns(
|
285 |
+
ft_tkns,
|
286 |
+
tokens,
|
287 |
+
separate_states=False,
|
288 |
+
n=3,
|
289 |
+
start=0,
|
290 |
+
stop=1000,
|
291 |
+
indices=None,
|
292 |
+
max_examples=100,
|
293 |
+
freq_filter=None,
|
294 |
+
randomize=False,
|
295 |
+
html=False,
|
296 |
+
return_example_list=False,
|
297 |
+
):
|
298 |
+
"""Print the tokens for the codebook features."""
|
299 |
+
indices = list(range(start, stop)) if indices is None else indices
|
300 |
+
num_tokens = len(tokens) * len(tokens[0])
|
301 |
+
codes, token_act_freqs, token_acts = [], [], []
|
302 |
+
for i in indices:
|
303 |
+
tkns = ft_tkns[i]
|
304 |
+
freq = (len(tkns), 100 * len(tkns) / num_tokens)
|
305 |
+
if freq_filter is not None and freq[1] > freq_filter:
|
306 |
+
continue
|
307 |
+
codes.append(i)
|
308 |
+
token_act_freqs.append(freq)
|
309 |
+
if len(tkns) > 0:
|
310 |
+
tkn_acts = tkn_print(
|
311 |
+
tkns,
|
312 |
+
tokens,
|
313 |
+
separate_states,
|
314 |
+
n=n,
|
315 |
+
max_examples=max_examples,
|
316 |
+
randomize=randomize,
|
317 |
+
html=html,
|
318 |
+
return_example_list=return_example_list,
|
319 |
+
)
|
320 |
+
token_acts.append(tkn_acts)
|
321 |
+
else:
|
322 |
+
token_acts.append("")
|
323 |
+
return codes, token_act_freqs, token_acts
|
324 |
+
|
325 |
+
|
326 |
+
def patch_in_codes(run_cb_ids, hook, pos, code, code_pos=None):
|
327 |
+
"""Patch in the `code` at `run_cb_ids`."""
|
328 |
+
pos = slice(None) if pos is None else pos
|
329 |
+
code_pos = slice(None) if code_pos is None else code_pos
|
330 |
+
|
331 |
+
if code_pos == "append":
|
332 |
+
assert pos == slice(None)
|
333 |
+
run_cb_ids = F.pad(run_cb_ids, (0, 1), mode="constant", value=code)
|
334 |
+
if isinstance(pos, typing.Iterable) or isinstance(pos, typing.Iterable):
|
335 |
+
for p in pos:
|
336 |
+
run_cb_ids[:, p, code_pos] = code
|
337 |
+
else:
|
338 |
+
run_cb_ids[:, pos, code_pos] = code
|
339 |
+
return run_cb_ids
|
340 |
+
|
341 |
+
|
342 |
+
def get_cb_layer_name(cb_at, layer_idx, head_idx=None):
|
343 |
+
"""Get the layer name used to store hooks/cache."""
|
344 |
+
if head_idx is None:
|
345 |
+
return f"blocks.{layer_idx}.{cb_at}.codebook_layer.hook_codebook_ids"
|
346 |
+
else:
|
347 |
+
return f"blocks.{layer_idx}.{cb_at}.codebook_layer.codebook.{head_idx}.hook_codebook_ids"
|
348 |
+
|
349 |
+
|
350 |
+
def get_cb_layer_names(layer, patch_types, n_heads):
|
351 |
+
"""Get the layer names used to store hooks/cache."""
|
352 |
+
layer_names = []
|
353 |
+
attn_added, mlp_added = False, False
|
354 |
+
if "attn_out" in patch_types:
|
355 |
+
attn_added = True
|
356 |
+
for head in range(n_heads):
|
357 |
+
layer_names.append(
|
358 |
+
f"blocks.{layer}.attn.codebook_layer.codebook.{head}.hook_codebook_ids"
|
359 |
+
)
|
360 |
+
if "mlp_out" in patch_types:
|
361 |
+
mlp_added = True
|
362 |
+
layer_names.append(f"blocks.{layer}.mlp.codebook_layer.hook_codebook_ids")
|
363 |
+
|
364 |
+
for patch_type in patch_types:
|
365 |
+
# match patch_type of the pattern attn_\d_head_\d
|
366 |
+
attn_head = re.match(r"attn_(\d)_head_(\d)", patch_type)
|
367 |
+
if (not attn_added) and attn_head and attn_head[1] == str(layer):
|
368 |
+
layer_names.append(
|
369 |
+
f"blocks.{layer}.attn.codebook_layer.codebook.{attn_head[2]}.hook_codebook_ids"
|
370 |
+
)
|
371 |
+
mlp = re.match(r"mlp_(\d)", patch_type)
|
372 |
+
if (not mlp_added) and mlp and mlp[1] == str(layer):
|
373 |
+
layer_names.append(f"blocks.{layer}.mlp.codebook_layer.hook_codebook_ids")
|
374 |
+
|
375 |
+
return layer_names
|
376 |
+
|
377 |
+
|
378 |
+
def cb_layer_name_to_info(layer_name):
|
379 |
+
"""Get the layer info from the layer name."""
|
380 |
+
layer_name_split = layer_name.split(".")
|
381 |
+
layer_idx = int(layer_name_split[1])
|
382 |
+
cb_at = layer_name_split[2]
|
383 |
+
if cb_at == "mlp":
|
384 |
+
head_idx = None
|
385 |
+
else:
|
386 |
+
head_idx = int(layer_name_split[5])
|
387 |
+
return cb_at, layer_idx, head_idx
|
388 |
+
|
389 |
+
|
390 |
+
def get_hooks(code, cb_at, layer_idx, head_idx=None, pos=None):
|
391 |
+
"""Get the hooks for the codebook features."""
|
392 |
+
hook_fns = [
|
393 |
+
partial(patch_in_codes, pos=pos, code=code[i]) for i in range(len(code))
|
394 |
+
]
|
395 |
+
return [
|
396 |
+
(get_cb_layer_name(cb_at[i], layer_idx[i], head_idx[i]), hook_fns[i])
|
397 |
+
for i in range(len(code))
|
398 |
+
]
|
399 |
+
|
400 |
+
|
401 |
+
def run_with_codes(
|
402 |
+
input, cb_model, code, cb_at, layer_idx, head_idx=None, pos=None, prepend_bos=True
|
403 |
+
):
|
404 |
+
"""Run the model with the codebook features patched in."""
|
405 |
+
hook_fns = [
|
406 |
+
partial(patch_in_codes, pos=pos, code=code[i]) for i in range(len(code))
|
407 |
+
]
|
408 |
+
cb_model.reset_codebook_metrics()
|
409 |
+
cb_model.reset_hook_kwargs()
|
410 |
+
fwd_hooks = [
|
411 |
+
(get_cb_layer_name(cb_at[i], layer_idx[i], head_idx[i]), hook_fns[i])
|
412 |
+
for i in range(len(cb_at))
|
413 |
+
]
|
414 |
+
with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model:
|
415 |
+
patched_logits, patched_cache = hooked_model.run_with_cache(
|
416 |
+
input, prepend_bos=prepend_bos
|
417 |
+
)
|
418 |
+
return patched_logits, patched_cache
|
419 |
+
|
420 |
+
|
421 |
+
def in_hook_list(list_of_arg_tuples, layer, head=None):
|
422 |
+
"""Check if the component specified by `layer` and `head` is in the `list_of_arg_tuples`."""
|
423 |
+
# if head is not provided, then checks in MLP
|
424 |
+
for arg_tuple in list_of_arg_tuples:
|
425 |
+
if head is None:
|
426 |
+
if arg_tuple.cb_at == "mlp" and arg_tuple.layer == layer:
|
427 |
+
return True
|
428 |
+
else:
|
429 |
+
if (
|
430 |
+
arg_tuple.cb_at == "attn"
|
431 |
+
and arg_tuple.layer == layer
|
432 |
+
and arg_tuple.head == head
|
433 |
+
):
|
434 |
+
return True
|
435 |
+
return False
|
436 |
+
|
437 |
+
|
438 |
+
# def generate_with_codes(input, code, cb_at, layer_idx, head_idx=None, pos=None, disable_other_comps=False):
|
439 |
+
def generate_with_codes(
|
440 |
+
input,
|
441 |
+
cb_model,
|
442 |
+
list_of_code_infos=(),
|
443 |
+
disable_other_comps=False,
|
444 |
+
automata=None,
|
445 |
+
generate_kwargs=None,
|
446 |
+
):
|
447 |
+
"""Model's generation with the codebook features patched in."""
|
448 |
+
if generate_kwargs is None:
|
449 |
+
generate_kwargs = {}
|
450 |
+
hook_fns = [
|
451 |
+
partial(patch_in_codes, pos=tupl.pos, code=tupl.code)
|
452 |
+
for tupl in list_of_code_infos
|
453 |
+
]
|
454 |
+
fwd_hooks = [
|
455 |
+
(get_cb_layer_name(tupl.cb_at, tupl.layer, tupl.head), hook_fns[i])
|
456 |
+
for i, tupl in enumerate(list_of_code_infos)
|
457 |
+
]
|
458 |
+
cb_model.reset_hook_kwargs()
|
459 |
+
if disable_other_comps:
|
460 |
+
for layer, cb in cb_model.all_codebooks.items():
|
461 |
+
for head_idx, head in enumerate(cb[0].codebook):
|
462 |
+
if not in_hook_list(list_of_code_infos, layer, head_idx):
|
463 |
+
head.set_hook_kwargs(
|
464 |
+
disable_topk=1, disable_for_tkns=[-1], keep_k_codes=False
|
465 |
+
)
|
466 |
+
if not in_hook_list(list_of_code_infos, layer):
|
467 |
+
cb[1].set_hook_kwargs(
|
468 |
+
disable_topk=1, disable_for_tkns=[-1], keep_k_codes=False
|
469 |
+
)
|
470 |
+
with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model:
|
471 |
+
gen = hooked_model.generate(input, **generate_kwargs)
|
472 |
+
return automata.seq_to_traj(gen)[0] if automata is not None else gen
|
473 |
+
|
474 |
+
|
475 |
+
def kl_div(logits1, logits2, pos=-1, reduction="batchmean"):
|
476 |
+
"""Calculate the KL divergence between the logits at `pos`."""
|
477 |
+
logits1_last, logits2_last = logits1[:, pos, :], logits2[:, pos, :]
|
478 |
+
# calculate kl divergence between clean and mod logits last
|
479 |
+
return F.kl_div(
|
480 |
+
F.log_softmax(logits1_last, dim=-1),
|
481 |
+
F.log_softmax(logits2_last, dim=-1),
|
482 |
+
log_target=True,
|
483 |
+
reduction=reduction,
|
484 |
+
)
|
485 |
+
|
486 |
+
|
487 |
+
def JSD(logits1, logits2, pos=-1, reduction="batchmean"):
|
488 |
+
"""Compute the Jensen-Shannon divergence between two distributions."""
|
489 |
+
if len(logits1.shape) == 3:
|
490 |
+
logits1, logits2 = logits1[:, pos, :], logits2[:, pos, :]
|
491 |
+
|
492 |
+
probs1 = F.softmax(logits1, dim=-1)
|
493 |
+
probs2 = F.softmax(logits2, dim=-1)
|
494 |
+
|
495 |
+
total_m = (0.5 * (probs1 + probs2)).log()
|
496 |
+
|
497 |
+
loss = 0.0
|
498 |
+
loss += F.kl_div(
|
499 |
+
total_m,
|
500 |
+
F.log_softmax(logits1, dim=-1),
|
501 |
+
log_target=True,
|
502 |
+
reduction=reduction,
|
503 |
+
)
|
504 |
+
loss += F.kl_div(
|
505 |
+
total_m,
|
506 |
+
F.log_softmax(logits2, dim=-1),
|
507 |
+
log_target=True,
|
508 |
+
reduction=reduction,
|
509 |
+
)
|
510 |
+
return 0.5 * loss
|
511 |
+
|
512 |
+
|
513 |
+
def residual_stream_patching_hook(resid_pre, hook, cache, position: int):
|
514 |
+
"""Patch in the codebook features at `position` from `cache`."""
|
515 |
+
clean_resid_pre = cache[hook.name]
|
516 |
+
resid_pre[:, position, :] = clean_resid_pre[:, position, :]
|
517 |
+
return resid_pre
|
518 |
+
|
519 |
+
|
520 |
+
def find_code_changes(cache1, cache2, pos=None):
|
521 |
+
"""Find the codebook codes that are different between the two caches."""
|
522 |
+
for k in cache1.keys():
|
523 |
+
if "codebook" in k:
|
524 |
+
c1 = cache1[k][0, pos]
|
525 |
+
c2 = cache2[k][0, pos]
|
526 |
+
if not torch.all(c1 == c2):
|
527 |
+
print(cb_layer_name_to_info(k), c1.tolist(), c2.tolist())
|
528 |
+
print(cb_layer_name_to_info(k), c1.tolist(), c2.tolist())
|
529 |
+
|
530 |
+
|
531 |
+
def common_codes_in_cache(cache_codes, threshold=0.0):
|
532 |
+
"""Get the common code in the cache."""
|
533 |
+
codes, counts = torch.unique(cache_codes, return_counts=True, sorted=True)
|
534 |
+
counts = counts.float() * 100
|
535 |
+
counts /= cache_codes.shape[1]
|
536 |
+
counts, indices = torch.sort(counts, descending=True)
|
537 |
+
codes = codes[indices]
|
538 |
+
indices = counts > threshold
|
539 |
+
codes, counts = codes[indices], counts[indices]
|
540 |
+
return codes, counts
|
541 |
+
|
542 |
+
|
543 |
+
def parse_code_info_string(
|
544 |
+
info_str: str, cb_at="attn", pos=None, code_pos=-1
|
545 |
+
) -> CodeInfo:
|
546 |
+
"""Parse the code info string.
|
547 |
+
|
548 |
+
The format of the `info_str` is:
|
549 |
+
`code: 0, layer: 0, head: 0, occ_freq: 0.0, train_act_freq: 0.0`.
|
550 |
+
"""
|
551 |
+
code, layer, head, occ_freq, train_act_freq = info_str.split(", ")
|
552 |
+
code = int(code.split(": ")[1])
|
553 |
+
layer = int(layer.split(": ")[1])
|
554 |
+
head = int(head.split(": ")[1]) if head else None
|
555 |
+
occ_freq = float(occ_freq.split(": ")[1])
|
556 |
+
train_act_freq = float(train_act_freq.split(": ")[1])
|
557 |
+
return CodeInfo(code, layer, head, pos=pos, code_pos=code_pos, cb_at=cb_at)
|
558 |
+
|
559 |
+
|
560 |
+
def parse_concept_codes_string(info_str: str, pos=None, code_append=False):
|
561 |
+
"""Parse the concept codes string."""
|
562 |
+
code_info_strs = info_str.strip().split("\n")
|
563 |
+
concept_codes = []
|
564 |
+
layer, head = None, None
|
565 |
+
code_pos = "append" if code_append else -1
|
566 |
+
for code_info_str in code_info_strs:
|
567 |
+
concept_codes.append(
|
568 |
+
parse_code_info_string(code_info_str, pos=pos, code_pos=code_pos)
|
569 |
+
)
|
570 |
+
if code_append:
|
571 |
+
continue
|
572 |
+
if layer == concept_codes[-1].layer and head == concept_codes[-1].head:
|
573 |
+
code_pos -= 1
|
574 |
+
else:
|
575 |
+
code_pos = -1
|
576 |
+
concept_codes[-1].code_pos = code_pos
|
577 |
+
layer, head = concept_codes[-1].layer, concept_codes[-1].head
|
578 |
+
return concept_codes
|
webapp_utils.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions for running webapp using streamlit."""
|
2 |
+
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
from streamlit.components.v1 import html
|
6 |
+
|
7 |
+
import code_search_utils
|
8 |
+
import utils
|
9 |
+
|
10 |
+
_PERSIST_STATE_KEY = f"{__name__}_PERSIST"
|
11 |
+
TOTAL_SAVE_BUTTONS = 0
|
12 |
+
|
13 |
+
|
14 |
+
def persist(key: str) -> str:
|
15 |
+
"""Mark widget state as persistent."""
|
16 |
+
if _PERSIST_STATE_KEY not in st.session_state:
|
17 |
+
st.session_state[_PERSIST_STATE_KEY] = set()
|
18 |
+
|
19 |
+
st.session_state[_PERSIST_STATE_KEY].add(key)
|
20 |
+
|
21 |
+
return key
|
22 |
+
|
23 |
+
|
24 |
+
def load_widget_state():
|
25 |
+
"""Load persistent widget state."""
|
26 |
+
if _PERSIST_STATE_KEY in st.session_state:
|
27 |
+
st.session_state.update(
|
28 |
+
{
|
29 |
+
key: value
|
30 |
+
for key, value in st.session_state.items()
|
31 |
+
if key in st.session_state[_PERSIST_STATE_KEY]
|
32 |
+
}
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
@st.cache_resource
|
37 |
+
def load_dataset_cache(dataset_cache_path):
|
38 |
+
"""Load cache files required for dataset from `cache_path`."""
|
39 |
+
return code_search_utils.load_dataset_cache(dataset_cache_path)
|
40 |
+
|
41 |
+
|
42 |
+
@st.cache_resource
|
43 |
+
def load_code_search_cache(codes_cache_path, dataset_cache_path):
|
44 |
+
"""Load cache files required for code search from `codes_cache_path`."""
|
45 |
+
(
|
46 |
+
tokens_str,
|
47 |
+
tokens_text,
|
48 |
+
token_byte_pos,
|
49 |
+
) = load_dataset_cache(dataset_cache_path)
|
50 |
+
(
|
51 |
+
cb_acts,
|
52 |
+
act_count_ft_tkns,
|
53 |
+
metrics,
|
54 |
+
) = code_search_utils.load_code_search_cache(codes_cache_path)
|
55 |
+
return tokens_str, tokens_text, token_byte_pos, cb_acts, act_count_ft_tkns, metrics
|
56 |
+
|
57 |
+
|
58 |
+
@st.cache_data(max_entries=100)
|
59 |
+
def load_ft_tkns(model_id, layer, head=None, code=None):
|
60 |
+
"""Load the code-to-token map for a codebook."""
|
61 |
+
# model_id required to not mix cache_data for different models
|
62 |
+
assert model_id is not None
|
63 |
+
cb_at = st.session_state["cb_at"]
|
64 |
+
ccb = st.session_state["ccb"]
|
65 |
+
cb_acts = st.session_state["cb_acts"]
|
66 |
+
if head is not None:
|
67 |
+
cb_name = f"layer{layer}_{cb_at}{ccb}{head}"
|
68 |
+
else:
|
69 |
+
cb_name = f"layer{layer}_{cb_at}"
|
70 |
+
return utils.features_to_tokens(
|
71 |
+
cb_name,
|
72 |
+
cb_acts,
|
73 |
+
num_codes=st.session_state["num_codes"],
|
74 |
+
code=code,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def get_code_acts(
|
79 |
+
model_id,
|
80 |
+
tokens_str,
|
81 |
+
code,
|
82 |
+
layer,
|
83 |
+
head=None,
|
84 |
+
ctx_size=5,
|
85 |
+
num_examples=100,
|
86 |
+
return_example_list=False,
|
87 |
+
):
|
88 |
+
"""Get the token activations for a given code."""
|
89 |
+
ft_tkns = load_ft_tkns(model_id, layer, head, code)
|
90 |
+
ft_tkns = [ft_tkns]
|
91 |
+
_, freqs, acts = utils.print_ft_tkns(
|
92 |
+
ft_tkns,
|
93 |
+
tokens=tokens_str,
|
94 |
+
indices=[0],
|
95 |
+
html=True,
|
96 |
+
n=ctx_size,
|
97 |
+
max_examples=num_examples,
|
98 |
+
return_example_list=return_example_list,
|
99 |
+
)
|
100 |
+
return acts[0], freqs[0]
|
101 |
+
|
102 |
+
|
103 |
+
def set_ct_acts(code, layer, head=None, extra_args=None, is_attn=False):
|
104 |
+
"""Set the code and layer for the token activations."""
|
105 |
+
# convert to int
|
106 |
+
code, layer, head = int(code), int(layer), int(head) if head is not None else None
|
107 |
+
st.session_state["ct_act_code"] = code
|
108 |
+
st.session_state["ct_act_layer"] = layer
|
109 |
+
if is_attn:
|
110 |
+
st.session_state["ct_act_head"] = head
|
111 |
+
st.session_state["filter_codes"] = False
|
112 |
+
|
113 |
+
my_html = """
|
114 |
+
<script>
|
115 |
+
document.location.href = "#code-token-activations";
|
116 |
+
</script>
|
117 |
+
"""
|
118 |
+
html(my_html, height=0, width=0, scrolling=False)
|
119 |
+
|
120 |
+
|
121 |
+
def find_next_code(code, layer_code_acts, act_range=None):
|
122 |
+
"""Find the next code that has activations in the given range."""
|
123 |
+
if act_range is None:
|
124 |
+
return code
|
125 |
+
for code_iter, code_act_count in enumerate(layer_code_acts[code:]):
|
126 |
+
if code_act_count >= act_range[0] and code_act_count <= act_range[1]:
|
127 |
+
code += code_iter
|
128 |
+
break
|
129 |
+
return code
|
130 |
+
|
131 |
+
|
132 |
+
def escape_markdown(text):
|
133 |
+
"""Escapes markdown special characters."""
|
134 |
+
MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
|
135 |
+
for char in MD_SPECIAL_CHARS:
|
136 |
+
text = text.replace(char, "\\" + char)
|
137 |
+
return text
|
138 |
+
|
139 |
+
|
140 |
+
def add_code_to_demo_file(code_info: utils.CodeInfo, file_path: str):
|
141 |
+
"""Add code to demo file."""
|
142 |
+
# TODO: add check for duplicate code and return False if found
|
143 |
+
# TODO: convert saved codes to databases instead of txt files?
|
144 |
+
code_info.check_description_info()
|
145 |
+
with open(file_path, "a") as f:
|
146 |
+
f.write("\n")
|
147 |
+
f.write(f"# {code_info.description}:")
|
148 |
+
if code_info.regex:
|
149 |
+
f.write(f" {code_info.regex}")
|
150 |
+
f.write("\n")
|
151 |
+
f.write(f"layer: {code_info.layer}")
|
152 |
+
f.write(f", head: {code_info.head}" if code_info.head is not None else "")
|
153 |
+
f.write(f", code: {code_info.code}")
|
154 |
+
if code_info.regex:
|
155 |
+
f.write(f", prec: {code_info.prec:.4f}, recall: {code_info.recall:.4f}")
|
156 |
+
f.write(f", num_acts: {code_info.num_acts}\n")
|
157 |
+
return True
|
158 |
+
|
159 |
+
|
160 |
+
def add_save_code_button(
|
161 |
+
demo_file_path: str,
|
162 |
+
num_acts: int,
|
163 |
+
save_regex: bool = False,
|
164 |
+
prec: float = None,
|
165 |
+
recall: float = None,
|
166 |
+
button_st_container=st,
|
167 |
+
button_text: bool = False,
|
168 |
+
button_key_suffix: str = "",
|
169 |
+
):
|
170 |
+
"""Add a button on streamlit to save code to demo codes file."""
|
171 |
+
save_button = button_st_container.button(
|
172 |
+
"πΎ" + (" Save Code to Demos" if button_text else ""),
|
173 |
+
key=f"save_code_button{button_key_suffix}",
|
174 |
+
help="Save code to demo codes file",
|
175 |
+
)
|
176 |
+
if save_button:
|
177 |
+
description = st.text_input(
|
178 |
+
"Write a description for the code",
|
179 |
+
key="save_code_desc",
|
180 |
+
)
|
181 |
+
if not description:
|
182 |
+
return
|
183 |
+
|
184 |
+
description = st.session_state.get("save_code_desc", None)
|
185 |
+
if description:
|
186 |
+
layer = st.session_state["ct_act_layer"]
|
187 |
+
is_attn = st.session_state["is_attn"]
|
188 |
+
if is_attn:
|
189 |
+
head = st.session_state["ct_act_head"]
|
190 |
+
else:
|
191 |
+
head = None
|
192 |
+
|
193 |
+
code = st.session_state["ct_act_code"]
|
194 |
+
code_info = utils.CodeInfo(
|
195 |
+
layer=layer,
|
196 |
+
head=head,
|
197 |
+
code=code,
|
198 |
+
description=description,
|
199 |
+
num_acts=num_acts,
|
200 |
+
)
|
201 |
+
|
202 |
+
if save_regex:
|
203 |
+
code_info.regex = st.session_state["regex_pattern"]
|
204 |
+
code_info.prec = prec
|
205 |
+
code_info.recall = recall
|
206 |
+
|
207 |
+
saved = add_code_to_demo_file(code_info, demo_file_path)
|
208 |
+
if saved:
|
209 |
+
st.success("Code saved!", icon="π")
|
210 |
+
st.success("Code saved!", icon="π")
|
webapp_utils_full_ft_tkns_for_ts.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions for running webapp using streamlit."""
|
2 |
+
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
from streamlit.components.v1 import html
|
6 |
+
|
7 |
+
import code_search_utils
|
8 |
+
import utils
|
9 |
+
|
10 |
+
_PERSIST_STATE_KEY = f"{__name__}_PERSIST"
|
11 |
+
TOTAL_SAVE_BUTTONS = 0
|
12 |
+
|
13 |
+
|
14 |
+
def persist(key: str) -> str:
|
15 |
+
"""Mark widget state as persistent."""
|
16 |
+
if _PERSIST_STATE_KEY not in st.session_state:
|
17 |
+
st.session_state[_PERSIST_STATE_KEY] = set()
|
18 |
+
|
19 |
+
st.session_state[_PERSIST_STATE_KEY].add(key)
|
20 |
+
|
21 |
+
return key
|
22 |
+
|
23 |
+
|
24 |
+
def load_widget_state():
|
25 |
+
"""Load persistent widget state."""
|
26 |
+
if _PERSIST_STATE_KEY in st.session_state:
|
27 |
+
st.session_state.update(
|
28 |
+
{
|
29 |
+
key: value
|
30 |
+
for key, value in st.session_state.items()
|
31 |
+
if key in st.session_state[_PERSIST_STATE_KEY]
|
32 |
+
}
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
@st.cache_resource
|
37 |
+
def load_dataset_cache(dataset_cache_path):
|
38 |
+
"""Load cache files required for dataset from `cache_path`."""
|
39 |
+
return code_search_utils.load_dataset_cache(dataset_cache_path)
|
40 |
+
|
41 |
+
|
42 |
+
@st.cache_resource
|
43 |
+
def load_code_search_cache(codes_cache_path, dataset_cache_path):
|
44 |
+
"""Load cache files required for code search from `codes_cache_path`."""
|
45 |
+
(
|
46 |
+
tokens_str,
|
47 |
+
tokens_text,
|
48 |
+
token_byte_pos,
|
49 |
+
) = load_dataset_cache(dataset_cache_path)
|
50 |
+
(
|
51 |
+
cb_acts,
|
52 |
+
act_count_ft_tkns,
|
53 |
+
metrics,
|
54 |
+
) = code_search_utils.load_code_search_cache(codes_cache_path)
|
55 |
+
return tokens_str, tokens_text, token_byte_pos, cb_acts, act_count_ft_tkns, metrics
|
56 |
+
|
57 |
+
|
58 |
+
@st.cache_data(max_entries=100)
|
59 |
+
def load_ft_tkns(model_id, layer, head=None, code=None):
|
60 |
+
"""Load the code-to-token map for a codebook."""
|
61 |
+
# model_id required to not mix cache_data for different models
|
62 |
+
assert model_id is not None
|
63 |
+
cb_at = st.session_state["cb_at"]
|
64 |
+
ccb = st.session_state["ccb"]
|
65 |
+
cb_acts = st.session_state["cb_acts"]
|
66 |
+
if head is not None:
|
67 |
+
cb_name = f"layer{layer}_{cb_at}{ccb}{head}"
|
68 |
+
else:
|
69 |
+
cb_name = f"layer{layer}_{cb_at}"
|
70 |
+
return utils.features_to_tokens(
|
71 |
+
cb_name,
|
72 |
+
cb_acts,
|
73 |
+
num_codes=st.session_state["num_codes"],
|
74 |
+
code=code,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def get_code_acts(
|
79 |
+
model_id,
|
80 |
+
tokens_str,
|
81 |
+
code,
|
82 |
+
layer,
|
83 |
+
head=None,
|
84 |
+
ctx_size=5,
|
85 |
+
num_examples=100,
|
86 |
+
return_example_list=False,
|
87 |
+
):
|
88 |
+
"""Get the token activations for a given code."""
|
89 |
+
code_to_pass = None if "tinystories" in model_id.lower() else code
|
90 |
+
ft_tkns = load_ft_tkns(model_id, layer, head, code_to_pass)
|
91 |
+
if code_to_pass is not None:
|
92 |
+
ft_tkns = [ft_tkns]
|
93 |
+
else:
|
94 |
+
ft_tkns = ft_tkns[code : code + 1]
|
95 |
+
_, freqs, acts = utils.print_ft_tkns(
|
96 |
+
ft_tkns,
|
97 |
+
tokens=tokens_str,
|
98 |
+
indices=[0],
|
99 |
+
html=True,
|
100 |
+
n=ctx_size,
|
101 |
+
max_examples=num_examples,
|
102 |
+
return_example_list=return_example_list,
|
103 |
+
)
|
104 |
+
return acts[0], freqs[0]
|
105 |
+
|
106 |
+
|
107 |
+
def set_ct_acts(code, layer, head=None, extra_args=None, is_attn=False):
|
108 |
+
"""Set the code and layer for the token activations."""
|
109 |
+
# convert to int
|
110 |
+
code, layer, head = int(code), int(layer), int(head) if head is not None else None
|
111 |
+
st.session_state["ct_act_code"] = code
|
112 |
+
st.session_state["ct_act_layer"] = layer
|
113 |
+
if is_attn:
|
114 |
+
st.session_state["ct_act_head"] = head
|
115 |
+
st.session_state["filter_codes"] = False
|
116 |
+
|
117 |
+
info_txt = (
|
118 |
+
f"layer: {layer},{f' head: {head},' if head is not None else ''} code: {code}"
|
119 |
+
)
|
120 |
+
if extra_args:
|
121 |
+
for k, v in extra_args.items():
|
122 |
+
info_txt += f", {k}: {v}"
|
123 |
+
my_html = f"""
|
124 |
+
<script>
|
125 |
+
async function myF() {{
|
126 |
+
await new Promise(r => setTimeout(r, 10));
|
127 |
+
const textarea = document.createElement("textarea");
|
128 |
+
textarea.textContent = "{info_txt}";
|
129 |
+
document.body.appendChild(textarea);
|
130 |
+
textarea.select();
|
131 |
+
document.execCommand("copy");
|
132 |
+
document.body.removeChild(textarea);
|
133 |
+
}}
|
134 |
+
myF();
|
135 |
+
window.location.hash = "code-token-activations";
|
136 |
+
console.log(window.location.hash)
|
137 |
+
</script>
|
138 |
+
"""
|
139 |
+
html(my_html, height=0, width=0, scrolling=False)
|
140 |
+
|
141 |
+
|
142 |
+
def find_next_code(code, layer_code_acts, act_range=None):
|
143 |
+
"""Find the next code that has activations in the given range."""
|
144 |
+
# code = st.session_state["ct_act_code"]
|
145 |
+
if act_range is None:
|
146 |
+
return code
|
147 |
+
for code_iter, code_act_count in enumerate(layer_code_acts[code:]):
|
148 |
+
if code_act_count >= act_range[0] and code_act_count <= act_range[1]:
|
149 |
+
code += code_iter
|
150 |
+
# st.session_state["ct_act_code"] = code
|
151 |
+
break
|
152 |
+
return code
|
153 |
+
|
154 |
+
|
155 |
+
def escape_markdown(text):
|
156 |
+
"""Escapes markdown special characters."""
|
157 |
+
MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
|
158 |
+
for char in MD_SPECIAL_CHARS:
|
159 |
+
text = text.replace(char, "\\" + char)
|
160 |
+
return text
|
161 |
+
|
162 |
+
|
163 |
+
def add_code_to_demo_file(code_info: utils.CodeInfo, file_path: str):
|
164 |
+
"""Add code to demo file."""
|
165 |
+
# TODO: add check for duplicate code and return False if found
|
166 |
+
# TODO: convert saved codes to databases instead of txt files?
|
167 |
+
code_info.check_description_info()
|
168 |
+
with open(file_path, "a") as f:
|
169 |
+
f.write("\n")
|
170 |
+
f.write(f"# {code_info.description}:")
|
171 |
+
if code_info.regex:
|
172 |
+
f.write(f" {code_info.regex}")
|
173 |
+
f.write("\n")
|
174 |
+
f.write(f"layer: {code_info.layer}")
|
175 |
+
f.write(f", head: {code_info.head}" if code_info.head is not None else "")
|
176 |
+
f.write(f", code: {code_info.code}")
|
177 |
+
if code_info.regex:
|
178 |
+
f.write(f", prec: {code_info.prec:.4f}, recall: {code_info.recall:.4f}")
|
179 |
+
f.write(f", num_acts: {code_info.num_acts}\n")
|
180 |
+
return True
|
181 |
+
|
182 |
+
|
183 |
+
def add_save_code_button(
|
184 |
+
demo_file_path: str,
|
185 |
+
num_acts: int,
|
186 |
+
save_regex: bool = False,
|
187 |
+
prec: float = None,
|
188 |
+
recall: float = None,
|
189 |
+
button_st_container=st,
|
190 |
+
button_text: bool = False,
|
191 |
+
button_key_suffix: str = "",
|
192 |
+
):
|
193 |
+
"""Add a button on streamlit to save code to demo codes file."""
|
194 |
+
save_button = button_st_container.button(
|
195 |
+
"πΎ" + (" Save Code to Demos" if button_text else ""),
|
196 |
+
key=f"save_code_button{button_key_suffix}",
|
197 |
+
help="Save code to demo codes file",
|
198 |
+
)
|
199 |
+
if save_button:
|
200 |
+
description = st.text_input(
|
201 |
+
"Write a description for the code",
|
202 |
+
key="save_code_desc",
|
203 |
+
)
|
204 |
+
if not description:
|
205 |
+
return
|
206 |
+
|
207 |
+
description = st.session_state.get("save_code_desc", None)
|
208 |
+
if description:
|
209 |
+
layer = st.session_state["ct_act_layer"]
|
210 |
+
is_attn = st.session_state["is_attn"]
|
211 |
+
if is_attn:
|
212 |
+
head = st.session_state["ct_act_head"]
|
213 |
+
else:
|
214 |
+
head = None
|
215 |
+
|
216 |
+
code = st.session_state["ct_act_code"]
|
217 |
+
code_info = utils.CodeInfo(
|
218 |
+
layer=layer,
|
219 |
+
head=head,
|
220 |
+
code=code,
|
221 |
+
description=description,
|
222 |
+
num_acts=num_acts,
|
223 |
+
)
|
224 |
+
|
225 |
+
if save_regex:
|
226 |
+
code_info.regex = st.session_state["regex_pattern"]
|
227 |
+
code_info.prec = prec
|
228 |
+
code_info.recall = recall
|
229 |
+
|
230 |
+
saved = add_code_to_demo_file(code_info, demo_file_path)
|
231 |
+
if saved:
|
232 |
+
st.success("Code saved!", icon="π")
|
233 |
+
st.success("Code saved!", icon="π")
|