Don't skip heads and layers by deafult
Browse files
hexviz/Attention_Visualization.py
CHANGED
@@ -58,7 +58,6 @@ label_highest = st.sidebar.checkbox("Label highest attention pairs", value=True)
|
|
58 |
|
59 |
|
60 |
|
61 |
-
|
62 |
left, mid, right = st.columns(3)
|
63 |
with left:
|
64 |
selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
|
@@ -70,7 +69,6 @@ with right:
|
|
70 |
head_one = st.number_input("Head", value=1, min_value=1, max_value=selected_model.heads)
|
71 |
head = head_one - 1
|
72 |
|
73 |
-
|
74 |
if selected_model.name == ModelType.ZymCTRL:
|
75 |
try:
|
76 |
ec_class = structure.header["compound"]["1"]["ec"]
|
|
|
58 |
|
59 |
|
60 |
|
|
|
61 |
left, mid, right = st.columns(3)
|
62 |
with left:
|
63 |
selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
|
|
|
69 |
head_one = st.number_input("Head", value=1, min_value=1, max_value=selected_model.heads)
|
70 |
head = head_one - 1
|
71 |
|
|
|
72 |
if selected_model.name == ModelType.ZymCTRL:
|
73 |
try:
|
74 |
ec_class = structure.header["compound"]["1"]["ec"]
|
hexviz/pages/Identify_interesting_Heads.py
CHANGED
@@ -41,9 +41,9 @@ slice_start, slice_end = st.sidebar.slider("Sequence", min_value=1, max_value=l,
|
|
41 |
# slice_end = st.sidebar.number_input(f"Section end(1-{l})",value=50, min_value=1, max_value=l)
|
42 |
truncated_sequence = sequence[slice_start-1:slice_end]
|
43 |
|
44 |
-
head_range = st.sidebar.slider("Heads to plot", min_value=1, max_value=selected_model.heads, value=(1, selected_model.heads), step=1)
|
45 |
-
layer_range = st.sidebar.slider("Layers to plot", min_value=1, max_value=selected_model.layers, value=(1, selected_model.layers), step=1)
|
46 |
-
step_size = st.sidebar.number_input("Optional step size to skip heads and layers", value=
|
47 |
layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
|
48 |
head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
|
49 |
|
|
|
41 |
# slice_end = st.sidebar.number_input(f"Section end(1-{l})",value=50, min_value=1, max_value=l)
|
42 |
truncated_sequence = sequence[slice_start-1:slice_end]
|
43 |
|
44 |
+
head_range = st.sidebar.slider("Heads to plot", min_value=1, max_value=selected_model.heads, value=(1, selected_model.heads//2), step=1)
|
45 |
+
layer_range = st.sidebar.slider("Layers to plot", min_value=1, max_value=selected_model.layers, value=(1, selected_model.layers//2), step=1)
|
46 |
+
step_size = st.sidebar.number_input("Optional step size to skip heads and layers", value=1, min_value=1, max_value=selected_model.layers)
|
47 |
layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
|
48 |
head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
|
49 |
|