Spaces:

aksell commited on
Commit
7a11152
1 Parent(s): e9d64be

Don't skip heads and layers by deafult

Browse files
hexviz/Attention_Visualization.py CHANGED
@@ -58,7 +58,6 @@ label_highest = st.sidebar.checkbox("Label highest attention pairs", value=True)
58
 
59
 
60
 
61
-
62
  left, mid, right = st.columns(3)
63
  with left:
64
  selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
@@ -70,7 +69,6 @@ with right:
70
  head_one = st.number_input("Head", value=1, min_value=1, max_value=selected_model.heads)
71
  head = head_one - 1
72
 
73
-
74
  if selected_model.name == ModelType.ZymCTRL:
75
  try:
76
  ec_class = structure.header["compound"]["1"]["ec"]
 
58
 
59
 
60
 
 
61
  left, mid, right = st.columns(3)
62
  with left:
63
  selected_model_name = st.selectbox("Select a model", [model.name.value for model in models], index=0)
 
69
  head_one = st.number_input("Head", value=1, min_value=1, max_value=selected_model.heads)
70
  head = head_one - 1
71
 
 
72
  if selected_model.name == ModelType.ZymCTRL:
73
  try:
74
  ec_class = structure.header["compound"]["1"]["ec"]
hexviz/pages/Identify_interesting_Heads.py CHANGED
@@ -41,9 +41,9 @@ slice_start, slice_end = st.sidebar.slider("Sequence", min_value=1, max_value=l,
41
  # slice_end = st.sidebar.number_input(f"Section end(1-{l})",value=50, min_value=1, max_value=l)
42
  truncated_sequence = sequence[slice_start-1:slice_end]
43
 
44
- head_range = st.sidebar.slider("Heads to plot", min_value=1, max_value=selected_model.heads, value=(1, selected_model.heads), step=1)
45
- layer_range = st.sidebar.slider("Layers to plot", min_value=1, max_value=selected_model.layers, value=(1, selected_model.layers), step=1)
46
- step_size = st.sidebar.number_input("Optional step size to skip heads and layers", value=2, min_value=1, max_value=selected_model.layers)
47
  layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
48
  head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
49
 
 
41
  # slice_end = st.sidebar.number_input(f"Section end(1-{l})",value=50, min_value=1, max_value=l)
42
  truncated_sequence = sequence[slice_start-1:slice_end]
43
 
44
+ head_range = st.sidebar.slider("Heads to plot", min_value=1, max_value=selected_model.heads, value=(1, selected_model.heads//2), step=1)
45
+ layer_range = st.sidebar.slider("Layers to plot", min_value=1, max_value=selected_model.layers, value=(1, selected_model.layers//2), step=1)
46
+ step_size = st.sidebar.number_input("Optional step size to skip heads and layers", value=1, min_value=1, max_value=selected_model.layers)
47
  layer_sequence = list(range(layer_range[0]-1, layer_range[1], step_size))
48
  head_sequence = list(range(head_range[0]-1, head_range[1], step_size))
49