{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AYZebfKn8gef"
},
"source": [
"#ProteinMPNN\n",
"This notebook is intended as a quick demo, more features to come!"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "iYDU3ftml2k5"
},
"outputs": [],
"source": [
"#@title Clone github repo\n",
"import json, time, os, sys, glob\n",
"\n",
"if not os.path.isdir(\"ProteinMPNN\"):\n",
" os.system(\"git clone -q https://github.com/dauparas/ProteinMPNN.git\")\n",
"sys.path.append('/content/ProteinMPNN')"
]
},
{
"cell_type": "code",
"source": [
"#@title Setup Model\n",
"import matplotlib.pyplot as plt\n",
"import shutil\n",
"import warnings\n",
"import numpy as np\n",
"import torch\n",
"from torch import optim\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.data.dataset import random_split, Subset\n",
"import copy\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import random\n",
"import os.path\n",
"from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB\n",
"from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN\n",
"\n",
"device = torch.device(\"cuda:0\" if (torch.cuda.is_available()) else \"cpu\")\n",
"#v_48_010=version with 48 edges 0.10A noise\n",
"model_name = \"v_48_020\" #@param [\"v_48_002\", \"v_48_010\", \"v_48_020\", \"v_48_030\"]\n",
"\n",
"\n",
"backbone_noise=0.00 # Standard deviation of Gaussian noise to add to backbone atoms\n",
"\n",
"path_to_model_weights='/content/ProteinMPNN/vanilla_model_weights' \n",
"hidden_dim = 128\n",
"num_layers = 3 \n",
"model_folder_path = path_to_model_weights\n",
"if model_folder_path[-1] != '/':\n",
" model_folder_path = model_folder_path + '/'\n",
"checkpoint_path = model_folder_path + f'{model_name}.pt'\n",
"\n",
"checkpoint = torch.load(checkpoint_path, map_location=device) \n",
"print('Number of edges:', checkpoint['num_edges'])\n",
"noise_level_print = checkpoint['noise_level']\n",
"print(f'Training noise level: {noise_level_print}A')\n",
"model = ProteinMPNN(num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])\n",
"model.to(device)\n",
"model.load_state_dict(checkpoint['model_state_dict'])\n",
"model.eval()\n",
"print(\"Model loaded\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2nKSlaMlSpcf",
"outputId": "293c6f08-8256-4e08-d8ab-3c2a02ab74b8"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of edges: 48\n",
"Training noise level: 0.2A\n",
"Model loaded\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"#@title Helper functions\n",
"def make_tied_positions_for_homomers(pdb_dict_list):\n",
" my_dict = {}\n",
" for result in pdb_dict_list:\n",
" all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ...\n",
" tied_positions_list = []\n",
" chain_length = len(result[f\"seq_chain_{all_chain_list[0]}\"])\n",
" for i in range(1,chain_length+1):\n",
" temp_dict = {}\n",
" for j, chain in enumerate(all_chain_list):\n",
" temp_dict[chain] = [i] #needs to be a list\n",
" tied_positions_list.append(temp_dict)\n",
" my_dict[result['name']] = tied_positions_list\n",
" return my_dict"
],
"metadata": {
"id": "yjnUkQuZX-de"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Examples: \n",
"\n",
"1) pdb: 6MRR, homomer: False, designed_chain: A\n",
"\n",
"2) pdb: 1O91, homomer: True, designed_chain: A B C, for correct symmetric tying lenghts of homomer chains should be the same"
],
"metadata": {
"id": "bgRXscTaYuqM"
}
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "k4o6w2Y23wxO",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "58d89f72-5d3c-44a7-c2cb-e26c3c54d43f"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'1O91': (['A', 'B', 'C'], [])}\n",
"Length of chain C is 131\n",
"Length of chain B is 131\n",
"Length of chain A is 131\n"
]
}
],
"source": [
"import re\n",
"from google.colab import files\n",
"import numpy as np\n",
"\n",
"#########################\n",
"def get_pdb(pdb_code=\"\"):\n",
" if pdb_code is None or pdb_code == \"\":\n",
" upload_dict = files.upload()\n",
" pdb_string = upload_dict[list(upload_dict.keys())[0]]\n",
" with open(\"tmp.pdb\",\"wb\") as out: out.write(pdb_string)\n",
" return \"tmp.pdb\"\n",
" else:\n",
" os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n",
" return f\"{pdb_code}.pdb\"\n",
"\n",
"#@markdown ### Input Options\n",
"pdb='1O91' #@param {type:\"string\"}\n",
"pdb_path = get_pdb(pdb)\n",
"#@markdown - pdb code (leave blank to get an upload prompt)\n",
"\n",
"homomer = True #@param {type:\"boolean\"}\n",
"designed_chain = \"A B C\" #@param {type:\"string\"}\n",
"fixed_chain = \"\" #@param {type:\"string\"}\n",
"\n",
"if designed_chain == \"\":\n",
" designed_chain_list = []\n",
"else:\n",
" designed_chain_list = re.sub(\"[^A-Za-z]+\",\",\", designed_chain).split(\",\")\n",
"\n",
"if fixed_chain == \"\":\n",
" fixed_chain_list = []\n",
"else:\n",
" fixed_chain_list = re.sub(\"[^A-Za-z]+\",\",\", fixed_chain).split(\",\")\n",
"\n",
"chain_list = list(set(designed_chain_list + fixed_chain_list))\n",
"\n",
"#@markdown - specified which chain(s) to design and which chain(s) to keep fixed. \n",
"#@markdown Use comma:`A,B` to specifiy more than one chain\n",
"\n",
"#chain = \"A\" #@param {type:\"string\"}\n",
"#pdb_path_chains = chain\n",
"##@markdown - Define which chain to redesign\n",
"\n",
"#@markdown ### Design Options\n",
"num_seqs = 1 #@param [\"1\", \"2\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n",
"num_seq_per_target = num_seqs\n",
"\n",
"#@markdown - Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.\n",
"sampling_temp = \"0.1\" #@param [\"0.0001\", \"0.1\", \"0.15\", \"0.2\", \"0.25\", \"0.3\", \"0.5\"]\n",
"\n",
"\n",
"\n",
"save_score=0 # 0 for False, 1 for True; save score=-log_prob to npy files\n",
"save_probs=0 # 0 for False, 1 for True; save MPNN predicted probabilites per position\n",
"score_only=0 # 0 for False, 1 for True; score input backbone-sequence pairs\n",
"conditional_probs_only=0 # 0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone)\n",
"conditional_probs_only_backbone=0 # 0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone)\n",
" \n",
"batch_size=1 # Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory\n",
"max_length=20000 # Max sequence length\n",
" \n",
"out_folder='.' # Path to a folder to output sequences, e.g. /home/out/\n",
"jsonl_path='' # Path to a folder with parsed pdb into jsonl\n",
"omit_AAs='X' # Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine.\n",
" \n",
"pssm_multi=0.0 # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions\n",
"pssm_threshold=0.0 # A value between -inf + inf to restric per position AAs\n",
"pssm_log_odds_flag=0 # 0 for False, 1 for True\n",
"pssm_bias_flag=0 # 0 for False, 1 for True\n",
"\n",
"\n",
"##############################################################\n",
"\n",
"folder_for_outputs = out_folder\n",
"\n",
"NUM_BATCHES = num_seq_per_target//batch_size\n",
"BATCH_COPIES = batch_size\n",
"temperatures = [float(item) for item in sampling_temp.split()]\n",
"omit_AAs_list = omit_AAs\n",
"alphabet = 'ACDEFGHIKLMNPQRSTVWYX'\n",
"\n",
"omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)\n",
"\n",
"chain_id_dict = None\n",
"fixed_positions_dict = None\n",
"pssm_dict = None\n",
"omit_AA_dict = None\n",
"bias_AA_dict = None\n",
"tied_positions_dict = None\n",
"bias_by_res_dict = None\n",
"bias_AAs_np = np.zeros(len(alphabet))\n",
"\n",
"\n",
"###############################################################\n",
"pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list)\n",
"dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)\n",
"\n",
"chain_id_dict = {}\n",
"chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)\n",
"\n",
"print(chain_id_dict)\n",
"for chain in chain_list:\n",
" l = len(pdb_dict_list[0][f\"seq_chain_{chain}\"])\n",
" print(f\"Length of chain {chain} is {l}\")\n",
"\n",
"if homomer:\n",
" tied_positions_dict = make_tied_positions_for_homomers(pdb_dict_list)\n",
"else:\n",
" tied_positions_dict = None"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"cellView": "form",
"id": "xMVlYh8Fv2of",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7b3d1839-d872-4f48-f907-775eaf035754"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Generating sequences...\n",
">1O91, score=1.3217, fixed_chains=[], designed_chains=['A', 'B', 'C'], model_name=v_48_020\n",
"EMPAFTAELTVPFPPVGAPVKFDKLLYNGRQNYNPQTGIFTCEVPGVYYFAYHVHCKGGNVWVALFKNNEPMMYTYDEYKKGFLDQASGSAVLLLRPGDQVFLQMPSEQAAGLYAGQYVHSSFSGYLLYPM/EMPAFTAELTVPFPPVGAPVKFDKLLYNGRQNYNPQTGIFTCEVPGVYYFAYHVHCKGGNVWVALFKNNEPMMYTYDEYKKGFLDQASGSAVLLLRPGDQVFLQMPSEQAAGLYAGQYVHSSFSGYLLYPM/EMPAFTAELTVPFPPVGAPVKFDKLLYNGRQNYNPQTGIFTCEVPGVYYFAYHVHCKGGNVWVALFKNNEPMMYTYDEYKKGFLDQASGSAVLLLRPGDQVFLQMPSEQAAGLYAGQYVHSSFSGYLLYPM\n",
">T=0.1, sample=0, score=0.6755, seq_recovery=0.4885\n",
"EVEAFTALLTTPHPPVGEPIKFNKLVYNGRNVYDPATGIFTVKTPGVYFFTFVLYVYGADLHAELMKNDTPVIKVYLQTVNGKINQVSGAAVLELEEGDKVYVKIPSASANGLWASADAHSYFSGYLLTEK/EVEAFTALLTTPHPPVGEPIKFNKLVYNGRNVYDPATGIFTVKTPGVYFFTFVLYVYGADLHAELMKNDTPVIKVYLQTVNGKINQVSGAAVLELEEGDKVYVKIPSASANGLWASADAHSYFSGYLLTEK/EVEAFTALLTTPHPPVGEPIKFNKLVYNGRNVYDPATGIFTVKTPGVYFFTFVLYVYGADLHAELMKNDTPVIKVYLQTVNGKINQVSGAAVLELEEGDKVYVKIPSASANGLWASADAHSYFSGYLLTEK\n"
]
}
],
"source": [
"#@title RUN\n",
"with torch.no_grad():\n",
" print('Generating sequences...')\n",
" for ix, protein in enumerate(dataset_valid):\n",
" score_list = []\n",
" all_probs_list = []\n",
" all_log_probs_list = []\n",
" S_sample_list = []\n",
" batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]\n",
" X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict)\n",
" pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false\n",
" name_ = batch_clones[0]['name']\n",
"\n",
" randn_1 = torch.randn(chain_M.shape, device=X.device)\n",
" log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)\n",
" mask_for_loss = mask*chain_M*chain_M_pos\n",
" scores = _scores(S, log_probs, mask_for_loss)\n",
" native_score = scores.cpu().data.numpy()\n",
"\n",
" for temp in temperatures:\n",
" for j in range(NUM_BATCHES):\n",
" randn_2 = torch.randn(chain_M.shape, device=X.device)\n",
" if tied_positions_dict == None:\n",
" sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all)\n",
" S_sample = sample_dict[\"S\"] \n",
" else:\n",
" sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all)\n",
" # Compute scores\n",
" S_sample = sample_dict[\"S\"]\n",
" log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict[\"decoding_order\"])\n",
" mask_for_loss = mask*chain_M*chain_M_pos\n",
" scores = _scores(S_sample, log_probs, mask_for_loss)\n",
" scores = scores.cpu().data.numpy()\n",
" all_probs_list.append(sample_dict[\"probs\"].cpu().data.numpy())\n",
" all_log_probs_list.append(log_probs.cpu().data.numpy())\n",
" S_sample_list.append(S_sample.cpu().data.numpy())\n",
" for b_ix in range(BATCH_COPIES):\n",
" masked_chain_length_list = masked_chain_length_list_list[b_ix]\n",
" masked_list = masked_list_list[b_ix]\n",
" seq_recovery_rate = torch.sum(torch.sum(torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix])\n",
" seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])\n",
" score = scores[b_ix]\n",
" score_list.append(score)\n",
" native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])\n",
" if b_ix == 0 and j==0 and temp==temperatures[0]:\n",
" start = 0\n",
" end = 0\n",
" list_of_AAs = []\n",
" for mask_l in masked_chain_length_list:\n",
" end += mask_l\n",
" list_of_AAs.append(native_seq[start:end])\n",
" start = end\n",
" native_seq = \"\".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))\n",
" l0 = 0\n",
" for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:\n",
" l0 += mc_length\n",
" native_seq = native_seq[:l0] + '/' + native_seq[l0:]\n",
" l0 += 1\n",
" sorted_masked_chain_letters = np.argsort(masked_list_list[0])\n",
" print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]\n",
" sorted_visible_chain_letters = np.argsort(visible_list_list[0])\n",
" print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters]\n",
" native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4)\n",
" line = '>{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\\n{}\\n'.format(name_, native_score_print, print_visible_chains, print_masked_chains, model_name, native_seq)\n",
" print(line.rstrip())\n",
" start = 0\n",
" end = 0\n",
" list_of_AAs = []\n",
" for mask_l in masked_chain_length_list:\n",
" end += mask_l\n",
" list_of_AAs.append(seq[start:end])\n",
" start = end\n",
"\n",
" seq = \"\".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))\n",
" l0 = 0\n",
" for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:\n",
" l0 += mc_length\n",
" seq = seq[:l0] + '/' + seq[l0:]\n",
" l0 += 1\n",
" score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)\n",
" seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)\n",
" line = '>T={}, sample={}, score={}, seq_recovery={}\\n{}\\n'.format(temp,b_ix,score_print,seq_rec_print,seq)\n",
" print(line.rstrip())\n",
"\n",
"\n",
"all_probs_concat = np.concatenate(all_probs_list)\n",
"all_log_probs_concat = np.concatenate(all_log_probs_list)\n",
"S_sample_concat = np.concatenate(S_sample_list)"
]
},
{
"cell_type": "code",
"source": [
"#@markdown ### Amino acid probabilties\n",
"import plotly.express as px\n",
"fig = px.imshow(np.exp(all_log_probs_concat).mean(0).T,\n",
" labels=dict(x=\"positions\", y=\"amino acids\", color=\"probability\"),\n",
" y=list(alphabet),\n",
" template=\"simple_white\"\n",
" )\n",
"\n",
"fig.update_xaxes(side=\"top\")\n",
"\n",
"\n",
"fig.show()"
],
"metadata": {
"cellView": "form",
"id": "xbsaRChNfybM",
"outputId": "e87c30dc-d960-4878-acae-f67a3b69fc1c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 542
}
},
"execution_count": 14,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
"