"Open

In [1]:
%%bash
if [ ! -d af_backprop ]; then
 git clone https://github.com/sokrypton/af_backprop.git
 pip -q install dm-haiku py3Dmol biopython ml_collections
fi
if [ ! -d params ]; then
 mkdir params
 curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
fi

Cloning into 'af_backprop'...


In [2]:
import os
import sys
sys.path.append('af_backprop')

import numpy as np
import matplotlib.pyplot as plt
import py3Dmol

import jax
import jax.numpy as jnp

from jax.experimental.optimizers import adam

from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.model import data, config, model, modules
from alphafold.common import residue_constants

from alphafold.model import all_atom
from alphafold.model import folding

# custom functions
from alphafold.data import prep_inputs
from utils import *

In [3]:
# setup which model params to use
model_name = "model_3_ptm"
model_config = config.model_config(model_name)

# enable checkpointing
model_config.model.global_config.use_remat = True

# number of recycles
model_config.model.num_recycle = 3
model_config.data.common.num_recycle = 3

# backprop through recycles
model_config.model.backprop_recycle = False
model_config.model.embeddings_and_evoformer.backprop_dgram = False

# custom relative features (needed for insertion/deletion)
INDELS = False
model_config.model.embeddings_and_evoformer.custom_relative_features = INDELS

# number of sequences
N = 1
model_config.data.eval.max_msa_clusters = N
model_config.data.common.max_extra_msa = 1
model_config.data.eval.masked_msa_replace_fraction = 0

# dropout
model_config = set_dropout(model_config, 0.0)

# setup model
model_params = [data.get_model_haiku_params(model_name=model_name, data_dir=".")]
model_runner = model.RunModel(model_config, model_params[0], is_training=True)

# load the other models to sample during design.
for model_name in ["model_1_ptm","model_2_ptm","model_5_ptm","model_4_ptm"]:
 params = data.get_model_haiku_params(model_name, '.')
 model_params.append({k: params[k] for k in model_runner.params.keys()})

In [4]:
#################
# USER INPUT
#################
# native structure you want to pull active site from
pos_idx_ref = [13,37,98] # note: zero indexed
PDB_REF = "af_backprop/examples/sc_hall/1QJG.pdb"

# starting structure (for random starting sequence, set PDB=None and LEN to desired length)
pos_idx = [74+5,32+5,7+5]
MODE = "af_backprop/examples/sc_hall/1QJS_starting"
PDB = f"{MODE}.pdb"
LEN = 105

In [5]:
# prep reference (native) features
OBJ_REF = protein.from_pdb_string(pdb_to_string(PDB_REF), chain_id="A")
SEQ_REF = jax.nn.one_hot(OBJ_REF.aatype,20)
START_SEQ_REF = "".join([order_restype[a] for a in OBJ_REF.aatype])

batch_ref = {'aatype': OBJ_REF.aatype,
 'all_atom_positions': OBJ_REF.atom_positions,
 'all_atom_mask': OBJ_REF.atom_mask}
batch_ref.update(all_atom.atom37_to_frames(**batch_ref))
batch_ref.update(prep_inputs.make_atom14_positions(batch_ref))
batch_ref["idx"] = pos_idx_ref

# prep starting (design) features
if PDB is not None:
 OBJ = protein.from_pdb_string(pdb_to_string(PDB), chain_id="A")
 SEQ = jax.nn.one_hot(OBJ.aatype,20)
 START_SEQ = "".join([order_restype[a] for a in OBJ.aatype])

 batch = {'aatype': OBJ.aatype,
 'all_atom_positions': OBJ.atom_positions,
 'all_atom_mask': OBJ.atom_mask}
 batch.update(all_atom.atom37_to_frames(**batch))
 batch.update(prep_inputs.make_atom14_positions(batch))
else:
 SEQ = jnp.zeros(LEN).at[jnp.asarray(pos_idx)].set([OBJ_REF.aatype[i] for i in pos_idx_ref])
 START_SEQ = "".join([order_restype[a] for a in SEQ])
 SEQ = jax.nn.one_hot(SEQ,20)

