import streamlit as st import torch import esm import requests import matplotlib.pyplot as plt from clickhouse_connect import get_client import random from collections import Counter from tqdm import tqdm from statistics import mean import biotite.structure.io as bsio import torch import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from stmol import * import py3Dmol # from streamlit_3Dmol import component_3dmol import scipy from sklearn.model_selection import GridSearchCV, train_test_split from sklearn.decomposition import PCA from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.linear_model import LogisticRegression, SGDRegressor from sklearn.pipeline import Pipeline from streamlit.components.v1 import html def init_esm(): msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S() msa_transformer = msa_transformer.eval() return msa_transformer, msa_transformer_alphabet @st.experimental_singleton(show_spinner=False) def init_db(): """ Initialize the Database Connection Returns: meta_field: Meta field that records if an image is viewed client: Database connection object """ r = parse("{http_pre}://{host}:{port}", st.secrets["DB_URL"]) client = get_client( host=r['host'], port=r['port'], user=st.secrets["USER"], password=st.secrets["PASSWD"], interface=r['http_pre'], ) meta_field = {} return meta_field, Client def perdict_contact_visualization(seq, model, batch_converter): data = [ ("protein1", seq), ] batch_labels, batch_strs, batch_tokens = batch_converter(data) # Extract per-residue representations (on CPU) with torch.no_grad(): results = model(batch_tokens, repr_layers=[12], return_contacts=True) token_representations = results["representations"][12] # Generate per-sequence representations via averaging # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1. sequence_representations = [] for i, (_, seq) in enumerate(data): sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0)) # Look at the unsupervised self-attention map contact predictions for (_, seq), attention_contacts in zip(data, results["contacts"]): fig, ax = plt.subplots() ax.matshow(attention_contacts[: len(seq), : len(seq)]) # fig.set_facecolor('black') return fig def visualize_3D_Coordinates(coords): xs = [] ys = [] zs = [] for i in coords: xs.append(i[0]) ys.append(i[1]) zs.append(i[2]) fig = plt.figure(figsize=(10,10)) ax = fig.add_subplot(111, projection='3d') ax.set_title('3D coordinates of $C_{b}$ backbone structure') N = len(coords) for i in range(len(coords) - 1): ax.plot( xs[i:i+2], ys[i:i+2], zs[i:i+2], color=plt.cm.viridis(i/N), marker='o' ) return fig def render_mol(pdb): pdbview = py3Dmol.view() pdbview.addModel(pdb,'pdb') pdbview.setStyle({'cartoon':{'color':'spectrum'}}) pdbview.setBackgroundColor('white')#('0xeeeeee') pdbview.zoomTo() pdbview.zoom(2, 800) pdbview.spin(True) showmol(pdbview, height = 500,width=800) def esm_search(model, sequnce, batch_converter,top_k=5): data = [ ("protein1", sequnce), ] batch_labels, batch_strs, batch_tokens = batch_converter(data) # Extract per-residue representations (on CPU) with torch.no_grad(): results = model(batch_tokens, repr_layers=[12], return_contacts=True) token_representations = results["representations"][12] token_list = token_representations.tolist()[0][0][0] result = st.session_state.client.query("SELECT seq, distance(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768 ORDER BY dist LIMIT 500") result = [r for r in result.named_results()] result_temp_seq = [] for i in result: # result_temp_coords = i['seq'] result_temp_seq.append(i['seq']) result_temp_seq = list(set(result_temp_seq)) return result_temp_seq def show_protein_structure(sequence): headers = { 'Content-Type': 'application/x-www-form-urlencoded', } response = requests.post('https://api.esmatlas.com/foldSequence/v1/pdb/', headers=headers, data=sequence) name = sequence[:3] + sequence[-3:] pdb_string = response.content.decode('utf-8') with open('predicted.pdb', 'w') as f: f.write(pdb_string) struct = bsio.load_structure('predicted.pdb', extra_fields=["b_factor"]) b_value = round(struct.b_factor.mean(), 4) render_mol(pdb_string) def KNN_search(sequence): model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() batch_converter = alphabet.get_batch_converter() model.eval() data = [("protein1", sequence), ] batch_labels, batch_strs, batch_tokens = batch_converter(data) batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) with torch.no_grad(): results = model(batch_tokens, repr_layers=[33], return_contacts=True) token_representations = results["representations"][33] token_list = token_representations.tolist()[0][0] print(token_list) result = st.session_state.client.query("SELECT activity, distance(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer ORDER BY dist LIMIT 10") result = [r for r in result.named_results()] result_temp_activity = [] for i in result: # print(result_temp_seq) result_temp_activity.append(i['activity']) res_1 = sum(result_temp_activity)/len(result_temp_activity) return res_1 def train_test_split_PCA(dataset): ys = [] Xs = [] FASTA_PATH = '/root/xuying_experiments/esm-main/P62593.fasta' EMB_PATH = '/root/xuying_experiments/esm-main/P62593_reprs' for header, _seq in esm.data.read_fasta(FASTA_PATH): scaled_effect = header.split('|')[-1] ys.append(float(scaled_effect)) fn = f'{EMB_PATH}/{header}.pt' embs = torch.load(fn) Xs.append(embs['mean_representations'][34]) Xs = torch.stack(Xs, dim=0).numpy() train_size = 0.8 Xs_train, Xs_test, ys_train, ys_test = train_test_split(Xs, ys, train_size=train_size, random_state=42) return Xs_train, Xs_test, ys_train, ys_test def PCA_visual(Xs_train): num_pca_components = 60 pca = PCA(num_pca_components) Xs_train_pca = pca.fit_transform(Xs_train) fig_dims = (4, 4) fig, ax = plt.subplots(figsize=fig_dims) ax.set_title('Visualize Embeddings') sc = ax.scatter(Xs_train_pca[:,0], Xs_train_pca[:,1], c=ys_train, marker='.') ax.set_xlabel('PCA first principal component') ax.set_ylabel('PCA second principal component') plt.colorbar(sc, label='Variant Effect') return fig def KNN_trainings(Xs_train, Xs_test, ys_train, ys_test): num_pca_components = 60 knn_grid = [ { 'model': [KNeighborsRegressor()], 'model__n_neighbors': [5, 10], 'model__weights': ['uniform', 'distance'], 'model__algorithm': ['ball_tree', 'kd_tree', 'brute'], 'model__leaf_size' : [15, 30], 'model__p' : [1, 2], }] cls_list = [KNeighborsRegressor] param_grid_list = [knn_grid] pipe = Pipeline( steps = ( ('pca', PCA(num_pca_components)), ('model', KNeighborsRegressor()) ) ) result_list = [] grid_list = [] for cls_name, param_grid in zip(cls_list, param_grid_list): print(cls_name) grid = GridSearchCV( estimator = pipe, param_grid = param_grid, scoring = 'r2', verbose = 1, n_jobs = -1 # use all available cores ) grid.fit(Xs_train, ys_train) # print(Xs_train, ys_train) result_list.append(pd.DataFrame.from_dict(grid.cv_results_)) grid_list.append(grid) dataframe = pd.DataFrame(result_list[0].sort_values('rank_test_score')[:5]) return dataframe[['param_model','params','param_model__algorithm','mean_test_score','rank_test_score']] st.markdown(""" """, unsafe_allow_html=True) messages = [ f""" Evolutionary-scale prediction of atomic level protein structure ESM is a high-capacity Transformer trained with protein sequences \ as input. After training, the secondary and tertiary structure, \ function, homology and other information of the protein are in the feature representation output by the model.\ Check out https://esmatlas.com/ for more information. We have 120k proteins features stored in our database. The app uses MyScale to store and query protein sequence using vector search. """ ] @st.experimental_singleton(show_spinner=False) def init_random_query(): xq = np.random.rand(DIMS).tolist() return xq, xq.copy() with st.spinner("Connecting DB..."): st.session_state.meta, st.session_state.client = init_db() with st.spinner("Loading Models..."): # Initialize SAGE model if 'xq' not in st.session_state: st.session_state.model, st.session_state.alphabet = init_esm() batch_converter = st.session_state.alphabet.get_batch_converter() st.session_state['batch'] = batch_converter st.session_state.batch_converter = st.session_state.alphabet.get_batch_converter() st.session_state.query_num = 0 if 'xq' not in st.session_state: # If it's a fresh start if st.session_state.query_num < len(messages): msg = messages[0] else: msg = messages[-1] with st.container(): st.title("Evolutionary Scale Modeling") start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty()] start[0].info(msg) function_list = ('self-contact prediction', 'search the database for similar proteins', 'activity prediction with similar proteins', 'PDB viewer') option = st.selectbox('Application options', function_list) st.session_state.db_name_ref = 'default.esm_protein' if option == function_list[0]: sequence = st.text_input('protein sequence(Capital letters only)', '') if st.button('Cas9 Enzyme'): sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' elif st.button('PETase'): sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' if sequence: st.write('') start[2] = st.pyplot(perdict_contact_visualization(sequence, model, batch_converter)) expander = st.expander("See explanation") expander.text("""Contact prediction is based on a logistic regression over the model's attention maps. \ This methodology is based on ICLR 2021 paper, Transformer protein language models are unsupervised structure learners. (Rao et al. 2020) The MSA Transformer (ESM-MSA-1) takes a multiple sequence alignment (MSA) as input, and uses the tied row self-attention maps in the same way.""") st.session_state['xq'] = st.session_state.model elif option == function_list[1]: sequence = st.text_input('protein sequence(Capital letters only)', '') st.write('Try an example:') if st.button('Cas9 Enzyme'): sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' elif st.button('PETase'): sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' if sequence: st.write('you have entered: ', sequence) result_temp_seq = esm_search(model, sequence, esm_search,top_k=5) st.text('search result: ') # tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"]) if st.button(result_temp_seq[0]): print(result_temp_seq[0]) elif st.button(result_temp_seq[1]): print(result_temp_seq[1]) elif st.button(result_temp_seq[2]): print(result_temp_seq[2]) elif st.button(result_temp_seq[3]): print(result_temp_seq[3]) elif st.button(result_temp_seq[4]): print(result_temp_seq[4]) start[2] = st.pyplot(visualize_3D_Coordinates(result_temp_coords).figure) st.session_state['xq'] = st.session_state.model elif option == function_list[2]: st.text('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') sequence = st.text_input('protein sequence', '') st.write('Try an example:') if st.button('Cas9 Enzyme'): sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' elif st.button('PETase'): sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' elif option == function_list[3]: id_PDB = st.text_input('enter PDB ID', '') residues_marker = st.text_input('residues class', '') if residues_marker: start[3] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) else: start[3] = showmol(render_pdb(id = id_PDB)) st.session_state['xq'] = st.session_state.model else: if st.session_state.query_num < len(messages): msg = messages[0] else: msg = messages[-1] with st.container(): st.title("Evolutionary Scale Modeling") start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty()] start[0].info(msg) option = st.selectbox('Application options', ('self-contact prediction', 'search the database for similar proteins', 'activity prediction with similar proteins', 'PDB viewer')) st.session_state.db_name_ref = 'default.esm_protein' if option == 'self-contact prediction': sequence = st.text_input('protein sequence(Capital letters only)', '') if st.button('Cas9 Enzyme'): sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' elif st.button('PETase'): sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' if sequence: st.write('you have entered: ',sequence) start[2] = st.pyplot(perdict_contact_visualization(sequence, st.session_state['xq'], st.session_state['batch'])) expander = st.expander("See explanation") expander.markdown( """Contact prediction is based on a logistic regression over the model's attention maps. This methodology is based on ICLR 2021 paper, Transformer protein language models are unsupervised structure learners. (Rao et al. 2020)The MSA Transformer (ESM-MSA-1) takes a multiple sequence alignment (MSA) as input, and uses the tied row self-attention maps in the same way. """, unsafe_allow_html=True) elif option == 'search the database for similar proteins': sequence = st.text_input('protein sequence(Capital letters only)', '') st.write('Try an example:') if st.button('Cas9 Enzyme'): sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' elif st.button('PETase'): sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' if sequence: st.write('you have entered: ', sequence) result_temp_seq = esm_search(st.session_state.model, sequence, st.session_state.batch_converter ,top_k=10) st.text('search result (top 5): ') # tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"]) tab1, tab2, tab3 , tab4, tab5 = st.tabs(['1','2','3','4','5']) with tab1: st.write(result_temp_seq[0]) show_protein_structure(result_temp_seq[0]) with tab2: st.write(result_temp_seq[1]) show_protein_structure(result_temp_seq[1]) with tab3: st.write(result_temp_seq[2]) show_protein_structure(result_temp_seq[2]) with tab4: st.write(result_temp_seq[3]) show_protein_structure(result_temp_seq[3]) with tab5: st.write(result_temp_seq[4]) show_protein_structure(result_temp_seq[4]) elif option == 'activity prediction with similar proteins': st.markdown('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') # st.text('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') sequence = st.text_input('protein sequence', '') st.write('Try an example:') if st.button('Cas9 Enzyme'): sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' elif st.button('PETase'): sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' if sequence: st.write('you have entered: ',sequence) res_knn = KNN_search(sequence) st.subheader('KNN predictor result') start[2] = st.markdown("Activity prediction: " + str(res_knn)) elif option == 'PDB viewer': id_PDB = st.text_input('enter PDB ID', '') residues_marker = st.text_input('residues class', '') st.write('Try an example:') if st.button('PDB ID: 1A2C / residues class: ALA'): id_PDB = '1A2C' residues_marker = 'ALA' st.subheader('PDB viewer') if residues_marker: start[7] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) else: start[7] = showmol(render_pdb(id = id_PDB)) expander = st.expander("See explanation") expander.markdown(""" A PDB ID is a unique 4-character code for each entry in the Protein Data Bank. The first character must be a number between 1 and 9, and the remaining three characters can be letters or numbers. see https://www.rcsb.org/ for more information. """)