"""Web App for the Codebook Features project.""" import argparse import glob import os import streamlit as st import code_search_utils import utils import webapp_utils # --- Parse command line arguments --- parser = argparse.ArgumentParser() parser.add_argument( "--deploy", default=True, help="Deploy mode.", ) parser.add_argument( "--cache_dir", type=str, default="cache/", help="Path to directory containing cache for codebook models.", ) try: args = parser.parse_args() except SystemExit as e: # This exception will be raised if --help or invalid command line arguments # are used. Currently streamlit prevents the program from exiting normally # so we have to do a hard exit. os._exit(e.code if isinstance(e.code, int) else 1) deploy = args.deploy webapp_utils.load_widget_state() st.set_page_config( page_title="Codebook Features", page_icon="๐Ÿ“š", ) st.title("Codebook Features") # --- Load model info and cache --- pretty_model_names = { "TinyStories-1Layer-21M#100ksteps_vcb_mlp": "TinyStories-1L-21M-MLP", "TinyStories-1Layer-21M_ccb_attn_preproj": "TinyStories 1 Layer Attention Codebook", "TinyStories-33M_ccb_attn_preproj": "TinyStories 4 Layer Attention Codebook", "TinyStories-1Layer-21M_vcb_mlp": "TinyStories 1 Layer MLP Codebook", } orig_model_name = {v: k for k, v in pretty_model_names.items()} base_cache_dir = args.cache_dir dirs = glob.glob(base_cache_dir + "models/*/") model_name_options = [d.split("/")[-2].split("_")[:-2] for d in dirs] model_name_options = ["_".join(m) for m in model_name_options] model_name_options = sorted(set(model_name_options)) def_model_idx = ["attn" in m.lower() for m in model_name_options].index(True) p_model_name = st.selectbox( "Model", [pretty_model_names.get(m, m) for m in model_name_options], index=def_model_idx, key=webapp_utils.persist("model_name"), ) model_name = orig_model_name.get(p_model_name, p_model_name) is_fsm = "FSM" in p_model_name codes_cache_path = base_cache_dir + f"models/{model_name}_*" dirs = glob.glob(codes_cache_path) dirs.sort(key=os.path.getmtime) # session states codes_cache_path = dirs[-1] + "/" model_info = utils.ModelInfoForWebapp.load(codes_cache_path) num_codes = model_info.num_codes num_layers = model_info.n_layers num_heads = model_info.n_heads cb_at = model_info.cb_at gcb = model_info.gcb gcb = "_gcb" if gcb else "" is_attn = "attn" in cb_at dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/" ( tokens_str, tokens_text, token_byte_pos, cb_acts, act_count_ft_tkns, metrics, ) = webapp_utils.load_code_search_cache(codes_cache_path, dataset_cache_path) seq_len = len(tokens_str[0]) metric_keys = ["eval_loss", "eval_accuracy", "eval_dead_code_fraction"] metrics = {k: v for k, v in metrics.items() if k.split("/")[0] in metric_keys} # --- Set the session states --- st.session_state["model_name_id"] = model_name st.session_state["cb_acts"] = cb_acts st.session_state["tokens_text"] = tokens_text st.session_state["tokens_str"] = tokens_str st.session_state["act_count_ft_tkns"] = act_count_ft_tkns st.session_state["num_codes"] = num_codes st.session_state["gcb"] = gcb st.session_state["cb_at"] = cb_at st.session_state["is_attn"] = is_attn st.session_state["seq_len"] = seq_len if not deploy: st.markdown("## Metrics") # hide metrics by default if st.checkbox("Show Model Metrics"): st.write(metrics) st.markdown("## Demo Codes") demo_codes_desc = ( "This section contains codes that we've found to be interpretable along " "with a description of the feature we think they are capturing. " "Click on the ๐Ÿ” search button for a code to see the tokens that code activates on." ) st.write(demo_codes_desc) demo_file_path = codes_cache_path + "demo_codes.txt" if st.checkbox("Show Demo Codes"): try: with open(demo_file_path, "r") as f: demo_codes = f.readlines() except FileNotFoundError: demo_codes = [] code_desc, code_regex = "", "" demo_codes = [code.strip() for code in demo_codes if code.strip()] num_cols = 6 if is_attn else 5 cols = st.columns([1] * (num_cols - 1) + [2]) # st.markdown(button_height_style, unsafe_allow_html=True) cols[0].markdown("Search", help="Button to see token activations for the code.") cols[1].write("Code") cols[2].write("Layer") if is_attn: cols[3].write("Head") cols[-2].markdown( "Num Acts", help="Number of tokens that the code activates on in the acts dataset.", ) cols[-1].markdown("Description", help="Interpreted description of the code.") if len(demo_codes) == 0: st.markdown( f"""
No demo codes found in file {demo_file_path}
""", unsafe_allow_html=True, ) skip = True for code_txt in demo_codes: if code_txt.startswith("##"): skip = True continue if code_txt.startswith("#"): code_desc, code_regex = code_txt[1:].split(":") code_desc, code_regex = code_desc.strip(), code_regex.strip() skip = False continue if skip: continue code_info = utils.CodeInfo.from_str(code_txt, regex=code_regex) comp_info = f"layer{code_info.layer}_{f'head{code_info.head}' if code_info.head is not None else ''}" button_key = ( f"demo_search_code{code_info.code}_layer{code_info.layer}_desc-{code_info.description}" + (f"head{code_info.head}" if code_info.head is not None else "") ) cols = st.columns([1] * (num_cols - 1) + [2]) button_clicked = cols[0].button( "๐Ÿ”", key=button_key, ) if button_clicked: webapp_utils.set_ct_acts( code_info.code, code_info.layer, code_info.head, None, is_attn ) cols[1].write(code_info.code) cols[2].write(str(code_info.layer)) if is_attn: cols[3].write(str(code_info.head)) cols[-2].write(str(act_count_ft_tkns[comp_info][code_info.code])) cols[-1].write(code_desc) skip = True # --- Code Search --- st.markdown("## Code Search") code_search_desc = ( "To find whether the codebooks model has captured a relevant feature from the data (e.g. pronouns)," " you can specify a regex pattern for your feature (e.g. โ€œhe|she|theyโ€) and find whether any code" " activating on the regex pattern exists.\n\n" "Since strings can contain several tokens, you can specify the token you want a code to fire on by" " using a capture group. For example, the search term โ€˜New (York)โ€™ will try to find codes that" " activate on the bigram feature โ€˜New Yorkโ€™ at the York token" ) if st.checkbox("Search with Regex"): st.write(code_search_desc) regex_pattern = st.text_input( "Enter a regex pattern", help="Wrap code token in the first group. E.g. New (York)", key="regex_pattern", ) # topk = st.slider("Top K", 1, 20, 10) prec_col, sort_col = st.columns(2) prec_threshold = prec_col.slider( "Precision Threshold", 0.0, 1.0, 0.9, help="Shows codes with precision on the regex pattern above the threshold.", ) sort_by_options = ["Precision", "Recall", "Num Acts"] sort_by_name = sort_col.radio( "Sort By", sort_by_options, index=0, horizontal=True, help="Sorts the codes by the selected metric.", ) sort_by = sort_by_options.index(sort_by_name) @st.cache_data(ttl=3600) def get_codebook_wise_codes_for_regex( regex_pattern, prec_threshold, gcb, model_name ): """Get codebook wise codes for a given regex pattern.""" assert model_name is not None # required for loading from correct cache data return code_search_utils.get_codes_from_pattern( regex_pattern, tokens_text, token_byte_pos, cb_acts, act_count_ft_tkns, gcb=gcb, topk=8, prec_threshold=prec_threshold, ) if regex_pattern: codebook_wise_codes, re_token_matches = get_codebook_wise_codes_for_regex( regex_pattern, prec_threshold, gcb, model_name, ) st.markdown( f"Found {re_token_matches} matches", unsafe_allow_html=True, ) num_search_cols = 7 if is_attn else 6 non_deploy_offset = 0 if not deploy: non_deploy_offset = 1 num_search_cols += non_deploy_offset cols = st.columns(num_search_cols) cols[0].markdown("Search", help="Button to see token activations for the code.") cols[1].write("Layer") if is_attn: cols[2].write("Head") cols[-4 - non_deploy_offset].write("Code") cols[-3 - non_deploy_offset].write("Precision") cols[-2 - non_deploy_offset].write("Recall") cols[-1 - non_deploy_offset].markdown( "Num Acts", help="Number of tokens that the code activates on in the acts dataset.", ) if not deploy: cols[-1].markdown( "Save to Demos", help="Button to save the code to demos along with the regex pattern.", ) all_codes = codebook_wise_codes.items() all_codes = [ (cb_name, code_pr_info) for cb_name, code_pr_infos in all_codes for code_pr_info in code_pr_infos ] all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True) for cb_name, (code, prec, rec, code_acts) in all_codes: layer_head = cb_name.split("_") layer = layer_head[0][5:] head = layer_head[1][4:] if len(layer_head) > 1 else None button_key = f"search_code{code}_layer{layer}" + ( f"head{head}" if head is not None else "" ) cols = st.columns(num_search_cols) extra_args = { "prec": prec, "recall": rec, "num_acts": code_acts, "regex": regex_pattern, } button_clicked = cols[0].button("๐Ÿ”", key=button_key) if button_clicked: webapp_utils.set_ct_acts(code, layer, head, extra_args, is_attn) cols[1].write(layer) if is_attn: cols[2].write(head) cols[-4 - non_deploy_offset].write(code) cols[-3 - non_deploy_offset].write(f"{prec*100:.2f}%") cols[-2 - non_deploy_offset].write(f"{rec*100:.2f}%") cols[-1 - non_deploy_offset].write(str(code_acts)) if not deploy: webapp_utils.add_save_code_button( demo_file_path, num_acts=code_acts, save_regex=True, prec=prec, recall=rec, button_st_container=cols[-1], button_key_suffix=f"_code{code}_layer{layer}_head{head}", ) if len(all_codes) == 0: st.markdown( f"""
No codes found for pattern {regex_pattern} at precision threshold: {prec_threshold}
""", unsafe_allow_html=True, ) # --- Display Code Token Activations --- st.markdown("## Code Token Activations") filter_codes = st.checkbox("Show filters", key="filter_codes", value=True) act_range, layer_code_acts = None, None if filter_codes: act_range = st.slider( "Minimum number of activations", 0, 10_000, 100, key="ct_act_range", help="Filter codes by the number of tokens they activate on.", ) cols = st.columns(5 if is_attn else 4) layer = cols[0].number_input("Layer", 0, num_layers - 1, 0, key="ct_act_layer") if is_attn: head = cols[1].number_input("Head", 0, num_heads - 1, 0, key="ct_act_head") else: head = None def_code = st.session_state.get("ct_act_code", 0) if filter_codes: layer_code_acts = act_count_ft_tkns[ f"layer{layer}{'_head'+str(head) if head is not None else ''}" ] def_code = webapp_utils.find_next_code(def_code, layer_code_acts, act_range) if "ct_act_code" in st.session_state: st.session_state["ct_act_code"] = def_code code = cols[-3].number_input( "Code", 0, num_codes - 1, def_code, key="ct_act_code", ) num_examples = cols[-2].number_input( "Max Results", -1, 1000, # setting to 1000 for efficiency purposes even though it can be more than 1000. 100, help="Number of examples to show in the results. Set to -1 to show all examples.", ) ctx_size = cols[-1].number_input( "Context Size", 1, 10, 5, help="Number of tokens to show before and after the code token.", ) acts, acts_count = webapp_utils.get_code_acts( model_name, tokens_str, code, layer, head, ctx_size, num_examples, is_fsm=is_fsm, ) st.write( f"Token Activations for Layer {layer}{f' Head {head}' if head is not None else ''} Code {code} | " f"Activates on {acts_count[0]} tokens on the acts dataset", ) if not deploy: webapp_utils.add_save_code_button( demo_file_path, acts_count[0], save_regex=False, button_text=True, button_key_suffix="_token_acts", ) st.markdown(webapp_utils.escape_markdown(acts), unsafe_allow_html=True)