# prep input features
feature_dict = {
 **pipeline.make_sequence_features(sequence=START_SEQ,description="none",num_res=len(START_SEQ)),
 **pipeline.make_msa_features(msas=[N*[START_SEQ]], deletion_matrices=[N*[[0]*len(START_SEQ)]]),
}
inputs = model_runner.process_features(feature_dict, random_seed=0)

if N > 1:
 inputs["msa_row_mask"] = jnp.ones_like(inputs["msa_row_mask"])
 inputs["msa_mask"] = jnp.ones_like(inputs["msa_mask"])

In [6]:
print([START_SEQ[i] for i in pos_idx])
print([START_SEQ_REF[i] for i in pos_idx_ref])

['Y', 'N', 'D']
['Y', 'N', 'D']


In [7]:
def get_grad_fn(model_runner, inputs, pos_idx_ref, inc_backbone=False):
 
 def mod(params, key, model_params, opt):
 pos_idx = opt["pos_idx"]
 pos_idx_ref = batch_ref["idx"]
 ############################
 # set amino acid sequence
 ############################
 seq_logits = jax.random.permutation(key, params["msa"])
 seq_soft = jax.nn.softmax(seq_logits)
 seq = jax.lax.stop_gradient(jax.nn.one_hot(seq_soft.argmax(-1),20) - seq_soft) + seq_soft
 seq = seq.at[:,pos_idx,:].set(SEQ_REF[pos_idx_ref,:])

 oh_mask = opt["oh_mask"][:,None]
 pseudo_seq = oh_mask * seq + (1-oh_mask) * seq_logits

 inputs_mod = inputs.copy()
 update_seq(pseudo_seq, inputs_mod, msa_input=("msa" in params))

 if "msa_mask" in opt:
 inputs_mod["msa_mask"] = inputs_mod["msa_mask"] * opt["msa_mask"][None,:,None]
 inputs_mod["msa_row_mask"] = inputs_mod["msa_row_mask"] * opt["msa_mask"][None,:]
 
 ####################
 # set sidechains identity
 ####################
 B,L = inputs_mod["aatype"].shape[:2]
 ALA = jax.nn.one_hot(residue_constants.restype_order["A"],21)

 aatype = jnp.zeros((B,L,21)).at[...,:20].set(seq[0])
 ala_mask = opt["ala_mask"][:,None]
 aatype_ala = jnp.zeros((B,L,21)).at[:].set(ALA)
 aatype_ala = aatype_ala.at[:,pos_idx,:20].set(SEQ_REF[pos_idx_ref,:])
 aatype_pseudo = ala_mask * aatype + (1-ala_mask) * aatype_ala
 update_aatype(aatype_pseudo, inputs_mod)

 ############################################################
 if model_runner.config.model.embeddings_and_evoformer.custom_relative_features:
 # set positions
 active_pos = jax.nn.sigmoid(params["active_pos"])
 active_pos = active_pos.at[jnp.asarray(pos_idx)].set(1.0)

 # hard constraint
 active_pos = jax.lax.stop_gradient((active_pos > 0.5).astype(jnp.float32) - active_pos) + active_pos
 
 residue_idx = jax.lax.scan(lambda x,y:(x+y,x), 0, active_pos)[1]
 offset = residue_idx[:, None] - residue_idx[None, :]
 rel_pos = jax.nn.softmax(-jnp.square(offset[...,None] - jnp.arange(-32,33,dtype=jnp.float32)))

 inputs_mod["rel_pos"] = jnp.tile(rel_pos[None],[B,1,1,1])
 inputs_mod["seq_mask"] = jnp.zeros_like(inputs_mod["seq_mask"]).at[...,:].set(active_pos)
 inputs_mod["msa_mask"] = jnp.zeros_like(inputs_mod["msa_mask"]).at[...,:].set(active_pos)

 inputs_mod["atom14_atom_exists"] *= active_pos[None,:,None]
 inputs_mod["atom37_atom_exists"] *= active_pos[None,:,None]
 inputs_mod["residx_atom14_to_atom37"] *= active_pos[None,:,None,None]
 inputs_mod["residx_atom37_to_atom14"] *= active_pos[None,:,None,None]

 ############################################################
 
 # get output
 outputs = model_runner.apply(model_params, key, inputs_mod)

 ###################
 # structure loss
 ###################
 fape_loss = get_fape_loss_idx(batch_ref, outputs, pos_idx, model_config, backbone=inc_backbone, sidechain=True)
 rmsd_loss = get_sidechain_rmsd_idx(batch_ref, outputs, pos_idx, model_config)
 dgram_loss = get_dgram_loss_idx(batch_ref, outputs, pos_idx, model_config)

 losses = {"fape":fape_loss,
 "rmsd":rmsd_loss,
 "dgram":dgram_loss}

 if "sc_weight_fape" in opt: fape_loss *= opt["sc_weight_fape"]
 if "sc_weight_rmsd" in opt: rmsd_loss *= opt["sc_weight_rmsd"]
 if "sc_weight_dgram" in opt: dgram_loss *= opt["sc_weight_dgram"]

 loss = (rmsd_loss + fape_loss + dgram_loss) * opt["sc_weight"]
 
 ################### 
 # background loss
 ###################
 if "conf_weight" in opt:
 pae = jax.nn.softmax(outputs["predicted_aligned_error"]["logits"])
 plddt = jax.nn.softmax(outputs['predicted_lddt']['logits'])
 pae_loss = (pae * jnp.arange(pae.shape[-1])).sum(-1)
 plddt_loss = (plddt * jnp.arange(plddt.shape[-1])[::-1]).sum(-1)

 if model_runner.config.model.embeddings_and_evoformer.custom_relative_features:
 active_pos_mask = active_pos[:,None] * active_pos[None,:]
 pae_loss = (pae_loss * active_pos_mask).sum() / (1e-8 + active_pos_mask.sum())
 plddt_loss = (plddt_loss * active_pos).sum() / (1e-8 + active_pos.sum())
 else:
 pae_loss = pae_loss.mean()
 plddt_loss = plddt_loss.mean()

 loss = loss + (pae_loss + plddt_loss) * opt["conf_weight"]
 losses["pae"] = pae_loss
 losses["plddt"] = plddt_loss

 if "rg_weight" in opt:
 ca_coords = outputs["structure_module"]["final_atom_positions"][:,1,:]
 rg_loss = jnp.sqrt(jnp.square(ca_coords - ca_coords.mean(0)).sum(-1).mean() + 1e-8)
 loss = loss + rg_loss * opt["rg_weight"]
 losses["rg"] = rg_loss
 
 if "msa" in params and "ent_weight" in opt:
 seq_prf = seq.mean(0)
 ent_loss = -(seq_prf * jnp.log(seq_prf + 1e-8)).sum(-1).mean()
 loss = loss + ent_loss * opt["ent_weight"]
 losses["ent"] = ent_loss
 else:
 ent_loss = 0

 outs = {"final_atom_positions":outputs["structure_module"]["final_atom_positions"],
 "final_atom_mask":outputs["structure_module"]["final_atom_mask"]}

 if model_runner.config.model.embeddings_and_evoformer.custom_relative_features:
 outs["residue_idx"] = residue_idx

 seq_ = seq[0] if "msa" in params else seq

 return loss, ({"losses":losses, "outputs":outs, "seq":seq_})
 loss_fn = mod
 grad_fn = jax.value_and_grad(mod, has_aux=True, argnums=0)
 return loss_fn, grad_fn

