from io import StringIO import streamlit as st from Bio.PDB import PDBParser from hexviz.attention import get_pdb_file, get_pdb_from_seq menu_items = { "Get Help": "https://huggingface.co/spaces/aksell/hexviz/discussions/new", "Report a bug": "https://huggingface.co/spaces/aksell/hexviz/discussions/new", "About": "Created by [Aksel Lenes](https://github.com/aksell/) from Noelia Ferruz's group at the Institute of Molecular Biology of Barcelona. Read more at https://www.aiproteindesign.com/", } def get_selecte_model_index(models): selected_model_name = st.session_state.get("selected_model_name", None) if selected_model_name is None: return 0 else: return next( (i for i, model in enumerate(models) if model.name.value == selected_model_name), None, ) def clear_model_state(): if "plot_heads" in st.session_state: del st.session_state.plot_heads if "plot_layers" in st.session_state: del st.session_state.plot_layers if "selected_head" in st.session_state: del st.session_state.selected_head if "selected_layer" in st.session_state: del st.session_state.selected_layer if "plot_layers" in st.session_state: del st.session_state.plot_layers if "plot_heads" in st.session_state: del st.session_state.plot_heads if "label_tokens" in st.session_state: del st.session_state.label_tokens def select_model(models): if "selected_model_name" not in st.session_state: st.session_state.selected_model_name = models[0].name.value selected_model_name = st.selectbox( "Select model", [model.name.value for model in models], key="selected_model_name", on_change=clear_model_state, ) select_model = next( (model for model in models if model.name.value == selected_model_name), None ) return select_model def clear_pdb_state(): if "selected_chains" in st.session_state: del st.session_state.selected_chains if "selected_chain" in st.session_state: del st.session_state.selected_chain if "sequence_slice" in st.session_state: del st.session_state.sequence_slice if "uploaded_pdb_str" in st.session_state: del st.session_state.uploaded_pdb_str def select_pdb(): if "pdb_id" not in st.session_state: st.session_state.pdb_id = "2FZ5" pdb_id = st.text_input(label="1.PDB ID", key="pdb_id", on_change=clear_pdb_state) return pdb_id def select_protein(pdb_code, uploaded_file, input_sequence): # We get the pdb from 1 of 3 places: # 1. Cached pdb from session storage # 2. PDB file from uploaded file # 3. PDB file fetched based on the pdb_code input parser = PDBParser() if uploaded_file is not None: pdb_str = uploaded_file.read().decode("utf-8") st.session_state["uploaded_pdb_str"] = pdb_str source = f"uploaded pdb file {uploaded_file.name}" structure = parser.get_structure("Userfile", StringIO(pdb_str)) elif input_sequence: pdb_str = get_pdb_from_seq(str(input_sequence)) if not pdb_str: st.error("ESMfold error, unable to fold sequence") return None, None, None else: structure = parser.get_structure("ESMFold", StringIO(pdb_str)) if "selected_chains" in st.session_state: del st.session_state.selected_chains source = "Input sequence + ESM-fold" elif "uploaded_pdb_str" in st.session_state: pdb_str = st.session_state.uploaded_pdb_str source = "Uploaded file stored in cache" structure = parser.get_structure("userfile", StringIO(pdb_str)) else: file = get_pdb_file(pdb_code) pdb_str = file.read() source = f"PDB ID: {pdb_code}" structure = parser.get_structure(pdb_code, StringIO(pdb_str)) return pdb_str, structure, source def select_heads_and_layers(sidebar, model): sidebar.markdown( """ Select Heads and Layers --- """ ) if "plot_heads" not in st.session_state: st.session_state.plot_heads = (1, model.heads // 2) head_range = sidebar.slider( "Heads to plot", min_value=1, max_value=model.heads, key="plot_heads", step=1 ) if "plot_layers" not in st.session_state: st.session_state.plot_layers = (1, model.layers // 2) layer_range = sidebar.slider( "Layers to plot", min_value=1, max_value=model.layers, key="plot_layers", step=1 ) if "plot_step_size" not in st.session_state: st.session_state.plot_step_size = 1 step_size = sidebar.number_input( "Optional step size to skip heads and layers", key="plot_step_size", min_value=1, max_value=model.layers, ) layer_sequence = list(range(layer_range[0] - 1, layer_range[1], step_size)) head_sequence = list(range(head_range[0] - 1, head_range[1], step_size)) return layer_sequence, head_sequence def select_sequence_slice(sequence_length): st.sidebar.markdown( """ Sequence segment to plot --- """ ) if "sequence_slice" not in st.session_state: st.session_state.sequence_slice = (1, min(50, sequence_length)) slice = st.sidebar.slider( "Sequence", key="sequence_slice", min_value=1, max_value=sequence_length, step=1 ) return slice