# Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Modules and code used in the core part of AlphaFold. The structure generation code is in 'folding.py'. """ import functools from alphafold.common import residue_constants from alphafold.model import all_atom from alphafold.model import common_modules from alphafold.model import folding from alphafold.model import layer_stack from alphafold.model import lddt from alphafold.model import mapping from alphafold.model import prng from alphafold.model import quat_affine from alphafold.model import utils import haiku as hk import jax import jax.numpy as jnp from alphafold.model.r3 import Rigids, Rots, Vecs def softmax_cross_entropy(logits, labels): """Computes softmax cross entropy given logits and one-hot class labels.""" loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1) return jnp.asarray(loss) def sigmoid_cross_entropy(logits, labels): """Computes sigmoid cross entropy given logits and multiple class labels.""" log_p = jax.nn.log_sigmoid(logits) # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter is more numerically stable log_not_p = jax.nn.log_sigmoid(-logits) loss = -labels * log_p - (1. - labels) * log_not_p return jnp.asarray(loss) def apply_dropout(*, tensor, safe_key, rate, is_training, broadcast_dim=None): """Applies dropout to a tensor.""" if is_training: # and rate != 0.0: shape = list(tensor.shape) if broadcast_dim is not None: shape[broadcast_dim] = 1 keep_rate = 1.0 - rate keep = jax.random.bernoulli(safe_key.get(), keep_rate, shape=shape) return keep * tensor / keep_rate else: return tensor def dropout_wrapper(module, input_act, mask, safe_key, global_config, output_act=None, is_training=True, scale_rate=1.0, **kwargs): """Applies module + dropout + residual update.""" if output_act is None: output_act = input_act gc = global_config residual = module(input_act, mask, is_training=is_training, **kwargs) dropout_rate = 0.0 if gc.deterministic else module.config.dropout_rate if module.config.shared_dropout: if module.config.orientation == 'per_row': broadcast_dim = 0 else: broadcast_dim = 1 else: broadcast_dim = None residual = apply_dropout(tensor=residual, safe_key=safe_key, rate=dropout_rate * scale_rate, is_training=is_training, broadcast_dim=broadcast_dim) new_act = output_act + residual return new_act def create_extra_msa_feature(batch): """Expand extra_msa into 1hot and concat with other extra msa features. We do this as late as possible as the one_hot extra msa can be very large. Arguments: batch: a dictionary with the following keys: * 'extra_msa': [N_extra_seq, N_res] MSA that wasn't selected as a cluster centre. Note, that this is not one-hot encoded. * 'extra_has_deletion': [N_extra_seq, N_res] Whether there is a deletion to the left of each position in the extra MSA. * 'extra_deletion_value': [N_extra_seq, N_res] The number of deletions to the left of each position in the extra MSA. Returns: Concatenated tensor of extra MSA features. """ # 23 = 20 amino acids + 'X' for unknown + gap + bert mask msa_1hot = jax.nn.one_hot(batch['extra_msa'], 23) msa_feat = [msa_1hot, jnp.expand_dims(batch['extra_has_deletion'], axis=-1), jnp.expand_dims(batch['extra_deletion_value'], axis=-1)] return jnp.concatenate(msa_feat, axis=-1) class AlphaFoldIteration(hk.Module): """A single recycling iteration of AlphaFold architecture. Computes ensembled (averaged) representations from the provided features. These representations are then passed to the various heads that have been requested by the configuration file. Each head also returns a loss which is combined as a weighted sum to produce the total loss. Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 3-22 """ def __init__(self, config, global_config, name='alphafold_iteration'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, ensembled_batch, non_ensembled_batch, is_training, compute_loss=False, ensemble_representations=False, return_representations=False): num_ensemble = jnp.asarray(ensembled_batch['seq_length'].shape[0]) if not ensemble_representations: assert ensembled_batch['seq_length'].shape[0] == 1 def slice_batch(i): b = {k: v[i] for k, v in ensembled_batch.items()} b.update(non_ensembled_batch) return b # Compute representations for each batch element and average. evoformer_module = EmbeddingsAndEvoformer( self.config.embeddings_and_evoformer, self.global_config) batch0 = slice_batch(0) representations = evoformer_module(batch0, is_training) # MSA representations are not ensembled so # we don't pass tensor into the loop. msa_representation = representations['msa'] del representations['msa'] # Average the representations (except MSA) over the batch dimension. if ensemble_representations: def body(x): """Add one element to the representations ensemble.""" i, current_representations = x feats = slice_batch(i) representations_update = evoformer_module( feats, is_training) new_representations = {} for k in current_representations: new_representations[k] = ( current_representations[k] + representations_update[k]) return i+1, new_representations if hk.running_init(): # When initializing the Haiku module, run one iteration of the # while_loop to initialize the Haiku modules used in `body`. _, representations = body((1, representations)) else: _, representations = hk.while_loop( lambda x: x[0] < num_ensemble, body, (1, representations)) for k in representations: if k != 'msa': representations[k] /= num_ensemble.astype(representations[k].dtype) representations['msa'] = msa_representation batch = batch0 # We are not ensembled from here on. if jnp.issubdtype(ensembled_batch['aatype'].dtype, jnp.integer): _, num_residues = ensembled_batch['aatype'].shape else: _, num_residues, _ = ensembled_batch['aatype'].shape if self.config.use_struct: struct_module = folding.StructureModule else: struct_module = folding.dummy heads = {} for head_name, head_config in sorted(self.config.heads.items()): if not head_config.weight: continue # Do not instantiate zero-weight heads. head_factory = { 'masked_msa': MaskedMsaHead, 'distogram': DistogramHead, 'structure_module': functools.partial(struct_module, compute_loss=compute_loss), 'predicted_lddt': PredictedLDDTHead, 'predicted_aligned_error': PredictedAlignedErrorHead, 'experimentally_resolved': ExperimentallyResolvedHead, }[head_name] heads[head_name] = (head_config, head_factory(head_config, self.global_config)) total_loss = 0. ret = {} ret['representations'] = representations def loss(module, head_config, ret, name, filter_ret=True): if filter_ret: value = ret[name] else: value = ret loss_output = module.loss(value, batch) ret[name].update(loss_output) loss = head_config.weight * ret[name]['loss'] return loss for name, (head_config, module) in heads.items(): # Skip PredictedLDDTHead and PredictedAlignedErrorHead until # StructureModule is executed. if name in ('predicted_lddt', 'predicted_aligned_error'): continue else: ret[name] = module(representations, batch, is_training) if 'representations' in ret[name]: # Extra representations from the head. Used by the structure module # to provide activations for the PredictedLDDTHead. representations.update(ret[name].pop('representations')) if compute_loss: total_loss += loss(module, head_config, ret, name) if self.config.use_struct: if self.config.heads.get('predicted_lddt.weight', 0.0): # Add PredictedLDDTHead after StructureModule executes. name = 'predicted_lddt' # Feed all previous results to give access to structure_module result. head_config, module = heads[name] ret[name] = module(representations, batch, is_training) if compute_loss: total_loss += loss(module, head_config, ret, name, filter_ret=False) if ('predicted_aligned_error' in self.config.heads and self.config.heads.get('predicted_aligned_error.weight', 0.0)): # Add PredictedAlignedErrorHead after StructureModule executes. name = 'predicted_aligned_error' # Feed all previous results to give access to structure_module result. head_config, module = heads[name] ret[name] = module(representations, batch, is_training) if compute_loss: total_loss += loss(module, head_config, ret, name, filter_ret=False) if compute_loss: return ret, total_loss else: return ret class AlphaFold(hk.Module): """AlphaFold model with recycling. Jumper et al. (2021) Suppl. Alg. 2 "Inference" """ def __init__(self, config, name='alphafold'): super().__init__(name=name) self.config = config self.global_config = config.global_config def __call__( self, batch, is_training, compute_loss=False, ensemble_representations=False, return_representations=False): """Run the AlphaFold model. Arguments: batch: Dictionary with inputs to the AlphaFold model. is_training: Whether the system is in training or inference mode. compute_loss: Whether to compute losses (requires extra features to be present in the batch and knowing the true structure). ensemble_representations: Whether to use ensembling of representations. return_representations: Whether to also return the intermediate representations. Returns: When compute_loss is True: a tuple of loss and output of AlphaFoldIteration. When compute_loss is False: just output of AlphaFoldIteration. The output of AlphaFoldIteration is a nested dictionary containing predictions from the various heads. """ if "scale_rate" not in batch: batch["scale_rate"] = jnp.ones((1,)) impl = AlphaFoldIteration(self.config, self.global_config) if jnp.issubdtype(batch['aatype'].dtype, jnp.integer): batch_size, num_residues = batch['aatype'].shape else: batch_size, num_residues, _ = batch['aatype'].shape def get_prev(ret): new_prev = { 'prev_msa_first_row': ret['representations']['msa_first_row'], 'prev_pair': ret['representations']['pair'], 'prev_dgram': ret["distogram"]["logits"], } if self.config.use_struct: new_prev.update({'prev_pos': ret['structure_module']['final_atom_positions'], 'prev_plddt': ret["predicted_lddt"]["logits"]}) if "predicted_aligned_error" in ret: new_prev["prev_pae"] = ret["predicted_aligned_error"]["logits"] if not self.config.backprop_recycle: for k in ["prev_pos","prev_msa_first_row","prev_pair"]: if k in new_prev: new_prev[k] = jax.lax.stop_gradient(new_prev[k]) return new_prev def do_call(prev, recycle_idx, compute_loss=compute_loss): if self.config.resample_msa_in_recycling: num_ensemble = batch_size // (self.config.num_recycle + 1) def slice_recycle_idx(x): start = recycle_idx * num_ensemble size = num_ensemble return jax.lax.dynamic_slice_in_dim(x, start, size, axis=0) ensembled_batch = jax.tree_map(slice_recycle_idx, batch) else: num_ensemble = batch_size ensembled_batch = batch non_ensembled_batch = jax.tree_map(lambda x: x, prev) return impl(ensembled_batch=ensembled_batch, non_ensembled_batch=non_ensembled_batch, is_training=is_training, compute_loss=compute_loss, ensemble_representations=ensemble_representations) emb_config = self.config.embeddings_and_evoformer prev = { 'prev_msa_first_row': jnp.zeros([num_residues, emb_config.msa_channel]), 'prev_pair': jnp.zeros([num_residues, num_residues, emb_config.pair_channel]), 'prev_dgram': jnp.zeros([num_residues, num_residues, 64]), } if self.config.use_struct: prev.update({'prev_pos': jnp.zeros([num_residues, residue_constants.atom_type_num, 3]), 'prev_plddt': jnp.zeros([num_residues, 50]), 'prev_pae': jnp.zeros([num_residues, num_residues, 64])}) for k in ["pos","msa_first_row","pair","dgram"]: if f"init_{k}" in batch: prev[f"prev_{k}"] = batch[f"init_{k}"][0] if self.config.num_recycle: if 'num_iter_recycling' in batch: # Training time: num_iter_recycling is in batch. # The value for each ensemble batch is the same, so arbitrarily taking # 0-th. num_iter = batch['num_iter_recycling'][0] # Add insurance that we will not run more # recyclings than the model is configured to run. num_iter = jnp.minimum(num_iter, self.config.num_recycle) else: # Eval mode or tests: use the maximum number of iterations. num_iter = self.config.num_recycle def add_prev(p,p_): p_["prev_dgram"] += p["prev_dgram"] if self.config.use_struct: p_["prev_plddt"] += p["prev_plddt"] p_["prev_pae"] += p["prev_pae"] return p_ ############################################################## def body(p, i): p_ = get_prev(do_call(p, recycle_idx=i, compute_loss=False)) if self.config.add_prev: p_ = add_prev(p, p_) return p_, None if hk.running_init(): prev,_ = body(prev, 0) else: prev,_ = hk.scan(body, prev, jnp.arange(num_iter)) ############################################################## else: num_iter = 0 ret = do_call(prev=prev, recycle_idx=num_iter) if self.config.add_prev: prev_ = get_prev(ret) if compute_loss: ret = ret[0], [ret[1]] if not return_representations: del (ret[0] if compute_loss else ret)['representations'] # pytype: disable=unsupported-operands if self.config.add_prev and num_iter > 0: prev_ = add_prev(prev, prev_) ret["distogram"]["logits"] = prev_["prev_dgram"]/(num_iter+1) if self.config.use_struct: ret["predicted_lddt"]["logits"] = prev_["prev_plddt"]/(num_iter+1) if "predicted_aligned_error" in ret: ret["predicted_aligned_error"]["logits"] = prev_["prev_pae"]/(num_iter+1) return ret class TemplatePairStack(hk.Module): """Pair stack for the templates. Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack" """ def __init__(self, config, global_config, name='template_pair_stack'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, pair_act, pair_mask, is_training, safe_key=None, scale_rate=1.0): """Builds TemplatePairStack module. Arguments: pair_act: Pair activations for single template, shape [N_res, N_res, c_t]. pair_mask: Pair mask, shape [N_res, N_res]. is_training: Whether the module is in training mode. safe_key: Safe key object encapsulating the random number generation key. Returns: Updated pair_act, shape [N_res, N_res, c_t]. """ if safe_key is None: safe_key = prng.SafeKey(hk.next_rng_key()) gc = self.global_config c = self.config if not c.num_block: return pair_act def block(x): """One block of the template pair stack.""" pair_act, safe_key = x dropout_wrapper_fn = functools.partial( dropout_wrapper, is_training=is_training, global_config=gc, scale_rate=scale_rate) safe_key, *sub_keys = safe_key.split(6) sub_keys = iter(sub_keys) pair_act = dropout_wrapper_fn( TriangleAttention(c.triangle_attention_starting_node, gc, name='triangle_attention_starting_node'), pair_act, pair_mask, next(sub_keys)) pair_act = dropout_wrapper_fn( TriangleAttention(c.triangle_attention_ending_node, gc, name='triangle_attention_ending_node'), pair_act, pair_mask, next(sub_keys)) pair_act = dropout_wrapper_fn( TriangleMultiplication(c.triangle_multiplication_outgoing, gc, name='triangle_multiplication_outgoing'), pair_act, pair_mask, next(sub_keys)) pair_act = dropout_wrapper_fn( TriangleMultiplication(c.triangle_multiplication_incoming, gc, name='triangle_multiplication_incoming'), pair_act, pair_mask, next(sub_keys)) pair_act = dropout_wrapper_fn( Transition(c.pair_transition, gc, name='pair_transition'), pair_act, pair_mask, next(sub_keys)) return pair_act, safe_key if gc.use_remat: block = hk.remat(block) res_stack = layer_stack.layer_stack(c.num_block)(block) pair_act, safe_key = res_stack((pair_act, safe_key)) return pair_act class Transition(hk.Module): """Transition layer. Jumper et al. (2021) Suppl. Alg. 9 "MSATransition" Jumper et al. (2021) Suppl. Alg. 15 "PairTransition" """ def __init__(self, config, global_config, name='transition_block'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, act, mask, is_training=True): """Builds Transition module. Arguments: act: A tensor of queries of size [batch_size, N_res, N_channel]. mask: A tensor denoting the mask of size [batch_size, N_res]. is_training: Whether the module is in training mode. Returns: A float32 tensor of size [batch_size, N_res, N_channel]. """ _, _, nc = act.shape num_intermediate = int(nc * self.config.num_intermediate_factor) mask = jnp.expand_dims(mask, axis=-1) act = hk.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='input_layer_norm')( act) transition_module = hk.Sequential([ common_modules.Linear( num_intermediate, initializer='relu', name='transition1'), jax.nn.relu, common_modules.Linear( nc, initializer=utils.final_init(self.global_config), name='transition2') ]) act = mapping.inference_subbatch( transition_module, self.global_config.subbatch_size, batched_args=[act], nonbatched_args=[], low_memory=not is_training) return act def glorot_uniform(): return hk.initializers.VarianceScaling(scale=1.0, mode='fan_avg', distribution='uniform') class Attention(hk.Module): """Multihead attention.""" def __init__(self, config, global_config, output_dim, name='attention'): super().__init__(name=name) self.config = config self.global_config = global_config self.output_dim = output_dim def __call__(self, q_data, m_data, bias, nonbatched_bias=None): """Builds Attention module. Arguments: q_data: A tensor of queries, shape [batch_size, N_queries, q_channels]. m_data: A tensor of memories from which the keys and values are projected, shape [batch_size, N_keys, m_channels]. bias: A bias for the attention, shape [batch_size, N_queries, N_keys]. nonbatched_bias: Shared bias, shape [N_queries, N_keys]. Returns: A float32 tensor of shape [batch_size, N_queries, output_dim]. """ # Sensible default for when the config keys are missing key_dim = self.config.get('key_dim', int(q_data.shape[-1])) value_dim = self.config.get('value_dim', int(m_data.shape[-1])) num_head = self.config.num_head assert key_dim % num_head == 0 assert value_dim % num_head == 0 key_dim = key_dim // num_head value_dim = value_dim // num_head q_weights = hk.get_parameter( 'query_w', shape=(q_data.shape[-1], num_head, key_dim), init=glorot_uniform()) k_weights = hk.get_parameter( 'key_w', shape=(m_data.shape[-1], num_head, key_dim), init=glorot_uniform()) v_weights = hk.get_parameter( 'value_w', shape=(m_data.shape[-1], num_head, value_dim), init=glorot_uniform()) q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5) k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights) v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights) logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) + bias if nonbatched_bias is not None: logits += jnp.expand_dims(nonbatched_bias, axis=0) weights = jax.nn.softmax(logits) weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v) if self.global_config.zero_init: init = hk.initializers.Constant(0.0) else: init = glorot_uniform() if self.config.gating: gating_weights = hk.get_parameter( 'gating_w', shape=(q_data.shape[-1], num_head, value_dim), init=hk.initializers.Constant(0.0)) gating_bias = hk.get_parameter( 'gating_b', shape=(num_head, value_dim), init=hk.initializers.Constant(1.0)) gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights) + gating_bias gate_values = jax.nn.sigmoid(gate_values) weighted_avg *= gate_values o_weights = hk.get_parameter( 'output_w', shape=(num_head, value_dim, self.output_dim), init=init) o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), init=hk.initializers.Constant(0.0)) output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias return output class GlobalAttention(hk.Module): """Global attention. Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" lines 2-7 """ def __init__(self, config, global_config, output_dim, name='attention'): super().__init__(name=name) self.config = config self.global_config = global_config self.output_dim = output_dim def __call__(self, q_data, m_data, q_mask, bias): """Builds GlobalAttention module. Arguments: q_data: A tensor of queries with size [batch_size, N_queries, q_channels] m_data: A tensor of memories from which the keys and values projected. Size [batch_size, N_keys, m_channels] q_mask: A binary mask for q_data with zeros in the padded sequence elements and ones otherwise. Size [batch_size, N_queries, q_channels] (or broadcastable to this shape). bias: A bias for the attention. Returns: A float32 tensor of size [batch_size, N_queries, output_dim]. """ # Sensible default for when the config keys are missing key_dim = self.config.get('key_dim', int(q_data.shape[-1])) value_dim = self.config.get('value_dim', int(m_data.shape[-1])) num_head = self.config.num_head assert key_dim % num_head == 0 assert value_dim % num_head == 0 key_dim = key_dim // num_head value_dim = value_dim // num_head q_weights = hk.get_parameter( 'query_w', shape=(q_data.shape[-1], num_head, key_dim), init=glorot_uniform()) k_weights = hk.get_parameter( 'key_w', shape=(m_data.shape[-1], key_dim), init=glorot_uniform()) v_weights = hk.get_parameter( 'value_w', shape=(m_data.shape[-1], value_dim), init=glorot_uniform()) v = jnp.einsum('bka,ac->bkc', m_data, v_weights) q_avg = utils.mask_mean(q_mask, q_data, axis=1) q = jnp.einsum('ba,ahc->bhc', q_avg, q_weights) * key_dim**(-0.5) k = jnp.einsum('bka,ac->bkc', m_data, k_weights) bias = (1e9 * (q_mask[:, None, :, 0] - 1.)) logits = jnp.einsum('bhc,bkc->bhk', q, k) + bias weights = jax.nn.softmax(logits) weighted_avg = jnp.einsum('bhk,bkc->bhc', weights, v) if self.global_config.zero_init: init = hk.initializers.Constant(0.0) else: init = glorot_uniform() o_weights = hk.get_parameter( 'output_w', shape=(num_head, value_dim, self.output_dim), init=init) o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), init=hk.initializers.Constant(0.0)) if self.config.gating: gating_weights = hk.get_parameter( 'gating_w', shape=(q_data.shape[-1], num_head, value_dim), init=hk.initializers.Constant(0.0)) gating_bias = hk.get_parameter( 'gating_b', shape=(num_head, value_dim), init=hk.initializers.Constant(1.0)) gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights) gate_values = jax.nn.sigmoid(gate_values + gating_bias) weighted_avg = weighted_avg[:, None] * gate_values output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias else: output = jnp.einsum('bhc,hco->bo', weighted_avg, o_weights) + o_bias output = output[:, None] return output class MSARowAttentionWithPairBias(hk.Module): """MSA per-row attention biased by the pair representation. Jumper et al. (2021) Suppl. Alg. 7 "MSARowAttentionWithPairBias" """ def __init__(self, config, global_config, name='msa_row_attention_with_pair_bias'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, msa_act, msa_mask, pair_act, is_training=False): """Builds MSARowAttentionWithPairBias module. Arguments: msa_act: [N_seq, N_res, c_m] MSA representation. msa_mask: [N_seq, N_res] mask of non-padded regions. pair_act: [N_res, N_res, c_z] pair representation. is_training: Whether the module is in training mode. Returns: Update to msa_act, shape [N_seq, N_res, c_m]. """ c = self.config assert len(msa_act.shape) == 3 assert len(msa_mask.shape) == 2 assert c.orientation == 'per_row' bias = (1e9 * (msa_mask - 1.))[:, None, None, :] assert len(bias.shape) == 4 msa_act = hk.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='query_norm')( msa_act) pair_act = hk.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='feat_2d_norm')( pair_act) init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1])) weights = hk.get_parameter( 'feat_2d_weights', shape=(pair_act.shape[-1], c.num_head), init=hk.initializers.RandomNormal(stddev=init_factor)) nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) attn_mod = Attention( c, self.global_config, msa_act.shape[-1]) msa_act = mapping.inference_subbatch( attn_mod, self.global_config.subbatch_size, batched_args=[msa_act, msa_act, bias], nonbatched_args=[nonbatched_bias], low_memory=not is_training) return msa_act class MSAColumnAttention(hk.Module): """MSA per-column attention. Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention" """ def __init__(self, config, global_config, name='msa_column_attention'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, msa_act, msa_mask, is_training=False): """Builds MSAColumnAttention module. Arguments: msa_act: [N_seq, N_res, c_m] MSA representation. msa_mask: [N_seq, N_res] mask of non-padded regions. is_training: Whether the module is in training mode. Returns: Update to msa_act, shape [N_seq, N_res, c_m] """ c = self.config assert len(msa_act.shape) == 3 assert len(msa_mask.shape) == 2 assert c.orientation == 'per_column' msa_act = jnp.swapaxes(msa_act, -2, -3) msa_mask = jnp.swapaxes(msa_mask, -1, -2) bias = (1e9 * (msa_mask - 1.))[:, None, None, :] assert len(bias.shape) == 4 msa_act = hk.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='query_norm')( msa_act) attn_mod = Attention( c, self.global_config, msa_act.shape[-1]) msa_act = mapping.inference_subbatch( attn_mod, self.global_config.subbatch_size, batched_args=[msa_act, msa_act, bias], nonbatched_args=[], low_memory=not is_training) msa_act = jnp.swapaxes(msa_act, -2, -3) return msa_act class MSAColumnGlobalAttention(hk.Module): """MSA per-column global attention. Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" """ def __init__(self, config, global_config, name='msa_column_global_attention'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, msa_act, msa_mask, is_training=False): """Builds MSAColumnGlobalAttention module. Arguments: msa_act: [N_seq, N_res, c_m] MSA representation. msa_mask: [N_seq, N_res] mask of non-padded regions. is_training: Whether the module is in training mode. Returns: Update to msa_act, shape [N_seq, N_res, c_m]. """ c = self.config assert len(msa_act.shape) == 3 assert len(msa_mask.shape) == 2 assert c.orientation == 'per_column' msa_act = jnp.swapaxes(msa_act, -2, -3) msa_mask = jnp.swapaxes(msa_mask, -1, -2) bias = (1e9 * (msa_mask - 1.))[:, None, None, :] assert len(bias.shape) == 4 msa_act = hk.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='query_norm')( msa_act) attn_mod = GlobalAttention( c, self.global_config, msa_act.shape[-1], name='attention') # [N_seq, N_res, 1] msa_mask = jnp.expand_dims(msa_mask, axis=-1) msa_act = mapping.inference_subbatch( attn_mod, self.global_config.subbatch_size, batched_args=[msa_act, msa_act, msa_mask, bias], nonbatched_args=[], low_memory=not is_training) msa_act = jnp.swapaxes(msa_act, -2, -3) return msa_act class TriangleAttention(hk.Module): """Triangle Attention. Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode" Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode" """ def __init__(self, config, global_config, name='triangle_attention'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, pair_act, pair_mask, is_training=False): """Builds TriangleAttention module. Arguments: pair_act: [N_res, N_res, c_z] pair activations tensor pair_mask: [N_res, N_res] mask of non-padded regions in the tensor. is_training: Whether the module is in training mode. Returns: Update to pair_act, shape [N_res, N_res, c_z]. """ c = self.config assert len(pair_act.shape) == 3 assert len(pair_mask.shape) == 2 assert c.orientation in ['per_row', 'per_column'] if c.orientation == 'per_column': pair_act = jnp.swapaxes(pair_act, -2, -3) pair_mask = jnp.swapaxes(pair_mask, -1, -2) bias = (1e9 * (pair_mask - 1.))[:, None, None, :] assert len(bias.shape) == 4 pair_act = hk.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='query_norm')( pair_act) init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1])) weights = hk.get_parameter( 'feat_2d_weights', shape=(pair_act.shape[-1], c.num_head), init=hk.initializers.RandomNormal(stddev=init_factor)) nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) attn_mod = Attention( c, self.global_config, pair_act.shape[-1]) pair_act = mapping.inference_subbatch( attn_mod, self.global_config.subbatch_size, batched_args=[pair_act, pair_act, bias], nonbatched_args=[nonbatched_bias], low_memory=not is_training) if c.orientation == 'per_column': pair_act = jnp.swapaxes(pair_act, -2, -3) return pair_act class MaskedMsaHead(hk.Module): """Head to predict MSA at the masked locations. The MaskedMsaHead employs a BERT-style objective to reconstruct a masked version of the full MSA, based on a linear projection of the MSA representation. Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction" """ def __init__(self, config, global_config, name='masked_msa_head'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, representations, batch, is_training): """Builds MaskedMsaHead module. Arguments: representations: Dictionary of representations, must contain: * 'msa': MSA representation, shape [N_seq, N_res, c_m]. batch: Batch, unused. is_training: Whether the module is in training mode. Returns: Dictionary containing: * 'logits': logits of shape [N_seq, N_res, N_aatype] with (unnormalized) log probabilies of predicted aatype at position. """ del batch logits = common_modules.Linear( self.config.num_output, initializer=utils.final_init(self.global_config), name='logits')( representations['msa']) return dict(logits=logits) def loss(self, value, batch): errors = softmax_cross_entropy( labels=jax.nn.one_hot(batch['true_msa'], num_classes=23), logits=value['logits']) loss = (jnp.sum(errors * batch['bert_mask'], axis=(-2, -1)) / (1e-8 + jnp.sum(batch['bert_mask'], axis=(-2, -1)))) return {'loss': loss} class PredictedLDDTHead(hk.Module): """Head to predict the per-residue LDDT to be used as a confidence measure. Jumper et al. (2021) Suppl. Sec. 1.9.6 "Model confidence prediction (pLDDT)" Jumper et al. (2021) Suppl. Alg. 29 "predictPerResidueLDDT_Ca" """ def __init__(self, config, global_config, name='predicted_lddt_head'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, representations, batch, is_training): """Builds ExperimentallyResolvedHead module. Arguments: representations: Dictionary of representations, must contain: * 'structure_module': Single representation from the structure module, shape [N_res, c_s]. batch: Batch, unused. is_training: Whether the module is in training mode. Returns: Dictionary containing : * 'logits': logits of shape [N_res, N_bins] with (unnormalized) log probabilies of binned predicted lDDT. """ act = representations['structure_module'] act = hk.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='input_layer_norm')( act) act = common_modules.Linear( self.config.num_channels, initializer='relu', name='act_0')( act) act = jax.nn.relu(act) act = common_modules.Linear( self.config.num_channels, initializer='relu', name='act_1')( act) act = jax.nn.relu(act) logits = common_modules.Linear( self.config.num_bins, initializer=utils.final_init(self.global_config), name='logits')( act) # Shape (batch_size, num_res, num_bins) return dict(logits=logits) def loss(self, value, batch): # Shape (num_res, 37, 3) pred_all_atom_pos = value['structure_module']['final_atom_positions'] # Shape (num_res, 37, 3) true_all_atom_pos = batch['all_atom_positions'] # Shape (num_res, 37) all_atom_mask = batch['all_atom_mask'] # Shape (num_res,) lddt_ca = lddt.lddt( # Shape (batch_size, num_res, 3) predicted_points=pred_all_atom_pos[None, :, 1, :], # Shape (batch_size, num_res, 3) true_points=true_all_atom_pos[None, :, 1, :], # Shape (batch_size, num_res, 1) true_points_mask=all_atom_mask[None, :, 1:2].astype(jnp.float32), cutoff=15., per_residue=True)[0] lddt_ca = jax.lax.stop_gradient(lddt_ca) num_bins = self.config.num_bins bin_index = jnp.floor(lddt_ca * num_bins).astype(jnp.int32) # protect against out of range for lddt_ca == 1 bin_index = jnp.minimum(bin_index, num_bins - 1) lddt_ca_one_hot = jax.nn.one_hot(bin_index, num_classes=num_bins) # Shape (num_res, num_channel) logits = value['predicted_lddt']['logits'] errors = softmax_cross_entropy(labels=lddt_ca_one_hot, logits=logits) # Shape (num_res,) mask_ca = all_atom_mask[:, residue_constants.atom_order['CA']] mask_ca = mask_ca.astype(jnp.float32) loss = jnp.sum(errors * mask_ca) / (jnp.sum(mask_ca) + 1e-8) if self.config.filter_by_resolution: # NMR & distillation have resolution = 0 loss *= ((batch['resolution'] >= self.config.min_resolution) & (batch['resolution'] <= self.config.max_resolution)).astype( jnp.float32) output = {'loss': loss} return output class PredictedAlignedErrorHead(hk.Module): """Head to predict the distance errors in the backbone alignment frames. Can be used to compute predicted TM-Score. Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction" """ def __init__(self, config, global_config, name='predicted_aligned_error_head'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, representations, batch, is_training): """Builds PredictedAlignedErrorHead module. Arguments: representations: Dictionary of representations, must contain: * 'pair': pair representation, shape [N_res, N_res, c_z]. batch: Batch, unused. is_training: Whether the module is in training mode. Returns: Dictionary containing: * logits: logits for aligned error, shape [N_res, N_res, N_bins]. * bin_breaks: array containing bin breaks, shape [N_bins - 1]. """ act = representations['pair'] # Shape (num_res, num_res, num_bins) logits = common_modules.Linear( self.config.num_bins, initializer=utils.final_init(self.global_config), name='logits')(act) # Shape (num_bins,) breaks = jnp.linspace( 0., self.config.max_error_bin, self.config.num_bins - 1) return dict(logits=logits, breaks=breaks) def loss(self, value, batch): # Shape (num_res, 7) predicted_affine = quat_affine.QuatAffine.from_tensor( value['structure_module']['final_affines']) # Shape (num_res, 7) true_affine = quat_affine.QuatAffine.from_tensor( batch['backbone_affine_tensor']) # Shape (num_res) mask = batch['backbone_affine_mask'] # Shape (num_res, num_res) square_mask = mask[:, None] * mask[None, :] num_bins = self.config.num_bins # (1, num_bins - 1) breaks = value['predicted_aligned_error']['breaks'] # (1, num_bins) logits = value['predicted_aligned_error']['logits'] # Compute the squared error for each alignment. def _local_frame_points(affine): points = [jnp.expand_dims(x, axis=-2) for x in affine.translation] return affine.invert_point(points, extra_dims=1) error_dist2_xyz = [ jnp.square(a - b) for a, b in zip(_local_frame_points(predicted_affine), _local_frame_points(true_affine))] error_dist2 = sum(error_dist2_xyz) # Shape (num_res, num_res) # First num_res are alignment frames, second num_res are the residues. error_dist2 = jax.lax.stop_gradient(error_dist2) sq_breaks = jnp.square(breaks) true_bins = jnp.sum(( error_dist2[..., None] > sq_breaks).astype(jnp.int32), axis=-1) errors = softmax_cross_entropy( labels=jax.nn.one_hot(true_bins, num_bins, axis=-1), logits=logits) loss = (jnp.sum(errors * square_mask, axis=(-2, -1)) / (1e-8 + jnp.sum(square_mask, axis=(-2, -1)))) if self.config.filter_by_resolution: # NMR & distillation have resolution = 0 loss *= ((batch['resolution'] >= self.config.min_resolution) & (batch['resolution'] <= self.config.max_resolution)).astype( jnp.float32) output = {'loss': loss} return output class ExperimentallyResolvedHead(hk.Module): """Predicts if an atom is experimentally resolved in a high-res structure. Only trained on high-resolution X-ray crystals & cryo-EM. Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction' """ def __init__(self, config, global_config, name='experimentally_resolved_head'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, representations, batch, is_training): """Builds ExperimentallyResolvedHead module. Arguments: representations: Dictionary of representations, must contain: * 'single': Single representation, shape [N_res, c_s]. batch: Batch, unused. is_training: Whether the module is in training mode. Returns: Dictionary containing: * 'logits': logits of shape [N_res, 37], log probability that an atom is resolved in atom37 representation, can be converted to probability by applying sigmoid. """ logits = common_modules.Linear( 37, # atom_exists.shape[-1] initializer=utils.final_init(self.global_config), name='logits')(representations['single']) return dict(logits=logits) def loss(self, value, batch): logits = value['logits'] assert len(logits.shape) == 2 # Does the atom appear in the amino acid? atom_exists = batch['atom37_atom_exists'] # Is the atom resolved in the experiment? Subset of atom_exists, # *except for OXT* all_atom_mask = batch['all_atom_mask'].astype(jnp.float32) xent = sigmoid_cross_entropy(labels=all_atom_mask, logits=logits) loss = jnp.sum(xent * atom_exists) / (1e-8 + jnp.sum(atom_exists)) if self.config.filter_by_resolution: # NMR & distillation examples have resolution = 0. loss *= ((batch['resolution'] >= self.config.min_resolution) & (batch['resolution'] <= self.config.max_resolution)).astype( jnp.float32) output = {'loss': loss} return output class TriangleMultiplication(hk.Module): """Triangle multiplication layer ("outgoing" or "incoming"). Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing" Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming" """ def __init__(self, config, global_config, name='triangle_multiplication'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, act, mask, is_training=True): """Builds TriangleMultiplication module. Arguments: act: Pair activations, shape [N_res, N_res, c_z] mask: Pair mask, shape [N_res, N_res]. is_training: Whether the module is in training mode. Returns: Outputs, same shape/type as act. """ del is_training c = self.config gc = self.global_config mask = mask[..., None] act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True, name='layer_norm_input')(act) input_act = act left_projection = common_modules.Linear( c.num_intermediate_channel, name='left_projection') left_proj_act = mask * left_projection(act) right_projection = common_modules.Linear( c.num_intermediate_channel, name='right_projection') right_proj_act = mask * right_projection(act) left_gate_values = jax.nn.sigmoid(common_modules.Linear( c.num_intermediate_channel, bias_init=1., initializer=utils.final_init(gc), name='left_gate')(act)) right_gate_values = jax.nn.sigmoid(common_modules.Linear( c.num_intermediate_channel, bias_init=1., initializer=utils.final_init(gc), name='right_gate')(act)) left_proj_act *= left_gate_values right_proj_act *= right_gate_values # "Outgoing" edges equation: 'ikc,jkc->ijc' # "Incoming" edges equation: 'kjc,kic->ijc' # Note on the Suppl. Alg. 11 & 12 notation: # For the "outgoing" edges, a = left_proj_act and b = right_proj_act # For the "incoming" edges, it's swapped: # b = left_proj_act and a = right_proj_act act = jnp.einsum(c.equation, left_proj_act, right_proj_act) act = hk.LayerNorm( axis=[-1], create_scale=True, create_offset=True, name='center_layer_norm')( act) output_channel = int(input_act.shape[-1]) act = common_modules.Linear( output_channel, initializer=utils.final_init(gc), name='output_projection')(act) gate_values = jax.nn.sigmoid(common_modules.Linear( output_channel, bias_init=1., initializer=utils.final_init(gc), name='gating_linear')(input_act)) act *= gate_values return act class DistogramHead(hk.Module): """Head to predict a distogram. Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction" """ def __init__(self, config, global_config, name='distogram_head'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, representations, batch, is_training): """Builds DistogramHead module. Arguments: representations: Dictionary of representations, must contain: * 'pair': pair representation, shape [N_res, N_res, c_z]. batch: Batch, unused. is_training: Whether the module is in training mode. Returns: Dictionary containing: * logits: logits for distogram, shape [N_res, N_res, N_bins]. * bin_breaks: array containing bin breaks, shape [N_bins - 1,]. """ half_logits = common_modules.Linear( self.config.num_bins, initializer=utils.final_init(self.global_config), name='half_logits')( representations['pair']) logits = half_logits + jnp.swapaxes(half_logits, -2, -3) breaks = jnp.linspace(self.config.first_break, self.config.last_break, self.config.num_bins - 1) return dict(logits=logits, bin_edges=breaks) def loss(self, value, batch): return _distogram_log_loss(value['logits'], value['bin_edges'], batch, self.config.num_bins) def _distogram_log_loss(logits, bin_edges, batch, num_bins): """Log loss of a distogram.""" assert len(logits.shape) == 3 positions = batch['pseudo_beta'] mask = batch['pseudo_beta_mask'] assert positions.shape[-1] == 3 sq_breaks = jnp.square(bin_edges) dist2 = jnp.sum( jnp.square( jnp.expand_dims(positions, axis=-2) - jnp.expand_dims(positions, axis=-3)), axis=-1, keepdims=True) true_bins = jnp.sum(dist2 > sq_breaks, axis=-1) errors = softmax_cross_entropy( labels=jax.nn.one_hot(true_bins, num_bins), logits=logits) square_mask = jnp.expand_dims(mask, axis=-2) * jnp.expand_dims(mask, axis=-1) avg_error = ( jnp.sum(errors * square_mask, axis=(-2, -1)) / (1e-6 + jnp.sum(square_mask, axis=(-2, -1)))) dist2 = dist2[..., 0] return dict(loss=avg_error, true_dist=jnp.sqrt(1e-6 + dist2)) class OuterProductMean(hk.Module): """Computes mean outer product. Jumper et al. (2021) Suppl. Alg. 10 "OuterProductMean" """ def __init__(self, config, global_config, num_output_channel, name='outer_product_mean'): super().__init__(name=name) self.global_config = global_config self.config = config self.num_output_channel = num_output_channel def __call__(self, act, mask, is_training=True): """Builds OuterProductMean module. Arguments: act: MSA representation, shape [N_seq, N_res, c_m]. mask: MSA mask, shape [N_seq, N_res]. is_training: Whether the module is in training mode. Returns: Update to pair representation, shape [N_res, N_res, c_z]. """ gc = self.global_config c = self.config mask = mask[..., None] act = hk.LayerNorm([-1], True, True, name='layer_norm_input')(act) left_act = mask * common_modules.Linear( c.num_outer_channel, initializer='linear', name='left_projection')( act) right_act = mask * common_modules.Linear( c.num_outer_channel, initializer='linear', name='right_projection')( act) if gc.zero_init: init_w = hk.initializers.Constant(0.0) else: init_w = hk.initializers.VarianceScaling(scale=2., mode='fan_in') output_w = hk.get_parameter( 'output_w', shape=(c.num_outer_channel, c.num_outer_channel, self.num_output_channel), init=init_w) output_b = hk.get_parameter( 'output_b', shape=(self.num_output_channel,), init=hk.initializers.Constant(0.0)) def compute_chunk(left_act): # This is equivalent to # # act = jnp.einsum('abc,ade->dceb', left_act, right_act) # act = jnp.einsum('dceb,cef->bdf', act, output_w) + output_b # # but faster. left_act = jnp.transpose(left_act, [0, 2, 1]) act = jnp.einsum('acb,ade->dceb', left_act, right_act) act = jnp.einsum('dceb,cef->dbf', act, output_w) + output_b return jnp.transpose(act, [1, 0, 2]) act = mapping.inference_subbatch( compute_chunk, c.chunk_size, batched_args=[left_act], nonbatched_args=[], low_memory=True, input_subbatch_dim=1, output_subbatch_dim=0) epsilon = 1e-3 norm = jnp.einsum('abc,adc->bdc', mask, mask) act /= epsilon + norm return act def dgram_from_positions(positions, num_bins, min_bin, max_bin): """Compute distogram from amino acid positions. Arguments: positions: [N_res, 3] Position coordinates. num_bins: The number of bins in the distogram. min_bin: The left edge of the first bin. max_bin: The left edge of the final bin. The final bin catches everything larger than `max_bin`. Returns: Distogram with the specified number of bins. """ def squared_difference(x, y): return jnp.square(x - y) lower_breaks = jnp.linspace(min_bin, max_bin, num_bins) lower_breaks = jnp.square(lower_breaks) upper_breaks = jnp.concatenate([lower_breaks[1:],jnp.array([1e8], dtype=jnp.float32)], axis=-1) dist2 = jnp.sum( squared_difference( jnp.expand_dims(positions, axis=-2), jnp.expand_dims(positions, axis=-3)), axis=-1, keepdims=True) return ((dist2 > lower_breaks).astype(jnp.float32) * (dist2 < upper_breaks).astype(jnp.float32)) def dgram_from_positions_soft(positions, num_bins, min_bin, max_bin, temp=2.0): '''soft positions to dgram converter''' lower_breaks = jnp.append(-1e8,jnp.linspace(min_bin, max_bin, num_bins)) upper_breaks = jnp.append(lower_breaks[1:],1e8) dist = jnp.sqrt(jnp.square(positions[...,:,None,:] - positions[...,None,:,:]).sum(-1,keepdims=True) + 1e-8) o = jax.nn.sigmoid((dist - lower_breaks)/temp) * jax.nn.sigmoid((upper_breaks - dist)/temp) o = o/(o.sum(-1,keepdims=True) + 1e-8) return o[...,1:] def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): """Create pseudo beta features.""" ca_idx = residue_constants.atom_order['CA'] cb_idx = residue_constants.atom_order['CB'] if jnp.issubdtype(aatype.dtype, jnp.integer): is_gly = jnp.equal(aatype, residue_constants.restype_order['G']) is_gly_tile = jnp.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]) pseudo_beta = jnp.where(is_gly_tile, all_atom_positions[..., ca_idx, :], all_atom_positions[..., cb_idx, :]) if all_atom_masks is not None: pseudo_beta_mask = jnp.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx]) pseudo_beta_mask = pseudo_beta_mask.astype(jnp.float32) return pseudo_beta, pseudo_beta_mask else: return pseudo_beta else: is_gly = aatype[...,residue_constants.restype_order['G']] ca_pos = all_atom_positions[...,ca_idx,:] cb_pos = all_atom_positions[...,cb_idx,:] pseudo_beta = is_gly[...,None] * ca_pos + (1-is_gly[...,None]) * cb_pos if all_atom_masks is not None: ca_mask = all_atom_masks[...,ca_idx] cb_mask = all_atom_masks[...,cb_idx] pseudo_beta_mask = is_gly * ca_mask + (1-is_gly) * cb_mask return pseudo_beta, pseudo_beta_mask else: return pseudo_beta class EvoformerIteration(hk.Module): """Single iteration (block) of Evoformer stack. Jumper et al. (2021) Suppl. Alg. 6 "EvoformerStack" lines 2-10 """ def __init__(self, config, global_config, is_extra_msa, name='evoformer_iteration'): super().__init__(name=name) self.config = config self.global_config = global_config self.is_extra_msa = is_extra_msa def __call__(self, activations, masks, is_training=True, safe_key=None, scale_rate=1.0): """Builds EvoformerIteration module. Arguments: activations: Dictionary containing activations: * 'msa': MSA activations, shape [N_seq, N_res, c_m]. * 'pair': pair activations, shape [N_res, N_res, c_z]. masks: Dictionary of masks: * 'msa': MSA mask, shape [N_seq, N_res]. * 'pair': pair mask, shape [N_res, N_res]. is_training: Whether the module is in training mode. safe_key: prng.SafeKey encapsulating rng key. Returns: Outputs, same shape/type as act. """ c = self.config gc = self.global_config msa_act, pair_act = activations['msa'], activations['pair'] if safe_key is None: safe_key = prng.SafeKey(hk.next_rng_key()) msa_mask, pair_mask = masks['msa'], masks['pair'] dropout_wrapper_fn = functools.partial( dropout_wrapper, is_training=is_training, global_config=gc, scale_rate=scale_rate) safe_key, *sub_keys = safe_key.split(10) sub_keys = iter(sub_keys) msa_act = dropout_wrapper_fn( MSARowAttentionWithPairBias( c.msa_row_attention_with_pair_bias, gc, name='msa_row_attention_with_pair_bias'), msa_act, msa_mask, safe_key=next(sub_keys), pair_act=pair_act) if not self.is_extra_msa: attn_mod = MSAColumnAttention( c.msa_column_attention, gc, name='msa_column_attention') else: attn_mod = MSAColumnGlobalAttention( c.msa_column_attention, gc, name='msa_column_global_attention') msa_act = dropout_wrapper_fn( attn_mod, msa_act, msa_mask, safe_key=next(sub_keys)) msa_act = dropout_wrapper_fn( Transition(c.msa_transition, gc, name='msa_transition'), msa_act, msa_mask, safe_key=next(sub_keys)) pair_act = dropout_wrapper_fn( OuterProductMean( config=c.outer_product_mean, global_config=self.global_config, num_output_channel=int(pair_act.shape[-1]), name='outer_product_mean'), msa_act, msa_mask, safe_key=next(sub_keys), output_act=pair_act) pair_act = dropout_wrapper_fn( TriangleMultiplication(c.triangle_multiplication_outgoing, gc, name='triangle_multiplication_outgoing'), pair_act, pair_mask, safe_key=next(sub_keys)) pair_act = dropout_wrapper_fn( TriangleMultiplication(c.triangle_multiplication_incoming, gc, name='triangle_multiplication_incoming'), pair_act, pair_mask, safe_key=next(sub_keys)) pair_act = dropout_wrapper_fn( TriangleAttention(c.triangle_attention_starting_node, gc, name='triangle_attention_starting_node'), pair_act, pair_mask, safe_key=next(sub_keys)) pair_act = dropout_wrapper_fn( TriangleAttention(c.triangle_attention_ending_node, gc, name='triangle_attention_ending_node'), pair_act, pair_mask, safe_key=next(sub_keys)) pair_act = dropout_wrapper_fn( Transition(c.pair_transition, gc, name='pair_transition'), pair_act, pair_mask, safe_key=next(sub_keys)) return {'msa': msa_act, 'pair': pair_act} class EmbeddingsAndEvoformer(hk.Module): """Embeds the input data and runs Evoformer. Produces the MSA, single and pair representations. Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5-18 """ def __init__(self, config, global_config, name='evoformer'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, batch, is_training, safe_key=None): c = self.config gc = self.global_config if safe_key is None: safe_key = prng.SafeKey(hk.next_rng_key()) # Embed clustered MSA. # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" preprocess_1d = common_modules.Linear( c.msa_channel, name='preprocess_1d')( batch['target_feat']) preprocess_msa = common_modules.Linear( c.msa_channel, name='preprocess_msa')( batch['msa_feat']) msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa left_single = common_modules.Linear( c.pair_channel, name='left_single')( batch['target_feat']) right_single = common_modules.Linear( c.pair_channel, name='right_single')( batch['target_feat']) pair_activations = left_single[:, None] + right_single[None] mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] # Inject previous outputs for recycling. # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" if "prev_pos" in batch: # use predicted position input prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) if c.backprop_dgram: dgram = dgram_from_positions_soft(prev_pseudo_beta, temp=c.backprop_dgram_temp, **c.prev_pos) else: dgram = dgram_from_positions(prev_pseudo_beta, **c.prev_pos) elif 'prev_dgram' in batch: # use predicted distogram input (from Sergey) dgram = jax.nn.softmax(batch["prev_dgram"]) dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) dgram = dgram @ dgram_map pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram) if c.recycle_features: if 'prev_msa_first_row' in batch: prev_msa_first_row = hk.LayerNorm([-1], True, True, name='prev_msa_first_row_norm')( batch['prev_msa_first_row']) msa_activations = msa_activations.at[0].add(prev_msa_first_row) if 'prev_pair' in batch: pair_activations += hk.LayerNorm([-1], True, True, name='prev_pair_norm')( batch['prev_pair']) # Relative position encoding. # Jumper et al. (2021) Suppl. Alg. 4 "relpos" # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" if c.max_relative_feature: # Add one-hot-encoded clipped residue distances to the pair activations. if "rel_pos" in batch: rel_pos = batch['rel_pos'] else: if "offset" in batch: offset = batch['offset'] else: pos = batch['residue_index'] offset = pos[:, None] - pos[None, :] rel_pos = jax.nn.one_hot( jnp.clip( offset + c.max_relative_feature, a_min=0, a_max=2 * c.max_relative_feature), 2 * c.max_relative_feature + 1) pair_activations += common_modules.Linear(c.pair_channel, name='pair_activiations')(rel_pos) # Embed templates into the pair activations. # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 if c.template.enabled: template_batch = {k: batch[k] for k in batch if k.startswith('template_')} template_pair_representation = TemplateEmbedding(c.template, gc)( pair_activations, template_batch, mask_2d, is_training=is_training, scale_rate=batch["scale_rate"]) pair_activations += template_pair_representation # Embed extra MSA features. # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 extra_msa_feat = create_extra_msa_feature(batch) extra_msa_activations = common_modules.Linear( c.extra_msa_channel, name='extra_msa_activations')( extra_msa_feat) # Extra MSA Stack. # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" extra_msa_stack_input = { 'msa': extra_msa_activations, 'pair': pair_activations, } extra_msa_stack_iteration = EvoformerIteration( c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') def extra_msa_stack_fn(x): act, safe_key = x safe_key, safe_subkey = safe_key.split() extra_evoformer_output = extra_msa_stack_iteration( activations=act, masks={ 'msa': batch['extra_msa_mask'], 'pair': mask_2d }, is_training=is_training, safe_key=safe_subkey, scale_rate=batch["scale_rate"]) return (extra_evoformer_output, safe_key) if gc.use_remat: extra_msa_stack_fn = hk.remat(extra_msa_stack_fn) extra_msa_stack = layer_stack.layer_stack( c.extra_msa_stack_num_block)( extra_msa_stack_fn) extra_msa_output, safe_key = extra_msa_stack( (extra_msa_stack_input, safe_key)) pair_activations = extra_msa_output['pair'] evoformer_input = { 'msa': msa_activations, 'pair': pair_activations, } evoformer_masks = {'msa': batch['msa_mask'], 'pair': mask_2d} #################################################################### #################################################################### # Append num_templ rows to msa_activations with template embeddings. # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 7-8 if c.template.enabled and c.template.embed_torsion_angles: if jnp.issubdtype(batch['template_aatype'].dtype, jnp.integer): num_templ, num_res = batch['template_aatype'].shape # Embed the templates aatypes. aatype = batch['template_aatype'] aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1) else: num_templ, num_res, _ = batch['template_aatype'].shape aatype = batch['template_aatype'].argmax(-1) aatype_one_hot = batch['template_aatype'] # Embed the templates aatype, torsion angles and masks. # Shape (templates, residues, msa_channels) ret = all_atom.atom37_to_torsion_angles( aatype=aatype, all_atom_pos=batch['template_all_atom_positions'], all_atom_mask=batch['template_all_atom_masks'], # Ensure consistent behaviour during testing: placeholder_for_undefined=not gc.zero_init) template_features = jnp.concatenate([ aatype_one_hot, jnp.reshape(ret['torsion_angles_sin_cos'], [num_templ, num_res, 14]), jnp.reshape(ret['alt_torsion_angles_sin_cos'], [num_templ, num_res, 14]), ret['torsion_angles_mask']], axis=-1) template_activations = common_modules.Linear( c.msa_channel, initializer='relu', name='template_single_embedding')(template_features) template_activations = jax.nn.relu(template_activations) template_activations = common_modules.Linear( c.msa_channel, initializer='relu', name='template_projection')(template_activations) # Concatenate the templates to the msa. evoformer_input['msa'] = jnp.concatenate([evoformer_input['msa'], template_activations], axis=0) # Concatenate templates masks to the msa masks. # Use mask from the psi angle, as it only depends on the backbone atoms # from a single residue. torsion_angle_mask = ret['torsion_angles_mask'][:, :, 2] torsion_angle_mask = torsion_angle_mask.astype(evoformer_masks['msa'].dtype) evoformer_masks['msa'] = jnp.concatenate([evoformer_masks['msa'], torsion_angle_mask], axis=0) #################################################################### #################################################################### # Main trunk of the network # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18 evoformer_iteration = EvoformerIteration( c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') def evoformer_fn(x): act, safe_key = x safe_key, safe_subkey = safe_key.split() evoformer_output = evoformer_iteration( activations=act, masks=evoformer_masks, is_training=is_training, safe_key=safe_subkey, scale_rate=batch["scale_rate"]) return (evoformer_output, safe_key) if gc.use_remat: evoformer_fn = hk.remat(evoformer_fn) evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(evoformer_fn) evoformer_output, safe_key = evoformer_stack((evoformer_input, safe_key)) msa_activations = evoformer_output['msa'] pair_activations = evoformer_output['pair'] single_activations = common_modules.Linear( c.seq_channel, name='single_activations')(msa_activations[0]) num_sequences = batch['msa_feat'].shape[0] output = { 'single': single_activations, 'pair': pair_activations, # Crop away template rows such that they are not used in MaskedMsaHead. 'msa': msa_activations[:num_sequences, :, :], 'msa_first_row': msa_activations[0], } return output #################################################################### #################################################################### class SingleTemplateEmbedding(hk.Module): """Embeds a single template. Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9+11 """ def __init__(self, config, global_config, name='single_template_embedding'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, query_embedding, batch, mask_2d, is_training, scale_rate=1.0): """Build the single template embedding. Arguments: query_embedding: Query pair representation, shape [N_res, N_res, c_z]. batch: A batch of template features (note the template dimension has been stripped out as this module only runs over a single template). mask_2d: Padding mask (Note: this doesn't care if a template exists, unlike the template_pseudo_beta_mask). is_training: Whether the module is in training mode. Returns: A template embedding [N_res, N_res, c_z]. """ assert mask_2d.dtype == query_embedding.dtype dtype = query_embedding.dtype num_res = batch['template_aatype'].shape[0] num_channels = (self.config.template_pair_stack .triangle_attention_ending_node.value_dim) template_mask = batch['template_pseudo_beta_mask'] template_mask_2d = template_mask[:, None] * template_mask[None, :] template_mask_2d = template_mask_2d.astype(dtype) if "template_dgram" in batch: template_dgram = batch["template_dgram"] else: if self.config.backprop_dgram: template_dgram = dgram_from_positions_soft(batch['template_pseudo_beta'], temp=self.config.backprop_dgram_temp, **self.config.dgram_features) else: template_dgram = dgram_from_positions(batch['template_pseudo_beta'], **self.config.dgram_features) template_dgram = template_dgram.astype(dtype) to_concat = [template_dgram, template_mask_2d[:, :, None]] if jnp.issubdtype(batch['template_aatype'].dtype, jnp.integer): aatype = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1, dtype=dtype) else: aatype = batch['template_aatype'] to_concat.append(jnp.tile(aatype[None, :, :], [num_res, 1, 1])) to_concat.append(jnp.tile(aatype[:, None, :], [1, num_res, 1])) # Backbone affine mask: whether the residue has C, CA, N # (the template mask defined above only considers pseudo CB). n, ca, c = [residue_constants.atom_order[a] for a in ('N', 'CA', 'C')] template_mask = ( batch['template_all_atom_masks'][..., n] * batch['template_all_atom_masks'][..., ca] * batch['template_all_atom_masks'][..., c]) template_mask_2d = template_mask[:, None] * template_mask[None, :] # compute unit_vector (not used by default) if self.config.use_template_unit_vector: rot, trans = quat_affine.make_transform_from_reference( n_xyz=batch['template_all_atom_positions'][:, n], ca_xyz=batch['template_all_atom_positions'][:, ca], c_xyz=batch['template_all_atom_positions'][:, c]) affines = quat_affine.QuatAffine( quaternion=quat_affine.rot_to_quat(rot, unstack_inputs=True), translation=trans, rotation=rot, unstack_inputs=True) points = [jnp.expand_dims(x, axis=-2) for x in affines.translation] affine_vec = affines.invert_point(points, extra_dims=1) inv_distance_scalar = jax.lax.rsqrt(1e-6 + sum([jnp.square(x) for x in affine_vec])) inv_distance_scalar *= template_mask_2d.astype(inv_distance_scalar.dtype) unit_vector = [(x * inv_distance_scalar)[..., None] for x in affine_vec] else: unit_vector = [jnp.zeros((num_res,num_res,1))] * 3 unit_vector = [x.astype(dtype) for x in unit_vector] to_concat.extend(unit_vector) template_mask_2d = template_mask_2d.astype(dtype) to_concat.append(template_mask_2d[..., None]) act = jnp.concatenate(to_concat, axis=-1) # Mask out non-template regions so we don't get arbitrary values in the # distogram for these regions. act *= template_mask_2d[..., None] # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 9 act = common_modules.Linear( num_channels, initializer='relu', name='embedding2d')(act) # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 11 act = TemplatePairStack( self.config.template_pair_stack, self.global_config)(act, mask_2d, is_training, scale_rate=scale_rate) act = hk.LayerNorm([-1], True, True, name='output_layer_norm')(act) return act class TemplateEmbedding(hk.Module): """Embeds a set of templates. Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12 Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention" """ def __init__(self, config, global_config, name='template_embedding'): super().__init__(name=name) self.config = config self.global_config = global_config def __call__(self, query_embedding, template_batch, mask_2d, is_training, scale_rate=1.0): """Build TemplateEmbedding module. Arguments: query_embedding: Query pair representation, shape [N_res, N_res, c_z]. template_batch: A batch of template features. mask_2d: Padding mask (Note: this doesn't care if a template exists, unlike the template_pseudo_beta_mask). is_training: Whether the module is in training mode. Returns: A template embedding [N_res, N_res, c_z]. """ num_templates = template_batch['template_mask'].shape[0] num_channels = (self.config.template_pair_stack .triangle_attention_ending_node.value_dim) num_res = query_embedding.shape[0] dtype = query_embedding.dtype template_mask = template_batch['template_mask'] template_mask = template_mask.astype(dtype) query_num_channels = query_embedding.shape[-1] # Make sure the weights are shared across templates by constructing the # embedder here. # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12 template_embedder = SingleTemplateEmbedding(self.config, self.global_config) def map_fn(batch): return template_embedder(query_embedding, batch, mask_2d, is_training, scale_rate=scale_rate) template_pair_representation = mapping.sharded_map(map_fn, in_axes=0)(template_batch) # Cross attend from the query to the templates along the residue # dimension by flattening everything else into the batch dimension. # Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention" flat_query = jnp.reshape(query_embedding,[num_res * num_res, 1, query_num_channels]) flat_templates = jnp.reshape( jnp.transpose(template_pair_representation, [1, 2, 0, 3]), [num_res * num_res, num_templates, num_channels]) bias = (1e9 * (template_mask[None, None, None, :] - 1.)) template_pointwise_attention_module = Attention( self.config.attention, self.global_config, query_num_channels) nonbatched_args = [bias] batched_args = [flat_query, flat_templates] embedding = mapping.inference_subbatch( template_pointwise_attention_module, self.config.subbatch_size, batched_args=batched_args, nonbatched_args=nonbatched_args, low_memory=not is_training) embedding = jnp.reshape(embedding,[num_res, num_res, query_num_channels]) # No gradients if no templates. embedding *= (jnp.sum(template_mask) > 0.).astype(embedding.dtype) return embedding ####################################################################