In [8]:
# gradient function (note for greedy search we won't be using grad_fn, only loss_fn)
loss_fn, grad_fn = get_grad_fn(model_runner, inputs, pos_idx_ref=pos_idx_ref)
loss_fn = jax.jit(loss_fn)

# stack model params (we exclude the last model: model_4_ptm for validation)
model_params_multi = jax.tree_multimap(lambda *values: jnp.stack(values, axis=0), *model_params[:-1])
loss_fn_multi = jax.jit(jax.vmap(loss_fn,(None,None,0,None)))

In [9]:
key = jax.random.PRNGKey(0)
L,A = len(START_SEQ),20

pos_idx_ = jnp.asarray(pos_idx)
pos_idx_ref_ = jnp.asarray(pos_idx_ref)

msa = SEQ[None]
params = {"msa":msa, "active_pos":jnp.ones(L)}

In [10]:
def mut(params, indel=False):
 L,A = params["msa"].shape[-2:]
 while True:
 i = np.random.randint(L)
 a = np.random.randint(A)
 if i not in pos_idx and params["msa"][0,i,a] == 0 and (params["active_pos"][i] == 1 or indel):
 break

 params_ = params.copy()
 params_["msa"] = params["msa"].at[:,i,:].set(jnp.eye(A)[a])

 if indel:
 state = -1 if params["active_pos"][i] == 1 else 1
 params_["active_pos"] = params["active_pos"].at[i].set(state)
 return params_

