import json, time, os, sys, glob import gradio as gr sys.path.append("/home/user/app/ProteinMPNN/vanilla_proteinmpnn") sys.path.append("/home/duerr/phd/08_Code/ProteinMPNN/ProteinMPNN/vanilla_proteinmpnn") import matplotlib.pyplot as plt import shutil import warnings import numpy as np import torch from torch import optim from torch.utils.data import DataLoader from torch.utils.data.dataset import random_split, Subset import copy import torch.nn as nn import torch.nn.functional as F import random import os import os.path from protein_mpnn_utils import ( loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB, ) from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN import plotly.express as px import urllib import jax.numpy as jnp import tensorflow as tf if "/home/user/app/af_backprop" not in sys.path: sys.path.append("/home/user/app/af_backprop") # local only if "/home/duerr/phd/08_Code/ProteinMPNN/af_backprop" not in sys.path: sys.path.append("/home/duerr/phd/08_Code/ProteinMPNN/af_backprop") from utils import * # import libraries import colabfold as cf from alphafold.common import protein from alphafold.data import pipeline from alphafold.model import data, config from alphafold.model import model as afmodel from alphafold.common import residue_constants import plotly.graph_objects as go import ray import re import numpy as np import jax tf.config.set_visible_devices([], "GPU") def chain_break(idx_res, Ls, length=200): # Minkyung's code # add big enough number to residue index to indicate chain breaks L_prev = 0 for L_i in Ls[:-1]: idx_res[L_prev + L_i :] += length L_prev += L_i return idx_res def clear_mem(): backend = jax.lib.xla_bridge.get_backend() for buf in backend.live_buffers(): buf.delete() def setup_af(seq, model_name="model_5_ptm"): clear_mem() # setup model cfg = config.model_config("model_5_ptm") cfg.model.num_recycle = 0 cfg.data.common.num_recycle = 0 cfg.data.eval.max_msa_clusters = 1 cfg.data.common.max_extra_msa = 1 cfg.data.eval.masked_msa_replace_fraction = 0 cfg.model.global_config.subbatch_size = None if os.path.exists("/home/duerr"): datadir = "/home/duerr/phd/08_Code/ProteinMPNN" else: datadir = "/home/user/app/" model_params = data.get_model_haiku_params(model_name=model_name, data_dir=datadir) model_runner = afmodel.RunModel(cfg, model_params, is_training=False) Ls = [len(s) for s in seq.split("/")] seq = re.sub("[^A-Z]", "", seq.upper()) length = len(seq) feature_dict = { **pipeline.make_sequence_features( sequence=seq, description="none", num_res=length ), **pipeline.make_msa_features(msas=[[seq]], deletion_matrices=[[[0] * length]]), } feature_dict["residue_index"] = chain_break(feature_dict["residue_index"], Ls) inputs = model_runner.process_features(feature_dict, random_seed=0) def runner(seq, opt): # update sequence inputs = opt["inputs"] inputs.update(opt["prev"]) update_seq(seq, inputs) update_aatype(inputs["target_feat"][..., 1:], inputs) # mask prediction mask = seq.sum(-1) inputs["seq_mask"] = inputs["seq_mask"].at[:].set(mask) inputs["msa_mask"] = inputs["msa_mask"].at[:].set(mask) inputs["residue_index"] = jnp.where(mask == 1, inputs["residue_index"], 0) # get prediction key = jax.random.PRNGKey(0) outputs = model_runner.apply(opt["params"], key, inputs) prev = { "init_msa_first_row": outputs["representations"]["msa_first_row"][None], "init_pair": outputs["representations"]["pair"][None], "init_pos": outputs["structure_module"]["final_atom_positions"][None], } aux = { "final_atom_positions": outputs["structure_module"]["final_atom_positions"], "final_atom_mask": outputs["structure_module"]["final_atom_mask"], "plddt": get_plddt(outputs), "pae": get_pae(outputs), "inputs": inputs, "prev": prev, } return aux return jax.jit(runner), {"inputs": inputs, "params": model_params} def make_tied_positions_for_homomers(pdb_dict_list): my_dict = {} for result in pdb_dict_list: all_chain_list = sorted( [item[-1:] for item in list(result) if item[:9] == "seq_chain"] ) # A, B, C, ... tied_positions_list = [] chain_length = len(result[f"seq_chain_{all_chain_list[0]}"]) for i in range(1, chain_length + 1): temp_dict = {} for j, chain in enumerate(all_chain_list): temp_dict[chain] = [i] # needs to be a list tied_positions_list.append(temp_dict) my_dict[result["name"]] = tied_positions_list return my_dict def align_structures(pdb1, pdb2, lenRes): """Take two structure and superimpose pdb1 on pdb2""" import Bio.PDB import subprocess pdb_parser = Bio.PDB.PDBParser(QUIET=True) # Get the structures ref_structure = pdb_parser.get_structure("samle", pdb1) sample_structure = pdb_parser.get_structure("reference", pdb2) aligner = Bio.PDB.CEAligner() aligner.set_reference(ref_structure) aligner.align(sample_structure) io = Bio.PDB.PDBIO() io.set_structure(ref_structure) io.save(f"reference.pdb") # Doing this to get around biopython CEALIGN bug subprocess.call("pymol -c -Q -r cealign.pml", shell=True) return aligner.rms, "reference.pdb", "out_aligned.pdb" def save_pdb(outs, filename, LEN): """save pdb coordinates""" p = { "residue_index": outs["inputs"]["residue_index"][0][:LEN], "aatype": outs["inputs"]["aatype"].argmax(-1)[0][:LEN], "atom_positions": outs["final_atom_positions"][:LEN], "atom_mask": outs["final_atom_mask"][:LEN], } b_factors = 100.0 * outs["plddt"][:LEN, None] * p["atom_mask"] p = protein.Protein(**p, b_factors=b_factors) pdb_lines = protein.to_pdb(p) with open(filename, "w") as f: f.write(pdb_lines) # @ray.remote(num_gpus=1, max_calls=1) def run_alphafold(sequence, num_recycles): recycles = num_recycles RUNNER, OPT = setup_af(sequence) SEQ = re.sub("[^A-Z]", "", sequence.upper()) MAX_LEN = len(SEQ) LEN = len(SEQ) x = np.array([residue_constants.restype_order.get(aa, -1) for aa in SEQ]) x = np.pad(x, [0, MAX_LEN - LEN], constant_values=-1) x = jax.nn.one_hot(x, 20) OPT["prev"] = { "init_msa_first_row": np.zeros([1, MAX_LEN, 256]), "init_pair": np.zeros([1, MAX_LEN, MAX_LEN, 128]), "init_pos": np.zeros([1, MAX_LEN, 37, 3]), } positions = [] plddts = [] for r in range(recycles + 1): outs = RUNNER(x, OPT) outs = jax.tree_map(lambda x: np.asarray(x), outs) positions.append(outs["prev"]["init_pos"][0, :LEN]) plddts.append(outs["plddt"][:LEN]) OPT["prev"] = outs["prev"] if recycles > 0: print(r, plddts[-1].mean()) save_pdb(outs, "out.pdb", LEN) return plddts, outs["pae"], LEN if os.path.exists("/home/duerr/phd/08_Code/ProteinMPNN"): path_to_model_weights = "/home/duerr/phd/08_Code/ProteinMPNN/ProteinMPNN/vanilla_proteinmpnn/vanilla_model_weights" else: path_to_model_weights = ( "/home/user/app/ProteinMPNN/vanilla_proteinmpnn/vanilla_model_weights" ) def setup_proteinmpnn(model_name="v_48_020", backbone_noise=0.00): device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") # ProteinMPNN model name: v_48_002, v_48_010, v_48_020, v_48_030, v_32_002, v_32_010; v_32_020, v_32_030; v_48_010=version with 48 edges 0.10A noise # Standard deviation of Gaussian noise to add to backbone atoms hidden_dim = 128 num_layers = 3 model_folder_path = path_to_model_weights if model_folder_path[-1] != "/": model_folder_path = model_folder_path + "/" checkpoint_path = model_folder_path + f"{model_name}.pt" checkpoint = torch.load(checkpoint_path, map_location=device) noise_level_print = checkpoint["noise_level"] model = ProteinMPNN( num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint["num_edges"], ) model.to(device) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() return model, device def get_pdb(pdb_code="", filepath=""): if pdb_code is None or pdb_code == "": try: return filepath.name except AttributeError as e: return None else: os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb") return f"{pdb_code}.pdb" def update( inp, file, designed_chain, fixed_chain, homomer, num_seqs, sampling_temp, model_name, backbone_noise, ): pdb_path = get_pdb(pdb_code=inp, filepath=file) if pdb_path == None: return "Error processing PDB" model, device = setup_proteinmpnn( model_name=model_name, backbone_noise=backbone_noise ) if designed_chain == "": designed_chain_list = [] else: designed_chain_list = re.sub("[^A-Za-z]+", ",", designed_chain).split(",") if fixed_chain == "": fixed_chain_list = [] else: fixed_chain_list = re.sub("[^A-Za-z]+", ",", fixed_chain).split(",") chain_list = list(set(designed_chain_list + fixed_chain_list)) num_seq_per_target = num_seqs save_score = 0 # 0 for False, 1 for True; save score=-log_prob to npy files save_probs = ( 0 # 0 for False, 1 for True; save MPNN predicted probabilites per position ) score_only = 0 # 0 for False, 1 for True; score input backbone-sequence pairs conditional_probs_only = 0 # 0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone) conditional_probs_only_backbone = 0 # 0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone) batch_size = 1 # Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory max_length = 20000 # Max sequence length out_folder = "." # Path to a folder to output sequences, e.g. /home/out/ jsonl_path = "" # Path to a folder with parsed pdb into jsonl omit_AAs = "X" # Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine. pssm_multi = 0.0 # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions pssm_threshold = 0.0 # A value between -inf + inf to restric per position AAs pssm_log_odds_flag = 0 # 0 for False, 1 for True pssm_bias_flag = 0 # 0 for False, 1 for True folder_for_outputs = out_folder NUM_BATCHES = num_seq_per_target // batch_size BATCH_COPIES = batch_size temperatures = [sampling_temp] omit_AAs_list = omit_AAs alphabet = "ACDEFGHIKLMNPQRSTVWYX" omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32) chain_id_dict = None fixed_positions_dict = None pssm_dict = None omit_AA_dict = None bias_AA_dict = None bias_by_res_dict = None bias_AAs_np = np.zeros(len(alphabet)) ############################################################### pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list) dataset_valid = StructureDatasetPDB( pdb_dict_list, truncate=None, max_length=max_length ) if homomer: tied_positions_dict = make_tied_positions_for_homomers(pdb_dict_list) else: tied_positions_dict = None chain_id_dict = {} chain_id_dict[pdb_dict_list[0]["name"]] = (designed_chain_list, fixed_chain_list) with torch.no_grad(): for ix, prot in enumerate(dataset_valid): score_list = [] all_probs_list = [] all_log_probs_list = [] S_sample_list = [] batch_clones = [copy.deepcopy(prot) for i in range(BATCH_COPIES)] ( X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta, ) = tied_featurize( batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict, ) pssm_log_odds_mask = ( pssm_log_odds_all > pssm_threshold ).float() # 1.0 for true, 0.0 for false name_ = batch_clones[0]["name"] randn_1 = torch.randn(chain_M.shape, device=X.device) log_probs = model( X, S, mask, chain_M * chain_M_pos, residue_idx, chain_encoding_all, randn_1, ) mask_for_loss = mask * chain_M * chain_M_pos scores = _scores(S, log_probs, mask_for_loss) native_score = scores.cpu().data.numpy() message = "" seq_list = [] for temp in temperatures: for j in range(NUM_BATCHES): randn_2 = torch.randn(chain_M.shape, device=X.device) if tied_positions_dict == None: sample_dict = model.sample( X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all, ) S_sample = sample_dict["S"] else: sample_dict = model.tied_sample( X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all, ) # Compute scores S_sample = sample_dict["S"] log_probs = model( X, S_sample, mask, chain_M * chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict["decoding_order"], ) mask_for_loss = mask * chain_M * chain_M_pos scores = _scores(S_sample, log_probs, mask_for_loss) scores = scores.cpu().data.numpy() all_probs_list.append(sample_dict["probs"].cpu().data.numpy()) all_log_probs_list.append(log_probs.cpu().data.numpy()) S_sample_list.append(S_sample.cpu().data.numpy()) for b_ix in range(BATCH_COPIES): masked_chain_length_list = masked_chain_length_list_list[b_ix] masked_list = masked_list_list[b_ix] seq_recovery_rate = torch.sum( torch.sum( torch.nn.functional.one_hot(S[b_ix], 21) * torch.nn.functional.one_hot(S_sample[b_ix], 21), axis=-1, ) * mask_for_loss[b_ix] ) / torch.sum(mask_for_loss[b_ix]) seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix]) score = scores[b_ix] score_list.append(score) native_seq = _S_to_seq(S[b_ix], chain_M[b_ix]) if b_ix == 0 and j == 0 and temp == temperatures[0]: start = 0 end = 0 list_of_AAs = [] for mask_l in masked_chain_length_list: end += mask_l list_of_AAs.append(native_seq[start:end]) start = end native_seq = "".join( list(np.array(list_of_AAs)[np.argsort(masked_list)]) ) l0 = 0 for mc_length in list( np.array(masked_chain_length_list)[ np.argsort(masked_list) ] )[:-1]: l0 += mc_length native_seq = native_seq[:l0] + "/" + native_seq[l0:] l0 += 1 sorted_masked_chain_letters = np.argsort( masked_list_list[0] ) print_masked_chains = [ masked_list_list[0][i] for i in sorted_masked_chain_letters ] sorted_visible_chain_letters = np.argsort( visible_list_list[0] ) print_visible_chains = [ visible_list_list[0][i] for i in sorted_visible_chain_letters ] native_score_print = np.format_float_positional( np.float32(native_score.mean()), unique=False, precision=4, ) line = ">{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\n{}\n".format( name_, native_score_print, print_visible_chains, print_masked_chains, model_name, native_seq, ) message += f"{line}\n" start = 0 end = 0 list_of_AAs = [] for mask_l in masked_chain_length_list: end += mask_l list_of_AAs.append(seq[start:end]) start = end seq = "".join( list(np.array(list_of_AAs)[np.argsort(masked_list)]) ) # add non designed chains to predicted sequence l0 = 0 for mc_length in list( np.array(masked_chain_length_list)[np.argsort(masked_list)] )[:-1]: l0 += mc_length seq = seq[:l0] + "/" + seq[l0:] l0 += 1 score_print = np.format_float_positional( np.float32(score), unique=False, precision=4 ) seq_rec_print = np.format_float_positional( np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4, ) chain_s = "" if len(visible_list_list[0]) > 0: chain_M_bool = chain_M.bool() not_designed = _S_to_seq(S[b_ix], ~chain_M_bool[b_ix]) labels = ( chain_encoding_all[b_ix][~chain_M_bool[b_ix]] .detach() .cpu() .numpy() ) for c in set(labels): chain_s += "/" nd_mask = labels == c for i, x in enumerate(not_designed): if nd_mask[i]: chain_s += x line = ( ">T={}, sample={}, score={}, seq_recovery={}\n{}\n".format( temp, b_ix, score_print, seq_rec_print, seq ) ) seq_list.append(seq + chain_s) message += f"{line}\n" # somehow sequences still contain X, remove again for i, x in enumerate(seq_list): for aa in omit_AAs: seq_list[i] = x.replace(aa, "") all_probs_concat = np.concatenate(all_probs_list) all_log_probs_concat = np.concatenate(all_log_probs_list) np.savetxt("all_probs_concat.csv", all_probs_concat.mean(0).T, delimiter=",") np.savetxt( "all_log_probs_concat.csv", np.exp(all_log_probs_concat).mean(0).T, delimiter=",", ) S_sample_concat = np.concatenate(S_sample_list) fig = px.imshow( np.exp(all_log_probs_concat).mean(0).T, labels=dict(x="positions", y="amino acids", color="probability"), y=list(alphabet), template="simple_white", ) fig.update_xaxes(side="top") fig_tadjusted = px.imshow( all_probs_concat.mean(0).T, labels=dict(x="positions", y="amino acids", color="probability"), y=list(alphabet), template="simple_white", ) fig_tadjusted.update_xaxes(side="top") return ( message, fig, fig_tadjusted, gr.File.update(value="all_log_probs_concat.csv", visible=True), gr.File.update(value="all_probs_concat.csv", visible=True), pdb_path, gr.Dropdown.update(choices=seq_list), ) def update_AF(startsequence, pdb, num_recycles): # # run alphafold using ray plddts, pae, num_res = run_alphafold( startsequence, num_recycles ) # ray.get(run_alphafold.remote(startsequence)) x = np.arange(10) plots = [] for recycle, plddts_val in enumerate(plddts): if recycle == 0 or recycle == len(plddts) - 1: visible = True else: visible = "legendonly" plots.append( go.Scatter( x=np.arange(len(plddts_val)), y=plddts_val, hovertemplate="pLDDT: %{y:.2f}
Residue index: %{x}
Recycle " + str(recycle), name=f"Recycle {recycle}", visible=visible, ) ) plotAF_plddt = go.Figure(data=plots) plotAF_plddt.update_layout( title="pLDDT", xaxis_title="Residue index", yaxis_title="pLDDT", height=500, template="simple_white", legend=dict(yanchor="bottom", y=0.01, xanchor="left", x=0.99), ) plt.figure() plt.title("Predicted Aligned Error") Ln = pae.shape[0] plt.imshow(pae, cmap="bwr", vmin=0, vmax=30, extent=(0, Ln, Ln, 0)) plt.colorbar() plt.xlabel("Scored residue") plt.ylabel("Aligned residue") # doesnt work (likely because too large) # plotAF_pae = px.imshow( # pae, # labels=dict(x="Scored residue", y="Aligned residue", color=""), # template="simple_white", # y=np.arange(len(plddts)), # ) # plotAF_pae.write_html("test.html") # plotAF_pae.update_layout(title="Predicted Aligned Error", template="simple_white") return molecule(pdb, "out.pdb", num_res), plotAF_plddt, plt def read_mol(molpath): with open(molpath, "r") as fp: lines = fp.readlines() mol = "" for l in lines: mol += l return mol def molecule(pdb, afpdb, num_res): rms, input_pdb, aligned_pdb = align_structures(pdb, afpdb, num_res) mol = read_mol(input_pdb) pred_mol = read_mol(aligned_pdb) x = ( """
RMSD AlphaFold vs. native: """ + f"{rms:.2f}" + """Å
AlphaFold model confidence:
 Very high (pLDDT > 90)
 Confident (90 > pLDDT > 70)
 Low (70 > pLDDT > 50)
 Very low (pLDDT < 50)
