aksell's picture
Cache parsing and folding
4d025a2
import re
import time
import py3Dmol
import requests
import stmol
import streamlit as st
st.set_page_config(layout="wide")
st.header("Grid visualization of folded protein sequences")
default_sequences = """MINDLLDISRIISGKMTLDRAEVNLTAIARQVVEEQRQAAEAKSIQLLCSTPDTNHYVFGDFDRLKQTLWNLLSNAVKFTPSGGTVELELGY
MQGDSSISSSNRMFTLCKPLTVANETSTLSTTRNSKSNKRVSKQRVNLAESPERNAPSPASIKTNETEEFSTIKTTNNEVLGYEPNYVSYDF
MSTHVSLENTLASLQATFFSLEARHTALETQLLSTRTELAATKQELVRVQAEISRADAQAQDLKAQILTLKEKADQAEVEAAAATQRAEESQ
MVLLSTGPLPILFLGPSLAELNQKYQVVSDTLLRFTNTVTFNTLKFLGSDS
MNNDEQPFIMSTSGYAGNTTSSMNSTSDFNTNNKSNTWSNRFSNFIAYFSGVGWFIGAISVIFFIIYVIVFLSRKTKPSGQKQYSRTERNNR
MEAVYSFTITETGTGTVEVTPLDRTISGADIVYPPDTACVPLTVQPVINANGTWTLGSGCTGHFSVDTTGHVNCLTGGFGAAGVHTVIYTVE
MGLTTSGGARGFCSLAVLQELVPRPELLFVIDRAFHSGKHAVDMQVVDQEGLGDGVATLLYAHQGLYTCLLQAEARLLGREWAAVPALEPNF
MGLTTSGGARGFCSLAVLQELVPRPELLFVIDRAFHSGKHAVDMQVVDQEGLGDGVATLLYAHQGLYTCLLQAEARLLGREWAAVPALEPNF
MGAAGYTGSLILAALKQNPDIAVYALNRNDEKLKDVCGQYSNLKGQVCDLSNESQVEALLSGPRKTVVNLVGPYSFYGSRVLNACIEANCHY
"""
input_sequences = st.text_area("Sequences separated by a newline (max 400 resis each)", default_sequences)
@st.cache_data
def get_sequences(sequences_string):
sequences = []
# Parse and clean input sequences
for seq in sequences_string.split("\n"):
seq = seq.strip()
if len(seq) > 400:
seq = seq[:400]
seq = re.sub("[^ACDEFGHIKLMNPQRSTVWY]", "", seq)
if len(seq) > 0:
sequences.append(seq)
return sequences
sequences = get_sequences(input_sequences)
st.write(f"Found {len(sequences)} valid sequences")
pdb_strings = []
url = "https://api.esmatlas.com/foldSequence/v1/pdb/"
@st.cache_data
def get_pdb(sequence):
retries = 0
pdb_str = None
while retries < 3 and pdb_str is None:
response = requests.post(url, data=sequence)
pdb_str = response.text
if pdb_str == "INTERNAL SERVER ERROR":
retries += 1
time.sleep(0.1)
pdb_str = None
return pdb_str
# Fold sequences with ESMfold
for seq in sequences:
if pdb := get_pdb(seq):
pdb_strings.append(pdb)
else:
st.write(f"Failed to retrieve PDB structure from ESMFold for {seq}")
num_pdb_structures = len(pdb_strings)
if num_pdb_structures == 0:
grid_columns = 1
grid_rows = 1
else:
grid_columns = int(num_pdb_structures ** 0.5)
if grid_columns ** 2 < num_pdb_structures:
grid_columns += 1
grid_columns = min(grid_columns, 12)
grid_rows = (num_pdb_structures + grid_columns - 1) // grid_columns
# Get the width of the viewer from the sidebar
viewer_width = int(st.sidebar.number_input("Viewer Width", 100, 2000, 900))
# Calculate the width and height of each grid cell
grid_cell_width = int(viewer_width / grid_columns)
grid_cell_height = grid_cell_width
viewer_height = grid_rows * grid_cell_height
xyzview = py3Dmol.view(
width=viewer_width,
height=viewer_height,
linked=False,
viewergrid=(grid_rows, grid_columns),
)
for row in range(grid_rows):
for col in range(grid_columns):
index = row * grid_columns + col
pdb_string = pdb_strings[index] if index < len(pdb_strings) else None
if pdb_string:
xyzview.addModel(pdb_string, "pdb", viewer=(row, col))
xyzview.setStyle({"cartoon": {"color": "spectrum"}})
# Focus the chains we added
xyzview.zoomTo()
# Draw our grid!
stmol.showmol(xyzview, height=viewer_height, width=viewer_width)