multi-model refinement

In [12]:
oh_mask = jnp.ones((L,))
ala_mask = jnp.ones((L,))
msa_mask = jnp.ones((N,))
opt={"oh_mask":oh_mask,
 "msa_mask":msa_mask,
 "ala_mask":ala_mask,
 "sc_weight":1.0,
 "sc_weight_rmsd":1.0,
 "sc_weight_fape":1.0,
 "sc_weight_dgram":0.0,
 "conf_weight":0.01,
 "pos_idx":pos_idx_}
loss, outs = loss_fn_multi(params, key, model_params_multi, opt)
print(np.mean(loss),
 np.mean(outs["losses"]["rmsd"]),
 np.mean(outs["losses"]["fape"]))

print(outs["losses"]["rmsd"])

5.0831347 4.0395207 0.5736508
[0.32047477 6.103425 3.9624836 5.7717004 ]


In [13]:
LOSS = np.mean(loss)
OVERALL_RMSD = np.mean(outs["losses"]["rmsd"])
OVERALL_FAPE = np.mean(outs["losses"]["fape"])
OVERALL_LOSS = LOSS
key = jax.random.PRNGKey(0)
for n in range(10):
 params_ = params.copy()
 buff_p,buff_l,buff_o = [],[],[]
 for m in range(20):
 key,subkey = jax.random.split(key)
 do_indel = False #np.random.uniform() < 0.25
 p = mut(params, indel=do_indel)
 l,o = loss_fn_multi(p, subkey, model_params_multi, opt)
 print("-----------", m, np.mean(o["losses"]["rmsd"]), list(o["losses"]["rmsd"]))
 buff_p.append(p); buff_l.append(l); buff_o.append(o)
 best = np.argmin(np.asarray(buff_l).mean(-1))
 params, LOSS, outs = buff_p[best], buff_l[best], buff_o[best]
 LOSS = np.mean(LOSS)
 RMSD = np.mean(outs["losses"]["rmsd"])
 FAPE = np.mean(outs["losses"]["fape"])

 outs = jax.tree_map(lambda x: x[0], outs)
 if RMSD < OVERALL_RMSD:
 OVERALL_RMSD = RMSD
 save_pdb(outs,f"{MODE}_best_rmsd.pdb")
 if FAPE < OVERALL_FAPE:
 OVERALL_FAPE = FAPE
 save_pdb(outs,f"{MODE}_best_fape.pdb")
 if LOSS < OVERALL_LOSS:
 OVERALL_LOSS = LOSS
 save_pdb(outs,f"{MODE}_best_loss.pdb")
 print(n, LOSS, RMSD, FAPE, (params["active_pos"] > 0).sum(), len(buff_l))

