Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Rename concept code to topic code
Browse files
    	
        pages/{Concept_Code.py → Topic_Code_Browser.py}
    RENAMED
    
    | @@ -24,8 +24,8 @@ act_count_ft_tkns = st.session_state["act_count_ft_tkns"] | |
| 24 | 
             
            gcb = st.session_state["gcb"]
         | 
| 25 |  | 
| 26 |  | 
| 27 | 
            -
            def  | 
| 28 | 
            -
                """Get  | 
| 29 | 
             
                token_pos_ids = [(example_id, i) for i in range(seq_len)]
         | 
| 30 | 
             
                all_codes = []
         | 
| 31 | 
             
                for cb_name, cb in cb_acts.items():
         | 
| @@ -55,11 +55,11 @@ def get_example_concept_codes(example_id): | |
| 55 |  | 
| 56 |  | 
| 57 | 
             
            def find_next_example(example_id):
         | 
| 58 | 
            -
                """Find the example after `example_id` that has  | 
| 59 | 
             
                initial_example_id = example_id
         | 
| 60 | 
             
                example_id += 1
         | 
| 61 | 
             
                while example_id != initial_example_id:
         | 
| 62 | 
            -
                    all_codes =  | 
| 63 | 
             
                    codes_found = sum([len(code_pr_infos) for _, code_pr_infos in all_codes])
         | 
| 64 | 
             
                    if codes_found > 0:
         | 
| 65 | 
             
                        st.session_state["example_id"] = example_id
         | 
| @@ -80,7 +80,7 @@ def redirect_to_main_with_code(code, layer, head): | |
| 80 | 
             
                switch_page("Code Browser")
         | 
| 81 |  | 
| 82 |  | 
| 83 | 
            -
            def  | 
| 84 | 
             
                """Show examples that the code activates on."""
         | 
| 85 | 
             
                ex_acts, _ = webapp_utils.get_code_acts(
         | 
| 86 | 
             
                    model_name,
         | 
| @@ -103,14 +103,16 @@ def show_examples_for_concept_code(code, layer, head, code_act_ratio=0.3): | |
| 103 |  | 
| 104 | 
             
            is_attn = st.session_state["is_attn"]
         | 
| 105 |  | 
| 106 | 
            -
            st.markdown("##  | 
| 107 | 
            -
             | 
| 108 | 
            -
                " | 
| 109 | 
            -
                " | 
| 110 | 
            -
                " | 
| 111 | 
            -
                " | 
|  | |
|  | |
| 112 | 
             
            )
         | 
| 113 | 
            -
            st.write( | 
| 114 |  | 
| 115 | 
             
            ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1])
         | 