AlphaFold produces a per-residue confidence score (pLDDT) between 0 and 100. Some regions below 50 pLDDT may be unstructured in isolation.
""" ) return f"""""" def set_examples(example): label, inp, designed_chain, fixed_chain, homomer, num_seqs, sampling_temp = example return [ label, inp, designed_chain, fixed_chain, homomer, gr.Slider.update(value=num_seqs), gr.Radio.update(value=sampling_temp), ] proteinMPNN = gr.Blocks() with proteinMPNN: gr.Markdown("# ProteinMPNN") gr.Markdown( """This model takes as input a protein structure and based on its backbone predicts new sequences that will fold into that backbone. Optionally, we can run AlphaFold2 on the predicted sequence to check whether the predicted sequences adopt the same backbone (WIP). """ ) gr.Markdown("![](https://simonduerr.eu/ProteinMPNN.png)") with gr.Tabs(): with gr.TabItem("Input"): inp = gr.Textbox( placeholder="PDB Code or upload file below", label="Input structure" ) file = gr.File(file_count="single") with gr.TabItem("Settings"): with gr.Row(): designed_chain = gr.Textbox(value="A", label="Designed chain") fixed_chain = gr.Textbox( placeholder="Use commas to fix multiple chains", label="Fixed chain" ) with gr.Row(): num_seqs = gr.Slider( minimum=1, maximum=50, value=1, step=1, label="Number of sequences" ) sampling_temp = gr.Radio( choices=[0.1, 0.15, 0.2, 0.25, 0.3], value=0.1, label="Sampling temperature", ) with gr.Row(): model_name = gr.Dropdown( choices=[ "v_48_002", "v_48_010", "v_48_020", "v_48_030", ], label="Model", value="v_48_020", ) backbone_noise = gr.Dropdown( choices=[0, 0.02, 0.10, 0.20, 0.30], label="Backbone noise", value=0 ) with gr.Row(): homomer = gr.Checkbox(value=False, label="Homomer?") gr.Markdown( "for correct symmetric tying lenghts of homomer chains should be the same" ) btn = gr.Button("Run") label = gr.Textbox(label="Label", visible=False) examples = gr.Dataset( components=[ label, inp, designed_chain, fixed_chain, homomer, num_seqs, sampling_temp, ], samples=[ ["Homomer design", "1O91", "A,B,C", "", True, 2, 0.1], ["Monomer design", "6MRR", "A", "", False, 2, 0.1], ["Redesign of Homomer to Heteromer", "3HTN", "A,B", "C", False, 2, 0.1], ], ) gr.Markdown( """ Sampling temperature for amino acids, `T=0.0` means taking argmax, `T>>1.0` means sample randomly. Suggested values `0.1, 0.15, 0.2, 0.25, 0.3`. Higher values will lead to more diversity. """ ) gr.Markdown("# Output") with gr.Tabs(): with gr.TabItem("Designed sequences"): out = gr.Textbox(label="Status") with gr.TabItem("Amino acid probabilities"): plot = gr.Plot() all_log_probs = gr.File(visible=False) with gr.TabItem("T adjusted probabilities"): gr.Markdown("Sampling temperature adjusted amino acid probabilties") plot_tadjusted = gr.Plot() all_probs = gr.File(visible=False) with gr.TabItem("Structure validation w/ AF2"): gr.HTML( """

