{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "AYZebfKn8gef" }, "source": [ "#ProteinMPNN w/AF2\n", "This notebook is intended as a quick demo, more features to come!\n", "\n", "Examples: \n", "1. pdb: `6MRR`, homomer: `False`, designed_chain: `A`\n", "2. pdb: `1X2I`, homomer: `True`, designed_chain: `A,B` \n", " (for correct symmetric tying lenghts of homomer chains should be the same)" ] }, { "cell_type": "code", "source": [ "#@title Setup ProteinMPNN\n", "import warnings\n", "warnings.simplefilter(action='ignore', category=FutureWarning)\n", "\n", "import json, time, os, sys, glob, re\n", "from google.colab import files\n", "import numpy as np\n", "\n", "if not os.path.isdir(\"ProteinMPNN\"):\n", " os.system(\"git clone -q https://github.com/dauparas/ProteinMPNN.git\")\n", "\n", "if \"ProteinMPNN\" not in sys.path:\n", " sys.path.append('/content/ProteinMPNN')\n", "\n", "import matplotlib.pyplot as plt\n", "import shutil\n", "import warnings\n", "import torch\n", "from torch import optim\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data.dataset import random_split, Subset\n", "import copy\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import random\n", "import os.path\n", "from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB\n", "from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN\n", "\n", "device = torch.device(\"cpu\")\n", "#v_48_010=version with 48 edges 0.10A noise\n", "model_name = \"v_48_020\" #@param [\"v_48_002\", \"v_48_010\", \"v_48_020\", \"v_48_030\"]\n", "\n", "\n", "backbone_noise=0.00 # Standard deviation of Gaussian noise to add to backbone atoms\n", "\n", "path_to_model_weights='/content/ProteinMPNN/vanilla_model_weights' \n", "hidden_dim = 128\n", "num_layers = 3 \n", "model_folder_path = path_to_model_weights\n", "if model_folder_path[-1] != '/':\n", " model_folder_path = model_folder_path + '/'\n", "checkpoint_path = model_folder_path + f'{model_name}.pt'\n", "\n", "checkpoint = torch.load(checkpoint_path, map_location=device) \n", "print('Number of edges:', checkpoint['num_edges'])\n", "noise_level_print = checkpoint['noise_level']\n", "print(f'Training noise level: {noise_level_print}A')\n", "model = ProteinMPNN(num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])\n", "model.to(device)\n", "model.load_state_dict(checkpoint['model_state_dict'])\n", "model.eval()\n", "print(\"Model loaded\")\n", "\n", "def make_tied_positions_for_homomers(pdb_dict_list):\n", " my_dict = {}\n", " for result in pdb_dict_list:\n", " all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ...\n", " tied_positions_list = []\n", " chain_length = len(result[f\"seq_chain_{all_chain_list[0]}\"])\n", " for i in range(1,chain_length+1):\n", " temp_dict = {}\n", " for j, chain in enumerate(all_chain_list):\n", " temp_dict[chain] = [i] #needs to be a list\n", " tied_positions_list.append(temp_dict)\n", " my_dict[result['name']] = tied_positions_list\n", " return my_dict\n", "\n", "#########################\n", "def get_pdb(pdb_code=\"\"):\n", " if pdb_code is None or pdb_code == \"\":\n", " upload_dict = files.upload()\n", " pdb_string = upload_dict[list(upload_dict.keys())[0]]\n", " with open(\"tmp.pdb\",\"wb\") as out: out.write(pdb_string)\n", " return \"tmp.pdb\"\n", " else:\n", " os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n", " return f\"{pdb_code}.pdb\"" ], "metadata": { "id": "2nKSlaMlSpcf", "cellView": "form" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "xMVlYh8Fv2of" }, "outputs": [], "source": [ "#@title #Run ProteinMPNN\n", "\n", "#@markdown #### Input Options\n", "pdb='6MRR' #@param {type:\"string\"}\n", "pdb = pdb.replace(\" \",\"\")\n", "pdb_path = get_pdb(pdb)\n", "#@markdown - pdb code (leave blank to get an upload prompt)\n", "\n", "homomer = False #@param {type:\"boolean\"}\n", "designed_chain = \"A\" #@param {type:\"string\"}\n", "fixed_chain = \"\" #@param {type:\"string\"}\n", "\n", "if designed_chain == \"\":\n", " designed_chain_list = []\n", "else:\n", " designed_chain_list = re.sub(\"[^A-Za-z]+\",\",\", designed_chain).split(\",\")\n", "\n", "if fixed_chain == \"\":\n", " fixed_chain_list = []\n", "else:\n", " fixed_chain_list = re.sub(\"[^A-Za-z]+\",\",\", fixed_chain).split(\",\")\n", "\n", "chain_list = list(set(designed_chain_list + fixed_chain_list))\n", "\n", "#@markdown - specified which chain(s) to design and which chain(s) to keep fixed. \n", "#@markdown Use comma:`A,B` to specifiy more than one chain\n", "\n", "#chain = \"A\" #@param {type:\"string\"}\n", "#pdb_path_chains = chain\n", "##@markdown - Define which chain to redesign\n", "\n", "#@markdown #### Design Options\n", "num_seqs = 8 #@param [\"1\", \"2\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n", "num_seq_per_target = num_seqs\n", "\n", "#@markdown - Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.\n", "sampling_temp = \"0.1\" #@param [\"0.0001\", \"0.1\", \"0.15\", \"0.2\", \"0.25\", \"0.3\", \"0.5\"]\n", "\n", "\n", "\n", "save_score=0 # 0 for False, 1 for True; save score=-log_prob to npy files\n", "save_probs=0 # 0 for False, 1 for True; save MPNN predicted probabilites per position\n", "score_only=0 # 0 for False, 1 for True; score input backbone-sequence pairs\n", "conditional_probs_only=0 # 0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone)\n", "conditional_probs_only_backbone=0 # 0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone)\n", " \n", "batch_size=1 # Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory\n", "max_length=20000 # Max sequence length\n", " \n", "out_folder='.' # Path to a folder to output sequences, e.g. /home/out/\n", "jsonl_path='' # Path to a folder with parsed pdb into jsonl\n", "omit_AAs='X' # Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine.\n", " \n", "pssm_multi=0.0 # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions\n", "pssm_threshold=0.0 # A value between -inf + inf to restric per position AAs\n", "pssm_log_odds_flag=0 # 0 for False, 1 for True\n", "pssm_bias_flag=0 # 0 for False, 1 for True\n", "\n", "\n", "##############################################################\n", "\n", "folder_for_outputs = out_folder\n", "\n", "NUM_BATCHES = num_seq_per_target//batch_size\n", "BATCH_COPIES = batch_size\n", "temperatures = [float(item) for item in sampling_temp.split()]\n", "omit_AAs_list = omit_AAs\n", "alphabet = 'ACDEFGHIKLMNPQRSTVWYX'\n", "\n", "omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)\n", "\n", "chain_id_dict = None\n", "fixed_positions_dict = None\n", "pssm_dict = None\n", "omit_AA_dict = None\n", "bias_AA_dict = None\n", "tied_positions_dict = None\n", "bias_by_res_dict = None\n", "bias_AAs_np = np.zeros(len(alphabet))\n", "\n", "\n", "###############################################################\n", "pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list)\n", "dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)\n", "\n", "chain_id_dict = {}\n", "chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)\n", "\n", "print(chain_id_dict)\n", "for chain in chain_list:\n", " l = len(pdb_dict_list[0][f\"seq_chain_{chain}\"])\n", " print(f\"Length of chain {chain} is {l}\")\n", "\n", "if homomer:\n", " tied_positions_dict = make_tied_positions_for_homomers(pdb_dict_list)\n", "else:\n", " tied_positions_dict = None\n", "\n", "#################################################################\n", "sequences = []\n", "with torch.no_grad():\n", " print('Generating sequences...')\n", " for ix, protein in enumerate(dataset_valid):\n", " score_list = []\n", " all_probs_list = []\n", " all_log_probs_list = []\n", " S_sample_list = []\n", " batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]\n", " X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict)\n", " pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false\n", " name_ = batch_clones[0]['name']\n", "\n", " randn_1 = torch.randn(chain_M.shape, device=X.device)\n", " log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)\n", " mask_for_loss = mask*chain_M*chain_M_pos\n", " scores = _scores(S, log_probs, mask_for_loss)\n", " native_score = scores.cpu().data.numpy()\n", "\n", " for temp in temperatures:\n", " for j in range(NUM_BATCHES):\n", " randn_2 = torch.randn(chain_M.shape, device=X.device)\n", " if tied_positions_dict == None:\n", " sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all)\n", " S_sample = sample_dict[\"S\"] \n", " else:\n", " sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all)\n", " # Compute scores\n", " S_sample = sample_dict[\"S\"]\n", " log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict[\"decoding_order\"])\n", " mask_for_loss = mask*chain_M*chain_M_pos\n", " scores = _scores(S_sample, log_probs, mask_for_loss)\n", " scores = scores.cpu().data.numpy()\n", " all_probs_list.append(sample_dict[\"probs\"].cpu().data.numpy())\n", " all_log_probs_list.append(log_probs.cpu().data.numpy())\n", " S_sample_list.append(S_sample.cpu().data.numpy())\n", " for b_ix in range(BATCH_COPIES):\n", " masked_chain_length_list = masked_chain_length_list_list[b_ix]\n", " masked_list = masked_list_list[b_ix]\n", " seq_recovery_rate = torch.sum(torch.sum(torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix])\n", " seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])\n", " score = scores[b_ix]\n", " score_list.append(score)\n", " native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])\n", " if b_ix == 0 and j==0 and temp==temperatures[0]:\n", " start = 0\n", " end = 0\n", " list_of_AAs = []\n", " for mask_l in masked_chain_length_list:\n", " end += mask_l\n", " list_of_AAs.append(native_seq[start:end])\n", " start = end\n", " native_seq = \"\".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))\n", " l0 = 0\n", " for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:\n", " l0 += mc_length\n", " native_seq = native_seq[:l0] + '/' + native_seq[l0:]\n", " l0 += 1\n", " sorted_masked_chain_letters = np.argsort(masked_list_list[0])\n", " print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]\n", " sorted_visible_chain_letters = np.argsort(visible_list_list[0])\n", " print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters]\n", " native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4)\n", " line = '>{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\\n{}\\n'.format(name_, native_score_print, print_visible_chains, print_masked_chains, model_name, native_seq)\n", " print(line.rstrip())\n", " start = 0\n", " end = 0\n", " list_of_AAs = []\n", " for mask_l in masked_chain_length_list:\n", " end += mask_l\n", " list_of_AAs.append(seq[start:end])\n", " start = end\n", "\n", " seq = \"\".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))\n", " l0 = 0\n", " for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:\n", " l0 += mc_length\n", " seq = seq[:l0] + '/' + seq[l0:]\n", " l0 += 1\n", " score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)\n", " seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)\n", " line = '>T={}, sample={}, score={}, seq_recovery={}\\n{}\\n'.format(temp,b_ix,score_print,seq_rec_print,seq)\n", " sequences.append(seq)\n", " print(line.rstrip())\n", "\n", "\n", "all_probs_concat = np.concatenate(all_probs_list)\n", "all_log_probs_concat = np.concatenate(all_log_probs_list)\n", "S_sample_concat = np.concatenate(S_sample_list)" ] }, { "cell_type": "markdown", "source": [ "# Predict with AlphaFold2 (with single-sequence input)" ], "metadata": { "id": "5mQ4VLG1dPsd" } }, { "cell_type": "code", "source": [ "#@title Setup AlphaFold\n", "\n", "# import libraries\n", "from IPython.utils import io\n", "import os,sys,re\n", "\n", "if \"af_backprop\" not in sys.path:\n", " import tensorflow as tf\n", " import jax\n", " import jax.numpy as jnp\n", " import numpy as np\n", " import matplotlib\n", " from matplotlib import animation\n", " import matplotlib.pyplot as plt\n", " from IPython.display import HTML\n", " import tqdm.notebook\n", " TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n", "\n", " with io.capture_output() as captured:\n", " # install ALPHAFOLD\n", " if not os.path.isdir(\"af_backprop\"):\n", " %shell git clone https://github.com/sokrypton/af_backprop.git\n", " %shell pip -q install biopython dm-haiku ml-collections py3Dmol\n", " %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py\n", " if not os.path.isdir(\"params\"):\n", " %shell mkdir params\n", " %shell curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", "\n", " if not os.path.exists(\"MMalign\"):\n", " # install MMalign\n", " os.system(\"wget -qnc https://zhanggroup.org/MM-align/bin/module/MMalign.cpp\")\n", " os.system(\"g++ -static -O3 -ffast-math -o MMalign MMalign.cpp\")\n", "\n", " def mmalign(pdb_a,pdb_b):\n", " # pass to MMalign\n", " output = os.popen(f'./MMalign {pdb_a} {pdb_b}')\n", " # parse outputs\n", " parse_float = lambda x: float(x.split(\"=\")[1].split()[0])\n", " tms = []\n", " for line in output:\n", " line = line.rstrip()\n", " if line.startswith(\"TM-score\"): tms.append(parse_float(line))\n", " return tms\n", "\n", " # configure which device to use\n", " try:\n", " # check if TPU is available\n", " import jax.tools.colab_tpu\n", " jax.tools.colab_tpu.setup_tpu()\n", " print('Running on TPU')\n", " DEVICE = \"tpu\"\n", " except:\n", " if jax.local_devices()[0].platform == 'cpu':\n", " print(\"WARNING: no GPU detected, will be using CPU\")\n", " DEVICE = \"cpu\"\n", " else:\n", " print('Running on GPU')\n", " DEVICE = \"gpu\"\n", " # disable GPU on tensorflow\n", " tf.config.set_visible_devices([], 'GPU')\n", "\n", " # import libraries\n", " sys.path.append('af_backprop')\n", " from utils import update_seq, update_aatype, get_plddt, get_pae\n", " import colabfold as cf\n", " from alphafold.common import protein as alphafold_protein\n", " from alphafold.data import pipeline\n", " from alphafold.model import data, config\n", " from alphafold.common import residue_constants\n", " from alphafold.model import model as alphafold_model\n", "\n", "# custom functions\n", "def clear_mem():\n", " backend = jax.lib.xla_bridge.get_backend()\n", " for buf in backend.live_buffers(): buf.delete()\n", "\n", "def setup_model(max_len):\n", " clear_mem()\n", "\n", " # setup model\n", " cfg = config.model_config(\"model_3_ptm\")\n", " cfg.model.num_recycle = 0\n", " cfg.data.common.num_recycle = 0\n", " cfg.data.eval.max_msa_clusters = 1\n", " cfg.data.common.max_extra_msa = 1\n", " cfg.data.eval.masked_msa_replace_fraction = 0\n", " cfg.model.global_config.subbatch_size = None\n", "\n", " # get params\n", " model_param = data.get_model_haiku_params(model_name=\"model_3_ptm\", data_dir=\".\")\n", " model_runner = alphafold_model.RunModel(cfg, model_param, is_training=False, recycle_mode=\"none\")\n", "\n", " model_params = []\n", " for k in [1,2,3,4,5]:\n", " if k == 3:\n", " model_params.append(model_param)\n", " else:\n", " params = data.get_model_haiku_params(model_name=f\"model_{k}_ptm\", data_dir=\".\")\n", " model_params.append({k: params[k] for k in model_runner.params.keys()})\n", "\n", " seq = \"A\" * max_len\n", " length = len(seq)\n", " feature_dict = {\n", " **pipeline.make_sequence_features(sequence=seq, description=\"none\", num_res=length),\n", " **pipeline.make_msa_features(msas=[[seq]], deletion_matrices=[[[0]*length]])\n", " }\n", " inputs = model_runner.process_features(feature_dict,random_seed=0)\n", "\n", " def runner(I, params):\n", " # update sequence\n", " inputs = I[\"inputs\"]\n", " inputs.update(I[\"prev\"])\n", "\n", " seq = jax.nn.one_hot(I[\"seq\"],20)\n", " update_seq(seq, inputs)\n", " update_aatype(inputs[\"target_feat\"][...,1:], inputs)\n", "\n", " # mask prediction\n", " mask = jnp.arange(inputs[\"residue_index\"].shape[0]) < I[\"length\"]\n", " inputs[\"seq_mask\"] = inputs[\"seq_mask\"].at[:].set(mask)\n", " inputs[\"msa_mask\"] = inputs[\"msa_mask\"].at[:].set(mask)\n", " inputs[\"residue_index\"] = jnp.where(mask, inputs[\"residue_index\"], 0)\n", "\n", " # get prediction\n", " key = jax.random.PRNGKey(0)\n", " outputs = model_runner.apply(params, key, inputs)\n", "\n", " prev = {\"init_msa_first_row\":outputs['representations']['msa_first_row'][None],\n", " \"init_pair\":outputs['representations']['pair'][None],\n", " \"init_pos\":outputs['structure_module']['final_atom_positions'][None]}\n", " \n", " aux = {\"final_atom_positions\":outputs[\"structure_module\"][\"final_atom_positions\"],\n", " \"final_atom_mask\":outputs[\"structure_module\"][\"final_atom_mask\"],\n", " \"plddt\":get_plddt(outputs),\"pae\":get_pae(outputs),\n", " \"length\":I[\"length\"], \"seq\":I[\"seq\"], \"prev\":prev,\n", " \"residue_idx\":inputs[\"residue_index\"][0]}\n", " return aux\n", "\n", " return jax.jit(runner), model_params, {\"inputs\":inputs, \"length\":max_length}\n", "\n", "def save_pdb(outs, filename, Ls=None):\n", " '''save pdb coordinates'''\n", " p = {\"residue_index\":outs[\"residue_idx\"] + 1,\n", " \"aatype\":outs[\"seq\"],\n", " \"atom_positions\":outs[\"final_atom_positions\"],\n", " \"atom_mask\":outs[\"final_atom_mask\"],\n", " \"plddt\":outs[\"plddt\"]}\n", " p = jax.tree_map(lambda x:x[:outs[\"length\"]], p)\n", " b_factors = 100 * p.pop(\"plddt\")[:,None] * p[\"atom_mask\"]\n", " p = alphafold_protein.Protein(**p,b_factors=b_factors)\n", " pdb_lines = alphafold_protein.to_pdb(p)\n", " with open(filename, 'w') as f:\n", " f.write(pdb_lines)\n", " if Ls is not None:\n", " pdb_lines = cf.read_pdb_renum(filename, Ls)\n", " with open(filename, 'w') as f:\n", " f.write(pdb_lines)" ], "metadata": { "cellView": "form", "id": "4ZBUThXU7yY8" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title Run AlphaFold\n", "num_models = 1 #@param [\"1\",\"2\",\"3\",\"4\",\"5\"] {type:\"raw\"}\n", "num_recycles = 1 #@param [\"0\",\"1\",\"2\",\"3\"] {type:\"raw\"}\n", "num_sequences = len(sequences)\n", "outs = []\n", "positions = []\n", "plddts = []\n", "paes = []\n", "LS = []\n", "\n", "with tqdm.notebook.tqdm(total=(num_recycles + 1) * num_models * num_sequences, bar_format=TQDM_BAR_FORMAT) as pbar:\n", " print(f\"seq_num model_num avg_pLDDT avg_pAE TMscore\")\n", " for s,ori_sequence in enumerate(sequences):\n", " Ls = [len(s) for s in ori_sequence.replace(\":\",\"/\").split(\"/\")]\n", " LS.append(Ls)\n", " sequence = re.sub(\"[^A-Z]\",\"\",ori_sequence)\n", " length = len(sequence)\n", "\n", " # avoid recompiling if length within 25\n", " if \"max_len\" not in dir() or length > max_len or (max_len - length) > 25:\n", " max_len = length + 25\n", " runner, params, I = setup_model(max_len)\n", "\n", " outs.append([])\n", " positions.append([])\n", " plddts.append([])\n", " paes.append([])\n", "\n", " r = -1\n", " # pad sequence to max length\n", " seq = np.array([residue_constants.restype_order.get(aa,0) for aa in sequence])\n", " seq = np.pad(seq,[0,max_len-length],constant_values=-1)\n", " I[\"inputs\"]['residue_index'][:] = cf.chain_break(np.arange(max_len), Ls, length=32)\n", " I.update({\"seq\":seq, \"length\":length})\n", " \n", " # for each model\n", " for n in range(num_models):\n", " # restart recycle\n", " I[\"prev\"] = {'init_msa_first_row': np.zeros([1, max_len, 256]),\n", " 'init_pair': np.zeros([1, max_len, max_len, 128]),\n", " 'init_pos': np.zeros([1, max_len, 37, 3])}\n", " for r in range(num_recycles + 1):\n", " O = runner(I, params[n])\n", " O = jax.tree_map(lambda x:np.asarray(x), O)\n", " I[\"prev\"] = O[\"prev\"]\n", " pbar.update(1)\n", " \n", " positions[-1].append(O[\"final_atom_positions\"][:length])\n", " plddts[-1].append(O[\"plddt\"][:length])\n", " paes[-1].append(O[\"pae\"][:length,:length])\n", " outs[-1].append(O)\n", " save_pdb(outs[-1][-1], f\"out_seq_{s}_model_{n}.pdb\", Ls=LS[-1])\n", " tmscores = mmalign(pdb_path, f\"out_seq_{s}_model_{n}.pdb\")\n", " print(f\"{s} {n}\\t{plddts[-1][-1].mean():.3}\\t{paes[-1][-1].mean():.3}\\t{tmscores[-1]:.3}\")" ], "metadata": { "cellView": "form", "id": "p2uNokqudTSH" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title Display 3D structure {run: \"auto\"}\n", "#@markdown #### select which sequence to show (if more than one designed example)\n", "seq_num = 0 #@param [\"0\",\"1\",\"2\",\"3\",\"4\",\"5\",\"6\",\"7\"] {type:\"raw\"}\n", "assert seq_num < len(outs), f\"ERROR: seq_num ({seq_num}) exceeds number of designed sequences ({num_sequences})\"\n", "model_num = 0 #@param [\"0\",\"1\",\"2\",\"3\",\"4\"] {type:\"raw\"}\n", "assert model_num < len(outs[0]), f\"ERROR: model_num ({num_models}) exceeds number of model params used ({num_models})\"\n", "#@markdown #### options\n", "\n", "color = \"confidence\" #@param [\"chain\", \"confidence\", \"rainbow\"]\n", "if color == \"confidence\": color = \"lDDT\"\n", "show_sidechains = False #@param {type:\"boolean\"}\n", "show_mainchains = False #@param {type:\"boolean\"}\n", "\n", "v = cf.show_pdb(f\"out_seq_{seq_num}_model_{model_num}.pdb\", show_sidechains, show_mainchains, color,\n", " color_HP=True, size=(800,480), Ls=LS[seq_num]) \n", "v.setHoverable({}, True,\n", " '''function(atom,viewer,event,container){if(!atom.label){atom.label=viewer.addLabel(\" \"+atom.resn+\":\"+atom.resi,{position:atom,backgroundColor:'mintcream',fontColor:'black'});}}''',\n", " '''function(atom,viewer){if(atom.label){viewer.removeLabel(atom.label);delete atom.label;}}''')\n", "v.show() \n", "if color == \"lDDT\":\n", " cf.plot_plddt_legend().show()\n", "\n", "# add confidence plots\n", "cf.plot_confidence(plddts[seq_num][model_num]*100, paes[seq_num][model_num], Ls=LS[seq_num]).show()" ], "metadata": { "cellView": "form", "id": "0TNhcwok8d_w" }, "execution_count": null, "outputs": [] } ], "metadata": { "colab": { "name": "quickdemo_wAF2.ipynb", "provenance": [], "include_colab_link": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "gpuClass": "standard" }, "nbformat": 4, "nbformat_minor": 0 }