| 116 | 
             
            example_id = ex_col.number_input(
         | 
| @@ -124,7 +126,7 @@ recall_threshold = r_col.slider( | |
| 124 | 
             
                "Recall Threshold",
         | 
| 125 | 
             
                0.0,
         | 
| 126 | 
             
                1.0,
         | 
| 127 | 
            -
                0. | 
| 128 | 
             
                key="recall",
         | 
| 129 | 
             
                help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.",
         | 
| 130 | 
             
            )
         | 
| @@ -167,7 +169,7 @@ cols[-1].markdown( | |
| 167 | 
             
                help="Number of tokens that the code activates on in the acts dataset.",
         | 
| 168 | 
             
            )
         | 
| 169 |  | 
| 170 | 
            -
            all_codes =  | 
| 171 | 
             
            all_codes = [
         | 
| 172 | 
             
                (cb_name, code_pr_info)
         | 
| 173 | 
             
                for cb_name, code_pr_infos in all_codes
         | 
| @@ -192,7 +194,7 @@ for cb_name, (code, p, r, acts) in all_codes: | |
| 192 | 
             
                cols[-1].write(str(acts))
         | 
| 193 |  | 
| 194 | 
             
                if code_button:
         | 
| 195 | 
            -
                     | 
| 196 | 
             
                        code,
         | 
| 197 | 
             
                        layer,
         | 
| 198 | 
             
                        head,
         | 
| @@ -200,6 +202,7 @@ for cb_name, (code, p, r, acts) in all_codes: | |
| 200 | 
             
                    )
         | 
| 201 | 
             
            if len(all_codes) == 0:
         | 
| 202 | 
             
                st.markdown(
         | 
| 203 | 
            -
                    f"<div style='text-align:center'>No codes found at recall threshold | 
|  | |
| 204 | 
             
                    unsafe_allow_html=True,
         | 
| 205 | 
             
                )
         | 
|  | |
| 24 | 
             
            gcb = st.session_state["gcb"]
         | 
| 25 |  | 
| 26 |  | 
| 27 | 
            +
            def get_example_topic_codes(example_id):
         | 
| 28 | 
            +
                """Get topic codes for the given example id."""
         | 
| 29 | 
             
                token_pos_ids = [(example_id, i) for i in range(seq_len)]
         | 
| 30 | 
             
                all_codes = []
         | 
| 31 | 
             
                for cb_name, cb in cb_acts.items():
         | 
|  | |
| 55 |  | 
| 56 |  | 
| 57 | 
             
            def find_next_example(example_id):
         | 
| 58 | 
            +
                """Find the example after `example_id` that has topic codes."""
         | 
| 59 | 
             
                initial_example_id = example_id
         | 
| 60 | 
             
                example_id += 1
         | 
| 61 | 
             
                while example_id != initial_example_id:
         | 
| 62 | 
            +
                    all_codes = get_example_topic_codes(example_id)
         | 
| 63 | 
             
                    codes_found = sum([len(code_pr_infos) for _, code_pr_infos in all_codes])
         | 
| 64 | 
             
                    if codes_found > 0:
         | 
| 65 | 
             
                        st.session_state["example_id"] = example_id
         | 
|  | |
| 80 | 
             
                switch_page("Code Browser")
         | 
| 81 |  | 
| 82 |  | 
| 83 | 
            +
            def show_examples_for_topic_code(code, layer, head, code_act_ratio=0.3):
         | 
| 84 | 
             
                """Show examples that the code activates on."""
         | 
| 85 | 
             
                ex_acts, _ = webapp_utils.get_code_acts(
         | 
| 86 | 
             
                    model_name,
         | 
|  | |
| 103 |  | 
| 104 | 
             
            is_attn = st.session_state["is_attn"]
         | 
| 105 |  | 
| 106 | 
            +
            st.markdown("## Topic Code")
         | 
| 107 | 
            +
            topic_code_description = (
         | 
| 108 | 
            +
                "Topic codes are codes that activate many different times on passages that describe a particular"
         | 
| 109 | 
            +
                " topic or concept (e.g. “fire”). This interface provides a way to search for such codes by looking"
         | 
| 110 | 
            +
                " at different examples in the dataset (ExampleID) and finding codes that activate on some fraction"
         | 
| 111 | 
            +
                " of the tokens in that example (Recall Threshold). Decrease the Recall Threshold to view more possible"
         | 
| 112 | 
            +
                " topic codes and increase it to see fewer. Click “Find Next Example” to find the next example with at"
         | 
| 113 | 
            +
                " least one code firing on that example above the Recall Threshold."
         | 
| 114 | 
             
            )
         | 
| 115 | 
            +
            st.write(topic_code_description)
         | 
| 116 |  | 
| 117 | 
             
            ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1])
         | 
| 118 | 
             
            example_id = ex_col.number_input(
         | 
|  | |
| 126 | 
             
                "Recall Threshold",
         | 
| 127 | 
             
                0.0,
         | 
| 128 | 
             
                1.0,
         | 
| 129 | 
            +
                0.2,
         | 
| 130 | 
             
                key="recall",
         | 
| 131 | 
             
                help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.",
         | 
| 132 | 
             
            )
         | 
|  | |
| 169 | 
             
                help="Number of tokens that the code activates on in the acts dataset.",
         | 
| 170 | 
             
            )
         | 
| 171 |  | 
| 172 | 
            +
            all_codes = get_example_topic_codes(example_id)
         | 
| 173 | 
             
            all_codes = [
         | 
| 174 | 
             
                (cb_name, code_pr_info)
         | 
| 175 | 
             
                for cb_name, code_pr_infos in all_codes
         | 
|  | |
| 194 | 
             
                cols[-1].write(str(acts))
         | 
| 195 |  | 
| 196 | 
             
                if code_button:
         | 
| 197 | 
            +
                    show_examples_for_topic_code(
         | 
| 198 | 
             
                        code,
         | 
| 199 | 
             
                        layer,
         | 
| 200 | 
             
                        head,
         | 
|  | |
| 202 | 
             
                    )
         | 
| 203 | 
             
            if len(all_codes) == 0:
         | 
| 204 | 
             
                st.markdown(
         | 
| 205 | 
            +
                    f"<div style='text-align:center'>No codes found at recall threshold = {recall_threshold}."
         | 
| 206 | 
            +
                    " Consider decreasing the recall threshold.</div>",
         | 
| 207 | 
             
                    unsafe_allow_html=True,
         | 
| 208 | 
             
                )
         |