aksell commited on
Commit
4a93494
1 Parent(s): cf8177e

Select what tokens to label in the heatmap

Browse files
hexviz/pages/1_🗺️Identify_Interesting_Heads.py CHANGED
@@ -130,19 +130,17 @@ with mid:
130
  )
131
  head = head_one - 1
132
  with right:
133
- st.markdown(
134
- f"""
135
-
136
-
137
- ### <a href="{URL}Attention_Visualization" target="_self">🧬View attention from head on structure</a>
138
- """,
139
- unsafe_allow_html=True,
140
- )
141
 
142
  if selected_model.name == ModelType.PROT_T5:
143
  # Remove leading underscores from residue tokens
144
  tokens = [token[1:] if str(token) != "</s>" else token for token in tokens]
145
 
 
 
 
146
 
147
  single_head_fig = plot_single_heatmap(attention, layer, head, tokens=tokens)
148
  st.pyplot(single_head_fig)
 
130
  )
131
  head = head_one - 1
132
  with right:
133
+ if "label_tokens" not in st.session_state:
134
+ st.session_state.label_tokens = []
135
+ tokens_to_label = st.multiselect("Label tokens", options=tokens, key="label_tokens")
 
 
 
 
 
136
 
137
  if selected_model.name == ModelType.PROT_T5:
138
  # Remove leading underscores from residue tokens
139
  tokens = [token[1:] if str(token) != "</s>" else token for token in tokens]
140
 
141
+ if len(tokens_to_label) > 0:
142
+ tokens = [token if token in tokens_to_label else "" for token in tokens]
143
+
144
 
145
  single_head_fig = plot_single_heatmap(attention, layer, head, tokens=tokens)
146
  st.pyplot(single_head_fig)
hexviz/view.py CHANGED
@@ -36,6 +36,8 @@ def clear_model_state():
36
  del st.session_state.plot_layers
37
  if "plot_heads" in st.session_state:
38
  del st.session_state.plot_heads
 
 
39
 
40
 
41
  def select_model(models):
 
36
  del st.session_state.plot_layers
37
  if "plot_heads" in st.session_state:
38
  del st.session_state.plot_heads
39
+ if "label_tokens" in st.session_state:
40
+ del st.session_state.label_tokens
41
 
42
 
43
  def select_model(models):