Spaces:
Sleeping
Sleeping
Select what tokens to label in the heatmap
Browse files
hexviz/pages/1_🗺️Identify_Interesting_Heads.py
CHANGED
@@ -130,19 +130,17 @@ with mid:
|
|
130 |
)
|
131 |
head = head_one - 1
|
132 |
with right:
|
133 |
-
st.
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
### <a href="{URL}Attention_Visualization" target="_self">🧬View attention from head on structure</a>
|
138 |
-
""",
|
139 |
-
unsafe_allow_html=True,
|
140 |
-
)
|
141 |
|
142 |
if selected_model.name == ModelType.PROT_T5:
|
143 |
# Remove leading underscores from residue tokens
|
144 |
tokens = [token[1:] if str(token) != "</s>" else token for token in tokens]
|
145 |
|
|
|
|
|
|
|
146 |
|
147 |
single_head_fig = plot_single_heatmap(attention, layer, head, tokens=tokens)
|
148 |
st.pyplot(single_head_fig)
|
|
|
130 |
)
|
131 |
head = head_one - 1
|
132 |
with right:
|
133 |
+
if "label_tokens" not in st.session_state:
|
134 |
+
st.session_state.label_tokens = []
|
135 |
+
tokens_to_label = st.multiselect("Label tokens", options=tokens, key="label_tokens")
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
if selected_model.name == ModelType.PROT_T5:
|
138 |
# Remove leading underscores from residue tokens
|
139 |
tokens = [token[1:] if str(token) != "</s>" else token for token in tokens]
|
140 |
|
141 |
+
if len(tokens_to_label) > 0:
|
142 |
+
tokens = [token if token in tokens_to_label else "" for token in tokens]
|
143 |
+
|
144 |
|
145 |
single_head_fig = plot_single_heatmap(attention, layer, head, tokens=tokens)
|
146 |
st.pyplot(single_head_fig)
|
hexviz/view.py
CHANGED
@@ -36,6 +36,8 @@ def clear_model_state():
|
|
36 |
del st.session_state.plot_layers
|
37 |
if "plot_heads" in st.session_state:
|
38 |
del st.session_state.plot_heads
|
|
|
|
|
39 |
|
40 |
|
41 |
def select_model(models):
|
|
|
36 |
del st.session_state.plot_layers
|
37 |
if "plot_heads" in st.session_state:
|
38 |
del st.session_state.plot_heads
|
39 |
+
if "label_tokens" in st.session_state:
|
40 |
+
del st.session_state.label_tokens
|
41 |
|
42 |
|
43 |
def select_model(models):
|