hexviz / hexviz /🧬Attention_Visualization.py
Aksel Lenes
Default to white protein
c848ab7
import re
import numpy as np
import pandas as pd
import py3Dmol
import stmol
import streamlit as st
from stmol import showmol
from hexviz.attention import clean_and_validate_sequence, get_attention_pairs, get_chains
from hexviz.config import URL
from hexviz.ec_number import ECNumber
from hexviz.models import Model, ModelType
from hexviz.view import menu_items, select_model, select_pdb, select_protein
st.set_page_config(layout="centered", menu_items=menu_items)
st.title("Attention Visualization on proteins")
for k, v in st.session_state.items():
st.session_state[k] = v
models = [
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
Model(name=ModelType.ZymCTRL, layers=36, heads=16),
Model(name=ModelType.PROT_BERT, layers=30, heads=16),
Model(name=ModelType.PROT_T5, layers=24, heads=32),
]
with st.expander("Input a PDB id, upload a PDB file or input a sequence", expanded=True):
pdb_id = select_pdb()
uploaded_file = st.file_uploader("2.Upload PDB", type=["pdb"])
input_sequence = st.text_area("3.Input sequence", "", key="input_sequence", max_chars=400)
sequence, error = clean_and_validate_sequence(input_sequence)
if error:
st.error(error)
pdb_str, structure, source = select_protein(pdb_id, uploaded_file, sequence)
st.write(f"Visualizing: {source}")
st.sidebar.markdown(
"""
Configure visualization
---
"""
)
chains = get_chains(structure)
if "selected_chains" not in st.session_state:
st.session_state.selected_chains = chains
selected_chains = st.sidebar.multiselect(
label="Select Chain(s)", options=chains, key="selected_chains"
)
if "show_ligands" not in st.session_state:
st.session_state.show_ligands = True
show_ligands = st.sidebar.checkbox("Show ligands", key="show_ligands")
if "color_protein" not in st.session_state:
st.session_state.color_protein = False
color_protein = st.sidebar.checkbox("Color protein", key="color_protein")
st.sidebar.markdown(
"""
Attention parameters
---
"""
)
min_attn = st.sidebar.slider("Minimum attention", min_value=0.0, max_value=0.4, value=0.1)
n_highest_resis = st.sidebar.number_input(
"Num highest attention resis to label", value=2, min_value=1, max_value=100
)
label_highest = st.sidebar.checkbox("Label highest attention residues", value=True)
sidechain_highest = st.sidebar.checkbox("Show sidechains", value=True)
with st.sidebar.expander("Label residues manually"):
hl_chain = st.selectbox(label="Chain to label", options=selected_chains, index=0)
hl_resi_list = st.multiselect(label="Selected Residues", options=list(range(1, 5000)))
label_resi = st.checkbox(label="Label Residues", value=True)
left, mid, right = st.columns(3)
with left:
selected_model = select_model(models)
with mid:
if "selected_layer" not in st.session_state:
st.session_state["selected_layer"] = 5
layer_one = (
st.selectbox(
"Layer",
options=[i for i in range(1, selected_model.layers + 1)],
key="selected_layer",
)
or 5
)
layer = layer_one - 1
with right:
if "selected_head" not in st.session_state:
st.session_state["selected_head"] = 1
head_one = st.selectbox(
"Head",
options=[i for i in range(1, selected_model.heads + 1)],
key="selected_head",
)
head = head_one - 1
ec_number = ""
if selected_model.name == ModelType.ZymCTRL:
st.sidebar.markdown(
"""
ZymCTRL EC number
---
"""
)
try:
ec_number = structure.header["compound"]["1"]["ec"]
except KeyError:
pass
ec_number = st.sidebar.text_input("Enzyme Comission number (EC)", ec_number)
# Validate EC number
if not re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", ec_number):
st.sidebar.error(
"Please enter a valid Enzyme Commission number in the format of 4 integers separated by periods (e.g., 1.2.3.21)"
)
if ec_number:
if selected_chains:
shown_chains = [ch for ch in structure.get_chains() if ch.id in selected_chains]
else:
shown_chains = list(structure.get_chains())
EC_tags = []
colors = ["blue", "green", "orange", "red"]
radius = 1
EC_numbers = ec_number.split(".")
for ch in shown_chains:
first_residues = []
i = 1
while len(first_residues) < 2:
try:
first_residues.append(ch[i]["CA"].coord.tolist())
except KeyError:
pass
i += 1
res_1, res_2 = first_residues
# Calculate the vector from res_1 to res_2
vector = [res_2[i] - res_1[i] for i in range(3)]
# Reverse the vector
reverse_vector = [-v for v in vector]
# Normalize the reverse vector
reverse_vector_normalized = np.array(reverse_vector) / np.linalg.norm(reverse_vector)
# Reveres coordinates to have the first EC number be furthers from the start of the
# protein like it is in the sequence with the EC serial number next to the sequence
# and the EC main class at the the very start, further from the sequence.
coordinates = reversed(
[
[res_1[j] + i * 2 * radius * reverse_vector_normalized[j] for j in range(3)]
for i in range(4)
]
)
EC_tag = [
ECNumber(number=num, coordinate=coord, color=color, radius=radius)
for num, coord, color in zip(EC_numbers, coordinates, colors)
]
EC_tags.append(EC_tag)
EC_colored = [f":{color}[{num}]" for num, color in zip(EC_numbers, colors)]
st.sidebar.write("Visualized as colored spheres: " + ".".join(EC_colored))
attention_pairs, top_residues = get_attention_pairs(
pdb_str=pdb_str,
chain_ids=selected_chains,
layer=layer,
head=head,
threshold=min_attn,
model_type=selected_model.name,
top_n=n_highest_resis,
ec_numbers=EC_tags if ec_number else None,
)
sorted_by_attention = sorted(attention_pairs, key=lambda x: x[0], reverse=True)
def get_3dview(pdb):
xyzview = py3Dmol.view(height=800, width=800) # TODO you can set the pixel dims here!
xyzview.addModel(pdb_str, "pdb")
xyzview.setStyle({"cartoon": {"color": "spectrum" if color_protein else "white"}})
stmol.add_hover(xyzview, backgroundColor="black", fontColor="white")
# Show all ligands as stick (heteroatoms)
if show_ligands:
xyzview.addStyle({"hetflag": True}, {"stick": {"radius": 0.2}})
# If no chains are selected, show all chains
if selected_chains:
hidden_chains = [x for x in chains if x not in selected_chains]
for chain in hidden_chains:
xyzview.setStyle({"chain": chain}, {"cross": {"hidden": "true"}})
# Hide ligands for chain too
xyzview.addStyle({"chain": chain, "hetflag": True}, {"cross": {"hidden": "true"}})
if len(selected_chains) == 1:
xyzview.zoomTo({"chain": f"{selected_chains[0]}"})
else:
xyzview.zoomTo()
for att_weight, first, second in attention_pairs:
stmol.add_cylinder(
xyzview,
start=first,
end=second,
cylradius=att_weight,
cylColor="red",
dashed=False,
)
if selected_model.name == ModelType.ZymCTRL and ec_number:
for EC_tag in EC_tags:
for EC_num in EC_tag:
stmol.add_sphere(
xyzview,
spcenter=EC_num.coordinate,
radius=EC_num.radius,
spColor=EC_num.color,
)
if label_resi:
for hl_resi in hl_resi_list:
xyzview.addResLabels(
{"chain": hl_chain, "resi": hl_resi},
{
"backgroundColor": "lightgray",
"fontColor": "black",
"backgroundOpacity": 0.5,
},
)
if label_highest:
for _, chain, res in top_residues:
one_indexed_res = res + 1
xyzview.addResLabels(
{"chain": chain, "resi": one_indexed_res},
{
"backgroundColor": "lightgray",
"fontColor": "black",
"backgroundOpacity": 0.5,
},
)
if sidechain_highest:
xyzview.addStyle({"chain": chain, "resi": res}, {"stick": {"radius": 0.2}})
return xyzview
xyzview = get_3dview(pdb_id)
showmol(xyzview, height=800, width=800)
st.markdown(
f"""
Pick a PDB ID, layer and head to visualize attention from the selected protein language model ({selected_model.name.value}).
""",
unsafe_allow_html=True,
)
chain_dict = {f"{chain.id}": list(chain.get_residues()) for chain in list(structure.get_chains())}
data = []
for fraction_of_total_attention, chain, resi in top_residues:
try:
res = chain_dict[chain][resi]
except KeyError:
continue
pct_of_total_attention = round(fraction_of_total_attention * 100, 3)
el = (pct_of_total_attention, f"{res.resname:3}{res.id[1]}({chain})")
data.append(el)
df = pd.DataFrame(data, columns=["% of total attention", "Residue"])
df = df.style.format(
{"% of total attention": "{:.3f}"} # Set 3 decimal places for "% of total attention"
)
st.markdown(
f"The {n_highest_resis} residues (per chain) with the highest attention to them are labeled in the visualization and listed here:"
)
st.table(df)
st.markdown(
f"""
### Check out the other pages
<a href="{URL}Identify_Interesting_Heads" target="_self">🗺️Identify Interesting Heads</a> plots attention heatmapt for many heads and
layers from a model. This can help you pick what specific attention heads to look at for your protein.
<a href="{URL}Birds_Eye_View" target="_self">🦅Birds Eye View</a> visualizes attention on structures in a large grid, see how attention patterns
change through layers and varies accross several heads.
<a href="{URL}Documentation" target="_self">📄Documentation</a> has information on protein
language models, attention analysis and hexviz.""",
unsafe_allow_html=True,
)
"""
The attention visualization is inspired by [provis](https://github.com/salesforce/provis#provis-attention-visualizer).
"""