{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "AlphaFold_single.ipynb", "provenance": [], "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "#AlphaFold - single sequence input\n", "- WARNING - For DEMO and educational purposes only. \n", "- For natural proteins you often need more than a single sequence to accurately predict the structure. See [ColabFold](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb) notebook if you want to predict the protein structure from a multiple-sequence-alignment. That being said, this notebook could be useful for evaluating *de novo* designed proteins and learning the idealized principles of proteins.\n", "\n", "### Tips and Instructions\n", "- click the little ▶ play icon to the left of each cell below.\n", "- hold mouseover aminoacid to get name and position number" ], "metadata": { "id": "VpfCw7IzVHXv" } }, { "cell_type": "code", "source": [ "#@title Setup\n", "\n", "# import libraries\n", "from IPython.utils import io\n", "import os,sys,re\n", "import tensorflow as tf\n", "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import matplotlib\n", "from matplotlib import animation\n", "import matplotlib.pyplot as plt\n", "from IPython.display import HTML\n", "\n", "with io.capture_output() as captured:\n", " if not os.path.isdir(\"af_backprop\"):\n", " %shell git clone -b beta https://github.com/sokrypton/af_backprop.git\n", " %shell pip -q install biopython dm-haiku ml-collections py3Dmol\n", " %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py\n", " if not os.path.isdir(\"params\"):\n", " %shell mkdir params\n", " %shell curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", "\n", "# configure which device to use\n", "try:\n", " # check if TPU is available\n", " import jax.tools.colab_tpu\n", " jax.tools.colab_tpu.setup_tpu()\n", " print('Running on TPU')\n", " DEVICE = \"tpu\"\n", "except:\n", " if jax.local_devices()[0].platform == 'cpu':\n", " print(\"WARNING: no GPU detected, will be using CPU\")\n", " DEVICE = \"cpu\"\n", " else:\n", " print('Running on GPU')\n", " DEVICE = \"gpu\"\n", " # disable GPU on tensorflow\n", " tf.config.set_visible_devices([], 'GPU')\n", "\n", "sys.path.append('/content/af_backprop')\n", "\n", "# import libraries\n", "from utils import update_seq, update_aatype, get_plddt, get_pae\n", "import colabfold as cf\n", "from alphafold.common import protein\n", "from alphafold.data import pipeline\n", "from alphafold.model import data, config, model\n", "from alphafold.common import residue_constants\n", "\n", "# custom functions\n", "def clear_mem():\n", " backend = jax.lib.xla_bridge.get_backend()\n", " for buf in backend.live_buffers(): buf.delete()\n", "\n", "def setup_model(max_len, model_name=\"model_2_ptm\"):\n", " clear_mem()\n", "\n", " # setup model\n", " cfg = config.model_config(\"model_5_ptm\")\n", " cfg.model.num_recycle = 0\n", " cfg.data.common.num_recycle = 0\n", " cfg.data.eval.max_msa_clusters = 1\n", " cfg.data.common.max_extra_msa = 1\n", " cfg.data.eval.masked_msa_replace_fraction = 0\n", " cfg.model.global_config.subbatch_size = None\n", " model_params = data.get_model_haiku_params(model_name=model_name, data_dir=\".\")\n", " model_runner = model.RunModel(cfg, model_params, is_training=False)\n", "\n", " seq = \"A\" * max_len\n", " length = len(seq)\n", " feature_dict = {\n", " **pipeline.make_sequence_features(sequence=seq, description=\"none\", num_res=length),\n", " **pipeline.make_msa_features(msas=[[seq]], deletion_matrices=[[[0]*length]])\n", " }\n", " inputs = model_runner.process_features(feature_dict,random_seed=0)\n", "\n", " def runner(seq, opt):\n", " # update sequence\n", " inputs = opt[\"inputs\"]\n", " inputs.update(opt[\"prev\"])\n", " update_seq(seq, inputs)\n", " update_aatype(inputs[\"target_feat\"][...,1:], inputs)\n", "\n", " # mask prediction\n", " mask = seq.sum(-1)\n", " inputs[\"seq_mask\"] = inputs[\"seq_mask\"].at[:].set(mask)\n", " inputs[\"msa_mask\"] = inputs[\"msa_mask\"].at[:].set(mask)\n", " inputs[\"residue_index\"] = jnp.where(mask==1,inputs[\"residue_index\"],0)\n", "\n", " # get prediction\n", " key = jax.random.PRNGKey(0)\n", " outputs = model_runner.apply(opt[\"params\"], key, inputs)\n", "\n", " prev = {\"init_msa_first_row\":outputs['representations']['msa_first_row'][None],\n", " \"init_pair\":outputs['representations']['pair'][None],\n", " \"init_pos\":outputs['structure_module']['final_atom_positions'][None]}\n", " \n", " aux = {\"final_atom_positions\":outputs[\"structure_module\"][\"final_atom_positions\"],\n", " \"final_atom_mask\":outputs[\"structure_module\"][\"final_atom_mask\"],\n", " \"plddt\":get_plddt(outputs),\"pae\":get_pae(outputs),\n", " \"inputs\":inputs, \"prev\":prev}\n", " return aux\n", "\n", " return jax.jit(runner), {\"inputs\":inputs,\"params\":model_params}\n", "\n", "def save_pdb(outs, filename, LEN):\n", " '''save pdb coordinates'''\n", " p = {\"residue_index\":outs[\"inputs\"][\"residue_index\"][0] + 1,\n", " \"aatype\":outs[\"inputs\"][\"aatype\"].argmax(-1)[0],\n", " \"atom_positions\":outs[\"final_atom_positions\"],\n", " \"atom_mask\":outs[\"final_atom_mask\"],\n", " \"plddt\":outs[\"plddt\"]}\n", " p = jax.tree_map(lambda x:x[:LEN], p)\n", " b_factors = 100.0 * p.pop(\"plddt\")[:,None] * p[\"atom_mask\"]\n", " p = protein.Protein(**p,b_factors=b_factors)\n", " pdb_lines = protein.to_pdb(p)\n", " with open(filename, 'w') as f: f.write(pdb_lines)\n", "\n", "def make_animation(positions, plddts=None, line_w=2.0, dpi=100):\n", "\n", " def ca_align_to_last(positions):\n", " def align(P, Q):\n", " p = P - P.mean(0,keepdims=True)\n", " q = Q - Q.mean(0,keepdims=True)\n", " return p @ cf.kabsch(p,q)\n", " \n", " pos = positions[-1,:,1,:] - positions[-1,:,1,:].mean(0,keepdims=True)\n", " best_2D_view = pos @ cf.kabsch(pos,pos,return_v=True)\n", "\n", " new_positions = []\n", " for i in range(len(positions)):\n", " new_positions.append(align(positions[i,:,1,:],best_2D_view))\n", " return np.asarray(new_positions)\n", "\n", " # align all to last recycle\n", " pos = ca_align_to_last(positions)\n", "\n", " fig, (ax1, ax2, ax3) = plt.subplots(1,3)\n", " fig.subplots_adjust(top = 0.90, bottom = 0.10, right = 1, left = 0, hspace = 0, wspace = 0)\n", " fig.set_figwidth(13)\n", " fig.set_figheight(5)\n", " fig.set_dpi(dpi)\n", "\n", " xy_min = pos[...,:2].min() - 1\n", " xy_max = pos[...,:2].max() + 1\n", "\n", " for ax in [ax1,ax3]:\n", " ax.set_xlim(xy_min, xy_max)\n", " ax.set_ylim(xy_min, xy_max)\n", " ax.axis(False)\n", "\n", " ims=[]\n", " for k,(xyz,plddt) in enumerate(zip(pos,plddts)):\n", " ims.append([])\n", " im2 = ax2.plot(plddt, animated=True, color=\"black\")\n", " tt1 = cf.add_text(\"colored by N->C\", ax1)\n", " tt2 = cf.add_text(f\"recycle={k}\", ax2)\n", " tt3 = cf.add_text(f\"pLDDT={plddt.mean():.3f}\", ax3)\n", " ax2.set_xlabel(\"positions\")\n", " ax2.set_ylabel(\"pLDDT\")\n", " ax2.set_ylim(0,100)\n", " ims[-1] += [cf.plot_pseudo_3D(xyz, ax=ax1, line_w=line_w)]\n", " ims[-1] += [im2[0],tt1,tt2,tt3]\n", " ims[-1] += [cf.plot_pseudo_3D(xyz, c=plddt, cmin=50, cmax=90, ax=ax3, line_w=line_w)]\n", " \n", " ani = animation.ArtistAnimation(fig, ims, blit=True, interval=120)\n", " plt.close()\n", " return ani.to_html5_video()" ], "metadata": { "cellView": "form", "id": "24ybo88aBiSU" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%%time\n", "#@title Enter the amino acid sequence to fold ⬇️\n", "\n", "# initialize model runner if doesn't exist\n", "if \"runner\" not in dir():\n", " max_length = 50\n", " current_seq = \"\"\n", " r = -1\n", " runner, I = setup_model(max_length)\n", "\n", "# collect user inputs\n", "sequence = 'GGGGGGGGGGGGGGGGGGGGG' #@param {type:\"string\"}\n", "recycles = 0 #@param [\"0\", \"1\", \"2\", \"3\", \"6\", \"12\", \"24\"] {type:\"raw\"}\n", "sequence = re.sub(\"[^A-Z]\", \"\", sequence.upper())\n", "length = len(sequence)\n", "\n", "# if length greater than max_len, recompile for larger length\n", "if length > max_length:\n", " max_length = length + 10 # a little buffer\n", " runner, I = setup_model(max_length)\n", "\n", "if sequence != current_seq:\n", " outs = []\n", " positions = []\n", " plddts = []\n", " paes = []\n", " r = -1\n", " # convert sequence to one_hot\n", " x = np.array([residue_constants.restype_order.get(aa,0) for aa in sequence])\n", " x = np.pad(x,[0,max_length-length],constant_values=-1)\n", " x = jax.nn.one_hot(x,20)\n", "\n", " # restart recycle\n", " I[\"prev\"] = {'init_msa_first_row': np.zeros([1, max_length, 256]),\n", " 'init_pair': np.zeros([1, max_length, max_length, 128]),\n", " 'init_pos': np.zeros([1, max_length, 37, 3])}\n", " current_seq = sequence\n", "\n", "# run for defined number of recycles\n", "while r < recycles:\n", " O = runner(x, I)\n", " O = jax.tree_map(lambda x:np.asarray(x), O)\n", " positions.append(O[\"final_atom_positions\"][:length])\n", " plddts.append(O[\"plddt\"][:length])\n", " paes.append(O[\"pae\"][:length,:length])\n", " I[\"prev\"] = O[\"prev\"]\n", " outs.append(O)\n", " r += 1\n", "\n", "#@markdown #### Display options\n", "color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n", "show_sidechains = True #@param {type:\"boolean\"}\n", "show_mainchains = False #@param {type:\"boolean\"}\n", "\n", "print(f\"plotting prediction at recycle={recycles}\")\n", "save_pdb(outs[recycles], \"out.pdb\", length)\n", "v = cf.show_pdb(\"out.pdb\", show_sidechains, show_mainchains, color,\n", " color_HP=True, size=(800,480)) \n", "v.setHoverable({},\n", " True,\n", " '''function(atom,viewer,event,container){if(!atom.label){atom.label=viewer.addLabel(\" \"+atom.resn+\":\"+atom.resi,{position:atom,backgroundColor:'mintcream',fontColor:'black'});}}''',\n", " '''function(atom,viewer){if(atom.label){viewer.removeLabel(atom.label);delete atom.label;}}''')\n", "v.show() \n", "if color == \"lDDT\": cf.plot_plddt_legend().show()\n", "\n", "# add confidence plots\n", "cf.plot_confidence(plddts[recycles]*100, paes[recycles]).show()" ], "metadata": { "cellView": "form", "id": "cAoC4ar8G7ZH" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title Animate\n", "#@markdown - Animate trajectory if more than 0 recycle(s)\n", "HTML(make_animation(np.asarray(positions),\n", " np.asarray(plddts) * 100.0))" ], "metadata": { "cellView": "form", "id": "tdjdC0KFPjWw" }, "execution_count": null, "outputs": [] } ] }