----------- 0 10.582149 [12.104791, 11.438794, 6.8327885, 11.95222]
----------- 1 5.7060966 [0.34404072, 7.8016953, 6.8321495, 7.8465023]
----------- 2 4.9283195 [0.38857314, 7.33674, 3.9765234, 8.011441]
----------- 3 4.596478 [1.2857033, 8.185179, 3.6270323, 5.2879977]
----------- 4 4.33639 [0.34355134, 7.6549, 3.8751266, 5.471982]
----------- 5 6.8581476 [7.102624, 8.735716, 4.742427, 6.851823]
----------- 6 3.9228725 [0.37756792, 6.210618, 3.7853143, 5.317991]
----------- 7 4.8257957 [0.3333986, 8.438012, 5.8835196, 4.648253]
----------- 8 4.437786 [0.32170418, 8.433742, 4.022462, 4.9732375]
----------- 9 4.3738766 [0.3662424, 8.010605, 3.854046, 5.264613]
----------- 10 5.4476604 [0.3221591, 7.0530643, 5.5048957, 8.910523]
----------- 11 4.311565 [0.33136475, 6.864394, 4.1079164, 5.9425855]
----------- 12 8.771259 [7.6971173, 9.20276, 8.589025, 9.596133]
----------- 13 4.458911 [0.32208133, 8.985687, 4.7378616, 3.7900143]
----------- 14 4.6013637 [0.35995653, 6.008712, 4.1167917, 

In [None]:
for n in range(300):
 params_ = params.copy()
 buff_p,buff_l,buff_o = [],[],[]
 for _ in range(20):
 key,subkey = jax.random.split(key)
 do_indel = (INDELS and np.random.uniform() < 0.25)
 p = mut(params, indel=do_indel)
 l,o = loss_fn_multi(p, subkey, model_params_multi, opt)
 buff_p.append(p); buff_l.append(l); buff_o.append(o)
 if np.mean(l) < LOSS: break
 best = np.argmin(np.asarray(buff_l).mean(-1))
 params, LOSS, outs = buff_p[best], buff_l[best], buff_o[best]
 LOSS = np.mean(LOSS)
 RMSD = np.mean(outs["losses"]["rmsd"])
 FAPE = np.mean(outs["losses"]["fape"])

 outs = jax.tree_map(lambda x: x[0], outs)
 if RMSD < OVERALL_RMSD:
 OVERALL_RMSD = RMSD
 save_pdb(outs,f"{MODE}_best_rmsd.pdb")
 if FAPE < OVERALL_FAPE:
 OVERALL_FAPE = FAPE
 save_pdb(outs,f"{MODE}_best_fape.pdb")
 if LOSS < OVERALL_LOSS:
 OVERALL_LOSS = LOSS
 save_pdb(outs,f"{MODE}_best_loss.pdb")
 l4,o4 = loss_fn(params, subkey, model_params[-1], opt)
 print(n, LOSS, RMSD, FAPE, (params["active_pos"] > 0).sum(), len(buff_l), o4["losses"]["rmsd"])

0 1.6684058 0.97566986 0.33917558 105 13 6.570461
1 1.7213771 1.0266101 0.34318787 105 20 0.5521997
2 1.7309557 1.0141158 0.35105625 105 20 2.0901546
3 1.6896502 0.9780132 0.34791833 105 2 3.075305
4 1.2836819 0.63118124 0.2795934 105 9 3.0850897
5 1.0221133 0.47783357 0.26446223 105 2 0.58060026
6 1.0106819 0.471611 0.2663944 105 17 0.579923
7 0.9707656 0.44111815 0.26466346 105 8 0.5299034
8 0.95611095 0.46036988 0.26778838 105 2 0.4459814
9 0.94456077 0.44426697 0.2588649 105 7 0.51522875
10 0.92737645 0.41471896 0.26362437 105 11 0.4391624
11 0.9707828 0.43690652 0.2637179 105 20 3.0559409
12 1.0268421 0.48600835 0.27266476 105 20 1.6863861
13 1.000179 0.45579082 0.26858282 105 9 1.6001016
14 1.0144405 0.46453598 0.2691497 105 20 3.1118512
15 0.9921475 0.45157805 0.26792496 105 6 3.04773
16 1.0790019 0.51705253 0.27363735 105 20 1.599278
17 0.9422333 0.44149858 0.26067227 105 5 0.49108532
18 0.95633763 0.4589308 0.26509196 105 20 0.46267003
19 0.91339755 0.4367942 0.26139438 105 2 