AF2 code is experimental and relies on @sokrypton's trick to speed up compile/module runtime. Results might differ from DeepMind's published results. Predictions are made using model_5_ptm and without MSA based on the selected single sequence (designed_chain + fixed_chain).

""" ) with gr.Row(): with gr.Row(): chosen_seq = gr.Dropdown( choices=[], label="Select a sequence for validation" ) num_recycles = gr.Dropdown( choices=[0, 1, 3, 5], value=3, label="num Recycles" ) btnAF = gr.Button("Run AF2 on sequence") with gr.Row(): mol = gr.HTML() with gr.Column(): plotAF_plddt = gr.Plot(label="pLDDT") # remove maxh80 class from css plotAF_pae = gr.Plot(label="PAE") tempFile = gr.Variable() btn.click( fn=update, inputs=[ inp, file, designed_chain, fixed_chain, homomer, num_seqs, sampling_temp, model_name, backbone_noise, ], outputs=[ out, plot, plot_tadjusted, all_log_probs, all_probs, tempFile, chosen_seq, ], ) btnAF.click( fn=update_AF, inputs=[chosen_seq, tempFile, num_recycles], outputs=[mol, plotAF_plddt, plotAF_pae], ) examples.click(fn=set_examples, inputs=examples, outputs=examples.components) gr.Markdown( """Citation: **Robust deep learning based protein sequence design using ProteinMPNN**
Justas Dauparas, Ivan Anishchenko, Nathaniel Bennett, Hua Bai, Robert J. Ragotte, Lukas F. Milles, Basile I. M. Wicky, Alexis Courbet, Robbert J. de Haas, Neville Bethel, Philip J. Y. Leung, Timothy F. Huddy, Sam Pellock, Doug Tischer, Frederick Chan, Brian Koepnick, Hannah Nguyen, Alex Kang, Banumathi Sankaran, Asim Bera, Neil P. King, David Baker
bioRxiv 2022.06.03.494563; doi: [10.1101/2022.06.03.494563](https://doi.org/10.1101/2022.06.03.494563)

Server built by [@simonduerr](https://twitter.com/simonduerr) and hosted by Huggingface""" ) # ray.init(runtime_env={"working_dir": "./af_backprop"}) proteinMPNN.launch(share=True, debug=True)