diff --git a/LICENSE b/LICENSE index fbb83b08d35cfbd2f598fa69f7eda783d74cd849..c6e692d7aa90b1a51eabc1921675bc7bac76bc1a 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2022 Justas Dauparas, Simon Duerr +Copyright (c) 2022 Justas Dauparas,Sergey Ovichinnikov, Simon Duerr Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/af_backprop/README.md b/af_backprop/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e1df10bd64cfe49c206a8fda13055c7efd2ea23c --- /dev/null +++ b/af_backprop/README.md @@ -0,0 +1,6 @@ +# af_backprop +various modifications to alphafold to allow backprop through the model + +### projects that use af_backprop +- [SMURF](https://github.com/spetti/SMURF): End-to-end learning of multiple sequence alignments with differentiable Smith-Waterman +- [ColabDesign](https://github.com/sokrypton/ColabDesign): Making Protein Design accessible to all via Google Colab! diff --git a/af_backprop/alphafold/__init__.py b/af_backprop/alphafold/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0fd7f8294b9d7be770127c356f0b6564f1baa6c --- /dev/null +++ b/af_backprop/alphafold/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An implementation of the inference pipeline of AlphaFold v2.0.""" diff --git a/af_backprop/alphafold/common/__init__.py b/af_backprop/alphafold/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c65d69d5f61b7b9547153c47d84e7f545e2636 --- /dev/null +++ b/af_backprop/alphafold/common/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common data types and constants used within Alphafold.""" diff --git a/af_backprop/alphafold/common/confidence.py b/af_backprop/alphafold/common/confidence.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1085f75d2c8528a05abc12a483c3685af8c7c1 --- /dev/null +++ b/af_backprop/alphafold/common/confidence.py @@ -0,0 +1,155 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for processing confidence metrics.""" + +from typing import Dict, Optional, Tuple +import numpy as np +import scipy.special + + +def compute_plddt(logits: np.ndarray) -> np.ndarray: + """Computes per-residue pLDDT from logits. + + Args: + logits: [num_res, num_bins] output from the PredictedLDDTHead. + + Returns: + plddt: [num_res] per-residue pLDDT. + """ + num_bins = logits.shape[-1] + bin_width = 1.0 / num_bins + bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width) + probs = scipy.special.softmax(logits, axis=-1) + predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1) + return predicted_lddt_ca * 100 + + +def _calculate_bin_centers(breaks: np.ndarray): + """Gets the bin centers from the bin edges. + + Args: + breaks: [num_bins - 1] the error bin edges. + + Returns: + bin_centers: [num_bins] the error bin centers. + """ + step = (breaks[1] - breaks[0]) + + # Add half-step to get the center + bin_centers = breaks + step / 2 + # Add a catch-all bin at the end. + bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]], + axis=0) + return bin_centers + + +def _calculate_expected_aligned_error( + alignment_confidence_breaks: np.ndarray, + aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Calculates expected aligned distance errors for every pair of residues. + + Args: + alignment_confidence_breaks: [num_bins - 1] the error bin edges. + aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted + probs for each error bin, for each pair of residues. + + Returns: + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + bin_centers = _calculate_bin_centers(alignment_confidence_breaks) + + # Tuple of expected aligned distance error and max possible error. + return (np.sum(aligned_distance_error_probs * bin_centers, axis=-1), + np.asarray(bin_centers[-1])) + + +def compute_predicted_aligned_error( + logits: np.ndarray, + breaks: np.ndarray) -> Dict[str, np.ndarray]: + """Computes aligned confidence metrics from logits. + + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins - 1] the error bin edges. + + Returns: + aligned_confidence_probs: [num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + aligned_confidence_probs = scipy.special.softmax( + logits, + axis=-1) + predicted_aligned_error, max_predicted_aligned_error = ( + _calculate_expected_aligned_error( + alignment_confidence_breaks=breaks, + aligned_distance_error_probs=aligned_confidence_probs)) + return { + 'aligned_confidence_probs': aligned_confidence_probs, + 'predicted_aligned_error': predicted_aligned_error, + 'max_predicted_aligned_error': max_predicted_aligned_error, + } + + +def predicted_tm_score( + logits: np.ndarray, + breaks: np.ndarray, + residue_weights: Optional[np.ndarray] = None) -> np.ndarray: + """Computes predicted TM alignment score. + + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins] the error bins. + residue_weights: [num_res] the per residue weights to use for the + expectation. + + Returns: + ptm_score: the predicted TM alignment score. + """ + + # residue_weights has to be in [0, 1], but can be floating-point, i.e. the + # exp. resolved head's probability. + if residue_weights is None: + residue_weights = np.ones(logits.shape[0]) + + bin_centers = _calculate_bin_centers(breaks) + + num_res = np.sum(residue_weights) + # Clip num_res to avoid negative/undefined d0. + clipped_num_res = max(num_res, 19) + + # Compute d_0(num_res) as defined by TM-score, eqn. (5) in + # http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf + # Yang & Skolnick "Scoring function for automated + # assessment of protein structure template quality" 2004 + d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8 + + # Convert logits to probs + probs = scipy.special.softmax(logits, axis=-1) + + # TM-Score term for every bin + tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0)) + # E_distances tm(distance) + predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1) + + normed_residue_mask = residue_weights / (1e-8 + residue_weights.sum()) + per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1) + return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()]) diff --git a/af_backprop/alphafold/common/protein.py b/af_backprop/alphafold/common/protein.py new file mode 100644 index 0000000000000000000000000000000000000000..2848f5bbc52d646ddc22a8f2e1c6b4d98ae1ffce --- /dev/null +++ b/af_backprop/alphafold/common/protein.py @@ -0,0 +1,229 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Protein data type.""" +import dataclasses +import io +from typing import Any, Mapping, Optional +from alphafold.common import residue_constants +from Bio.PDB import PDBParser +import numpy as np + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + +def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: + """Takes a PDB string and constructs a Protein object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + pdb_str: The contents of the pdb file + chain_id: If None, then the pdb file must contain a single chain (which + will be parsed). If chain_id is specified (e.g. A), then only that chain + is parsed. + + Returns: + A new `Protein` parsed from the pdb contents. + """ + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser(QUIET=True) + structure = parser.get_structure('none', pdb_fh) + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + f'Only single model PDBs are supported. Found {len(models)} models.') + model = models[0] + + if chain_id is not None: + chain = model[chain_id] + else: + chains = list(model.get_chains()) + if len(chains) != 1: + raise ValueError( + 'Only single chain PDBs are supported when chain_id not specified. ' + f'Found {len(chains)} chains.') + else: + chain = chains[0] + + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + b_factors = [] + + for res in chain: + if res.id[2] != ' ': + raise ValueError( + f'PDB contains an insertion code at chain {chain.id} and residue ' + f'index {res.id[1]}. These are not supported.') + res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1. + res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + b_factors.append(res_b_factors) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + b_factors=np.array(b_factors)) + + +def to_pdb(prot: Protein) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ['X'] + res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + b_factors = prot.b_factors + + if np.any(aatype > residue_constants.restype_num): + raise ValueError('Invalid aatypes.') + + pdb_lines.append('MODEL 1') + atom_index = 1 + chain_id = 'A' + # Add all atom sites. + for i in range(aatype.shape[0]): + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip( + atom_types, atom_positions[i], atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + + record_type = 'ATOM' + name = atom_name if len(atom_name) == 4 else f' {atom_name}' + alt_loc = '' + insertion_code = '' + occupancy = 1.00 + element = atom_name[0] # Protein supports only C, N, O, S, this works. + charge = '' + # PDB is a columnar format, every space matters here! + atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' + f'{res_name_3:>3} {chain_id:>1}' + f'{residue_index[i]:>4}{insertion_code:>1} ' + f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' + f'{occupancy:>6.2f}{b_factor:>6.2f} ' + f'{element:>2}{charge:>2}') + pdb_lines.append(atom_line) + atom_index += 1 + + # Close the chain. + chain_end = 'TER' + chain_termination_line = ( + f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} ' + f'{chain_id:>1}{residue_index[-1]:>4}') + pdb_lines.append(chain_termination_line) + pdb_lines.append('ENDMDL') + + pdb_lines.append('END') + pdb_lines.append('') + return '\n'.join(pdb_lines) + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are + reported in the PDB. This function computes a mask according to heavy atoms + that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction(features: FeatureDict, result: ModelOutput, + b_factors: Optional[np.ndarray] = None) -> Protein: + """Assembles a protein from a prediction. + + Args: + features: Dictionary holding model inputs. + result: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + + Returns: + A protein instance. + """ + fold_output = result['structure_module'] + if b_factors is None: + b_factors = np.zeros_like(fold_output['final_atom_mask']) + + return Protein( + aatype=features['aatype'][0], + atom_positions=fold_output['final_atom_positions'], + atom_mask=fold_output['final_atom_mask'], + residue_index=features['residue_index'][0] + 1, + b_factors=b_factors) diff --git a/af_backprop/alphafold/common/residue_constants.py b/af_backprop/alphafold/common/residue_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee38c72e76aa90a7b8abc4b7ec43552f28cc715 --- /dev/null +++ b/af_backprop/alphafold/common/residue_constants.py @@ -0,0 +1,911 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used in AlphaFold.""" + +import collections +import functools +from typing import List, Mapping, Tuple + +import numpy as np +import tree + +# Internal import (35fd). + + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + 'ALA': [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + 'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']], + 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'CYS': [['N', 'CA', 'CB', 'SG']], + 'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1']], + 'GLY': [], + 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']], + 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']], + 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']], + 'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'], + ['CB', 'CG', 'SD', 'CE']], + 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']], + 'SER': [['N', 'CA', 'CB', 'OG']], + 'THR': [['N', 'CA', 'CB', 'OG1']], + 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'VAL': [['N', 'CA', 'CB', 'CG1']], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + 'ALA': [ + ['N', 0, (-0.525, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.529, -0.774, -1.205)], + ['O', 3, (0.627, 1.062, 0.000)], + ], + 'ARG': [ + ['N', 0, (-0.524, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.524, -0.778, -1.209)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.616, 1.390, -0.000)], + ['CD', 5, (0.564, 1.414, 0.000)], + ['NE', 6, (0.539, 1.357, -0.000)], + ['NH1', 7, (0.206, 2.301, 0.000)], + ['NH2', 7, (2.078, 0.978, -0.000)], + ['CZ', 7, (0.758, 1.093, -0.000)], + ], + 'ASN': [ + ['N', 0, (-0.536, 1.357, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.531, -0.787, -1.200)], + ['O', 3, (0.625, 1.062, 0.000)], + ['CG', 4, (0.584, 1.399, 0.000)], + ['ND2', 5, (0.593, -1.188, 0.001)], + ['OD1', 5, (0.633, 1.059, 0.000)], + ], + 'ASP': [ + ['N', 0, (-0.525, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, 0.000, -0.000)], + ['CB', 0, (-0.526, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.593, 1.398, -0.000)], + ['OD1', 5, (0.610, 1.091, 0.000)], + ['OD2', 5, (0.592, -1.101, -0.003)], + ], + 'CYS': [ + ['N', 0, (-0.522, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, 0.000)], + ['CB', 0, (-0.519, -0.773, -1.212)], + ['O', 3, (0.625, 1.062, -0.000)], + ['SG', 4, (0.728, 1.653, 0.000)], + ], + 'GLN': [ + ['N', 0, (-0.526, 1.361, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.779, -1.207)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.615, 1.393, 0.000)], + ['CD', 5, (0.587, 1.399, -0.000)], + ['NE2', 6, (0.593, -1.189, -0.001)], + ['OE1', 6, (0.634, 1.060, 0.000)], + ], + 'GLU': [ + ['N', 0, (-0.528, 1.361, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.526, -0.781, -1.207)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.615, 1.392, 0.000)], + ['CD', 5, (0.600, 1.397, 0.000)], + ['OE1', 6, (0.607, 1.095, -0.000)], + ['OE2', 6, (0.589, -1.104, -0.001)], + ], + 'GLY': [ + ['N', 0, (-0.572, 1.337, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.517, -0.000, -0.000)], + ['O', 3, (0.626, 1.062, -0.000)], + ], + 'HIS': [ + ['N', 0, (-0.527, 1.360, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.778, -1.208)], + ['O', 3, (0.625, 1.063, 0.000)], + ['CG', 4, (0.600, 1.370, -0.000)], + ['CD2', 5, (0.889, -1.021, 0.003)], + ['ND1', 5, (0.744, 1.160, -0.000)], + ['CE1', 5, (2.030, 0.851, 0.002)], + ['NE2', 5, (2.145, -0.466, 0.004)], + ], + 'ILE': [ + ['N', 0, (-0.493, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.536, -0.793, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.534, 1.437, -0.000)], + ['CG2', 4, (0.540, -0.785, -1.199)], + ['CD1', 5, (0.619, 1.391, 0.000)], + ], + 'LEU': [ + ['N', 0, (-0.520, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.773, -1.214)], + ['O', 3, (0.625, 1.063, -0.000)], + ['CG', 4, (0.678, 1.371, 0.000)], + ['CD1', 5, (0.530, 1.430, -0.000)], + ['CD2', 5, (0.535, -0.774, 1.200)], + ], + 'LYS': [ + ['N', 0, (-0.526, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.524, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.619, 1.390, 0.000)], + ['CD', 5, (0.559, 1.417, 0.000)], + ['CE', 6, (0.560, 1.416, 0.000)], + ['NZ', 7, (0.554, 1.387, 0.000)], + ], + 'MET': [ + ['N', 0, (-0.521, 1.364, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.210)], + ['O', 3, (0.625, 1.062, -0.000)], + ['CG', 4, (0.613, 1.391, -0.000)], + ['SD', 5, (0.703, 1.695, 0.000)], + ['CE', 6, (0.320, 1.786, -0.000)], + ], + 'PHE': [ + ['N', 0, (-0.518, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, -0.000)], + ['CB', 0, (-0.525, -0.776, -1.212)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.607, 1.377, 0.000)], + ['CD1', 5, (0.709, 1.195, -0.000)], + ['CD2', 5, (0.706, -1.196, 0.000)], + ['CE1', 5, (2.102, 1.198, -0.000)], + ['CE2', 5, (2.098, -1.201, -0.000)], + ['CZ', 5, (2.794, -0.003, -0.001)], + ], + 'PRO': [ + ['N', 0, (-0.566, 1.351, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, 0.000)], + ['CB', 0, (-0.546, -0.611, -1.293)], + ['O', 3, (0.621, 1.066, 0.000)], + ['CG', 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + 'SER': [ + ['N', 0, (-0.529, 1.360, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.518, -0.777, -1.211)], + ['O', 3, (0.626, 1.062, -0.000)], + ['OG', 4, (0.503, 1.325, 0.000)], + ], + 'THR': [ + ['N', 0, (-0.517, 1.364, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, -0.000)], + ['CB', 0, (-0.516, -0.793, -1.215)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG2', 4, (0.550, -0.718, -1.228)], + ['OG1', 4, (0.472, 1.353, 0.000)], + ], + 'TRP': [ + ['N', 0, (-0.521, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.212)], + ['O', 3, (0.627, 1.062, 0.000)], + ['CG', 4, (0.609, 1.370, -0.000)], + ['CD1', 5, (0.824, 1.091, 0.000)], + ['CD2', 5, (0.854, -1.148, -0.005)], + ['CE2', 5, (2.186, -0.678, -0.007)], + ['CE3', 5, (0.622, -2.530, -0.007)], + ['NE1', 5, (2.140, 0.690, -0.004)], + ['CH2', 5, (3.028, -2.890, -0.013)], + ['CZ2', 5, (3.283, -1.543, -0.011)], + ['CZ3', 5, (1.715, -3.389, -0.011)], + ], + 'TYR': [ + ['N', 0, (-0.522, 1.362, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.776, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG', 4, (0.607, 1.382, -0.000)], + ['CD1', 5, (0.716, 1.195, -0.000)], + ['CD2', 5, (0.713, -1.194, -0.001)], + ['CE1', 5, (2.107, 1.200, -0.002)], + ['CE2', 5, (2.104, -1.201, -0.003)], + ['OH', 5, (4.168, -0.002, -0.005)], + ['CZ', 5, (2.791, -0.001, -0.003)], + ], + 'VAL': [ + ['N', 0, (-0.494, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.533, -0.795, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.540, 1.429, -0.000)], + ['CG2', 4, (0.533, -0.776, 1.203)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + 'ALA': ['C', 'CA', 'CB', 'N', 'O'], + 'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'], + 'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'], + 'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'], + 'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'], + 'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'], + 'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'], + 'GLY': ['C', 'CA', 'N', 'O'], + 'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'], + 'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'], + 'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'], + 'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'], + 'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'], + 'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'], + 'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'], + 'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'], + 'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'], + 'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3', + 'CH2', 'N', 'NE1', 'O'], + 'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O', + 'OH'], + 'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O'] +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +residue_atom_renaming_swaps = { + 'ASP': {'OD1': 'OD2'}, + 'GLU': {'OE1': 'OE2'}, + 'PHE': {'CD1': 'CD2', 'CE1': 'CE2'}, + 'TYR': {'CD1': 'CD2', 'CE1': 'CE2'}, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + 'C': 1.7, + 'N': 1.55, + 'O': 1.52, + 'S': 1.8, +} + +Bond = collections.namedtuple( + 'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev']) +BondAngle = collections.namedtuple( + 'BondAngle', + ['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev']) + + +@functools.lru_cache(maxsize=None) +def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]], + Mapping[str, List[Bond]], + Mapping[str, List[BondAngle]]]: + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate + bond angles into the length of the opposite edge of the triangle + ("residue_virtual_bonds"). + + Returns: + residue_bonds: dict that maps resname --> list of Bond tuples + residue_virtual_bonds: dict that maps resname --> list of Bond tuples + residue_bond_angles: dict that maps resname --> list of BondAngle tuples + """ + stereo_chemical_props_path = ( + 'alphafold/common/stereo_chemical_props.txt') + with open(stereo_chemical_props_path, 'rt') as f: + stereo_chemical_props = f.read() + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split('-') + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append( + Bond(atom1, atom2, float(length), float(stddev))) + residue_bonds['UNK'] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split('-') + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle(atom1, atom2, atom3, + float(angle_degree) / 180. * np.pi, + float(stddev_degree) / 180. * np.pi)) + residue_bond_angles['UNK'] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return '-'.join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt(bond1.length**2 + bond2.length**2 + - 2 * bond1.length * bond2.length * np.cos(gamma)) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer + dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer + dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer + stddev = np.sqrt((dl_dgamma * ba.stddev)**2 + + (dl_db1 * bond1.stddev)**2 + + (dl_db2 * bond2.stddev)**2) + residue_virtual_bonds[resname].append( + Bond(ba.atom1_name, ba.atom3name, length, stddev)) + + return (residue_bonds, + residue_virtual_bonds, + residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + 'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD', + 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3', + 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2', + 'CZ3', 'NZ', 'OXT' +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''], + 'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''], + 'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''], + 'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''], + 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''], + 'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''], + 'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''], + 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''], + 'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''], + 'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''], + 'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''], + 'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''], + 'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''], + 'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''], + 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''], + 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''], + 'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''], + 'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'], + 'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''], + 'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''], + 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''], + +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', + 'S', 'T', 'W', 'Y', 'V' +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ['X'] +restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} + + +def sequence_to_onehot( + sequence: str, + mapping: Mapping[str, int], + map_unknown_to_x: bool = False) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain + amino acid 'X', an error will be thrown. If False, any amino acid not in + the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of + the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError('The mapping must have values from 0 to num_unique_aas-1 ' + 'without any gaps. Got: %s' % sorted(mapping.values())) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping['X']) + else: + raise ValueError(f'Invalid character in the sequence: {aa_type}') + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + 'A': 'ALA', + 'R': 'ARG', + 'N': 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'Q': 'GLN', + 'E': 'GLU', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL', +} + + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = 'UNK' + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID = { + 'A': 0, + 'B': 2, + 'C': 1, + 'D': 2, + 'E': 3, + 'F': 4, + 'G': 5, + 'H': 6, + 'I': 7, + 'J': 20, + 'K': 8, + 'L': 9, + 'M': 10, + 'N': 11, + 'O': 20, + 'P': 12, + 'Q': 13, + 'R': 14, + 'S': 15, + 'T': 16, + 'U': 1, + 'V': 17, + 'W': 18, + 'X': 20, + 'Y': 19, + 'Z': 3, + '-': 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA = { + 0: 'A', + 1: 'C', # Also U. + 2: 'D', # Also B. + 3: 'E', # Also Z. + 4: 'F', + 5: 'G', + 6: 'H', + 7: 'I', + 8: 'K', + 9: 'L', + 10: 'M', + 11: 'N', + 12: 'P', + 13: 'Q', + 14: 'R', + 15: 'S', + 16: 'T', + 17: 'V', + 18: 'W', + 19: 'Y', + 20: 'X', # Includes J and O. + 21: '-', +} + +restypes_with_x_and_gap = restypes + ['X', '-'] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple( + restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) + for i in range(len(restypes_with_x_and_gap))) + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1]*(4-len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices = tree.map_structure( + lambda atom_name: atom_order[atom_name], chi_angles_atom_indices) +chi_angles_atom_indices = np.array([ + chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) + for chi_atoms in chi_angles_atom_indices]) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose() + m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + +############################################### +restype_atom14_to_atom37 = [] +restype_atom37_to_atom14 = [] +for rt in restypes: + atom_names = restype_name_to_atom14_names[restype_1to3[rt]] + restype_atom14_to_atom37.append([(atom_order[name] if name else 0) for name in atom_names]) + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in atom_types]) +restype_atom14_to_atom37.append([0] * 14) +restype_atom37_to_atom14.append([0] * 37) +restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) +restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) +################################################ + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position + + atom_names = residue_atoms[resname] + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = {name: np.array(pos) for name, _, pos + in rigid_group_atom_positions[resname]} + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['N'] - atom_positions['CA'], + ey=np.array([1., 0., 0.]), + translation=atom_positions['N']) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['C'] - atom_positions['CA'], + ey=atom_positions['CA'] - atom_positions['N'], + translation=atom_positions['C']) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [atom_positions[name] for name in base_atom_names] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2]) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1., 0., 0.]), + translation=axis_end_atom_position) + restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds(overlap_tolerance=1.5, + bond_length_tolerance_factor=15): + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev + return {'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14) + 'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14) + 'stddev': restype_atom14_bond_stddev, # shape (21,14,14) + } diff --git a/af_backprop/alphafold/data/__init__.py b/af_backprop/alphafold/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9821d212c3c7781e601ea8d2137493942d0937d4 --- /dev/null +++ b/af_backprop/alphafold/data/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data pipeline for model features.""" diff --git a/af_backprop/alphafold/data/mmcif_parsing.py b/af_backprop/alphafold/data/mmcif_parsing.py new file mode 100644 index 0000000000000000000000000000000000000000..18375165a526a780f2e602a3800e12833dbb3e67 --- /dev/null +++ b/af_backprop/alphafold/data/mmcif_parsing.py @@ -0,0 +1,384 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parses the mmCIF file format.""" +import collections +import dataclasses +import io +from typing import Any, Mapping, Optional, Sequence, Tuple + +from absl import logging +from Bio import PDB +from Bio.Data import SCOPData + +# Type aliases: +ChainId = str +PdbHeader = Mapping[str, Any] +PdbStructure = PDB.Structure.Structure +SeqRes = str +MmCIFDict = Mapping[str, Sequence[str]] + + +@dataclasses.dataclass(frozen=True) +class Monomer: + id: str + num: int + + +# Note - mmCIF format provides no guarantees on the type of author-assigned +# sequence numbers. They need not be integers. +@dataclasses.dataclass(frozen=True) +class AtomSite: + residue_name: str + author_chain_id: str + mmcif_chain_id: str + author_seq_num: str + mmcif_seq_num: int + insertion_code: str + hetatm_atom: str + model_num: int + + +# Used to map SEQRES index to a residue in the structure. +@dataclasses.dataclass(frozen=True) +class ResiduePosition: + chain_id: str + residue_number: int + insertion_code: str + + +@dataclasses.dataclass(frozen=True) +class ResidueAtPosition: + position: Optional[ResiduePosition] + name: str + is_missing: bool + hetflag: str + + +@dataclasses.dataclass(frozen=True) +class MmcifObject: + """Representation of a parsed mmCIF file. + + Contains: + file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all + files being processed. + header: Biopython header. + structure: Biopython structure. + chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g. + {'A': 'ABCDEFG'} + seqres_to_structure: Dict; for each chain_id contains a mapping between + SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition, + 1: ResidueAtPosition, + ...}} + raw_string: The raw string used to construct the MmcifObject. + """ + file_id: str + header: PdbHeader + structure: PdbStructure + chain_to_seqres: Mapping[ChainId, SeqRes] + seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]] + raw_string: Any + + +@dataclasses.dataclass(frozen=True) +class ParsingResult: + """Returned by the parse function. + + Contains: + mmcif_object: A MmcifObject, may be None if no chain could be successfully + parsed. + errors: A dict mapping (file_id, chain_id) to any exception generated. + """ + mmcif_object: Optional[MmcifObject] + errors: Mapping[Tuple[str, str], Any] + + +class ParseError(Exception): + """An error indicating that an mmCIF file could not be parsed.""" + + +def mmcif_loop_to_list(prefix: str, + parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]: + """Extracts loop associated with a prefix from mmCIF data as a list. + + Reference for loop_ in mmCIF: + http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a list of dicts; each dict represents 1 entry from an mmCIF loop. + """ + cols = [] + data = [] + for key, value in parsed_info.items(): + if key.startswith(prefix): + cols.append(key) + data.append(value) + + assert all([len(xs) == len(data[0]) for xs in data]), ( + 'mmCIF error: Not all loops are the same length: %s' % cols) + + return [dict(zip(cols, xs)) for xs in zip(*data)] + + +def mmcif_loop_to_dict(prefix: str, + index: str, + parsed_info: MmCIFDict, + ) -> Mapping[str, Mapping[str, str]]: + """Extracts loop associated with a prefix from mmCIF data as a dictionary. + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + index: Which item of loop data should serve as the key. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop, + indexed by the index column. + """ + entries = mmcif_loop_to_list(prefix, parsed_info) + return {entry[index]: entry for entry in entries} + + +def parse(*, + file_id: str, + mmcif_string: str, + catch_all_errors: bool = True) -> ParsingResult: + """Entry point, parses an mmcif_string. + + Args: + file_id: A string identifier for this file. Should be unique within the + collection of files being processed. + mmcif_string: Contents of an mmCIF file. + catch_all_errors: If True, all exceptions are caught and error messages are + returned as part of the ParsingResult. If False exceptions will be allowed + to propagate. + + Returns: + A ParsingResult. + """ + errors = {} + try: + parser = PDB.MMCIFParser(QUIET=True) + handle = io.StringIO(mmcif_string) + full_structure = parser.get_structure('', handle) + first_model_structure = _get_first_model(full_structure) + # Extract the _mmcif_dict from the parser, which contains useful fields not + # reflected in the Biopython structure. + parsed_info = parser._mmcif_dict # pylint:disable=protected-access + + # Ensure all values are lists, even if singletons. + for key, value in parsed_info.items(): + if not isinstance(value, list): + parsed_info[key] = [value] + + header = _get_header(parsed_info) + + # Determine the protein chains, and their start numbers according to the + # internal mmCIF numbering scheme (likely but not guaranteed to be 1). + valid_chains = _get_protein_chains(parsed_info=parsed_info) + if not valid_chains: + return ParsingResult( + None, {(file_id, ''): 'No protein chains found in this file.'}) + seq_start_num = {chain_id: min([monomer.num for monomer in seq]) + for chain_id, seq in valid_chains.items()} + + # Loop over the atoms for which we have coordinates. Populate two mappings: + # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used + # the authors / Biopython). + # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition). + mmcif_to_author_chain_id = {} + seq_to_structure_mappings = {} + for atom in _get_atom_site_list(parsed_info): + if atom.model_num != '1': + # We only process the first model at the moment. + continue + + mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id + + if atom.mmcif_chain_id in valid_chains: + hetflag = ' ' + if atom.hetatm_atom == 'HETATM': + # Water atoms are assigned a special hetflag of W in Biopython. We + # need to do the same, so that this hetflag can be used to fetch + # a residue from the Biopython structure by id. + if atom.residue_name in ('HOH', 'WAT'): + hetflag = 'W' + else: + hetflag = 'H_' + atom.residue_name + insertion_code = atom.insertion_code + if not _is_set(atom.insertion_code): + insertion_code = ' ' + position = ResiduePosition(chain_id=atom.author_chain_id, + residue_number=int(atom.author_seq_num), + insertion_code=insertion_code) + seq_idx = int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id] + current = seq_to_structure_mappings.get(atom.author_chain_id, {}) + current[seq_idx] = ResidueAtPosition(position=position, + name=atom.residue_name, + is_missing=False, + hetflag=hetflag) + seq_to_structure_mappings[atom.author_chain_id] = current + + # Add missing residue information to seq_to_structure_mappings. + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id[chain_id] + current_mapping = seq_to_structure_mappings[author_chain] + for idx, monomer in enumerate(seq_info): + if idx not in current_mapping: + current_mapping[idx] = ResidueAtPosition(position=None, + name=monomer.id, + is_missing=True, + hetflag=' ') + + author_chain_to_sequence = {} + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id[chain_id] + seq = [] + for monomer in seq_info: + code = SCOPData.protein_letters_3to1.get(monomer.id, 'X') + seq.append(code if len(code) == 1 else 'X') + seq = ''.join(seq) + author_chain_to_sequence[author_chain] = seq + + mmcif_object = MmcifObject( + file_id=file_id, + header=header, + structure=first_model_structure, + chain_to_seqres=author_chain_to_sequence, + seqres_to_structure=seq_to_structure_mappings, + raw_string=parsed_info) + + return ParsingResult(mmcif_object=mmcif_object, errors=errors) + except Exception as e: # pylint:disable=broad-except + errors[(file_id, '')] = e + if not catch_all_errors: + raise + return ParsingResult(mmcif_object=None, errors=errors) + + +def _get_first_model(structure: PdbStructure) -> PdbStructure: + """Returns the first model in a Biopython structure.""" + return next(structure.get_models()) + +_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21 + + +def get_release_date(parsed_info: MmCIFDict) -> str: + """Returns the oldest revision date.""" + revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date'] + return min(revision_dates) + + +def _get_header(parsed_info: MmCIFDict) -> PdbHeader: + """Returns a basic header containing method, release date and resolution.""" + header = {} + + experiments = mmcif_loop_to_list('_exptl.', parsed_info) + header['structure_method'] = ','.join([ + experiment['_exptl.method'].lower() for experiment in experiments]) + + # Note: The release_date here corresponds to the oldest revision. We prefer to + # use this for dataset filtering over the deposition_date. + if '_pdbx_audit_revision_history.revision_date' in parsed_info: + header['release_date'] = get_release_date(parsed_info) + else: + logging.warning('Could not determine release_date: %s', + parsed_info['_entry.id']) + + header['resolution'] = 0.00 + for res_key in ('_refine.ls_d_res_high', '_em_3d_reconstruction.resolution', + '_reflns.d_resolution_high'): + if res_key in parsed_info: + try: + raw_resolution = parsed_info[res_key][0] + header['resolution'] = float(raw_resolution) + except ValueError: + logging.warning('Invalid resolution format: %s', parsed_info[res_key]) + + return header + + +def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]: + """Returns list of atom sites; contains data not present in the structure.""" + return [AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension + parsed_info['_atom_site.label_comp_id'], + parsed_info['_atom_site.auth_asym_id'], + parsed_info['_atom_site.label_asym_id'], + parsed_info['_atom_site.auth_seq_id'], + parsed_info['_atom_site.label_seq_id'], + parsed_info['_atom_site.pdbx_PDB_ins_code'], + parsed_info['_atom_site.group_PDB'], + parsed_info['_atom_site.pdbx_PDB_model_num'], + )] + + +def _get_protein_chains( + *, parsed_info: Mapping[str, Any]) -> Mapping[ChainId, Sequence[Monomer]]: + """Extracts polymer information for protein chains only. + + Args: + parsed_info: _mmcif_dict produced by the Biopython parser. + + Returns: + A dict mapping mmcif chain id to a list of Monomers. + """ + # Get polymer information for each entity in the structure. + entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info) + + polymers = collections.defaultdict(list) + for entity_poly_seq in entity_poly_seqs: + polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append( + Monomer(id=entity_poly_seq['_entity_poly_seq.mon_id'], + num=int(entity_poly_seq['_entity_poly_seq.num']))) + + # Get chemical compositions. Will allow us to identify which of these polymers + # are proteins. + chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', parsed_info) + + # Get chains information for each entity. Necessary so that we can return a + # dict keyed on chain id rather than entity. + struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info) + + entity_to_mmcif_chains = collections.defaultdict(list) + for struct_asym in struct_asyms: + chain_id = struct_asym['_struct_asym.id'] + entity_id = struct_asym['_struct_asym.entity_id'] + entity_to_mmcif_chains[entity_id].append(chain_id) + + # Identify and return the valid protein chains. + valid_chains = {} + for entity_id, seq_info in polymers.items(): + chain_ids = entity_to_mmcif_chains[entity_id] + + # Reject polymers without any peptide-like components, such as DNA/RNA. + if any(['peptide' in chem_comps[monomer.id]['_chem_comp.type'] + for monomer in seq_info]): + for chain_id in chain_ids: + valid_chains[chain_id] = seq_info + return valid_chains + + +def _is_set(data: str) -> bool: + """Returns False if data is a special mmCIF character indicating 'unset'.""" + return data not in ('.', '?') diff --git a/af_backprop/alphafold/data/parsers.py b/af_backprop/alphafold/data/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..edc21bbeb897520baae2352dbfb4ac0ebfbb7a59 --- /dev/null +++ b/af_backprop/alphafold/data/parsers.py @@ -0,0 +1,364 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for parsing various file formats.""" +import collections +import dataclasses +import re +import string +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +DeletionMatrix = Sequence[Sequence[int]] + + +@dataclasses.dataclass(frozen=True) +class TemplateHit: + """Class representing a template hit.""" + index: int + name: str + aligned_cols: int + sum_probs: float + query: str + hit_sequence: str + indices_query: List[int] + indices_hit: List[int] + + +def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: + """Parses FASTA string and returns list of strings with amino-acid sequences. + + Arguments: + fasta_string: The string contents of a FASTA file. + + Returns: + A tuple of two lists: + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. + """ + sequences = [] + descriptions = [] + index = -1 + for line in fasta_string.splitlines(): + line = line.strip() + if line.startswith('>'): + index += 1 + descriptions.append(line[1:]) # Remove the '>' at the beginning. + sequences.append('') + continue + elif not line: + continue # Skip blank lines. + sequences[index] += line + + return sequences, descriptions + + +def parse_stockholm( + stockholm_string: str +) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]: + """Parses sequences and deletion matrix from stockholm format alignment. + + Args: + stockholm_string: The string contents of a stockholm file. The first + sequence in the file should be the query sequence. + + Returns: + A tuple of: + * A list of sequences that have been aligned to the query. These + might contain duplicates. + * The deletion matrix for the alignment as a list of lists. The element + at `deletion_matrix[i][j]` is the number of residues deleted from + the aligned sequence i at residue position j. + * The names of the targets matched, including the jackhmmer subsequence + suffix. + """ + name_to_sequence = collections.OrderedDict() + for line in stockholm_string.splitlines(): + line = line.strip() + if not line or line.startswith(('#', '//')): + continue + name, sequence = line.split() + if name not in name_to_sequence: + name_to_sequence[name] = '' + name_to_sequence[name] += sequence + + msa = [] + deletion_matrix = [] + + query = '' + keep_columns = [] + for seq_index, sequence in enumerate(name_to_sequence.values()): + if seq_index == 0: + # Gather the columns with gaps from the query + query = sequence + keep_columns = [i for i, res in enumerate(query) if res != '-'] + + # Remove the columns with gaps in the query from all sequences. + aligned_sequence = ''.join([sequence[c] for c in keep_columns]) + + msa.append(aligned_sequence) + + # Count the number of deletions w.r.t. query. + deletion_vec = [] + deletion_count = 0 + for seq_res, query_res in zip(sequence, query): + if seq_res != '-' or query_res != '-': + if query_res == '-': + deletion_count += 1 + else: + deletion_vec.append(deletion_count) + deletion_count = 0 + deletion_matrix.append(deletion_vec) + + return msa, deletion_matrix, list(name_to_sequence.keys()) + + +def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: + """Parses sequences and deletion matrix from a3m format alignment. + + Args: + a3m_string: The string contents of a a3m file. The first sequence in the + file should be the query sequence. + + Returns: + A tuple of: + * A list of sequences that have been aligned to the query. These + might contain duplicates. + * The deletion matrix for the alignment as a list of lists. The element + at `deletion_matrix[i][j]` is the number of residues deleted from + the aligned sequence i at residue position j. + """ + sequences, _ = parse_fasta(a3m_string) + deletion_matrix = [] + for msa_sequence in sequences: + deletion_vec = [] + deletion_count = 0 + for j in msa_sequence: + if j.islower(): + deletion_count += 1 + else: + deletion_vec.append(deletion_count) + deletion_count = 0 + deletion_matrix.append(deletion_vec) + + # Make the MSA matrix out of aligned (deletion-free) sequences. + deletion_table = str.maketrans('', '', string.ascii_lowercase) + aligned_sequences = [s.translate(deletion_table) for s in sequences] + return aligned_sequences, deletion_matrix + + +def _convert_sto_seq_to_a3m( + query_non_gaps: Sequence[bool], sto_seq: str) -> Iterable[str]: + for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq): + if is_query_res_non_gap: + yield sequence_res + elif sequence_res != '-': + yield sequence_res.lower() + + +def convert_stockholm_to_a3m(stockholm_format: str, + max_sequences: Optional[int] = None) -> str: + """Converts MSA in Stockholm format to the A3M format.""" + descriptions = {} + sequences = {} + reached_max_sequences = False + + for line in stockholm_format.splitlines(): + reached_max_sequences = max_sequences and len(sequences) >= max_sequences + if line.strip() and not line.startswith(('#', '//')): + # Ignore blank lines, markup and end symbols - remainder are alignment + # sequence parts. + seqname, aligned_seq = line.split(maxsplit=1) + if seqname not in sequences: + if reached_max_sequences: + continue + sequences[seqname] = '' + sequences[seqname] += aligned_seq + + for line in stockholm_format.splitlines(): + if line[:4] == '#=GS': + # Description row - example format is: + # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... + columns = line.split(maxsplit=3) + seqname, feature = columns[1:3] + value = columns[3] if len(columns) == 4 else '' + if feature != 'DE': + continue + if reached_max_sequences and seqname not in sequences: + continue + descriptions[seqname] = value + if len(descriptions) == len(sequences): + break + + # Convert sto format to a3m line by line + a3m_sequences = {} + # query_sequence is assumed to be the first sequence + query_sequence = next(iter(sequences.values())) + query_non_gaps = [res != '-' for res in query_sequence] + for seqname, sto_sequence in sequences.items(): + a3m_sequences[seqname] = ''.join( + _convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)) + + fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" + for k in a3m_sequences) + return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. + + +def _get_hhr_line_regex_groups( + regex_pattern: str, line: str) -> Sequence[Optional[str]]: + match = re.match(regex_pattern, line) + if match is None: + raise RuntimeError(f'Could not parse query line {line}') + return match.groups() + + +def _update_hhr_residue_indices_list( + sequence: str, start_index: int, indices_list: List[int]): + """Computes the relative indices for each residue with respect to the original sequence.""" + counter = start_index + for symbol in sequence: + if symbol == '-': + indices_list.append(-1) + else: + indices_list.append(counter) + counter += 1 + + +def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: + """Parses the detailed HMM HMM comparison section for a single Hit. + + This works on .hhr files generated from both HHBlits and HHSearch. + + Args: + detailed_lines: A list of lines from a single comparison section between 2 + sequences (which each have their own HMM's) + + Returns: + A dictionary with the information from that detailed comparison section + + Raises: + RuntimeError: If a certain line cannot be processed + """ + # Parse first 2 lines. + number_of_hit = int(detailed_lines[0].split()[-1]) + name_hit = detailed_lines[1][1:] + + # Parse the summary line. + pattern = ( + 'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t' + ' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t ' + ']*Template_Neff=(.*)') + match = re.match(pattern, detailed_lines[2]) + if match is None: + raise RuntimeError( + 'Could not parse section: %s. Expected this: \n%s to contain summary.' % + (detailed_lines, detailed_lines[2])) + (prob_true, e_value, _, aligned_cols, _, _, sum_probs, + neff) = [float(x) for x in match.groups()] + + # The next section reads the detailed comparisons. These are in a 'human + # readable' format which has a fixed length. The strategy employed is to + # assume that each block starts with the query sequence line, and to parse + # that with a regexp in order to deduce the fixed length used for that block. + query = '' + hit_sequence = '' + indices_query = [] + indices_hit = [] + length_block = None + + for line in detailed_lines[3:]: + # Parse the query sequence line + if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and + not line.startswith('Q ss_pred') and + not line.startswith('Q Consensus')): + # Thus the first 17 characters must be 'Q ', and we can parse + # everything after that. + # start sequence end total_sequence_length + patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)' + groups = _get_hhr_line_regex_groups(patt, line[17:]) + + # Get the length of the parsed block using the start and finish indices, + # and ensure it is the same as the actual block length. + start = int(groups[0]) - 1 # Make index zero based. + delta_query = groups[1] + end = int(groups[2]) + num_insertions = len([x for x in delta_query if x == '-']) + length_block = end - start + num_insertions + assert length_block == len(delta_query) + + # Update the query sequence and indices list. + query += delta_query + _update_hhr_residue_indices_list(delta_query, start, indices_query) + + elif line.startswith('T '): + # Parse the hit sequence. + if (not line.startswith('T ss_dssp') and + not line.startswith('T ss_pred') and + not line.startswith('T Consensus')): + # Thus the first 17 characters must be 'T ', and we can + # parse everything after that. + # start sequence end total_sequence_length + patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)' + groups = _get_hhr_line_regex_groups(patt, line[17:]) + start = int(groups[0]) - 1 # Make index zero based. + delta_hit_sequence = groups[1] + assert length_block == len(delta_hit_sequence) + + # Update the hit sequence and indices list. + hit_sequence += delta_hit_sequence + _update_hhr_residue_indices_list(delta_hit_sequence, start, indices_hit) + + return TemplateHit( + index=number_of_hit, + name=name_hit, + aligned_cols=int(aligned_cols), + sum_probs=sum_probs, + query=query, + hit_sequence=hit_sequence, + indices_query=indices_query, + indices_hit=indices_hit, + ) + + +def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]: + """Parses the content of an entire HHR file.""" + lines = hhr_string.splitlines() + + # Each .hhr file starts with a results table, then has a sequence of hit + # "paragraphs", each paragraph starting with a line 'No '. We + # iterate through each paragraph to parse each hit. + + block_starts = [i for i, line in enumerate(lines) if line.startswith('No ')] + + hits = [] + if block_starts: + block_starts.append(len(lines)) # Add the end of the final block. + for i in range(len(block_starts) - 1): + hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]])) + return hits + + +def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: + """Parse target to e-value mapping parsed from Jackhmmer tblout string.""" + e_values = {'query': 0} + lines = [line for line in tblout.splitlines() if line[0] != '#'] + # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are + # space-delimited. Relevant fields are (1) target name: and + # (5) E-value (full sequence) (numbering from 1). + for line in lines: + fields = line.split() + e_value = fields[4] + target_name = fields[0] + e_values[target_name] = float(e_value) + return e_values diff --git a/af_backprop/alphafold/data/pipeline.py b/af_backprop/alphafold/data/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..461bce875ab6f9cad4e2b0897c44a6cf1ef399ae --- /dev/null +++ b/af_backprop/alphafold/data/pipeline.py @@ -0,0 +1,209 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for building the input features for the AlphaFold model.""" + +import os +from typing import Mapping, Optional, Sequence +from absl import logging +from alphafold.common import residue_constants +from alphafold.data import parsers +from alphafold.data import templates +from alphafold.data.tools import hhblits +from alphafold.data.tools import hhsearch +from alphafold.data.tools import jackhmmer +import numpy as np + +# Internal import (7716). + +FeatureDict = Mapping[str, np.ndarray] + + +def make_sequence_features( + sequence: str, description: str, num_res: int) -> FeatureDict: + """Constructs a feature dict of sequence features.""" + features = {} + features['aatype'] = residue_constants.sequence_to_onehot( + sequence=sequence, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True) + features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32) + features['domain_name'] = np.array([description.encode('utf-8')], + dtype=np.object_) + features['residue_index'] = np.array(range(num_res), dtype=np.int32) + features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) + features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_) + return features + + +def make_msa_features( + msas: Sequence[Sequence[str]], + deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict: + """Constructs a feature dict of MSA features.""" + if not msas: + raise ValueError('At least one MSA must be provided.') + + int_msa = [] + deletion_matrix = [] + seen_sequences = set() + for msa_index, msa in enumerate(msas): + if not msa: + raise ValueError(f'MSA {msa_index} must contain at least one sequence.') + for sequence_index, sequence in enumerate(msa): + if sequence in seen_sequences: + continue + seen_sequences.add(sequence) + int_msa.append( + [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) + deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) + + num_res = len(msas[0][0]) + num_alignments = len(int_msa) + features = {} + features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) + features['msa'] = np.array(int_msa, dtype=np.int32) + features['num_alignments'] = np.array( + [num_alignments] * num_res, dtype=np.int32) + return features + + +class DataPipeline: + """Runs the alignment tools and assembles the input features.""" + + def __init__(self, + jackhmmer_binary_path: str, + hhblits_binary_path: str, + hhsearch_binary_path: str, + uniref90_database_path: str, + mgnify_database_path: str, + bfd_database_path: Optional[str], + uniclust30_database_path: Optional[str], + small_bfd_database_path: Optional[str], + pdb70_database_path: str, + template_featurizer: templates.TemplateHitFeaturizer, + use_small_bfd: bool, + mgnify_max_hits: int = 501, + uniref_max_hits: int = 10000): + """Constructs a feature dict for a given FASTA file.""" + self._use_small_bfd = use_small_bfd + self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=uniref90_database_path) + if use_small_bfd: + self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=small_bfd_database_path) + else: + self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( + binary_path=hhblits_binary_path, + databases=[bfd_database_path, uniclust30_database_path]) + self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=mgnify_database_path) + self.hhsearch_pdb70_runner = hhsearch.HHSearch( + binary_path=hhsearch_binary_path, + databases=[pdb70_database_path]) + self.template_featurizer = template_featurizer + self.mgnify_max_hits = mgnify_max_hits + self.uniref_max_hits = uniref_max_hits + + def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: + """Runs alignment tools on the input sequence and creates features.""" + with open(input_fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) + if len(input_seqs) != 1: + raise ValueError( + f'More than one input sequence found in {input_fasta_path}.') + input_sequence = input_seqs[0] + input_description = input_descs[0] + num_res = len(input_sequence) + + jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query( + input_fasta_path)[0] + jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query( + input_fasta_path)[0] + + uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( + jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits) + hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m) + + uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') + with open(uniref90_out_path, 'w') as f: + f.write(jackhmmer_uniref90_result['sto']) + + mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') + with open(mgnify_out_path, 'w') as f: + f.write(jackhmmer_mgnify_result['sto']) + + pdb70_out_path = os.path.join(msa_output_dir, 'pdb70_hits.hhr') + with open(pdb70_out_path, 'w') as f: + f.write(hhsearch_result) + + uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm( + jackhmmer_uniref90_result['sto']) + mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm( + jackhmmer_mgnify_result['sto']) + hhsearch_hits = parsers.parse_hhr(hhsearch_result) + mgnify_msa = mgnify_msa[:self.mgnify_max_hits] + mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits] + + if self._use_small_bfd: + jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query( + input_fasta_path)[0] + + bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.a3m') + with open(bfd_out_path, 'w') as f: + f.write(jackhmmer_small_bfd_result['sto']) + + bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm( + jackhmmer_small_bfd_result['sto']) + else: + hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query( + input_fasta_path) + + bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') + with open(bfd_out_path, 'w') as f: + f.write(hhblits_bfd_uniclust_result['a3m']) + + bfd_msa, bfd_deletion_matrix = parsers.parse_a3m( + hhblits_bfd_uniclust_result['a3m']) + + templates_result = self.template_featurizer.get_templates( + query_sequence=input_sequence, + query_pdb_code=None, + query_release_date=None, + hits=hhsearch_hits) + + sequence_features = make_sequence_features( + sequence=input_sequence, + description=input_description, + num_res=num_res) + + msa_features = make_msa_features( + msas=(uniref90_msa, bfd_msa, mgnify_msa), + deletion_matrices=(uniref90_deletion_matrix, + bfd_deletion_matrix, + mgnify_deletion_matrix)) + + logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa)) + logging.info('BFD MSA size: %d sequences.', len(bfd_msa)) + logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa)) + logging.info('Final (deduplicated) MSA size: %d sequences.', + msa_features['num_alignments'][0]) + logging.info('Total number of templates (NB: this can include bad ' + 'templates and is later filtered to top 4): %d.', + templates_result.features['template_domain_names'].shape[0]) + + return {**sequence_features, **msa_features, **templates_result.features} diff --git a/af_backprop/alphafold/data/prep_inputs.py b/af_backprop/alphafold/data/prep_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..ffd871964796444b11f8cbd0573c73cdce13851a --- /dev/null +++ b/af_backprop/alphafold/data/prep_inputs.py @@ -0,0 +1,133 @@ +import numpy as np +from alphafold.common import residue_constants + +def make_atom14_positions(prot): + """Constructs denser atom positions (14 dimensions instead of 37).""" + restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 + restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 + restype_atom14_mask = [] + + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + + restype_atom14_to_atom37.append([ + (residue_constants.atom_order[name] if name else 0) + for name in atom_names + ]) + + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in residue_constants.atom_types + ]) + + restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) + + # Add dummy mapping for restype 'UNK'. + restype_atom14_to_atom37.append([0] * 14) + restype_atom37_to_atom14.append([0] * 37) + restype_atom14_mask.append([0.] * 14) + + restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) + restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + + # Create the mapping for (residx, atom14) --> atom37, i.e. an array + # with shape (num_res, 14) containing the atom37 indices for this protein. + residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]] + residx_atom14_mask = restype_atom14_mask[prot["aatype"]] + + # Create a mask for known ground truth positions. + residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis( + prot["all_atom_mask"], residx_atom14_to_atom37, axis=1).astype(np.float32) + + # Gather the ground truth positions. + residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * ( + np.take_along_axis(prot["all_atom_positions"], + residx_atom14_to_atom37[..., None], + axis=1)) + + prot["atom14_atom_exists"] = residx_atom14_mask + prot["atom14_gt_exists"] = residx_atom14_gt_mask + prot["atom14_gt_positions"] = residx_atom14_gt_positions + + prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37 + + # Create the gather indices for mapping back. + residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]] + prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14 + + # Create the corresponding mask. + restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) + for restype, restype_letter in enumerate(residue_constants.restypes): + restype_name = residue_constants.restype_1to3[restype_letter] + atom_names = residue_constants.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = residue_constants.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + + residx_atom37_mask = restype_atom37_mask[prot["aatype"]] + prot["atom37_atom_exists"] = residx_atom37_mask + + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative ground truth coordinates where the naming is swapped + restype_3 = [ + residue_constants.restype_1to3[res] for res in residue_constants.restypes + ] + restype_3 += ["UNK"] + + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = residue_constants.restype_name_to_atom14_names[ + resname].index(source_atom_swap) + target_index = residue_constants.restype_name_to_atom14_names[ + resname].index(target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14), dtype=np.float32) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1. + all_matrices[resname] = renaming_matrix.astype(np.float32) + renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) + + # Pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14). + renaming_transform = renaming_matrices[prot["aatype"]] + + # Apply it to the ground truth positions. shape (num_res, 14, 3). + alternative_gt_positions = np.einsum("rac,rab->rbc", + residx_atom14_gt_positions, + renaming_transform) + prot["atom14_alt_gt_positions"] = alternative_gt_positions + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position). + alternative_gt_mask = np.einsum("ra,rab->rb", + residx_atom14_gt_mask, + renaming_transform) + + prot["atom14_alt_gt_exists"] = alternative_gt_mask + + # Create an ambiguous atoms mask. shape: (21, 14). + restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32) + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = residue_constants.restype_order[ + residue_constants.restype_3to1[resname]] + atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index( + atom_name1) + atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index( + atom_name2) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + + # From this create an ambiguous_mask for the given sequence. + prot["atom14_atom_is_ambiguous"] = ( + restype_atom14_is_ambiguous[prot["aatype"]]) + + return prot \ No newline at end of file diff --git a/af_backprop/alphafold/data/templates.py b/af_backprop/alphafold/data/templates.py new file mode 100644 index 0000000000000000000000000000000000000000..a9fc45865604bcd2aeef429c8d183f96ef4bc3a0 --- /dev/null +++ b/af_backprop/alphafold/data/templates.py @@ -0,0 +1,910 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for getting templates and calculating template features.""" +import dataclasses +import datetime +import glob +import os +import re +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple + +from absl import logging +from alphafold.common import residue_constants +from alphafold.data import mmcif_parsing +from alphafold.data import parsers +from alphafold.data.tools import kalign +import numpy as np + +# Internal import (7716). + + +class Error(Exception): + """Base class for exceptions.""" + + +class NoChainsError(Error): + """An error indicating that template mmCIF didn't have any chains.""" + + +class SequenceNotInTemplateError(Error): + """An error indicating that template mmCIF didn't contain the sequence.""" + + +class NoAtomDataInTemplateError(Error): + """An error indicating that template mmCIF didn't contain atom positions.""" + + +class TemplateAtomMaskAllZerosError(Error): + """An error indicating that template mmCIF had all atom positions masked.""" + + +class QueryToTemplateAlignError(Error): + """An error indicating that the query can't be aligned to the template.""" + + +class CaDistanceError(Error): + """An error indicating that a CA atom distance exceeds a threshold.""" + + +class MultipleChainsError(Error): + """An error indicating that multiple chains were found for a given ID.""" + + +# Prefilter exceptions. +class PrefilterError(Exception): + """A base class for template prefilter exceptions.""" + + +class DateError(PrefilterError): + """An error indicating that the hit date was after the max allowed date.""" + + +class PdbIdError(PrefilterError): + """An error indicating that the hit PDB ID was identical to the query.""" + + +class AlignRatioError(PrefilterError): + """An error indicating that the hit align ratio to the query was too small.""" + + +class DuplicateError(PrefilterError): + """An error indicating that the hit was an exact subsequence of the query.""" + + +class LengthError(PrefilterError): + """An error indicating that the hit was too short.""" + + +TEMPLATE_FEATURES = { + 'template_aatype': np.float32, + 'template_all_atom_masks': np.float32, + 'template_all_atom_positions': np.float32, + 'template_domain_names': np.object, + 'template_sequence': np.object, + 'template_sum_probs': np.float32, +} + + +def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]: + """Returns PDB id and chain id for an HHSearch Hit.""" + # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown. + id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name) + if not id_match: + raise ValueError(f'hit.name did not start with PDBID_chain: {hit.name}') + pdb_id, chain_id = id_match.group(0).split('_') + return pdb_id.lower(), chain_id + + +def _is_after_cutoff( + pdb_id: str, + release_dates: Mapping[str, datetime.datetime], + release_date_cutoff: Optional[datetime.datetime]) -> bool: + """Checks if the template date is after the release date cutoff. + + Args: + pdb_id: 4 letter pdb code. + release_dates: Dictionary mapping PDB ids to their structure release dates. + release_date_cutoff: Max release date that is valid for this query. + + Returns: + True if the template release date is after the cutoff, False otherwise. + """ + if release_date_cutoff is None: + raise ValueError('The release_date_cutoff must not be None.') + if pdb_id in release_dates: + return release_dates[pdb_id] > release_date_cutoff + else: + # Since this is just a quick prefilter to reduce the number of mmCIF files + # we need to parse, we don't have to worry about returning True here. + logging.warning('Template structure not in release dates dict: %s', pdb_id) + return False + + +def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: + """Parses the data file from PDB that lists which PDB ids are obsolete.""" + with open(obsolete_file_path) as f: + result = {} + for line in f: + line = line.strip() + # We skip obsolete entries that don't contain a mapping to a new entry. + if line.startswith('OBSLTE') and len(line) > 30: + # Format: Date From To + # 'OBSLTE 31-JUL-94 116L 216L' + from_id = line[20:24].lower() + to_id = line[29:33].lower() + result[from_id] = to_id + return result + + +def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]: + """Parses release dates file, returns a mapping from PDBs to release dates.""" + if path.endswith('txt'): + release_dates = {} + with open(path, 'r') as f: + for line in f: + pdb_id, date = line.split(':') + date = date.strip() + # Python 3.6 doesn't have datetime.date.fromisoformat() which is about + # 90x faster than strptime. However, splitting the string manually is + # about 10x faster than strptime. + release_dates[pdb_id.strip()] = datetime.datetime( + year=int(date[:4]), month=int(date[5:7]), day=int(date[8:10])) + return release_dates + else: + raise ValueError('Invalid format of the release date file %s.' % path) + + +def _assess_hhsearch_hit( + hit: parsers.TemplateHit, + hit_pdb_code: str, + query_sequence: str, + query_pdb_code: Optional[str], + release_dates: Mapping[str, datetime.datetime], + release_date_cutoff: datetime.datetime, + max_subsequence_ratio: float = 0.95, + min_align_ratio: float = 0.1) -> bool: + """Determines if template is valid (without parsing the template mmcif file). + + Args: + hit: HhrHit for the template. + hit_pdb_code: The 4 letter pdb code of the template hit. This might be + different from the value in the actual hit since the original pdb might + have become obsolete. + query_sequence: Amino acid sequence of the query. + query_pdb_code: 4 letter pdb code of the query. + release_dates: Dictionary mapping pdb codes to their structure release + dates. + release_date_cutoff: Max release date that is valid for this query. + max_subsequence_ratio: Exclude any exact matches with this much overlap. + min_align_ratio: Minimum overlap between the template and query. + + Returns: + True if the hit passed the prefilter. Raises an exception otherwise. + + Raises: + DateError: If the hit date was after the max allowed date. + PdbIdError: If the hit PDB ID was identical to the query. + AlignRatioError: If the hit align ratio to the query was too small. + DuplicateError: If the hit was an exact subsequence of the query. + LengthError: If the hit was too short. + """ + aligned_cols = hit.aligned_cols + align_ratio = aligned_cols / len(query_sequence) + + template_sequence = hit.hit_sequence.replace('-', '') + length_ratio = float(len(template_sequence)) / len(query_sequence) + + # Check whether the template is a large subsequence or duplicate of original + # query. This can happen due to duplicate entries in the PDB database. + duplicate = (template_sequence in query_sequence and + length_ratio > max_subsequence_ratio) + + if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff): + raise DateError(f'Date ({release_dates[hit_pdb_code]}) > max template date ' + f'({release_date_cutoff}).') + + if query_pdb_code is not None: + if query_pdb_code.lower() == hit_pdb_code.lower(): + raise PdbIdError('PDB code identical to Query PDB code.') + + if align_ratio <= min_align_ratio: + raise AlignRatioError('Proportion of residues aligned to query too small. ' + f'Align ratio: {align_ratio}.') + + if duplicate: + raise DuplicateError('Template is an exact subsequence of query with large ' + f'coverage. Length ratio: {length_ratio}.') + + if len(template_sequence) < 10: + raise LengthError(f'Template too short. Length: {len(template_sequence)}.') + + return True + + +def _find_template_in_pdb( + template_chain_id: str, + template_sequence: str, + mmcif_object: mmcif_parsing.MmcifObject) -> Tuple[str, str, int]: + """Tries to find the template chain in the given pdb file. + + This method tries the three following things in order: + 1. Tries if there is an exact match in both the chain ID and the sequence. + If yes, the chain sequence is returned. Otherwise: + 2. Tries if there is an exact match only in the sequence. + If yes, the chain sequence is returned. Otherwise: + 3. Tries if there is a fuzzy match (X = wildcard) in the sequence. + If yes, the chain sequence is returned. + If none of these succeed, a SequenceNotInTemplateError is thrown. + + Args: + template_chain_id: The template chain ID. + template_sequence: The template chain sequence. + mmcif_object: The PDB object to search for the template in. + + Returns: + A tuple with: + * The chain sequence that was found to match the template in the PDB object. + * The ID of the chain that is being returned. + * The offset where the template sequence starts in the chain sequence. + + Raises: + SequenceNotInTemplateError: If no match is found after the steps described + above. + """ + # Try if there is an exact match in both the chain ID and the (sub)sequence. + pdb_id = mmcif_object.file_id + chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id) + if chain_sequence and (template_sequence in chain_sequence): + logging.info( + 'Found an exact template match %s_%s.', pdb_id, template_chain_id) + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, template_chain_id, mapping_offset + + # Try if there is an exact match in the (sub)sequence only. + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + if chain_sequence and (template_sequence in chain_sequence): + logging.info('Found a sequence-only match %s_%s.', pdb_id, chain_id) + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, chain_id, mapping_offset + + # Return a chain sequence that fuzzy matches (X = wildcard) the template. + # Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit. + regex = ['.' if aa == 'X' else '(?:%s|X)' % aa for aa in template_sequence] + regex = re.compile(''.join(regex)) + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + match = re.search(regex, chain_sequence) + if match: + logging.info('Found a fuzzy sequence-only match %s_%s.', pdb_id, chain_id) + mapping_offset = match.start() + return chain_sequence, chain_id, mapping_offset + + # No hits, raise an error. + raise SequenceNotInTemplateError( + 'Could not find the template sequence in %s_%s. Template sequence: %s, ' + 'chain_to_seqres: %s' % (pdb_id, template_chain_id, template_sequence, + mmcif_object.chain_to_seqres)) + + +def _realign_pdb_template_to_query( + old_template_sequence: str, + template_chain_id: str, + mmcif_object: mmcif_parsing.MmcifObject, + old_mapping: Mapping[int, int], + kalign_binary_path: str) -> Tuple[str, Mapping[int, int]]: + """Aligns template from the mmcif_object to the query. + + In case PDB70 contains a different version of the template sequence, we need + to perform a realignment to the actual sequence that is in the mmCIF file. + This method performs such realignment, but returns the new sequence and + mapping only if the sequence in the mmCIF file is 90% identical to the old + sequence. + + Note that the old_template_sequence comes from the hit, and contains only that + part of the chain that matches with the query while the new_template_sequence + is the full chain. + + Args: + old_template_sequence: The template sequence that was returned by the PDB + template search (typically done using HHSearch). + template_chain_id: The template chain id was returned by the PDB template + search (typically done using HHSearch). This is used to find the right + chain in the mmcif_object chain_to_seqres mapping. + mmcif_object: A mmcif_object which holds the actual template data. + old_mapping: A mapping from the query sequence to the template sequence. + This mapping will be used to compute the new mapping from the query + sequence to the actual mmcif_object template sequence by aligning the + old_template_sequence and the actual template sequence. + kalign_binary_path: The path to a kalign executable. + + Returns: + A tuple (new_template_sequence, new_query_to_template_mapping) where: + * new_template_sequence is the actual template sequence that was found in + the mmcif_object. + * new_query_to_template_mapping is the new mapping from the query to the + actual template found in the mmcif_object. + + Raises: + QueryToTemplateAlignError: + * If there was an error thrown by the alignment tool. + * Or if the actual template sequence differs by more than 10% from the + old_template_sequence. + """ + aligner = kalign.Kalign(binary_path=kalign_binary_path) + new_template_sequence = mmcif_object.chain_to_seqres.get( + template_chain_id, '') + + # Sometimes the template chain id is unknown. But if there is only a single + # sequence within the mmcif_object, it is safe to assume it is that one. + if not new_template_sequence: + if len(mmcif_object.chain_to_seqres) == 1: + logging.info('Could not find %s in %s, but there is only 1 sequence, so ' + 'using that one.', + template_chain_id, + mmcif_object.file_id) + new_template_sequence = list(mmcif_object.chain_to_seqres.values())[0] + else: + raise QueryToTemplateAlignError( + f'Could not find chain {template_chain_id} in {mmcif_object.file_id}. ' + 'If there are no mmCIF parsing errors, it is possible it was not a ' + 'protein chain.') + + try: + (old_aligned_template, new_aligned_template), _ = parsers.parse_a3m( + aligner.align([old_template_sequence, new_template_sequence])) + except Exception as e: + raise QueryToTemplateAlignError( + 'Could not align old template %s to template %s (%s_%s). Error: %s' % + (old_template_sequence, new_template_sequence, mmcif_object.file_id, + template_chain_id, str(e))) + + logging.info('Old aligned template: %s\nNew aligned template: %s', + old_aligned_template, new_aligned_template) + + old_to_new_template_mapping = {} + old_template_index = -1 + new_template_index = -1 + num_same = 0 + for old_template_aa, new_template_aa in zip( + old_aligned_template, new_aligned_template): + if old_template_aa != '-': + old_template_index += 1 + if new_template_aa != '-': + new_template_index += 1 + if old_template_aa != '-' and new_template_aa != '-': + old_to_new_template_mapping[old_template_index] = new_template_index + if old_template_aa == new_template_aa: + num_same += 1 + + # Require at least 90 % sequence identity wrt to the shorter of the sequences. + if float(num_same) / min( + len(old_template_sequence), len(new_template_sequence)) < 0.9: + raise QueryToTemplateAlignError( + 'Insufficient similarity of the sequence in the database: %s to the ' + 'actual sequence in the mmCIF file %s_%s: %s. We require at least ' + '90 %% similarity wrt to the shorter of the sequences. This is not a ' + 'problem unless you think this is a template that should be included.' % + (old_template_sequence, mmcif_object.file_id, template_chain_id, + new_template_sequence)) + + new_query_to_template_mapping = {} + for query_index, old_template_index in old_mapping.items(): + new_query_to_template_mapping[query_index] = ( + old_to_new_template_mapping.get(old_template_index, -1)) + + new_template_sequence = new_template_sequence.replace('-', '') + + return new_template_sequence, new_query_to_template_mapping + + +def _check_residue_distances(all_positions: np.ndarray, + all_positions_mask: np.ndarray, + max_ca_ca_distance: float): + """Checks if the distance between unmasked neighbor residues is ok.""" + ca_position = residue_constants.atom_order['CA'] + prev_is_unmasked = False + prev_calpha = None + for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)): + this_is_unmasked = bool(mask[ca_position]) + if this_is_unmasked: + this_calpha = coords[ca_position] + if prev_is_unmasked: + distance = np.linalg.norm(this_calpha - prev_calpha) + if distance > max_ca_ca_distance: + raise CaDistanceError( + 'The distance between residues %d and %d is %f > limit %f.' % ( + i, i + 1, distance, max_ca_ca_distance)) + prev_calpha = this_calpha + prev_is_unmasked = this_is_unmasked + + +def _get_atom_positions( + mmcif_object: mmcif_parsing.MmcifObject, + auth_chain_id: str, + max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]: + """Gets atom positions and mask from a list of Biopython Residues.""" + num_res = len(mmcif_object.chain_to_seqres[auth_chain_id]) + + relevant_chains = [c for c in mmcif_object.structure.get_chains() + if c.id == auth_chain_id] + if len(relevant_chains) != 1: + raise MultipleChainsError( + f'Expected exactly one chain in structure with id {auth_chain_id}.') + chain = relevant_chains[0] + + all_positions = np.zeros([num_res, residue_constants.atom_type_num, 3]) + all_positions_mask = np.zeros([num_res, residue_constants.atom_type_num], + dtype=np.int64) + for res_index in range(num_res): + pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32) + mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32) + res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][res_index] + if not res_at_position.is_missing: + res = chain[(res_at_position.hetflag, + res_at_position.position.residue_number, + res_at_position.position.insertion_code)] + for atom in res.get_atoms(): + atom_name = atom.get_name() + x, y, z = atom.get_coord() + if atom_name in residue_constants.atom_order.keys(): + pos[residue_constants.atom_order[atom_name]] = [x, y, z] + mask[residue_constants.atom_order[atom_name]] = 1.0 + elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE': + # Put the coordinates of the selenium atom in the sulphur column. + pos[residue_constants.atom_order['SD']] = [x, y, z] + mask[residue_constants.atom_order['SD']] = 1.0 + + all_positions[res_index] = pos + all_positions_mask[res_index] = mask + _check_residue_distances( + all_positions, all_positions_mask, max_ca_ca_distance) + return all_positions, all_positions_mask + + +def _extract_template_features( + mmcif_object: mmcif_parsing.MmcifObject, + pdb_id: str, + mapping: Mapping[int, int], + template_sequence: str, + query_sequence: str, + template_chain_id: str, + kalign_binary_path: str) -> Tuple[Dict[str, Any], Optional[str]]: + """Parses atom positions in the target structure and aligns with the query. + + Atoms for each residue in the template structure are indexed to coincide + with their corresponding residue in the query sequence, according to the + alignment mapping provided. + + Args: + mmcif_object: mmcif_parsing.MmcifObject representing the template. + pdb_id: PDB code for the template. + mapping: Dictionary mapping indices in the query sequence to indices in + the template sequence. + template_sequence: String describing the amino acid sequence for the + template protein. + query_sequence: String describing the amino acid sequence for the query + protein. + template_chain_id: String ID describing which chain in the structure proto + should be used. + kalign_binary_path: The path to a kalign executable used for template + realignment. + + Returns: + A tuple with: + * A dictionary containing the extra features derived from the template + protein structure. + * A warning message if the hit was realigned to the actual mmCIF sequence. + Otherwise None. + + Raises: + NoChainsError: If the mmcif object doesn't contain any chains. + SequenceNotInTemplateError: If the given chain id / sequence can't + be found in the mmcif object. + QueryToTemplateAlignError: If the actual template in the mmCIF file + can't be aligned to the query. + NoAtomDataInTemplateError: If the mmcif object doesn't contain + atom positions. + TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any + unmasked residues. + """ + if mmcif_object is None or not mmcif_object.chain_to_seqres: + raise NoChainsError('No chains in PDB: %s_%s' % (pdb_id, template_chain_id)) + + warning = None + try: + seqres, chain_id, mapping_offset = _find_template_in_pdb( + template_chain_id=template_chain_id, + template_sequence=template_sequence, + mmcif_object=mmcif_object) + except SequenceNotInTemplateError: + # If PDB70 contains a different version of the template, we use the sequence + # from the mmcif_object. + chain_id = template_chain_id + warning = ( + f'The exact sequence {template_sequence} was not found in ' + f'{pdb_id}_{chain_id}. Realigning the template to the actual sequence.') + logging.warning(warning) + # This throws an exception if it fails to realign the hit. + seqres, mapping = _realign_pdb_template_to_query( + old_template_sequence=template_sequence, + template_chain_id=template_chain_id, + mmcif_object=mmcif_object, + old_mapping=mapping, + kalign_binary_path=kalign_binary_path) + logging.info('Sequence in %s_%s: %s successfully realigned to %s', + pdb_id, chain_id, template_sequence, seqres) + # The template sequence changed. + template_sequence = seqres + # No mapping offset, the query is aligned to the actual sequence. + mapping_offset = 0 + + try: + # Essentially set to infinity - we don't want to reject templates unless + # they're really really bad. + all_atom_positions, all_atom_mask = _get_atom_positions( + mmcif_object, chain_id, max_ca_ca_distance=150.0) + except (CaDistanceError, KeyError) as ex: + raise NoAtomDataInTemplateError( + 'Could not get atom data (%s_%s): %s' % (pdb_id, chain_id, str(ex)) + ) from ex + + all_atom_positions = np.split(all_atom_positions, all_atom_positions.shape[0]) + all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0]) + + output_templates_sequence = [] + templates_all_atom_positions = [] + templates_all_atom_masks = [] + + for _ in query_sequence: + # Residues in the query_sequence that are not in the template_sequence: + templates_all_atom_positions.append( + np.zeros((residue_constants.atom_type_num, 3))) + templates_all_atom_masks.append(np.zeros(residue_constants.atom_type_num)) + output_templates_sequence.append('-') + + for k, v in mapping.items(): + template_index = v + mapping_offset + templates_all_atom_positions[k] = all_atom_positions[template_index][0] + templates_all_atom_masks[k] = all_atom_masks[template_index][0] + output_templates_sequence[k] = template_sequence[v] + + # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O). + if np.sum(templates_all_atom_masks) < 5: + raise TemplateAtomMaskAllZerosError( + 'Template all atom mask was all zeros: %s_%s. Residue range: %d-%d' % + (pdb_id, chain_id, min(mapping.values()) + mapping_offset, + max(mapping.values()) + mapping_offset)) + + output_templates_sequence = ''.join(output_templates_sequence) + + templates_aatype = residue_constants.sequence_to_onehot( + output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID) + + return ( + { + 'template_all_atom_positions': np.array(templates_all_atom_positions), + 'template_all_atom_masks': np.array(templates_all_atom_masks), + 'template_sequence': output_templates_sequence.encode(), + 'template_aatype': np.array(templates_aatype), + 'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(), + }, + warning) + + +def _build_query_to_hit_index_mapping( + hit_query_sequence: str, + hit_sequence: str, + indices_hit: Sequence[int], + indices_query: Sequence[int], + original_query_sequence: str) -> Mapping[int, int]: + """Gets mapping from indices in original query sequence to indices in the hit. + + hit_query_sequence and hit_sequence are two aligned sequences containing gap + characters. hit_query_sequence contains only the part of the original query + sequence that matched the hit. When interpreting the indices from the .hhr, we + need to correct for this to recover a mapping from original query sequence to + the hit sequence. + + Args: + hit_query_sequence: The portion of the query sequence that is in the .hhr + hit + hit_sequence: The portion of the hit sequence that is in the .hhr + indices_hit: The indices for each aminoacid relative to the hit sequence + indices_query: The indices for each aminoacid relative to the original query + sequence + original_query_sequence: String describing the original query sequence. + + Returns: + Dictionary with indices in the original query sequence as keys and indices + in the hit sequence as values. + """ + # If the hit is empty (no aligned residues), return empty mapping + if not hit_query_sequence: + return {} + + # Remove gaps and find the offset of hit.query relative to original query. + hhsearch_query_sequence = hit_query_sequence.replace('-', '') + hit_sequence = hit_sequence.replace('-', '') + hhsearch_query_offset = original_query_sequence.find(hhsearch_query_sequence) + + # Index of -1 used for gap characters. Subtract the min index ignoring gaps. + min_idx = min(x for x in indices_hit if x > -1) + fixed_indices_hit = [ + x - min_idx if x > -1 else -1 for x in indices_hit + ] + + min_idx = min(x for x in indices_query if x > -1) + fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query] + + # Zip the corrected indices, ignore case where both seqs have gap characters. + mapping = {} + for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit): + if q_t != -1 and q_i != -1: + if (q_t >= len(hit_sequence) or + q_i + hhsearch_query_offset >= len(original_query_sequence)): + continue + mapping[q_i + hhsearch_query_offset] = q_t + + return mapping + + +@dataclasses.dataclass(frozen=True) +class SingleHitResult: + features: Optional[Mapping[str, Any]] + error: Optional[str] + warning: Optional[str] + + +def _process_single_hit( + query_sequence: str, + query_pdb_code: Optional[str], + hit: parsers.TemplateHit, + mmcif_dir: str, + max_template_date: datetime.datetime, + release_dates: Mapping[str, datetime.datetime], + obsolete_pdbs: Mapping[str, str], + kalign_binary_path: str, + strict_error_check: bool = False) -> SingleHitResult: + """Tries to extract template features from a single HHSearch hit.""" + # Fail hard if we can't get the PDB ID and chain name from the hit. + hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit) + + if hit_pdb_code not in release_dates: + if hit_pdb_code in obsolete_pdbs: + hit_pdb_code = obsolete_pdbs[hit_pdb_code] + + # Pass hit_pdb_code since it might have changed due to the pdb being obsolete. + try: + _assess_hhsearch_hit( + hit=hit, + hit_pdb_code=hit_pdb_code, + query_sequence=query_sequence, + query_pdb_code=query_pdb_code, + release_dates=release_dates, + release_date_cutoff=max_template_date) + except PrefilterError as e: + msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}' + logging.info('%s: %s', query_pdb_code, msg) + if strict_error_check and isinstance( + e, (DateError, PdbIdError, DuplicateError)): + # In strict mode we treat some prefilter cases as errors. + return SingleHitResult(features=None, error=msg, warning=None) + + return SingleHitResult(features=None, error=None, warning=None) + + mapping = _build_query_to_hit_index_mapping( + hit.query, hit.hit_sequence, hit.indices_hit, hit.indices_query, + query_sequence) + + # The mapping is from the query to the actual hit sequence, so we need to + # remove gaps (which regardless have a missing confidence score). + template_sequence = hit.hit_sequence.replace('-', '') + + cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif') + logging.info('Reading PDB entry from %s. Query: %s, template: %s', + cif_path, query_sequence, template_sequence) + # Fail if we can't find the mmCIF file. + with open(cif_path, 'r') as cif_file: + cif_string = cif_file.read() + + parsing_result = mmcif_parsing.parse( + file_id=hit_pdb_code, mmcif_string=cif_string) + + if parsing_result.mmcif_object is not None: + hit_release_date = datetime.datetime.strptime( + parsing_result.mmcif_object.header['release_date'], '%Y-%m-%d') + if hit_release_date > max_template_date: + error = ('Template %s date (%s) > max template date (%s).' % + (hit_pdb_code, hit_release_date, max_template_date)) + if strict_error_check: + return SingleHitResult(features=None, error=error, warning=None) + else: + logging.warning(error) + return SingleHitResult(features=None, error=None, warning=None) + + try: + features, realign_warning = _extract_template_features( + mmcif_object=parsing_result.mmcif_object, + pdb_id=hit_pdb_code, + mapping=mapping, + template_sequence=template_sequence, + query_sequence=query_sequence, + template_chain_id=hit_chain_id, + kalign_binary_path=kalign_binary_path) + features['template_sum_probs'] = [hit.sum_probs] + + # It is possible there were some errors when parsing the other chains in the + # mmCIF file, but the template features for the chain we want were still + # computed. In such case the mmCIF parsing errors are not relevant. + return SingleHitResult( + features=features, error=None, warning=realign_warning) + except (NoChainsError, NoAtomDataInTemplateError, + TemplateAtomMaskAllZerosError) as e: + # These 3 errors indicate missing mmCIF experimental data rather than a + # problem with the template search, so turn them into warnings. + warning = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: ' + '%s, mmCIF parsing errors: %s' + % (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index, + str(e), parsing_result.errors)) + if strict_error_check: + return SingleHitResult(features=None, error=warning, warning=None) + else: + return SingleHitResult(features=None, error=None, warning=warning) + except Error as e: + error = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: ' + '%s, mmCIF parsing errors: %s' + % (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index, + str(e), parsing_result.errors)) + return SingleHitResult(features=None, error=error, warning=None) + + +@dataclasses.dataclass(frozen=True) +class TemplateSearchResult: + features: Mapping[str, Any] + errors: Sequence[str] + warnings: Sequence[str] + + +class TemplateHitFeaturizer: + """A class for turning hhr hits to template features.""" + + def __init__( + self, + mmcif_dir: str, + max_template_date: str, + max_hits: int, + kalign_binary_path: str, + release_dates_path: Optional[str], + obsolete_pdbs_path: Optional[str], + strict_error_check: bool = False): + """Initializes the Template Search. + + Args: + mmcif_dir: Path to a directory with mmCIF structures. Once a template ID + is found by HHSearch, this directory is used to retrieve the template + data. + max_template_date: The maximum date permitted for template structures. No + template with date higher than this date will be returned. In ISO8601 + date format, YYYY-MM-DD. + max_hits: The maximum number of templates that will be returned. + kalign_binary_path: The path to a kalign executable used for template + realignment. + release_dates_path: An optional path to a file with a mapping from PDB IDs + to their release dates. Thanks to this we don't have to redundantly + parse mmCIF files to get that information. + obsolete_pdbs_path: An optional path to a file containing a mapping from + obsolete PDB IDs to the PDB IDs of their replacements. + strict_error_check: If True, then the following will be treated as errors: + * If any template date is after the max_template_date. + * If any template has identical PDB ID to the query. + * If any template is a duplicate of the query. + * Any feature computation errors. + """ + self._mmcif_dir = mmcif_dir + if not glob.glob(os.path.join(self._mmcif_dir, '*.cif')): + logging.error('Could not find CIFs in %s', self._mmcif_dir) + raise ValueError(f'Could not find CIFs in {self._mmcif_dir}') + + try: + self._max_template_date = datetime.datetime.strptime( + max_template_date, '%Y-%m-%d') + except ValueError: + raise ValueError( + 'max_template_date must be set and have format YYYY-MM-DD.') + self._max_hits = max_hits + self._kalign_binary_path = kalign_binary_path + self._strict_error_check = strict_error_check + + if release_dates_path: + logging.info('Using precomputed release dates %s.', release_dates_path) + self._release_dates = _parse_release_dates(release_dates_path) + else: + self._release_dates = {} + + if obsolete_pdbs_path: + logging.info('Using precomputed obsolete pdbs %s.', obsolete_pdbs_path) + self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path) + else: + self._obsolete_pdbs = {} + + def get_templates( + self, + query_sequence: str, + query_pdb_code: Optional[str], + query_release_date: Optional[datetime.datetime], + hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: + """Computes the templates for given query sequence (more details above).""" + logging.info('Searching for template for: %s', query_pdb_code) + + template_features = {} + for template_feature_name in TEMPLATE_FEATURES: + template_features[template_feature_name] = [] + + # Always use a max_template_date. Set to query_release_date minus 60 days + # if that's earlier. + template_cutoff_date = self._max_template_date + if query_release_date: + delta = datetime.timedelta(days=60) + if query_release_date - delta < template_cutoff_date: + template_cutoff_date = query_release_date - delta + assert template_cutoff_date < query_release_date + assert template_cutoff_date <= self._max_template_date + + num_hits = 0 + errors = [] + warnings = [] + + for hit in sorted(hits, key=lambda x: x.sum_probs, reverse=True): + # We got all the templates we wanted, stop processing hits. + if num_hits >= self._max_hits: + break + + result = _process_single_hit( + query_sequence=query_sequence, + query_pdb_code=query_pdb_code, + hit=hit, + mmcif_dir=self._mmcif_dir, + max_template_date=template_cutoff_date, + release_dates=self._release_dates, + obsolete_pdbs=self._obsolete_pdbs, + strict_error_check=self._strict_error_check, + kalign_binary_path=self._kalign_binary_path) + + if result.error: + errors.append(result.error) + + # There could be an error even if there are some results, e.g. thrown by + # other unparsable chains in the same mmCIF file. + if result.warning: + warnings.append(result.warning) + + if result.features is None: + logging.info('Skipped invalid hit %s, error: %s, warning: %s', + hit.name, result.error, result.warning) + else: + # Increment the hit counter, since we got features out of this hit. + num_hits += 1 + for k in template_features: + template_features[k].append(result.features[k]) + + for name in template_features: + if num_hits > 0: + template_features[name] = np.stack( + template_features[name], axis=0).astype(TEMPLATE_FEATURES[name]) + else: + # Make sure the feature has correct dtype even if empty. + template_features[name] = np.array([], dtype=TEMPLATE_FEATURES[name]) + + return TemplateSearchResult( + features=template_features, errors=errors, warnings=warnings) diff --git a/af_backprop/alphafold/data/tools/__init__.py b/af_backprop/alphafold/data/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..903d09793c39d08491dd9a6fecefd463b058a251 --- /dev/null +++ b/af_backprop/alphafold/data/tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Python wrappers for third party tools.""" diff --git a/af_backprop/alphafold/data/tools/hhblits.py b/af_backprop/alphafold/data/tools/hhblits.py new file mode 100644 index 0000000000000000000000000000000000000000..e0aa098a6f6a2e702340aafbde7a5a045b674543 --- /dev/null +++ b/af_backprop/alphafold/data/tools/hhblits.py @@ -0,0 +1,155 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library to run HHblits from Python.""" + +import glob +import os +import subprocess +from typing import Any, Mapping, Optional, Sequence + +from absl import logging +from alphafold.data.tools import utils +# Internal import (7716). + + +_HHBLITS_DEFAULT_P = 20 +_HHBLITS_DEFAULT_Z = 500 + + +class HHBlits: + """Python wrapper of the HHblits binary.""" + + def __init__(self, + *, + binary_path: str, + databases: Sequence[str], + n_cpu: int = 4, + n_iter: int = 3, + e_value: float = 0.001, + maxseq: int = 1_000_000, + realign_max: int = 100_000, + maxfilt: int = 100_000, + min_prefilter_hits: int = 1000, + all_seqs: bool = False, + alt: Optional[int] = None, + p: int = _HHBLITS_DEFAULT_P, + z: int = _HHBLITS_DEFAULT_Z): + """Initializes the Python HHblits wrapper. + + Args: + binary_path: The path to the HHblits executable. + databases: A sequence of HHblits database paths. This should be the + common prefix for the database files (i.e. up to but not including + _hhm.ffindex etc.) + n_cpu: The number of CPUs to give HHblits. + n_iter: The number of HHblits iterations. + e_value: The E-value, see HHblits docs for more details. + maxseq: The maximum number of rows in an input alignment. Note that this + parameter is only supported in HHBlits version 3.1 and higher. + realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. + maxfilt: Max number of hits allowed to pass the 2nd prefilter. + HHblits default: 20000. + min_prefilter_hits: Min number of hits to pass prefilter. + HHblits default: 100. + all_seqs: Return all sequences in the MSA / Do not filter the result MSA. + HHblits default: False. + alt: Show up to this many alternative alignments. + p: Minimum Prob for a hit to be included in the output hhr file. + HHblits default: 20. + z: Hard cap on number of hits reported in the hhr file. + HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. + + Raises: + RuntimeError: If HHblits binary not found within the path. + """ + self.binary_path = binary_path + self.databases = databases + + for database_path in self.databases: + if not glob.glob(database_path + '_*'): + logging.error('Could not find HHBlits database %s', database_path) + raise ValueError(f'Could not find HHBlits database {database_path}') + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.maxseq = maxseq + self.realign_max = realign_max + self.maxfilt = maxfilt + self.min_prefilter_hits = min_prefilter_hits + self.all_seqs = all_seqs + self.alt = alt + self.p = p + self.z = z + + def query(self, input_fasta_path: str) -> Mapping[str, Any]: + """Queries the database using HHblits.""" + with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + a3m_path = os.path.join(query_tmp_dir, 'output.a3m') + + db_cmd = [] + for db_path in self.databases: + db_cmd.append('-d') + db_cmd.append(db_path) + cmd = [ + self.binary_path, + '-i', input_fasta_path, + '-cpu', str(self.n_cpu), + '-oa3m', a3m_path, + '-o', '/dev/null', + '-n', str(self.n_iter), + '-e', str(self.e_value), + '-maxseq', str(self.maxseq), + '-realign_max', str(self.realign_max), + '-maxfilt', str(self.maxfilt), + '-min_prefilter_hits', str(self.min_prefilter_hits)] + if self.all_seqs: + cmd += ['-all'] + if self.alt: + cmd += ['-alt', str(self.alt)] + if self.p != _HHBLITS_DEFAULT_P: + cmd += ['-p', str(self.p)] + if self.z != _HHBLITS_DEFAULT_Z: + cmd += ['-Z', str(self.z)] + cmd += db_cmd + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + with utils.timing('HHblits query'): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + # Logs have a 15k character limit, so log HHblits error line by line. + logging.error('HHblits failed. HHblits stderr begin:') + for error_line in stderr.decode('utf-8').splitlines(): + if error_line.strip(): + logging.error(error_line.strip()) + logging.error('HHblits stderr end') + raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % ( + stdout.decode('utf-8'), stderr[:500_000].decode('utf-8'))) + + with open(a3m_path) as f: + a3m = f.read() + + raw_output = dict( + a3m=a3m, + output=stdout, + stderr=stderr, + n_iter=self.n_iter, + e_value=self.e_value) + return raw_output diff --git a/af_backprop/alphafold/data/tools/hhsearch.py b/af_backprop/alphafold/data/tools/hhsearch.py new file mode 100644 index 0000000000000000000000000000000000000000..fac137e0172f53e7c7ef943c5fa73dcb69f72246 --- /dev/null +++ b/af_backprop/alphafold/data/tools/hhsearch.py @@ -0,0 +1,91 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library to run HHsearch from Python.""" + +import glob +import os +import subprocess +from typing import Sequence + +from absl import logging + +from alphafold.data.tools import utils +# Internal import (7716). + + +class HHSearch: + """Python wrapper of the HHsearch binary.""" + + def __init__(self, + *, + binary_path: str, + databases: Sequence[str], + maxseq: int = 1_000_000): + """Initializes the Python HHsearch wrapper. + + Args: + binary_path: The path to the HHsearch executable. + databases: A sequence of HHsearch database paths. This should be the + common prefix for the database files (i.e. up to but not including + _hhm.ffindex etc.) + maxseq: The maximum number of rows in an input alignment. Note that this + parameter is only supported in HHBlits version 3.1 and higher. + + Raises: + RuntimeError: If HHsearch binary not found within the path. + """ + self.binary_path = binary_path + self.databases = databases + self.maxseq = maxseq + + for database_path in self.databases: + if not glob.glob(database_path + '_*'): + logging.error('Could not find HHsearch database %s', database_path) + raise ValueError(f'Could not find HHsearch database {database_path}') + + def query(self, a3m: str) -> str: + """Queries the database using HHsearch using a given a3m.""" + with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + input_path = os.path.join(query_tmp_dir, 'query.a3m') + hhr_path = os.path.join(query_tmp_dir, 'output.hhr') + with open(input_path, 'w') as f: + f.write(a3m) + + db_cmd = [] + for db_path in self.databases: + db_cmd.append('-d') + db_cmd.append(db_path) + cmd = [self.binary_path, + '-i', input_path, + '-o', hhr_path, + '-maxseq', str(self.maxseq) + ] + db_cmd + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing('HHsearch query'): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + # Stderr is truncated to prevent proto size errors in Beam. + raise RuntimeError( + 'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( + stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) + + with open(hhr_path) as f: + hhr = f.read() + return hhr diff --git a/af_backprop/alphafold/data/tools/hmmbuild.py b/af_backprop/alphafold/data/tools/hmmbuild.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c573047450f5f17e791ad9a54f1b436e71b095 --- /dev/null +++ b/af_backprop/alphafold/data/tools/hmmbuild.py @@ -0,0 +1,138 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Python wrapper for hmmbuild - construct HMM profiles from MSA.""" + +import os +import re +import subprocess + +from absl import logging +from alphafold.data.tools import utils +# Internal import (7716). + + +class Hmmbuild(object): + """Python wrapper of the hmmbuild binary.""" + + def __init__(self, + *, + binary_path: str, + singlemx: bool = False): + """Initializes the Python hmmbuild wrapper. + + Args: + binary_path: The path to the hmmbuild executable. + singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to + just use a common substitution score matrix. + + Raises: + RuntimeError: If hmmbuild binary not found within the path. + """ + self.binary_path = binary_path + self.singlemx = singlemx + + def build_profile_from_sto(self, sto: str, model_construction='fast') -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + sto: A string with the aligned sequences in the Stockholm format. + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + return self._build_profile(sto, model_construction=model_construction) + + def build_profile_from_a3m(self, a3m: str) -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + a3m: A string with the aligned sequences in the A3M format. + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + lines = [] + for line in a3m.splitlines(): + if not line.startswith('>'): + line = re.sub('[a-z]+', '', line) # Remove inserted residues. + lines.append(line + '\n') + msa = ''.join(lines) + return self._build_profile(msa, model_construction='fast') + + def _build_profile(self, msa: str, model_construction: str = 'fast') -> str: + """Builds a HMM for the aligned sequences given as an MSA string. + + Args: + msa: A string with the aligned sequences, in A3M or STO format. + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + ValueError: If unspecified arguments are provided. + """ + if model_construction not in {'hand', 'fast'}: + raise ValueError(f'Invalid model_construction {model_construction} - only' + 'hand and fast supported.') + + with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + input_query = os.path.join(query_tmp_dir, 'query.msa') + output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') + + with open(input_query, 'w') as f: + f.write(msa) + + cmd = [self.binary_path] + # If adding flags, we have to do so before the output and input: + + if model_construction == 'hand': + cmd.append(f'--{model_construction}') + if self.singlemx: + cmd.append('--singlemx') + cmd.extend([ + '--amino', + output_hmm_path, + input_query, + ]) + + logging.info('Launching subprocess %s', cmd) + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + with utils.timing('hmmbuild query'): + stdout, stderr = process.communicate() + retcode = process.wait() + logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n', + stdout.decode('utf-8'), stderr.decode('utf-8')) + + if retcode: + raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n' + % (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(output_hmm_path, encoding='utf-8') as f: + hmm = f.read() + + return hmm diff --git a/af_backprop/alphafold/data/tools/hmmsearch.py b/af_backprop/alphafold/data/tools/hmmsearch.py new file mode 100644 index 0000000000000000000000000000000000000000..a60d3e760e217f175b7daeffb803837e23391b0a --- /dev/null +++ b/af_backprop/alphafold/data/tools/hmmsearch.py @@ -0,0 +1,90 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Python wrapper for hmmsearch - search profile against a sequence db.""" + +import os +import subprocess +from typing import Optional, Sequence + +from absl import logging +from alphafold.data.tools import utils +# Internal import (7716). + + +class Hmmsearch(object): + """Python wrapper of the hmmsearch binary.""" + + def __init__(self, + *, + binary_path: str, + database_path: str, + flags: Optional[Sequence[str]] = None): + """Initializes the Python hmmsearch wrapper. + + Args: + binary_path: The path to the hmmsearch executable. + database_path: The path to the hmmsearch database (FASTA format). + flags: List of flags to be used by hmmsearch. + + Raises: + RuntimeError: If hmmsearch binary not found within the path. + """ + self.binary_path = binary_path + self.database_path = database_path + self.flags = flags + + if not os.path.exists(self.database_path): + logging.error('Could not find hmmsearch database %s', database_path) + raise ValueError(f'Could not find hmmsearch database {database_path}') + + def query(self, hmm: str) -> str: + """Queries the database using hmmsearch using a given hmm.""" + with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') + a3m_out_path = os.path.join(query_tmp_dir, 'output.a3m') + with open(hmm_input_path, 'w') as f: + f.write(hmm) + + cmd = [ + self.binary_path, + '--noali', # Don't include the alignment in stdout. + '--cpu', '8' + ] + # If adding flags, we have to do so before the output and input: + if self.flags: + cmd.extend(self.flags) + cmd.extend([ + '-A', a3m_out_path, + hmm_input_path, + self.database_path, + ]) + + logging.info('Launching sub-process %s', cmd) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing( + f'hmmsearch ({os.path.basename(self.database_path)}) query'): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + raise RuntimeError( + 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( + stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(a3m_out_path) as f: + a3m_out = f.read() + + return a3m_out diff --git a/af_backprop/alphafold/data/tools/jackhmmer.py b/af_backprop/alphafold/data/tools/jackhmmer.py new file mode 100644 index 0000000000000000000000000000000000000000..194d266c1251de25d2f85ba3a2b338ca0adf95e0 --- /dev/null +++ b/af_backprop/alphafold/data/tools/jackhmmer.py @@ -0,0 +1,198 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library to run Jackhmmer from Python.""" + +from concurrent import futures +import glob +import os +import subprocess +from typing import Any, Callable, Mapping, Optional, Sequence +from urllib import request + +from absl import logging + +from alphafold.data.tools import utils +# Internal import (7716). + + +class Jackhmmer: + """Python wrapper of the Jackhmmer binary.""" + + def __init__(self, + *, + binary_path: str, + database_path: str, + n_cpu: int = 8, + n_iter: int = 1, + e_value: float = 0.0001, + z_value: Optional[int] = None, + get_tblout: bool = False, + filter_f1: float = 0.0005, + filter_f2: float = 0.00005, + filter_f3: float = 0.0000005, + incdom_e: Optional[float] = None, + dom_e: Optional[float] = None, + num_streamed_chunks: Optional[int] = None, + streaming_callback: Optional[Callable[[int], None]] = None): + """Initializes the Python Jackhmmer wrapper. + + Args: + binary_path: The path to the jackhmmer executable. + database_path: The path to the jackhmmer database (FASTA format). + n_cpu: The number of CPUs to give Jackhmmer. + n_iter: The number of Jackhmmer iterations. + e_value: The E-value, see Jackhmmer docs for more details. + z_value: The Z-value, see Jackhmmer docs for more details. + get_tblout: Whether to save tblout string. + filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. + filter_f2: Viterbi pre-filter, set to >1.0 to turn off. + filter_f3: Forward pre-filter, set to >1.0 to turn off. + incdom_e: Domain e-value criteria for inclusion of domains in MSA/next + round. + dom_e: Domain e-value criteria for inclusion in tblout. + num_streamed_chunks: Number of database chunks to stream over. + streaming_callback: Callback function run after each chunk iteration with + the iteration number as argument. + """ + self.binary_path = binary_path + self.database_path = database_path + self.num_streamed_chunks = num_streamed_chunks + + if not os.path.exists(self.database_path) and num_streamed_chunks is None: + logging.error('Could not find Jackhmmer database %s', database_path) + raise ValueError(f'Could not find Jackhmmer database {database_path}') + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.z_value = z_value + self.filter_f1 = filter_f1 + self.filter_f2 = filter_f2 + self.filter_f3 = filter_f3 + self.incdom_e = incdom_e + self.dom_e = dom_e + self.get_tblout = get_tblout + self.streaming_callback = streaming_callback + + def _query_chunk(self, input_fasta_path: str, database_path: str + ) -> Mapping[str, Any]: + """Queries the database chunk using Jackhmmer.""" + with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + sto_path = os.path.join(query_tmp_dir, 'output.sto') + + # The F1/F2/F3 are the expected proportion to pass each of the filtering + # stages (which get progressively more expensive), reducing these + # speeds up the pipeline at the expensive of sensitivity. They are + # currently set very low to make querying Mgnify run in a reasonable + # amount of time. + cmd_flags = [ + # Don't pollute stdout with Jackhmmer output. + '-o', '/dev/null', + '-A', sto_path, + '--noali', + '--F1', str(self.filter_f1), + '--F2', str(self.filter_f2), + '--F3', str(self.filter_f3), + '--incE', str(self.e_value), + # Report only sequences with E-values <= x in per-sequence output. + '-E', str(self.e_value), + '--cpu', str(self.n_cpu), + '-N', str(self.n_iter) + ] + if self.get_tblout: + tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') + cmd_flags.extend(['--tblout', tblout_path]) + + if self.z_value: + cmd_flags.extend(['-Z', str(self.z_value)]) + + if self.dom_e is not None: + cmd_flags.extend(['--domE', str(self.dom_e)]) + + if self.incdom_e is not None: + cmd_flags.extend(['--incdomE', str(self.incdom_e)]) + + cmd = [self.binary_path] + cmd_flags + [input_fasta_path, + database_path] + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing( + f'Jackhmmer ({os.path.basename(database_path)}) query'): + _, stderr = process.communicate() + retcode = process.wait() + + if retcode: + raise RuntimeError( + 'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8')) + + # Get e-values for each target name + tbl = '' + if self.get_tblout: + with open(tblout_path) as f: + tbl = f.read() + + with open(sto_path) as f: + sto = f.read() + + raw_output = dict( + sto=sto, + tbl=tbl, + stderr=stderr, + n_iter=self.n_iter, + e_value=self.e_value) + + return raw_output + + def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: + """Queries the database using Jackhmmer.""" + if self.num_streamed_chunks is None: + return [self._query_chunk(input_fasta_path, self.database_path)] + + db_basename = os.path.basename(self.database_path) + db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' + db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' + + # Remove existing files to prevent OOM + for f in glob.glob(db_local_chunk('[0-9]*')): + try: + os.remove(f) + except OSError: + print(f'OSError while deleting {f}') + + # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk + with futures.ThreadPoolExecutor(max_workers=2) as executor: + chunked_output = [] + for i in range(1, self.num_streamed_chunks + 1): + # Copy the chunk locally + if i == 1: + future = executor.submit( + request.urlretrieve, db_remote_chunk(i), db_local_chunk(i)) + if i < self.num_streamed_chunks: + next_future = executor.submit( + request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1)) + + # Run Jackhmmer with the chunk + future.result() + chunked_output.append( + self._query_chunk(input_fasta_path, db_local_chunk(i))) + + # Remove the local copy of the chunk + os.remove(db_local_chunk(i)) + future = next_future + if self.streaming_callback: + self.streaming_callback(i) + return chunked_output diff --git a/af_backprop/alphafold/data/tools/kalign.py b/af_backprop/alphafold/data/tools/kalign.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4e58a43205c138b7f29c07f39a87ea741d2656 --- /dev/null +++ b/af_backprop/alphafold/data/tools/kalign.py @@ -0,0 +1,104 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Python wrapper for Kalign.""" +import os +import subprocess +from typing import Sequence + +from absl import logging + +from alphafold.data.tools import utils +# Internal import (7716). + + +def _to_a3m(sequences: Sequence[str]) -> str: + """Converts sequences to an a3m file.""" + names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] + a3m = [] + for sequence, name in zip(sequences, names): + a3m.append(u'>' + name + u'\n') + a3m.append(sequence + u'\n') + return ''.join(a3m) + + +class Kalign: + """Python wrapper of the Kalign binary.""" + + def __init__(self, *, binary_path: str): + """Initializes the Python Kalign wrapper. + + Args: + binary_path: The path to the Kalign binary. + + Raises: + RuntimeError: If Kalign binary not found within the path. + """ + self.binary_path = binary_path + + def align(self, sequences: Sequence[str]) -> str: + """Aligns the sequences and returns the alignment in A3M string. + + Args: + sequences: A list of query sequence strings. The sequences have to be at + least 6 residues long (Kalign requires this). Note that the order in + which you give the sequences might alter the output slightly as + different alignment tree might get constructed. + + Returns: + A string with the alignment in a3m format. + + Raises: + RuntimeError: If Kalign fails. + ValueError: If any of the sequences is less than 6 residues long. + """ + logging.info('Aligning %d sequences', len(sequences)) + + for s in sequences: + if len(s) < 6: + raise ValueError('Kalign requires all sequences to be at least 6 ' + 'residues long. Got %s (%d residues).' % (s, len(s))) + + with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') + output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') + + with open(input_fasta_path, 'w') as f: + f.write(_to_a3m(sequences)) + + cmd = [ + self.binary_path, + '-i', input_fasta_path, + '-o', output_a3m_path, + '-format', 'fasta', + ] + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + with utils.timing('Kalign query'): + stdout, stderr = process.communicate() + retcode = process.wait() + logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n', + stdout.decode('utf-8'), stderr.decode('utf-8')) + + if retcode: + raise RuntimeError('Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' + % (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(output_a3m_path) as f: + a3m = f.read() + + return a3m diff --git a/af_backprop/alphafold/data/tools/utils.py b/af_backprop/alphafold/data/tools/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e65b8824d3f240e869ca073a8264f32cb224813c --- /dev/null +++ b/af_backprop/alphafold/data/tools/utils.py @@ -0,0 +1,40 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common utilities for data pipeline tools.""" +import contextlib +import shutil +import tempfile +import time +from typing import Optional + +from absl import logging + + +@contextlib.contextmanager +def tmpdir_manager(base_dir: Optional[str] = None): + """Context manager that deletes a temporary directory on exit.""" + tmpdir = tempfile.mkdtemp(dir=base_dir) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +@contextlib.contextmanager +def timing(msg: str): + logging.info('Started %s', msg) + tic = time.time() + yield + toc = time.time() + logging.info('Finished %s in %.3f seconds', msg, toc - tic) diff --git a/af_backprop/alphafold/model/__init__.py b/af_backprop/alphafold/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc2efc8d3e1439f8d264268adcde82231f784636 --- /dev/null +++ b/af_backprop/alphafold/model/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Alphafold model.""" diff --git a/af_backprop/alphafold/model/all_atom.py b/af_backprop/alphafold/model/all_atom.py new file mode 100644 index 0000000000000000000000000000000000000000..2d7896d80d9a9dfd1f79312fc67c10ada68df394 --- /dev/null +++ b/af_backprop/alphafold/model/all_atom.py @@ -0,0 +1,1155 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ops for all atom representations. + +Generally we employ two different representations for all atom coordinates, +one is atom37 where each heavy atom corresponds to a given position in a 37 +dimensional array, This mapping is non amino acid specific, but each slot +corresponds to an atom of a given name, for example slot 12 always corresponds +to 'C delta 1', positions that are not present for a given amino acid are +zeroed out and denoted by a mask. +The other representation we employ is called atom14, this is a more dense way +of representing atoms with 14 slots. Here a given slot will correspond to a +different kind of atom depending on amino acid type, for example slot 5 +corresponds to 'N delta 2' for Aspargine, but to 'C delta 1' for Isoleucine. +14 is chosen because it is the maximum number of heavy atoms for any standard +amino acid. +The order of slots can be found in 'residue_constants.residue_atoms'. +Internally the model uses the atom14 representation because it is +computationally more efficient. +The internal atom14 representation is turned into the atom37 at the output of +the network to facilitate easier conversion to existing protein datastructures. +""" + +from typing import Dict, Optional +from alphafold.common import residue_constants + +from alphafold.model import r3 +from alphafold.model import utils +import jax +import jax.numpy as jnp +import numpy as np + + +def squared_difference(x, y): + return jnp.square(x - y) + + +def get_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in residue_constants.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in residue_constants.restypes: + residue_name = residue_constants.restype_1to3[residue_name] + residue_chi_angles = residue_constants.chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append( + [residue_constants.atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return jnp.asarray(chi_atom_indices) + + +def atom14_to_atom37(atom14_data: jnp.ndarray, # (N, 14, ...) + batch: Dict[str, jnp.ndarray] + ) -> jnp.ndarray: # (N, 37, ...) + """Convert atom14 to atom37 representation.""" + assert len(atom14_data.shape) in [2, 3] + assert 'residx_atom37_to_atom14' in batch + assert 'atom37_atom_exists' in batch + + if jnp.issubdtype(batch['residx_atom37_to_atom14'].dtype, jnp.integer): + atom37_data = utils.batched_gather(atom14_data, batch['residx_atom37_to_atom14'], batch_dims=1) + else: + atom37_data = jnp.einsum("na...,nba->nb...", atom14_data, batch['residx_atom37_to_atom14']) + + if len(atom14_data.shape) == 2: + atom37_data *= batch['atom37_atom_exists'] + elif len(atom14_data.shape) == 3: + atom37_data *= batch['atom37_atom_exists'][:, :, None].astype(atom37_data.dtype) + return atom37_data + +def atom37_to_atom14( + atom37_data: jnp.ndarray, # (N, 37, ...) + batch: Dict[str, jnp.ndarray]) -> jnp.ndarray: # (N, 14, ...) + """Convert atom14 to atom37 representation.""" + assert len(atom37_data.shape) in [2, 3] + assert 'residx_atom14_to_atom37' in batch + assert 'atom14_atom_exists' in batch + + if jnp.issubdtype(batch['residx_atom14_to_atom37'].dtype, jnp.integer): + atom14_data = utils.batched_gather(atom37_data, batch['residx_atom14_to_atom37'], batch_dims=1) + else: + atom14_data = jnp.einsum("na...,nba->nb...", atom37_data, batch['residx_atom14_to_atom37']) + + if len(atom37_data.shape) == 2: + atom14_data *= batch['atom14_atom_exists'].astype(atom14_data.dtype) + elif len(atom37_data.shape) == 3: + atom14_data *= batch['atom14_atom_exists'][:, :, None].astype(atom14_data.dtype) + return atom14_data + + +def atom37_to_frames( + aatype: jnp.ndarray, # (...) + all_atom_positions: jnp.ndarray, # (..., 37, 3) + all_atom_mask: jnp.ndarray, # (..., 37) +) -> Dict[str, jnp.ndarray]: + """Computes the frames for the up to 8 rigid groups for each residue. + + The rigid groups are defined by the possible torsions in a given amino acid. + We group the atoms according to their dependence on the torsion angles into + "rigid groups". E.g., the position of atoms in the chi2-group depend on + chi1 and chi2, but do not depend on chi3 or chi4. + Jumper et al. (2021) Suppl. Table 2 and corresponding text. + + Args: + aatype: Amino acid type, given as array with integers. + all_atom_positions: atom37 representation of all atom coordinates. + all_atom_mask: atom37 representation of mask on all atom coordinates. + Returns: + Dictionary containing: + * 'rigidgroups_gt_frames': 8 Frames corresponding to 'all_atom_positions' + represented as flat 12 dimensional array. + * 'rigidgroups_gt_exists': Mask denoting whether the atom positions for + the given frame are available in the ground truth, e.g. if they were + resolved in the experiment. + * 'rigidgroups_group_exists': Mask denoting whether given group is in + principle present for given amino acid type. + * 'rigidgroups_group_is_ambiguous': Mask denoting whether frame is + affected by naming ambiguity. + * 'rigidgroups_alt_gt_frames': 8 Frames with alternative atom renaming + corresponding to 'all_atom_positions' represented as flat + 12 dimensional array. + """ + # 0: 'backbone group', + # 1: 'pre-omega-group', (empty) + # 2: 'phi-group', (currently empty, because it defines only hydrogens) + # 3: 'psi-group', + # 4,5,6,7: 'chi1,2,3,4-group' + aatype_in_shape = aatype.shape + + # If there is a batch axis, just flatten it away, and reshape everything + # back at the end of the function. + aatype = jnp.reshape(aatype, [-1]) + all_atom_positions = jnp.reshape(all_atom_positions, [-1, 37, 3]) + all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37]) + + # Create an array with the atom names. + # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3) + restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object) + + # 0: backbone frame + restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N'] + + # 3: 'psi-group' + restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O'] + + # 4,5,6,7: 'chi1,2,3,4-group' + for restype, restype_letter in enumerate(residue_constants.restypes): + resname = residue_constants.restype_1to3[restype_letter] + for chi_idx in range(4): + if residue_constants.chi_angles_mask[restype][chi_idx]: + atom_names = residue_constants.chi_angles_atoms[resname][chi_idx] + restype_rigidgroup_base_atom_names[ + restype, chi_idx + 4, :] = atom_names[1:] + + # Create mask for existing rigid groups. + restype_rigidgroup_mask = np.zeros([21, 8], dtype=np.float32) + restype_rigidgroup_mask[:, 0] = 1 + restype_rigidgroup_mask[:, 3] = 1 + restype_rigidgroup_mask[:20, 4:] = residue_constants.chi_angles_mask + + # Translate atom names into atom37 indices. + lookuptable = residue_constants.atom_order.copy() + lookuptable[''] = 0 + restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])( + restype_rigidgroup_base_atom_names) + + # Compute the gather indices for all residues in the chain. + # shape (N, 8, 3) + residx_rigidgroup_base_atom37_idx = utils.batched_gather( + restype_rigidgroup_base_atom37_idx, aatype) + + # Gather the base atom positions for each rigid group. + base_atom_pos = utils.batched_gather( + all_atom_positions, + residx_rigidgroup_base_atom37_idx, + batch_dims=1) + + # Compute the Rigids. + gt_frames = r3.rigids_from_3_points( + point_on_neg_x_axis=r3.vecs_from_tensor(base_atom_pos[:, :, 0, :]), + origin=r3.vecs_from_tensor(base_atom_pos[:, :, 1, :]), + point_on_xy_plane=r3.vecs_from_tensor(base_atom_pos[:, :, 2, :]) + ) + + # Compute a mask whether the group exists. + # (N, 8) + group_exists = utils.batched_gather(restype_rigidgroup_mask, aatype) + + # Compute a mask whether ground truth exists for the group + gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3) + all_atom_mask.astype(jnp.float32), + residx_rigidgroup_base_atom37_idx, + batch_dims=1) + gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8) + + # Adapt backbone frame to old convention (mirror x-axis and z-axis). + rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) + rots[0, 0, 0] = -1 + rots[0, 2, 2] = -1 + gt_frames = r3.rigids_mul_rots(gt_frames, r3.rots_from_tensor3x3(rots)) + + # The frames for ambiguous rigid groups are just rotated by 180 degree around + # the x-axis. The ambiguous group is always the last chi-group. + restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32) + restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1]) + + for resname, _ in residue_constants.residue_atom_renaming_swaps.items(): + restype = residue_constants.restype_order[ + residue_constants.restype_3to1[resname]] + chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1) + restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1 + restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1 + restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1 + + # Gather the ambiguity information for each residue. + residx_rigidgroup_is_ambiguous = utils.batched_gather( + restype_rigidgroup_is_ambiguous, aatype) + residx_rigidgroup_ambiguity_rot = utils.batched_gather( + restype_rigidgroup_rots, aatype) + + # Create the alternative ground truth frames. + alt_gt_frames = r3.rigids_mul_rots( + gt_frames, r3.rots_from_tensor3x3(residx_rigidgroup_ambiguity_rot)) + + gt_frames_flat12 = r3.rigids_to_tensor_flat12(gt_frames) + alt_gt_frames_flat12 = r3.rigids_to_tensor_flat12(alt_gt_frames) + + # reshape back to original residue layout + gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12)) + gt_exists = jnp.reshape(gt_exists, aatype_in_shape + (8,)) + group_exists = jnp.reshape(group_exists, aatype_in_shape + (8,)) + gt_frames_flat12 = jnp.reshape(gt_frames_flat12, aatype_in_shape + (8, 12)) + residx_rigidgroup_is_ambiguous = jnp.reshape(residx_rigidgroup_is_ambiguous, + aatype_in_shape + (8,)) + alt_gt_frames_flat12 = jnp.reshape(alt_gt_frames_flat12, + aatype_in_shape + (8, 12,)) + + return { + 'rigidgroups_gt_frames': gt_frames_flat12, # (..., 8, 12) + 'rigidgroups_gt_exists': gt_exists, # (..., 8) + 'rigidgroups_group_exists': group_exists, # (..., 8) + 'rigidgroups_group_is_ambiguous': + residx_rigidgroup_is_ambiguous, # (..., 8) + 'rigidgroups_alt_gt_frames': alt_gt_frames_flat12, # (..., 8, 12) + } + + +def atom37_to_torsion_angles( + aatype: jnp.ndarray, # (B, N) + all_atom_pos: jnp.ndarray, # (B, N, 37, 3) + all_atom_mask: jnp.ndarray, # (B, N, 37) + placeholder_for_undefined=False, +) -> Dict[str, jnp.ndarray]: + """Computes the 7 torsion angles (in sin, cos encoding) for each residue. + + The 7 torsion angles are in the order + '[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]', + here pre_omega denotes the omega torsion angle between the given amino acid + and the previous amino acid. + + Args: + aatype: Amino acid type, given as array with integers. + all_atom_pos: atom37 representation of all atom coordinates. + all_atom_mask: atom37 representation of mask on all atom coordinates. + placeholder_for_undefined: flag denoting whether to set masked torsion + angles to zero. + Returns: + Dict containing: + * 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final + 2 dimensions denote sin and cos respectively + * 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but + with the angle shifted by pi for all chi angles affected by the naming + ambiguities. + * 'torsion_angles_mask': Mask for which chi angles are present. + """ + + # Map aatype > 20 to 'Unknown' (20). + aatype = jnp.minimum(aatype, 20) + + # Compute the backbone angles. + num_batch, num_res = aatype.shape + + pad = jnp.zeros([num_batch, 1, 37, 3], jnp.float32) + prev_all_atom_pos = jnp.concatenate([pad, all_atom_pos[:, :-1, :, :]], axis=1) + + pad = jnp.zeros([num_batch, 1, 37], jnp.float32) + prev_all_atom_mask = jnp.concatenate([pad, all_atom_mask[:, :-1, :]], axis=1) + + # For each torsion angle collect the 4 atom positions that define this angle. + # shape (B, N, atoms=4, xyz=3) + pre_omega_atom_pos = jnp.concatenate( + [prev_all_atom_pos[:, :, 1:3, :], # prev CA, C + all_atom_pos[:, :, 0:2, :] # this N, CA + ], axis=-2) + phi_atom_pos = jnp.concatenate( + [prev_all_atom_pos[:, :, 2:3, :], # prev C + all_atom_pos[:, :, 0:3, :] # this N, CA, C + ], axis=-2) + psi_atom_pos = jnp.concatenate( + [all_atom_pos[:, :, 0:3, :], # this N, CA, C + all_atom_pos[:, :, 4:5, :] # this O + ], axis=-2) + + # Collect the masks from these atoms. + # Shape [batch, num_res] + pre_omega_mask = ( + jnp.prod(prev_all_atom_mask[:, :, 1:3], axis=-1) # prev CA, C + * jnp.prod(all_atom_mask[:, :, 0:2], axis=-1)) # this N, CA + phi_mask = ( + prev_all_atom_mask[:, :, 2] # prev C + * jnp.prod(all_atom_mask[:, :, 0:3], axis=-1)) # this N, CA, C + psi_mask = ( + jnp.prod(all_atom_mask[:, :, 0:3], axis=-1) * # this N, CA, C + all_atom_mask[:, :, 4]) # this O + + # Collect the atoms for the chi-angles. + # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. + chi_atom_indices = get_chi_atom_indices() + # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4]. + atom_indices = utils.batched_gather( + params=chi_atom_indices, indices=aatype, axis=0, batch_dims=0) + # Gather atom positions. Shape: [batch, num_res, chis=4, atoms=4, xyz=3]. + chis_atom_pos = utils.batched_gather( + params=all_atom_pos, indices=atom_indices, axis=-2, + batch_dims=2) + + # Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4]. + chi_angles_mask = list(residue_constants.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = jnp.asarray(chi_angles_mask) + + # Compute the chi angle mask. I.e. which chis angles exist according to the + # aatype. Shape [batch, num_res, chis=4]. + chis_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype, + axis=0, batch_dims=0) + + # Constrain the chis_mask to those chis, where the ground truth coordinates of + # all defining four atoms are available. + # Gather the chi angle atoms mask. Shape: [batch, num_res, chis=4, atoms=4]. + chi_angle_atoms_mask = utils.batched_gather( + params=all_atom_mask, indices=atom_indices, axis=-1, + batch_dims=2) + # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4]. + chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1]) + chis_mask = chis_mask * (chi_angle_atoms_mask).astype(jnp.float32) + + # Stack all torsion angle atom positions. + # Shape (B, N, torsions=7, atoms=4, xyz=3) + torsions_atom_pos = jnp.concatenate( + [pre_omega_atom_pos[:, :, None, :, :], + phi_atom_pos[:, :, None, :, :], + psi_atom_pos[:, :, None, :, :], + chis_atom_pos + ], axis=2) + + # Stack up masks for all torsion angles. + # shape (B, N, torsions=7) + torsion_angles_mask = jnp.concatenate( + [pre_omega_mask[:, :, None], + phi_mask[:, :, None], + psi_mask[:, :, None], + chis_mask + ], axis=2) + + # Create a frame from the first three atoms: + # First atom: point on x-y-plane + # Second atom: point on negative x-axis + # Third atom: origin + # r3.Rigids (B, N, torsions=7) + torsion_frames = r3.rigids_from_3_points( + point_on_neg_x_axis=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 1, :]), + origin=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]), + point_on_xy_plane=r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :])) + + # Compute the position of the forth atom in this frame (y and z coordinate + # define the chi angle) + # r3.Vecs (B, N, torsions=7) + forth_atom_rel_pos = r3.rigids_mul_vecs( + r3.invert_rigids(torsion_frames), + r3.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :])) + + # Normalize to have the sin and cos of the torsion angle. + # jnp.ndarray (B, N, torsions=7, sincos=2) + torsion_angles_sin_cos = jnp.stack( + [forth_atom_rel_pos.z, forth_atom_rel_pos.y], axis=-1) + torsion_angles_sin_cos /= jnp.sqrt( + jnp.sum(jnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + + 1e-8) + + # Mirror psi, because we computed it from the Oxygen-atom. + torsion_angles_sin_cos *= jnp.asarray( + [1., 1., -1., 1., 1., 1., 1.])[None, None, :, None] + + # Create alternative angles for ambiguous atom names. + chi_is_ambiguous = utils.batched_gather( + jnp.asarray(residue_constants.chi_pi_periodic), aatype) + mirror_torsion_angles = jnp.concatenate( + [jnp.ones([num_batch, num_res, 3]), + 1.0 - 2.0 * chi_is_ambiguous], axis=-1) + alt_torsion_angles_sin_cos = ( + torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None]) + + if placeholder_for_undefined: + # Add placeholder torsions in place of undefined torsion angles + # (e.g. N-terminus pre-omega) + placeholder_torsions = jnp.stack([ + jnp.ones(torsion_angles_sin_cos.shape[:-1]), + jnp.zeros(torsion_angles_sin_cos.shape[:-1]) + ], axis=-1) + torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[ + ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[ + ..., None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + + return { + 'torsion_angles_sin_cos': torsion_angles_sin_cos, # (B, N, 7, 2) + 'alt_torsion_angles_sin_cos': alt_torsion_angles_sin_cos, # (B, N, 7, 2) + 'torsion_angles_mask': torsion_angles_mask # (B, N, 7) + } + + +def torsion_angles_to_frames( + aatype: jnp.ndarray, # (N) + backb_to_global: r3.Rigids, # (N) + torsion_angles_sin_cos: jnp.ndarray # (N, 7, 2) +) -> r3.Rigids: # (N, 8) + """Compute rigid group frames from torsion angles. + + Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" lines 2-10 + Jumper et al. (2021) Suppl. Alg. 25 "makeRotX" + + Args: + aatype: aatype for each residue + backb_to_global: Rigid transformations describing transformation from + backbone frame to global frame. + torsion_angles_sin_cos: sin and cosine of the 7 torsion angles + Returns: + Frames corresponding to all the Sidechain Rigid Transforms + """ + if jnp.issubdtype(aatype.dtype, jnp.integer): + assert len(aatype.shape) == 1 + else: + assert len(aatype.shape) == 2 + assert len(backb_to_global.rot.xx.shape) == 1 + assert len(torsion_angles_sin_cos.shape) == 3 + assert torsion_angles_sin_cos.shape[1] == 7 + assert torsion_angles_sin_cos.shape[2] == 2 + + # Gather the default frames for all rigid groups. + # r3.Rigids with shape (N, 8) + + if jnp.issubdtype(aatype.dtype, jnp.integer): + m = utils.batched_gather(residue_constants.restype_rigid_group_default_frame, aatype) + else: + m = jnp.einsum("...a,abcd->...bcd",aatype,residue_constants.restype_rigid_group_default_frame) + + default_frames = r3.rigids_from_tensor4x4(m) + + # Create the rotation matrices according to the given angles (each frame is + # defined such that its rotation is around the x-axis). + sin_angles = torsion_angles_sin_cos[..., 0] + cos_angles = torsion_angles_sin_cos[..., 1] + + # insert zero rotation for backbone group. + if jnp.issubdtype(aatype.dtype, jnp.integer): + num_residues, = aatype.shape + else: + num_residues,_ = aatype.shape + sin_angles = jnp.concatenate([jnp.zeros([num_residues, 1]), sin_angles],axis=-1) + cos_angles = jnp.concatenate([jnp.ones([num_residues, 1]), cos_angles],axis=-1) + zeros = jnp.zeros_like(sin_angles) + ones = jnp.ones_like(sin_angles) + + # all_rots are r3.Rots with shape (N, 8) + all_rots = r3.Rots(ones, zeros, zeros, + zeros, cos_angles, -sin_angles, + zeros, sin_angles, cos_angles) + + # Apply rotations to the frames. + all_frames = r3.rigids_mul_rots(default_frames, all_rots) + + # chi2, chi3, and chi4 frames do not transform to the backbone frame but to + # the previous frame. So chain them up accordingly. + chi2_frame_to_frame = jax.tree_map(lambda x: x[:, 5], all_frames) + chi3_frame_to_frame = jax.tree_map(lambda x: x[:, 6], all_frames) + chi4_frame_to_frame = jax.tree_map(lambda x: x[:, 7], all_frames) + + chi1_frame_to_backb = jax.tree_map(lambda x: x[:, 4], all_frames) + chi2_frame_to_backb = r3.rigids_mul_rigids(chi1_frame_to_backb, + chi2_frame_to_frame) + chi3_frame_to_backb = r3.rigids_mul_rigids(chi2_frame_to_backb, + chi3_frame_to_frame) + chi4_frame_to_backb = r3.rigids_mul_rigids(chi3_frame_to_backb, + chi4_frame_to_frame) + + # Recombine them to a r3.Rigids with shape (N, 8). + def _concat_frames(xall, x5, x6, x7): + return jnp.concatenate( + [xall[:, 0:5], x5[:, None], x6[:, None], x7[:, None]], axis=-1) + + all_frames_to_backb = jax.tree_map( + _concat_frames, + all_frames, + chi2_frame_to_backb, + chi3_frame_to_backb, + chi4_frame_to_backb) + + # Create the global frames. + # shape (N, 8) + all_frames_to_global = r3.rigids_mul_rigids( + jax.tree_map(lambda x: x[:, None], backb_to_global), + all_frames_to_backb) + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + aatype: jnp.ndarray, # (N) + all_frames_to_global: r3.Rigids # (N, 8) +) -> r3.Vecs: # (N, 14) + """Put atom literature positions (atom14 encoding) in each rigid group. + + Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11 + + Args: + aatype: aatype for each residue. + all_frames_to_global: All per residue coordinate frames. + Returns: + Positions of all atom coordinates in global frame. + """ + + # Pick the appropriate transform for every atom. + if jnp.issubdtype(aatype.dtype, jnp.integer): + residx_to_group_idx = utils.batched_gather(residue_constants.restype_atom14_to_rigid_group, aatype) + group_mask = jax.nn.one_hot(residx_to_group_idx, num_classes=8) # shape (N, 14, 8) + else: + group_mask = jnp.einsum("...a,abc->...bc",aatype, jax.nn.one_hot(residue_constants.restype_atom14_to_rigid_group, 8)) + + # r3.Rigids with shape (N, 14) + map_atoms_to_global = jax.tree_map( + lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1), + all_frames_to_global) + + # Gather the literature atom positions for each residue. + # r3.Vecs with shape (N, 14) + if jnp.issubdtype(aatype.dtype, jnp.integer): + group_pos = utils.batched_gather(residue_constants.restype_atom14_rigid_group_positions, aatype) + else: + group_pos = jnp.einsum("...a,abc->...bc", aatype, residue_constants.restype_atom14_rigid_group_positions) + lit_positions = r3.vecs_from_tensor(group_pos) + + # Transform each atom from its local frame to the global frame. + # r3.Vecs with shape (N, 14) + pred_positions = r3.rigids_mul_vecs(map_atoms_to_global, lit_positions) + + # Mask out non-existing atoms. + if jnp.issubdtype(aatype.dtype, jnp.integer): + mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype) + else: + mask = jnp.einsum("...a,ab->...b",aatype,residue_constants.restype_atom14_mask) + pred_positions = jax.tree_map(lambda x: x * mask, pred_positions) + return pred_positions + + +def extreme_ca_ca_distance_violations( + pred_atom_positions: jnp.ndarray, # (N, 37(14), 3) + pred_atom_mask: jnp.ndarray, # (N, 37(14)) + residue_index: jnp.ndarray, # (N) + max_angstrom_tolerance=1.5 + ) -> jnp.ndarray: + """Counts residues whose Ca is a large distance from its neighbour. + + Measures the fraction of CA-CA pairs between consecutive amino acids that are + more than 'max_angstrom_tolerance' apart. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + max_angstrom_tolerance: Maximum distance allowed to not count as violation. + Returns: + Fraction of consecutive CA-CA pairs with violation. + """ + this_ca_pos = pred_atom_positions[:-1, 1, :] # (N - 1, 3) + this_ca_mask = pred_atom_mask[:-1, 1] # (N - 1) + next_ca_pos = pred_atom_positions[1:, 1, :] # (N - 1, 3) + next_ca_mask = pred_atom_mask[1:, 1] # (N - 1) + has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype( + jnp.float32) + ca_ca_distance = jnp.sqrt( + 1e-6 + jnp.sum(squared_difference(this_ca_pos, next_ca_pos), axis=-1)) + violations = (ca_ca_distance - + residue_constants.ca_ca) > max_angstrom_tolerance + mask = this_ca_mask * next_ca_mask * has_no_gap_mask + return utils.mask_mean(mask=mask, value=violations) + + +def between_residue_bond_loss( + pred_atom_positions: jnp.ndarray, # (N, 37(14), 3) + pred_atom_mask: jnp.ndarray, # (N, 37(14)) + residue_index: jnp.ndarray, # (N) + aatype: jnp.ndarray, # (N) + tolerance_factor_soft=12.0, + tolerance_factor_hard=12.0 +) -> Dict[str, jnp.ndarray]: + """Flat-bottom loss to penalize structural violations between residues. + + This is a loss penalizing any violation of the geometry around the peptide + bond between consecutive amino acids. This loss corresponds to + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45. + + Args: + pred_atom_positions: Atom positions in atom37/14 representation + pred_atom_mask: Atom mask in atom37/14 representation + residue_index: Residue index for given amino acid, this is assumed to be + monotonically increasing. + aatype: Amino acid type of given residue + tolerance_factor_soft: soft tolerance factor measured in standard deviations + of pdb distributions + tolerance_factor_hard: hard tolerance factor measured in standard deviations + of pdb distributions + + Returns: + Dict containing: + * 'c_n_loss_mean': Loss for peptide bond length violations + * 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned + by CA, C, N + * 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned + by C, N, CA + * 'per_residue_loss_sum': sum of all losses for each residue + * 'per_residue_violation_mask': mask denoting all residues with violation + present. + """ + assert len(pred_atom_positions.shape) == 3 + assert len(pred_atom_mask.shape) == 2 + assert len(residue_index.shape) == 1 + assert len(aatype.shape) == 1 + + # Get the positions of the relevant backbone atoms. + this_ca_pos = pred_atom_positions[:-1, 1, :] # (N - 1, 3) + this_ca_mask = pred_atom_mask[:-1, 1] # (N - 1) + this_c_pos = pred_atom_positions[:-1, 2, :] # (N - 1, 3) + this_c_mask = pred_atom_mask[:-1, 2] # (N - 1) + next_n_pos = pred_atom_positions[1:, 0, :] # (N - 1, 3) + next_n_mask = pred_atom_mask[1:, 0] # (N - 1) + next_ca_pos = pred_atom_positions[1:, 1, :] # (N - 1, 3) + next_ca_mask = pred_atom_mask[1:, 1] # (N - 1) + has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype( + jnp.float32) + + # Compute loss for the C--N bond. + c_n_bond_length = jnp.sqrt( + 1e-6 + jnp.sum(squared_difference(this_c_pos, next_n_pos), axis=-1)) + + # The C-N bond to proline has slightly different length because of the ring. + next_is_proline = ( + aatype[1:] == residue_constants.resname_to_idx['PRO']).astype(jnp.float32) + gt_length = ( + (1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_c_n[1]) + gt_stddev = ( + (1. - next_is_proline) * + residue_constants.between_res_bond_length_stddev_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1]) + c_n_bond_length_error = jnp.sqrt(1e-6 + + jnp.square(c_n_bond_length - gt_length)) + c_n_loss_per_residue = jax.nn.relu( + c_n_bond_length_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * has_no_gap_mask + c_n_loss = jnp.sum(mask * c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6) + c_n_violation_mask = mask * ( + c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)) + + # Compute loss for the angles. + ca_c_bond_length = jnp.sqrt(1e-6 + jnp.sum( + squared_difference(this_ca_pos, this_c_pos), axis=-1)) + n_ca_bond_length = jnp.sqrt(1e-6 + jnp.sum( + squared_difference(next_n_pos, next_ca_pos), axis=-1)) + + c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[:, None] + c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[:, None] + n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[:, None] + + ca_c_n_cos_angle = jnp.sum(c_ca_unit_vec * c_n_unit_vec, axis=-1) + gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] + gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] + ca_c_n_cos_angle_error = jnp.sqrt( + 1e-6 + jnp.square(ca_c_n_cos_angle - gt_angle)) + ca_c_n_loss_per_residue = jax.nn.relu( + ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask + ca_c_n_loss = jnp.sum(mask * ca_c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6) + ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error > + (tolerance_factor_hard * gt_stddev)) + + c_n_ca_cos_angle = jnp.sum((-c_n_unit_vec) * n_ca_unit_vec, axis=-1) + gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] + gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] + c_n_ca_cos_angle_error = jnp.sqrt( + 1e-6 + jnp.square(c_n_ca_cos_angle - gt_angle)) + c_n_ca_loss_per_residue = jax.nn.relu( + c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask + c_n_ca_loss = jnp.sum(mask * c_n_ca_loss_per_residue) / (jnp.sum(mask) + 1e-6) + c_n_ca_violation_mask = mask * ( + c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)) + + # Compute a per residue loss (equally distribute the loss to both + # neighbouring residues). + per_residue_loss_sum = (c_n_loss_per_residue + + ca_c_n_loss_per_residue + + c_n_ca_loss_per_residue) + per_residue_loss_sum = 0.5 * (jnp.pad(per_residue_loss_sum, [[0, 1]]) + + jnp.pad(per_residue_loss_sum, [[1, 0]])) + + # Compute hard violations. + violation_mask = jnp.max( + jnp.stack([c_n_violation_mask, + ca_c_n_violation_mask, + c_n_ca_violation_mask]), axis=0) + violation_mask = jnp.maximum( + jnp.pad(violation_mask, [[0, 1]]), + jnp.pad(violation_mask, [[1, 0]])) + + return {'c_n_loss_mean': c_n_loss, # shape () + 'ca_c_n_loss_mean': ca_c_n_loss, # shape () + 'c_n_ca_loss_mean': c_n_ca_loss, # shape () + 'per_residue_loss_sum': per_residue_loss_sum, # shape (N) + 'per_residue_violation_mask': violation_mask # shape (N) + } + + +def between_residue_clash_loss( + atom14_pred_positions: jnp.ndarray, # (N, 14, 3) + atom14_atom_exists: jnp.ndarray, # (N, 14) + atom14_atom_radius: jnp.ndarray, # (N, 14) + residue_index: jnp.ndarray, # (N) + overlap_tolerance_soft=1.5, + overlap_tolerance_hard=1.5 +) -> Dict[str, jnp.ndarray]: + """Loss to penalize steric clashes between residues. + + This is a loss penalizing any steric clashes due to non bonded atoms in + different peptides coming too close. This loss corresponds to the part with + different residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions: Predicted positions of atoms in + global prediction frame + atom14_atom_exists: Mask denoting whether atom at positions exists for given + amino acid type + atom14_atom_radius: Van der Waals radius for each atom. + residue_index: Residue index for given amino acid. + overlap_tolerance_soft: Soft tolerance factor. + overlap_tolerance_hard: Hard tolerance factor. + + Returns: + Dict containing: + * 'mean_loss': average clash loss + * 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14) + * 'per_atom_clash_mask': mask whether atom clashes with any other atom + shape (N, 14) + """ + assert len(atom14_pred_positions.shape) == 3 + assert len(atom14_atom_exists.shape) == 2 + assert len(atom14_atom_radius.shape) == 2 + assert len(residue_index.shape) == 1 + + # Create the distance matrix. + # (N, N, 14, 14) + dists = jnp.sqrt(1e-10 + jnp.sum( + squared_difference( + atom14_pred_positions[:, None, :, None, :], + atom14_pred_positions[None, :, None, :, :]), + axis=-1)) + + # Create the mask for valid distances. + # shape (N, N, 14, 14) + dists_mask = (atom14_atom_exists[:, None, :, None] * + atom14_atom_exists[None, :, None, :]) + + # Mask out all the duplicate entries in the lower triangular matrix. + # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms + # are handled separately. + dists_mask *= ( + residue_index[:, None, None, None] < residue_index[None, :, None, None]) + + # Backbone C--N bond between subsequent residues is no clash. + c_one_hot = jax.nn.one_hot(2, num_classes=14) + n_one_hot = jax.nn.one_hot(0, num_classes=14) + neighbour_mask = ((residue_index[:, None, None, None] + + 1) == residue_index[None, :, None, None]) + c_n_bonds = neighbour_mask * c_one_hot[None, None, :, + None] * n_one_hot[None, None, None, :] + dists_mask *= (1. - c_n_bonds) + + # Disulfide bridge between two cysteines is no clash. + cys_sg_idx = residue_constants.restype_name_to_atom14_names['CYS'].index('SG') + cys_sg_one_hot = jax.nn.one_hot(cys_sg_idx, num_classes=14) + disulfide_bonds = (cys_sg_one_hot[None, None, :, None] * + cys_sg_one_hot[None, None, None, :]) + dists_mask *= (1. - disulfide_bonds) + + # Compute the lower bound for the allowed distances. + # shape (N, N, 14, 14) + dists_lower_bound = dists_mask * (atom14_atom_radius[:, None, :, None] + + atom14_atom_radius[None, :, None, :]) + + # Compute the error. + # shape (N, N, 14, 14) + dists_to_low_error = dists_mask * jax.nn.relu( + dists_lower_bound - overlap_tolerance_soft - dists) + + # Compute the mean loss. + # shape () + mean_loss = (jnp.sum(dists_to_low_error) + / (1e-6 + jnp.sum(dists_mask))) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = (jnp.sum(dists_to_low_error, axis=[0, 2]) + + jnp.sum(dists_to_low_error, axis=[1, 3])) + + # Compute the hard clash mask. + # shape (N, N, 14, 14) + clash_mask = dists_mask * ( + dists < (dists_lower_bound - overlap_tolerance_hard)) + + # Compute the per atom clash. + # shape (N, 14) + per_atom_clash_mask = jnp.maximum( + jnp.max(clash_mask, axis=[0, 2]), + jnp.max(clash_mask, axis=[1, 3])) + + return {'mean_loss': mean_loss, # shape () + 'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14) + 'per_atom_clash_mask': per_atom_clash_mask # shape (N, 14) + } + + +def within_residue_violations( + atom14_pred_positions: jnp.ndarray, # (N, 14, 3) + atom14_atom_exists: jnp.ndarray, # (N, 14) + atom14_dists_lower_bound: jnp.ndarray, # (N, 14, 14) + atom14_dists_upper_bound: jnp.ndarray, # (N, 14, 14) + tighten_bounds_for_loss=0.0, +) -> Dict[str, jnp.ndarray]: + """Loss to penalize steric clashes within residues. + + This is a loss penalizing any steric violations or clashes of non-bonded atoms + in a given peptide. This loss corresponds to the part with + the same residues of + Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46. + + Args: + atom14_pred_positions: Predicted positions of atoms in + global prediction frame + atom14_atom_exists: Mask denoting whether atom at positions exists for given + amino acid type + atom14_dists_lower_bound: Lower bound on allowed distances. + atom14_dists_upper_bound: Upper bound on allowed distances + tighten_bounds_for_loss: Extra factor to tighten loss + + Returns: + Dict containing: + * 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14) + * 'per_atom_clash_mask': mask whether atom clashes with any other atom + shape (N, 14) + """ + assert len(atom14_pred_positions.shape) == 3 + assert len(atom14_atom_exists.shape) == 2 + assert len(atom14_dists_lower_bound.shape) == 3 + assert len(atom14_dists_upper_bound.shape) == 3 + + # Compute the mask for each residue. + # shape (N, 14, 14) + dists_masks = (1. - jnp.eye(14, 14)[None]) + dists_masks *= (atom14_atom_exists[:, :, None] * + atom14_atom_exists[:, None, :]) + + # Distance matrix + # shape (N, 14, 14) + dists = jnp.sqrt(1e-10 + jnp.sum( + squared_difference( + atom14_pred_positions[:, :, None, :], + atom14_pred_positions[:, None, :, :]), + axis=-1)) + + # Compute the loss. + # shape (N, 14, 14) + dists_to_low_error = jax.nn.relu( + atom14_dists_lower_bound + tighten_bounds_for_loss - dists) + dists_to_high_error = jax.nn.relu( + dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)) + loss = dists_masks * (dists_to_low_error + dists_to_high_error) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = (jnp.sum(loss, axis=1) + + jnp.sum(loss, axis=2)) + + # Compute the violations mask. + # shape (N, 14, 14) + violations = dists_masks * ((dists < atom14_dists_lower_bound) | + (dists > atom14_dists_upper_bound)) + + # Compute the per atom violations. + # shape (N, 14) + per_atom_violations = jnp.maximum( + jnp.max(violations, axis=1), jnp.max(violations, axis=2)) + + return {'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14) + 'per_atom_violations': per_atom_violations # shape (N, 14) + } + + +def find_optimal_renaming( + atom14_gt_positions: jnp.ndarray, # (N, 14, 3) + atom14_alt_gt_positions: jnp.ndarray, # (N, 14, 3) + atom14_atom_is_ambiguous: jnp.ndarray, # (N, 14) + atom14_gt_exists: jnp.ndarray, # (N, 14) + atom14_pred_positions: jnp.ndarray, # (N, 14, 3) + atom14_atom_exists: jnp.ndarray, # (N, 14) +) -> jnp.ndarray: # (N): + """Find optimal renaming for ground truth that maximizes LDDT. + + Jumper et al. (2021) Suppl. Alg. 26 + "renameSymmetricGroundTruthAtoms" lines 1-5 + + Args: + atom14_gt_positions: Ground truth positions in global frame of ground truth. + atom14_alt_gt_positions: Alternate ground truth positions in global frame of + ground truth with coordinates of ambiguous atoms swapped relative to + 'atom14_gt_positions'. + atom14_atom_is_ambiguous: Mask denoting whether atom is among ambiguous + atoms, see Jumper et al. (2021) Suppl. Table 3 + atom14_gt_exists: Mask denoting whether atom at positions exists in ground + truth. + atom14_pred_positions: Predicted positions of atoms in + global prediction frame + atom14_atom_exists: Mask denoting whether atom at positions exists for given + amino acid type + + Returns: + Float array of shape [N] with 1. where atom14_alt_gt_positions is closer to + prediction and 0. otherwise + """ + assert len(atom14_gt_positions.shape) == 3 + assert len(atom14_alt_gt_positions.shape) == 3 + assert len(atom14_atom_is_ambiguous.shape) == 2 + assert len(atom14_gt_exists.shape) == 2 + assert len(atom14_pred_positions.shape) == 3 + assert len(atom14_atom_exists.shape) == 2 + + # Create the pred distance matrix. + # shape (N, N, 14, 14) + pred_dists = jnp.sqrt(1e-10 + jnp.sum( + squared_difference( + atom14_pred_positions[:, None, :, None, :], + atom14_pred_positions[None, :, None, :, :]), + axis=-1)) + + # Compute distances for ground truth with original and alternative names. + # shape (N, N, 14, 14) + gt_dists = jnp.sqrt(1e-10 + jnp.sum( + squared_difference( + atom14_gt_positions[:, None, :, None, :], + atom14_gt_positions[None, :, None, :, :]), + axis=-1)) + alt_gt_dists = jnp.sqrt(1e-10 + jnp.sum( + squared_difference( + atom14_alt_gt_positions[:, None, :, None, :], + atom14_alt_gt_positions[None, :, None, :, :]), + axis=-1)) + + # Compute LDDT's. + # shape (N, N, 14, 14) + lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, gt_dists)) + alt_lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, alt_gt_dists)) + + # Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms + # in cols. + # shape (N ,N, 14, 14) + mask = (atom14_gt_exists[:, None, :, None] * # rows + atom14_atom_is_ambiguous[:, None, :, None] * # rows + atom14_gt_exists[None, :, None, :] * # cols + (1. - atom14_atom_is_ambiguous[None, :, None, :])) # cols + + # Aggregate distances for each residue to the non-amibuguous atoms. + # shape (N) + per_res_lddt = jnp.sum(mask * lddt, axis=[1, 2, 3]) + alt_per_res_lddt = jnp.sum(mask * alt_lddt, axis=[1, 2, 3]) + + # Decide for each residue, whether alternative naming is better. + # shape (N) + alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).astype(jnp.float32) + + return alt_naming_is_better # shape (N) + + +def frame_aligned_point_error( + pred_frames: r3.Rigids, # shape (num_frames) + target_frames: r3.Rigids, # shape (num_frames) + frames_mask: jnp.ndarray, # shape (num_frames) + pred_positions: r3.Vecs, # shape (num_positions) + target_positions: r3.Vecs, # shape (num_positions) + positions_mask: jnp.ndarray, # shape (num_positions) + length_scale: float, + l1_clamp_distance: Optional[float] = None, + epsilon=1e-4) -> jnp.ndarray: # shape () + """Measure point error under different alignments. + + Jumper et al. (2021) Suppl. Alg. 28 "computeFAPE" + + Computes error between two structures with B points under A alignments derived + from the given pairs of frames. + Args: + pred_frames: num_frames reference frames for 'pred_positions'. + target_frames: num_frames reference frames for 'target_positions'. + frames_mask: Mask for frame pairs to use. + pred_positions: num_positions predicted positions of the structure. + target_positions: num_positions target positions of the structure. + positions_mask: Mask on which positions to score. + length_scale: length scale to divide loss by. + l1_clamp_distance: Distance cutoff on error beyond which gradients will + be zero. + epsilon: small value used to regularize denominator for masked average. + Returns: + Masked Frame Aligned Point Error. + """ + assert pred_frames.rot.xx.ndim == 1 + assert target_frames.rot.xx.ndim == 1 + assert frames_mask.ndim == 1, frames_mask.ndim + assert pred_positions.x.ndim == 1 + assert target_positions.x.ndim == 1 + assert positions_mask.ndim == 1 + + # Compute array of predicted positions in the predicted frames. + # r3.Vecs (num_frames, num_positions) + local_pred_pos = r3.rigids_mul_vecs( + jax.tree_map(lambda r: r[:, None], r3.invert_rigids(pred_frames)), + jax.tree_map(lambda x: x[None, :], pred_positions)) + + # Compute array of target positions in the target frames. + # r3.Vecs (num_frames, num_positions) + local_target_pos = r3.rigids_mul_vecs( + jax.tree_map(lambda r: r[:, None], r3.invert_rigids(target_frames)), + jax.tree_map(lambda x: x[None, :], target_positions)) + + # Compute errors between the structures. + # jnp.ndarray (num_frames, num_positions) + error_dist = jnp.sqrt( + r3.vecs_squared_distance(local_pred_pos, local_target_pos) + + epsilon) + + if l1_clamp_distance: + error_dist = jnp.clip(error_dist, 0, l1_clamp_distance) + + normed_error = error_dist / length_scale + normed_error *= jnp.expand_dims(frames_mask, axis=-1) + normed_error *= jnp.expand_dims(positions_mask, axis=-2) + + normalization_factor = ( + jnp.sum(frames_mask, axis=-1) * + jnp.sum(positions_mask, axis=-1)) + return (jnp.sum(normed_error, axis=(-2, -1)) / + (epsilon + normalization_factor)) + + +def _make_renaming_matrices(): + """Matrices to map atoms to symmetry partners in ambiguous case.""" + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative groundtruth coordinates where the naming is swapped + restype_3 = [ + residue_constants.restype_1to3[res] for res in residue_constants.restypes + ] + restype_3 += ['UNK'] + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = residue_constants.restype_name_to_atom14_names[ + resname].index(source_atom_swap) + target_index = residue_constants.restype_name_to_atom14_names[ + resname].index(target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14), dtype=np.float32) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1. + all_matrices[resname] = renaming_matrix.astype(np.float32) + renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) + return renaming_matrices + + +RENAMING_MATRICES = _make_renaming_matrices() + + +def get_alt_atom14(aatype, positions, mask): + """Get alternative atom14 positions. + + Constructs renamed atom positions for ambiguous residues. + + Jumper et al. (2021) Suppl. Table 3 "Ambiguous atom names due to 180 degree- + rotation-symmetry" + + Args: + aatype: Amino acid at given position + positions: Atom positions as r3.Vecs in atom14 representation, (N, 14) + mask: Atom masks in atom14 representation, (N, 14) + Returns: + renamed atom positions, renamed atom mask + """ + # pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14) + renaming_transform = utils.batched_gather( + jnp.asarray(RENAMING_MATRICES), aatype) + + positions = jax.tree_map(lambda x: x[:, :, None], positions) + alternative_positions = jax.tree_map( + lambda x: jnp.sum(x, axis=1), positions * renaming_transform) + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position) + alternative_mask = jnp.sum(mask[..., None] * renaming_transform, axis=1) + + return alternative_positions, alternative_mask diff --git a/af_backprop/alphafold/model/common_modules.py b/af_backprop/alphafold/model/common_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..f239c870bde49e1e5b1a7e6622c5ef4f44a37b3f --- /dev/null +++ b/af_backprop/alphafold/model/common_modules.py @@ -0,0 +1,84 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collection of common Haiku modules for use in protein folding.""" +import haiku as hk +import jax.numpy as jnp + + +class Linear(hk.Module): + """Protein folding specific Linear Module. + + This differs from the standard Haiku Linear in a few ways: + * It supports inputs of arbitrary rank + * Initializers are specified by strings + """ + + def __init__(self, + num_output: int, + initializer: str = 'linear', + use_bias: bool = True, + bias_init: float = 0., + name: str = 'linear'): + """Constructs Linear Module. + + Args: + num_output: number of output channels. + initializer: What initializer to use, should be one of {'linear', 'relu', + 'zeros'} + use_bias: Whether to include trainable bias + bias_init: Value used to initialize bias. + name: name of module, used for name scopes. + """ + + super().__init__(name=name) + self.num_output = num_output + self.initializer = initializer + self.use_bias = use_bias + self.bias_init = bias_init + + def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Connects Module. + + Args: + inputs: Tensor of shape [..., num_channel] + + Returns: + output of shape [..., num_output] + """ + n_channels = int(inputs.shape[-1]) + + weight_shape = [n_channels, self.num_output] + if self.initializer == 'linear': + weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.) + elif self.initializer == 'relu': + weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.) + elif self.initializer == 'zeros': + weight_init = hk.initializers.Constant(0.0) + + weights = hk.get_parameter('weights', weight_shape, inputs.dtype, + weight_init) + + # this is equivalent to einsum('...c,cd->...d', inputs, weights) + # but turns out to be slightly faster + inputs = jnp.swapaxes(inputs, -1, -2) + output = jnp.einsum('...cb,cd->...db', inputs, weights) + output = jnp.swapaxes(output, -1, -2) + + if self.use_bias: + bias = hk.get_parameter('bias', [self.num_output], inputs.dtype, + hk.initializers.Constant(self.bias_init)) + output += bias + + return output diff --git a/af_backprop/alphafold/model/config.py b/af_backprop/alphafold/model/config.py new file mode 100644 index 0000000000000000000000000000000000000000..a725395651c462b68528e0a5d10da14a3c098552 --- /dev/null +++ b/af_backprop/alphafold/model/config.py @@ -0,0 +1,412 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Model config.""" + +import copy +from alphafold.model.tf import shape_placeholders +import ml_collections + + +NUM_RES = shape_placeholders.NUM_RES +NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ +NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ +NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES + + +def model_config(name: str) -> ml_collections.ConfigDict: + """Get the ConfigDict of a CASP14 model.""" + + if name not in CONFIG_DIFFS: + raise ValueError(f'Invalid model name {name}.') + cfg = copy.deepcopy(CONFIG) + cfg.update_from_flattened_dict(CONFIG_DIFFS[name]) + return cfg + + +CONFIG_DIFFS = { + 'model_1': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.1.1 + 'data.common.max_extra_msa': 5120, + 'data.common.reduce_msa_clusters_by_max_templates': True, + 'data.common.use_templates': True, + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True + }, + 'model_2': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.1.2 + 'data.common.reduce_msa_clusters_by_max_templates': True, + 'data.common.use_templates': True, + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True + }, + 'model_3': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.2.1 + 'data.common.max_extra_msa': 5120, + }, + 'model_4': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.2.2 + 'data.common.max_extra_msa': 5120, + }, + 'model_5': { + # Jumper et al. (2021) Suppl. Table 5, Model 1.2.3 + }, + + # The following models are fine-tuned from the corresponding models above + # with an additional predicted_aligned_error head that can produce + # predicted TM-score (pTM) and predicted aligned errors. + 'model_1_ptm': { + 'data.common.max_extra_msa': 5120, + 'data.common.reduce_msa_clusters_by_max_templates': True, + 'data.common.use_templates': True, + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True, + 'model.heads.predicted_aligned_error.weight': 0.1 + }, + 'model_2_ptm': { + 'data.common.reduce_msa_clusters_by_max_templates': True, + 'data.common.use_templates': True, + 'model.embeddings_and_evoformer.template.embed_torsion_angles': True, + 'model.embeddings_and_evoformer.template.enabled': True, + 'model.heads.predicted_aligned_error.weight': 0.1 + }, + 'model_3_ptm': { + 'data.common.max_extra_msa': 5120, + 'model.heads.predicted_aligned_error.weight': 0.1 + }, + 'model_4_ptm': { + 'data.common.max_extra_msa': 5120, + 'model.heads.predicted_aligned_error.weight': 0.1 + }, + 'model_5_ptm': { + 'model.heads.predicted_aligned_error.weight': 0.1 + } +} + +CONFIG = ml_collections.ConfigDict({ + 'data': { + 'common': { + 'masked_msa': { + 'profile_prob': 0.1, + 'same_prob': 0.1, + 'uniform_prob': 0.1 + }, + 'max_extra_msa': 1024, + 'msa_cluster_features': True, + 'num_recycle': 3, + 'reduce_msa_clusters_by_max_templates': False, + 'resample_msa_in_recycling': True, + 'template_features': [ + 'template_all_atom_positions', 'template_sum_probs', + 'template_aatype', 'template_all_atom_masks', + 'template_domain_names' + ], + 'unsupervised_features': [ + 'aatype', 'residue_index', 'sequence', 'msa', 'domain_name', + 'num_alignments', 'seq_length', 'between_segment_residues', + 'deletion_matrix' + ], + 'use_templates': False, + }, + 'eval': { + 'feat': { + 'aatype': [NUM_RES], + 'all_atom_mask': [NUM_RES, None], + 'all_atom_positions': [NUM_RES, None, None], + 'alt_chi_angles': [NUM_RES, None], + 'atom14_alt_gt_exists': [NUM_RES, None], + 'atom14_alt_gt_positions': [NUM_RES, None, None], + 'atom14_atom_exists': [NUM_RES, None], + 'atom14_atom_is_ambiguous': [NUM_RES, None], + 'atom14_gt_exists': [NUM_RES, None], + 'atom14_gt_positions': [NUM_RES, None, None], + 'atom37_atom_exists': [NUM_RES, None], + 'backbone_affine_mask': [NUM_RES], + 'backbone_affine_tensor': [NUM_RES, None], + 'bert_mask': [NUM_MSA_SEQ, NUM_RES], + 'chi_angles': [NUM_RES, None], + 'chi_mask': [NUM_RES, None], + 'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES], + 'extra_msa_row_mask': [NUM_EXTRA_SEQ], + 'is_distillation': [], + 'msa_feat': [NUM_MSA_SEQ, NUM_RES, None], + 'msa_mask': [NUM_MSA_SEQ, NUM_RES], + 'msa_row_mask': [NUM_MSA_SEQ], + 'pseudo_beta': [NUM_RES, None], + 'pseudo_beta_mask': [NUM_RES], + 'random_crop_to_size_seed': [None], + 'residue_index': [NUM_RES], + 'residx_atom14_to_atom37': [NUM_RES, None], + 'residx_atom37_to_atom14': [NUM_RES, None], + 'resolution': [], + 'rigidgroups_alt_gt_frames': [NUM_RES, None, None], + 'rigidgroups_group_exists': [NUM_RES, None], + 'rigidgroups_group_is_ambiguous': [NUM_RES, None], + 'rigidgroups_gt_exists': [NUM_RES, None], + 'rigidgroups_gt_frames': [NUM_RES, None, None], + 'seq_length': [], + 'seq_mask': [NUM_RES], + 'target_feat': [NUM_RES, None], + 'template_aatype': [NUM_TEMPLATES, NUM_RES], + 'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None], + 'template_all_atom_positions': [ + NUM_TEMPLATES, NUM_RES, None, None], + 'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES], + 'template_backbone_affine_tensor': [ + NUM_TEMPLATES, NUM_RES, None], + 'template_mask': [NUM_TEMPLATES], + 'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None], + 'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES], + 'template_sum_probs': [NUM_TEMPLATES, None], + 'true_msa': [NUM_MSA_SEQ, NUM_RES] + }, + 'fixed_size': True, + 'subsample_templates': False, # We want top templates. + 'masked_msa_replace_fraction': 0.15, + 'max_msa_clusters': 512, + 'max_templates': 4, + 'num_ensemble': 1, + }, + }, + 'model': { + 'embeddings_and_evoformer': { + 'evoformer_num_block': 48, + 'evoformer': { + 'msa_row_attention_with_pair_bias': { + 'dropout_rate': 0.15, + 'gating': True, + 'num_head': 8, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'msa_column_attention': { + 'dropout_rate': 0.0, + 'gating': True, + 'num_head': 8, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'msa_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'outer_product_mean': { + 'chunk_size': 128, + 'dropout_rate': 0.0, + 'num_outer_channel': 32, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_attention_starting_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_attention_ending_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'triangle_multiplication_outgoing': { + 'dropout_rate': 0.25, + 'equation': 'ikc,jkc->ijc', + 'num_intermediate_channel': 128, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_incoming': { + 'dropout_rate': 0.25, + 'equation': 'kjc,kic->ijc', + 'num_intermediate_channel': 128, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'pair_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 4, + 'orientation': 'per_row', + 'shared_dropout': True + } + }, + 'extra_msa_channel': 64, + 'extra_msa_stack_num_block': 4, + 'max_relative_feature': 32, + 'custom_relative_features': False, + 'msa_channel': 256, + 'pair_channel': 128, + 'prev_pos': { + 'min_bin': 3.25, + 'max_bin': 20.75, + 'num_bins': 15 + }, + 'recycle_features': True, + 'recycle_pos': True, + 'recycle_dgram': False, + 'backprop_dgram': False, + 'backprop_dgram_temp': 1.0, + 'seq_channel': 384, + 'template': { + 'attention': { + 'gating': False, + 'key_dim': 64, + 'num_head': 4, + 'value_dim': 64 + }, + 'dgram_features': { + 'min_bin': 3.25, + 'max_bin': 50.75, + 'num_bins': 39 + }, + 'backprop_dgram': False, + 'backprop_dgram_temp': 1.0, + 'embed_torsion_angles': False, + 'enabled': False, + 'template_pair_stack': { + 'num_block': 2, + 'triangle_attention_starting_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'key_dim': 64, + 'num_head': 4, + 'orientation': 'per_row', + 'shared_dropout': True, + 'value_dim': 64 + }, + 'triangle_attention_ending_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'key_dim': 64, + 'num_head': 4, + 'orientation': 'per_column', + 'shared_dropout': True, + 'value_dim': 64 + }, + 'triangle_multiplication_outgoing': { + 'dropout_rate': 0.25, + 'equation': 'ikc,jkc->ijc', + 'num_intermediate_channel': 64, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_incoming': { + 'dropout_rate': 0.25, + 'equation': 'kjc,kic->ijc', + 'num_intermediate_channel': 64, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'pair_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 2, + 'orientation': 'per_row', + 'shared_dropout': True + } + }, + 'max_templates': 4, + 'subbatch_size': 128, + 'use_template_unit_vector': False, + } + }, + 'global_config': { + 'mixed_precision': False, + 'deterministic': False, + 'subbatch_size': 4, + 'use_remat': False, + 'zero_init': True + }, + 'heads': { + 'distogram': { + 'first_break': 2.3125, + 'last_break': 21.6875, + 'num_bins': 64, + 'weight': 0.3 + }, + 'predicted_aligned_error': { + # `num_bins - 1` bins uniformly space the + # [0, max_error_bin A] range. + # The final bin covers [max_error_bin A, +infty] + # 31A gives bins with 0.5A width. + 'max_error_bin': 31., + 'num_bins': 64, + 'num_channels': 128, + 'filter_by_resolution': True, + 'min_resolution': 0.1, + 'max_resolution': 3.0, + 'weight': 0.0, + }, + 'experimentally_resolved': { + 'filter_by_resolution': True, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'weight': 0.01 + }, + 'structure_module': { + 'num_layer': 8, + 'fape': { + 'clamp_distance': 10.0, + 'clamp_type': 'relu', + 'loss_unit_distance': 10.0 + }, + 'angle_norm_weight': 0.01, + 'chi_weight': 0.5, + 'clash_overlap_tolerance': 1.5, + 'compute_in_graph_metrics': True, + 'dropout': 0.1, + 'num_channel': 384, + 'num_head': 12, + 'num_layer_in_transition': 3, + 'num_point_qk': 4, + 'num_point_v': 8, + 'num_scalar_qk': 16, + 'num_scalar_v': 16, + 'position_scale': 10.0, + 'sidechain': { + 'atom_clamp_distance': 10.0, + 'num_channel': 128, + 'num_residual_block': 2, + 'weight_frac': 0.5, + 'length_scale': 10., + }, + 'structural_violation_loss_weight': 1.0, + 'violation_tolerance_factor': 12.0, + 'weight': 1.0 + }, + 'predicted_lddt': { + 'filter_by_resolution': True, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'num_bins': 50, + 'num_channels': 128, + 'weight': 0.01 + }, + 'masked_msa': { + 'num_output': 23, + 'weight': 2.0 + }, + }, + 'num_recycle': 3, + 'backprop_recycle': False, + 'resample_msa_in_recycling': True, + 'add_prev': False, + 'use_struct': True, + }, +}) diff --git a/af_backprop/alphafold/model/data.py b/af_backprop/alphafold/model/data.py new file mode 100644 index 0000000000000000000000000000000000000000..249cdb3158f94d3f3ff4ae04c971b903284fdc09 --- /dev/null +++ b/af_backprop/alphafold/model/data.py @@ -0,0 +1,39 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convenience functions for reading data.""" + +import io +import os +from typing import List +from alphafold.model import utils +import haiku as hk +import numpy as np +# Internal import (7716). + + +def casp_model_names(data_dir: str) -> List[str]: + params = os.listdir(os.path.join(data_dir, 'params')) + return [os.path.splitext(filename)[0] for filename in params] + + +def get_model_haiku_params(model_name: str, data_dir: str) -> hk.Params: + """Get the Haiku parameters from a model name.""" + + path = os.path.join(data_dir, 'params', f'params_{model_name}.npz') + + with open(path, 'rb') as f: + params = np.load(io.BytesIO(f.read()), allow_pickle=False) + + return utils.flat_params_to_haiku(params) diff --git a/af_backprop/alphafold/model/features.py b/af_backprop/alphafold/model/features.py new file mode 100644 index 0000000000000000000000000000000000000000..b31b277e02d66aa94013cef914ed035e7f041edc --- /dev/null +++ b/af_backprop/alphafold/model/features.py @@ -0,0 +1,102 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Code to generate processed features.""" +import copy +from typing import List, Mapping, Tuple +from alphafold.model.tf import input_pipeline +from alphafold.model.tf import proteins_dataset +import ml_collections +import numpy as np +import tensorflow.compat.v1 as tf + +FeatureDict = Mapping[str, np.ndarray] + + +def make_data_config( + config: ml_collections.ConfigDict, + num_res: int, + ) -> Tuple[ml_collections.ConfigDict, List[str]]: + """Makes a data config for the input pipeline.""" + cfg = copy.deepcopy(config.data) + + feature_names = cfg.common.unsupervised_features + if cfg.common.use_templates: + feature_names += cfg.common.template_features + + with cfg.unlocked(): + cfg.eval.crop_size = num_res + + return cfg, feature_names + + +def tf_example_to_features(tf_example: tf.train.Example, + config: ml_collections.ConfigDict, + random_seed: int = 0) -> FeatureDict: + """Converts tf_example to numpy feature dictionary.""" + num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0]) + cfg, feature_names = make_data_config(config, num_res=num_res) + + if 'deletion_matrix_int' in set(tf_example.features.feature): + deletion_matrix_int = ( + tf_example.features.feature['deletion_matrix_int'].int64_list.value) + feat = tf.train.Feature(float_list=tf.train.FloatList( + value=map(float, deletion_matrix_int))) + tf_example.features.feature['deletion_matrix'].CopyFrom(feat) + del tf_example.features.feature['deletion_matrix_int'] + + tf_graph = tf.Graph() + with tf_graph.as_default(), tf.device('/device:CPU:0'): + tf.compat.v1.set_random_seed(random_seed) + tensor_dict = proteins_dataset.create_tensor_dict( + raw_data=tf_example.SerializeToString(), + features=feature_names) + processed_batch = input_pipeline.process_tensors_from_config( + tensor_dict, cfg) + + tf_graph.finalize() + + with tf.Session(graph=tf_graph) as sess: + features = sess.run(processed_batch) + + return {k: v for k, v in features.items() if v.dtype != 'O'} + + +def np_example_to_features(np_example: FeatureDict, + config: ml_collections.ConfigDict, + random_seed: int = 0) -> FeatureDict: + """Preprocesses NumPy feature dict using TF pipeline.""" + np_example = dict(np_example) + num_res = int(np_example['seq_length'][0]) + cfg, feature_names = make_data_config(config, num_res=num_res) + + if 'deletion_matrix_int' in np_example: + np_example['deletion_matrix'] = ( + np_example.pop('deletion_matrix_int').astype(np.float32)) + + tf_graph = tf.Graph() + with tf_graph.as_default(), tf.device('/device:CPU:0'): + tf.compat.v1.set_random_seed(random_seed) + tensor_dict = proteins_dataset.np_to_tensor_dict( + np_example=np_example, features=feature_names) + + processed_batch = input_pipeline.process_tensors_from_config( + tensor_dict, cfg) + + tf_graph.finalize() + + with tf.Session(graph=tf_graph) as sess: + features = sess.run(processed_batch) + + return {k: v for k, v in features.items() if v.dtype != 'O'} diff --git a/af_backprop/alphafold/model/folding.py b/af_backprop/alphafold/model/folding.py new file mode 100644 index 0000000000000000000000000000000000000000..b802229aa6235e1fd00e55674cdd180d2724b193 --- /dev/null +++ b/af_backprop/alphafold/model/folding.py @@ -0,0 +1,1016 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modules and utilities for the structure module.""" + +import functools +from typing import Dict +from alphafold.common import residue_constants +from alphafold.model import all_atom +from alphafold.model import common_modules +from alphafold.model import prng +from alphafold.model import quat_affine +from alphafold.model import r3 +from alphafold.model import utils +import haiku as hk +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np + + +def squared_difference(x, y): + return jnp.square(x - y) + + +class InvariantPointAttention(hk.Module): + """Invariant Point attention module. + + The high-level idea is that this attention module works over a set of points + and associated orientations in 3D space (e.g. protein residues). + + Each residue outputs a set of queries and keys as points in their local + reference frame. The attention is then defined as the euclidean distance + between the queries and keys in the global frame. + + Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointAttention" + """ + + def __init__(self, + config, + global_config, + dist_epsilon=1e-8, + name='invariant_point_attention'): + """Initialize. + + Args: + config: Structure Module Config + global_config: Global Config of Model. + dist_epsilon: Small value to avoid NaN in distance calculation. + name: Haiku Module name. + """ + super().__init__(name=name) + + self._dist_epsilon = dist_epsilon + self._zero_initialize_last = global_config.zero_init + + self.config = config + + self.global_config = global_config + + def __call__(self, inputs_1d, inputs_2d, mask, affine): + """Compute geometry-aware attention. + + Given a set of query residues (defined by affines and associated scalar + features), this function computes geometry-aware attention between the + query residues and target residues. + + The residues produce points in their local reference frame, which + are converted into the global frame in order to compute attention via + euclidean distance. + + Equivalently, the target residues produce points in their local frame to be + used as attention values, which are converted into the query residues' + local frames. + + Args: + inputs_1d: (N, C) 1D input embedding that is the basis for the + scalar queries. + inputs_2d: (N, M, C') 2D input embedding, used for biases and values. + mask: (N, 1) mask to indicate which elements of inputs_1d participate + in the attention. + affine: QuatAffine object describing the position and orientation of + every element in inputs_1d. + + Returns: + Transformation of the input embedding. + """ + num_residues, _ = inputs_1d.shape + + # Improve readability by removing a large number of 'self's. + num_head = self.config.num_head + num_scalar_qk = self.config.num_scalar_qk + num_point_qk = self.config.num_point_qk + num_scalar_v = self.config.num_scalar_v + num_point_v = self.config.num_point_v + num_output = self.config.num_channel + + assert num_scalar_qk > 0 + assert num_point_qk > 0 + assert num_point_v > 0 + + # Construct scalar queries of shape: + # [num_query_residues, num_head, num_points] + q_scalar = common_modules.Linear( + num_head * num_scalar_qk, name='q_scalar')( + inputs_1d) + q_scalar = jnp.reshape( + q_scalar, [num_residues, num_head, num_scalar_qk]) + + # Construct scalar keys/values of shape: + # [num_target_residues, num_head, num_points] + kv_scalar = common_modules.Linear( + num_head * (num_scalar_v + num_scalar_qk), name='kv_scalar')( + inputs_1d) + kv_scalar = jnp.reshape(kv_scalar, + [num_residues, num_head, + num_scalar_v + num_scalar_qk]) + k_scalar, v_scalar = jnp.split(kv_scalar, [num_scalar_qk], axis=-1) + + # Construct query points of shape: + # [num_residues, num_head, num_point_qk] + + # First construct query points in local frame. + q_point_local = common_modules.Linear( + num_head * 3 * num_point_qk, name='q_point_local')( + inputs_1d) + q_point_local = jnp.split(q_point_local, 3, axis=-1) + # Project query points into global frame. + q_point_global = affine.apply_to_point(q_point_local, extra_dims=1) + # Reshape query point for later use. + q_point = [ + jnp.reshape(x, [num_residues, num_head, num_point_qk]) + for x in q_point_global] + + # Construct key and value points. + # Key points have shape [num_residues, num_head, num_point_qk] + # Value points have shape [num_residues, num_head, num_point_v] + + # Construct key and value points in local frame. + kv_point_local = common_modules.Linear( + num_head * 3 * (num_point_qk + num_point_v), name='kv_point_local')( + inputs_1d) + kv_point_local = jnp.split(kv_point_local, 3, axis=-1) + # Project key and value points into global frame. + kv_point_global = affine.apply_to_point(kv_point_local, extra_dims=1) + kv_point_global = [ + jnp.reshape(x, [num_residues, + num_head, (num_point_qk + num_point_v)]) + for x in kv_point_global] + # Split key and value points. + k_point, v_point = list( + zip(*[ + jnp.split(x, [num_point_qk,], axis=-1) + for x in kv_point_global + ])) + + # We assume that all queries and keys come iid from N(0, 1) distribution + # and compute the variances of the attention logits. + # Each scalar pair (q, k) contributes Var q*k = 1 + scalar_variance = max(num_scalar_qk, 1) * 1. + # Each point pair (q, k) contributes Var [0.5 ||q||^2 - ] = 9 / 2 + point_variance = max(num_point_qk, 1) * 9. / 2 + + # Allocate equal variance to scalar, point and attention 2d parts so that + # the sum is 1. + + num_logit_terms = 3 + + scalar_weights = np.sqrt(1.0 / (num_logit_terms * scalar_variance)) + point_weights = np.sqrt(1.0 / (num_logit_terms * point_variance)) + attention_2d_weights = np.sqrt(1.0 / (num_logit_terms)) + + # Trainable per-head weights for points. + trainable_point_weights = jax.nn.softplus(hk.get_parameter( + 'trainable_point_weights', shape=[num_head], + # softplus^{-1} (1) + init=hk.initializers.Constant(np.log(np.exp(1.) - 1.)))) + point_weights *= jnp.expand_dims(trainable_point_weights, axis=1) + + v_point = [jnp.swapaxes(x, -2, -3) for x in v_point] + + q_point = [jnp.swapaxes(x, -2, -3) for x in q_point] + k_point = [jnp.swapaxes(x, -2, -3) for x in k_point] + dist2 = [ + squared_difference(qx[:, :, None, :], kx[:, None, :, :]) + for qx, kx in zip(q_point, k_point) + ] + dist2 = sum(dist2) + attn_qk_point = -0.5 * jnp.sum( + point_weights[:, None, None, :] * dist2, axis=-1) + + v = jnp.swapaxes(v_scalar, -2, -3) + q = jnp.swapaxes(scalar_weights * q_scalar, -2, -3) + k = jnp.swapaxes(k_scalar, -2, -3) + attn_qk_scalar = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) + attn_logits = attn_qk_scalar + attn_qk_point + + attention_2d = common_modules.Linear( + num_head, name='attention_2d')( + inputs_2d) + + attention_2d = jnp.transpose(attention_2d, [2, 0, 1]) + attention_2d = attention_2d_weights * attention_2d + attn_logits += attention_2d + + mask_2d = mask * jnp.swapaxes(mask, -1, -2) + attn_logits -= 1e5 * (1. - mask_2d) + + # [num_head, num_query_residues, num_target_residues] + attn = jax.nn.softmax(attn_logits) + + # [num_head, num_query_residues, num_head * num_scalar_v] + result_scalar = jnp.matmul(attn, v) + + # For point result, implement matmul manually so that it will be a float32 + # on TPU. This is equivalent to + # result_point_global = [jnp.einsum('bhqk,bhkc->bhqc', attn, vx) + # for vx in v_point] + # but on the TPU, doing the multiply and reduce_sum ensures the + # computation happens in float32 instead of bfloat16. + result_point_global = [jnp.sum( + attn[:, :, :, None] * vx[:, None, :, :], + axis=-2) for vx in v_point] + + # [num_query_residues, num_head, num_head * num_(scalar|point)_v] + result_scalar = jnp.swapaxes(result_scalar, -2, -3) + result_point_global = [ + jnp.swapaxes(x, -2, -3) + for x in result_point_global] + + # Features used in the linear output projection. Should have the size + # [num_query_residues, ?] + output_features = [] + + result_scalar = jnp.reshape( + result_scalar, [num_residues, num_head * num_scalar_v]) + output_features.append(result_scalar) + + result_point_global = [ + jnp.reshape(r, [num_residues, num_head * num_point_v]) + for r in result_point_global] + result_point_local = affine.invert_point(result_point_global, extra_dims=1) + output_features.extend(result_point_local) + + output_features.append(jnp.sqrt(self._dist_epsilon + + jnp.square(result_point_local[0]) + + jnp.square(result_point_local[1]) + + jnp.square(result_point_local[2]))) + + # Dimensions: h = heads, i and j = residues, + # c = inputs_2d channels + # Contraction happens over the second residue dimension, similarly to how + # the usual attention is performed. + result_attention_over_2d = jnp.einsum('hij, ijc->ihc', attn, inputs_2d) + num_out = num_head * result_attention_over_2d.shape[-1] + output_features.append( + jnp.reshape(result_attention_over_2d, + [num_residues, num_out])) + + final_init = 'zeros' if self._zero_initialize_last else 'linear' + + final_act = jnp.concatenate(output_features, axis=-1) + + return common_modules.Linear( + num_output, + initializer=final_init, + name='output_projection')(final_act) + + +class FoldIteration(hk.Module): + """A single iteration of the main structure module loop. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" lines 6-21 + + First, each residue attends to all residues using InvariantPointAttention. + Then, we apply transition layers to update the hidden representations. + Finally, we use the hidden representations to produce an update to the + affine of each residue. + """ + + def __init__(self, config, global_config, + name='fold_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + activations, + sequence_mask, + update_affine, + is_training, + initial_act, + safe_key=None, + static_feat_2d=None, + aatype=None, + scale_rate=1.0): + c = self.config + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + def safe_dropout_fn(tensor, safe_key): + return prng.safe_dropout( + tensor=tensor, + safe_key=safe_key, + rate=c.dropout * scale_rate, + is_deterministic=self.global_config.deterministic, + is_training=is_training) + + affine = quat_affine.QuatAffine.from_tensor(activations['affine']) + + act = activations['act'] + attention_module = InvariantPointAttention(self.config, self.global_config) + # Attention + attn = attention_module( + inputs_1d=act, + inputs_2d=static_feat_2d, + mask=sequence_mask, + affine=affine) + act += attn + safe_key, *sub_keys = safe_key.split(3) + sub_keys = iter(sub_keys) + act = safe_dropout_fn(act, next(sub_keys)) + act = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='attention_layer_norm')( + act) + + final_init = 'zeros' if self.global_config.zero_init else 'linear' + + # Transition + input_act = act + for i in range(c.num_layer_in_transition): + init = 'relu' if i < c.num_layer_in_transition - 1 else final_init + act = common_modules.Linear( + c.num_channel, + initializer=init, + name='transition')( + act) + if i < c.num_layer_in_transition - 1: + act = jax.nn.relu(act) + act += input_act + act = safe_dropout_fn(act, next(sub_keys)) + act = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='transition_layer_norm')(act) + + if update_affine: + # This block corresponds to + # Jumper et al. (2021) Alg. 23 "Backbone update" + affine_update_size = 6 + + # Affine update + affine_update = common_modules.Linear( + affine_update_size, + initializer=final_init, + name='affine_update')( + act) + + affine = affine.pre_compose(affine_update) + + sc = MultiRigidSidechain(c.sidechain, self.global_config)( + affine.scale_translation(c.position_scale), [act, initial_act], aatype) + + outputs = {'affine': affine.to_tensor(), 'sc': sc} + + # affine = affine.apply_rotation_tensor_fn(jax.lax.stop_gradient) + + new_activations = { + 'act': act, + 'affine': affine.to_tensor() + } + return new_activations, outputs + + +def generate_affines(representations, batch, config, global_config, + is_training, safe_key): + """Generate predicted affines for a single chain. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" + + This is the main part of the structure module - it iteratively applies + folding to produce a set of predicted residue positions. + + Args: + representations: Representations dictionary. + batch: Batch dictionary. + config: Config for the structure module. + global_config: Global config. + is_training: Whether the model is being trained. + safe_key: A prng.SafeKey object that wraps a PRNG key. + + Returns: + A dictionary containing residue affines and sidechain positions. + """ + c = config + sequence_mask = batch['seq_mask'][:, None] + + act = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='single_layer_norm')( + representations['single']) + + initial_act = act + act = common_modules.Linear( + c.num_channel, name='initial_projection')( + act) + + affine = generate_new_affine(sequence_mask) + + fold_iteration = FoldIteration( + c, global_config, name='fold_iteration') + + assert len(batch['seq_mask'].shape) == 1 + + activations = {'act': act, + 'affine': affine.to_tensor(), + } + + act_2d = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='pair_layer_norm')( + representations['pair']) + + def fold_iter(x,_): + x["key"], key = x["key"].split() + x["act"], out = fold_iteration( + x["act"], + initial_act=initial_act, + static_feat_2d=act_2d, + safe_key=key, + sequence_mask=sequence_mask, + update_affine=True, + is_training=is_training, + aatype=batch['aatype'], + scale_rate=batch["scale_rate"]) + return x, out + x = {"act":activations,"key":safe_key} + x, output = hk.scan(fold_iter, x, None, c.num_layer) + activations = x["act"] + + # Include the activations in the output dict for use by the LDDT-Head. + output['act'] = activations['act'] + + return output + + +class dummy(hk.Module): + def __init__(self, config, global_config, compute_loss=True): + super().__init__(name="dummy") + def __call__(self, representations, batch, is_training, safe_key=None): + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + return {} + +class StructureModule(hk.Module): + """StructureModule as a network head. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" + """ + + def __init__(self, config, global_config, compute_loss=True, + name='structure_module'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + self.compute_loss = compute_loss + + def __call__(self, representations, batch, is_training, + safe_key=None): + c = self.config + ret = {} + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + output = generate_affines( + representations=representations, + batch=batch, + config=self.config, + global_config=self.global_config, + is_training=is_training, + safe_key=safe_key) + + ret['representations'] = {'structure_module': output['act']} + + ret['traj'] = output['affine'] * jnp.array([1.] * 4 + [c.position_scale] * 3) + ret['sidechains'] = output['sc'] + atom14_pred_positions = r3.vecs_to_tensor(output['sc']['atom_pos'])[-1] + ret['final_atom14_positions'] = atom14_pred_positions # (N, 14, 3) + ret['final_atom14_mask'] = batch['atom14_atom_exists'] # (N, 14) + + atom37_pred_positions = all_atom.atom14_to_atom37(atom14_pred_positions, batch) + atom37_pred_positions *= batch['atom37_atom_exists'][:, :, None] + ret['final_atom_positions'] = atom37_pred_positions # (N, 37, 3) + ret['final_atom_mask'] = batch['atom37_atom_exists'] # (N, 37) + ret['final_affines'] = ret['traj'][-1] + + return ret + + def loss(self, value, batch): + ret = {'loss': 0.} + + ret['metrics'] = {} + # If requested, compute in-graph metrics. + if self.config.compute_in_graph_metrics: + atom14_pred_positions = value['final_atom14_positions'] + # Compute renaming and violations. + value.update(compute_renamed_ground_truth(batch, atom14_pred_positions)) + value['violations'] = find_structural_violations( + batch, atom14_pred_positions, self.config) + + # Several violation metrics: + violation_metrics = compute_violation_metrics( + batch=batch, + atom14_pred_positions=atom14_pred_positions, + violations=value['violations']) + ret['metrics'].update(violation_metrics) + + backbone_loss(ret, batch, value, self.config) + + if 'renamed_atom14_gt_positions' not in value: + value.update(compute_renamed_ground_truth( + batch, value['final_atom14_positions'])) + sc_loss = sidechain_loss(batch, value, self.config) + + ret['loss'] = ((1 - self.config.sidechain.weight_frac) * ret['loss'] + + self.config.sidechain.weight_frac * sc_loss['loss']) + ret['sidechain_fape'] = sc_loss['fape'] + + supervised_chi_loss(ret, batch, value, self.config) + + if self.config.structural_violation_loss_weight: + if 'violations' not in value: + value['violations'] = find_structural_violations( + batch, value['final_atom14_positions'], self.config) + structural_violation_loss(ret, batch, value, self.config) + + return ret + + +def compute_renamed_ground_truth( + batch: Dict[str, jnp.ndarray], + atom14_pred_positions: jnp.ndarray, + ) -> Dict[str, jnp.ndarray]: + """Find optimal renaming of ground truth based on the predicted positions. + + Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" + + This renamed ground truth is then used for all losses, + such that each loss moves the atoms in the same direction. + Shape (N). + + Args: + batch: Dictionary containing: + * atom14_gt_positions: Ground truth positions. + * atom14_alt_gt_positions: Ground truth positions with renaming swaps. + * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by + renaming swaps. + * atom14_gt_exists: Mask for which atoms exist in ground truth. + * atom14_alt_gt_exists: Mask for which atoms exist in ground truth + after renaming. + * atom14_atom_exists: Mask for whether each atom is part of the given + amino acid type. + atom14_pred_positions: Array of atom positions in global frame with shape + (N, 14, 3). + Returns: + Dictionary containing: + alt_naming_is_better: Array with 1.0 where alternative swap is better. + renamed_atom14_gt_positions: Array of optimal ground truth positions + after renaming swaps are performed. + renamed_atom14_gt_exists: Mask after renaming swap is performed. + """ + alt_naming_is_better = all_atom.find_optimal_renaming( + atom14_gt_positions=batch['atom14_gt_positions'], + atom14_alt_gt_positions=batch['atom14_alt_gt_positions'], + atom14_atom_is_ambiguous=batch['atom14_atom_is_ambiguous'], + atom14_gt_exists=batch['atom14_gt_exists'], + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch['atom14_atom_exists']) + + renamed_atom14_gt_positions = ( + (1. - alt_naming_is_better[:, None, None]) + * batch['atom14_gt_positions'] + + alt_naming_is_better[:, None, None] + * batch['atom14_alt_gt_positions']) + + renamed_atom14_gt_mask = ( + (1. - alt_naming_is_better[:, None]) * batch['atom14_gt_exists'] + + alt_naming_is_better[:, None] * batch['atom14_alt_gt_exists']) + + return { + 'alt_naming_is_better': alt_naming_is_better, # (N) + 'renamed_atom14_gt_positions': renamed_atom14_gt_positions, # (N, 14, 3) + 'renamed_atom14_gt_exists': renamed_atom14_gt_mask, # (N, 14) + } + + +def backbone_loss(ret, batch, value, config): + """Backbone FAPE Loss. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" line 17 + + Args: + ret: Dictionary to write outputs into, needs to contain 'loss'. + batch: Batch, needs to contain 'backbone_affine_tensor', + 'backbone_affine_mask'. + value: Dictionary containing structure module output, needs to contain + 'traj', a trajectory of rigids. + config: Configuration of loss, should contain 'fape.clamp_distance' and + 'fape.loss_unit_distance'. + """ + affine_trajectory = quat_affine.QuatAffine.from_tensor(value['traj']) + rigid_trajectory = r3.rigids_from_quataffine(affine_trajectory) + + if 'backbone_affine_tensor' in batch: + gt_affine = quat_affine.QuatAffine.from_tensor(batch['backbone_affine_tensor']) + backbone_mask = batch['backbone_affine_mask'] + else: + n_xyz = batch['all_atom_positions'][...,0,:] + ca_xyz = batch['all_atom_positions'][...,1,:] + c_xyz = batch['all_atom_positions'][...,2,:] + rot, trans = quat_affine.make_transform_from_reference(n_xyz, ca_xyz, c_xyz) + gt_affine = quat_affine.QuatAffine(quaternion=None, + translation=trans, + rotation=rot, + unstack_inputs=True) + backbone_mask = batch['all_atom_mask'][...,0] + + gt_rigid = r3.rigids_from_quataffine(gt_affine) + + fape_loss_fn = functools.partial( + all_atom.frame_aligned_point_error, + l1_clamp_distance=config.fape.clamp_distance, + length_scale=config.fape.loss_unit_distance) + + fape_loss_fn = jax.vmap(fape_loss_fn, (0, None, None, 0, None, None)) + fape_loss = fape_loss_fn(rigid_trajectory, gt_rigid, backbone_mask, + rigid_trajectory.trans, gt_rigid.trans, + backbone_mask) + + if 'use_clamped_fape' in batch: + # Jumper et al. (2021) Suppl. Sec. 1.11.5 "Loss clamping details" + use_clamped_fape = jnp.asarray(batch['use_clamped_fape'], jnp.float32) + unclamped_fape_loss_fn = functools.partial( + all_atom.frame_aligned_point_error, + l1_clamp_distance=None, + length_scale=config.fape.loss_unit_distance) + unclamped_fape_loss_fn = jax.vmap(unclamped_fape_loss_fn, + (0, None, None, 0, None, None)) + fape_loss_unclamped = unclamped_fape_loss_fn(rigid_trajectory, gt_rigid, + backbone_mask, + rigid_trajectory.trans, + gt_rigid.trans, + backbone_mask) + + fape_loss = (fape_loss * use_clamped_fape + fape_loss_unclamped * (1 - use_clamped_fape)) + + ret['fape'] = fape_loss[-1] + ret['loss'] += jnp.mean(fape_loss) + + +def sidechain_loss(batch, value, config): + """All Atom FAPE Loss using renamed rigids.""" + # Rename Frames + # Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" line 7 + alt_naming_is_better = value['alt_naming_is_better'] + renamed_gt_frames = ( + (1. - alt_naming_is_better[:, None, None]) + * batch['rigidgroups_gt_frames'] + + alt_naming_is_better[:, None, None] + * batch['rigidgroups_alt_gt_frames']) + + flat_gt_frames = r3.rigids_from_tensor_flat12(jnp.reshape(renamed_gt_frames, [-1, 12])) + flat_frames_mask = jnp.reshape(batch['rigidgroups_gt_exists'], [-1]) + + flat_gt_positions = r3.vecs_from_tensor(jnp.reshape(value['renamed_atom14_gt_positions'], [-1, 3])) + flat_positions_mask = jnp.reshape(value['renamed_atom14_gt_exists'], [-1]) + + # Compute frame_aligned_point_error score for the final layer. + pred_frames = value['sidechains']['frames'] + pred_positions = value['sidechains']['atom_pos'] + + def _slice_last_layer_and_flatten(x): + return jnp.reshape(x[-1], [-1]) + + flat_pred_frames = jax.tree_map(_slice_last_layer_and_flatten, pred_frames) + flat_pred_positions = jax.tree_map(_slice_last_layer_and_flatten, pred_positions) + # FAPE Loss on sidechains + fape = all_atom.frame_aligned_point_error( + pred_frames=flat_pred_frames, + target_frames=flat_gt_frames, + frames_mask=flat_frames_mask, + pred_positions=flat_pred_positions, + target_positions=flat_gt_positions, + positions_mask=flat_positions_mask, + l1_clamp_distance=config.sidechain.atom_clamp_distance, + length_scale=config.sidechain.length_scale) + + return { + 'fape': fape, + 'loss': fape} + + +def structural_violation_loss(ret, batch, value, config): + """Computes loss for structural violations.""" + assert config.sidechain.weight_frac + + # Put all violation losses together to one large loss. + violations = value['violations'] + num_atoms = jnp.sum(batch['atom14_atom_exists']).astype(jnp.float32) + ret['loss'] += (config.structural_violation_loss_weight * ( + violations['between_residues']['bonds_c_n_loss_mean'] + + violations['between_residues']['angles_ca_c_n_loss_mean'] + + violations['between_residues']['angles_c_n_ca_loss_mean'] + + jnp.sum( + violations['between_residues']['clashes_per_atom_loss_sum'] + + violations['within_residues']['per_atom_loss_sum']) / + (1e-6 + num_atoms))) + + +def find_structural_violations( + batch: Dict[str, jnp.ndarray], + atom14_pred_positions: jnp.ndarray, # (N, 14, 3) + config: ml_collections.ConfigDict + ): + """Computes several checks for structural violations.""" + + # Compute between residue backbone violations of bonds and angles. + connection_violations = all_atom.between_residue_bond_loss( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32), + residue_index=batch['residue_index'].astype(jnp.float32), + aatype=batch['aatype'], + tolerance_factor_soft=config.violation_tolerance_factor, + tolerance_factor_hard=config.violation_tolerance_factor) + + # Compute the Van der Waals radius for every atom + # (the first letter of the atom name is the element type). + # Shape: (N, 14). + atomtype_radius = [ + residue_constants.van_der_waals_radius[name[0]] + for name in residue_constants.atom_types + ] + atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather( + atomtype_radius, batch['residx_atom14_to_atom37']) + + # Compute the between residue clash loss. + between_residue_clashes = all_atom.between_residue_clash_loss( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch['atom14_atom_exists'], + atom14_atom_radius=atom14_atom_radius, + residue_index=batch['residue_index'], + overlap_tolerance_soft=config.clash_overlap_tolerance, + overlap_tolerance_hard=config.clash_overlap_tolerance) + + # Compute all within-residue violations (clashes, + # bond length and angle violations). + restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( + overlap_tolerance=config.clash_overlap_tolerance, + bond_length_tolerance_factor=config.violation_tolerance_factor) + atom14_dists_lower_bound = utils.batched_gather( + restype_atom14_bounds['lower_bound'], batch['aatype']) + atom14_dists_upper_bound = utils.batched_gather( + restype_atom14_bounds['upper_bound'], batch['aatype']) + within_residue_violations = all_atom.within_residue_violations( + atom14_pred_positions=atom14_pred_positions, + atom14_atom_exists=batch['atom14_atom_exists'], + atom14_dists_lower_bound=atom14_dists_lower_bound, + atom14_dists_upper_bound=atom14_dists_upper_bound, + tighten_bounds_for_loss=0.0) + + # Combine them to a single per-residue violation mask (used later for LDDT). + per_residue_violations_mask = jnp.max(jnp.stack([ + connection_violations['per_residue_violation_mask'], + jnp.max(between_residue_clashes['per_atom_clash_mask'], axis=-1), + jnp.max(within_residue_violations['per_atom_violations'], + axis=-1)]), axis=0) + + return { + 'between_residues': { + 'bonds_c_n_loss_mean': + connection_violations['c_n_loss_mean'], # () + 'angles_ca_c_n_loss_mean': + connection_violations['ca_c_n_loss_mean'], # () + 'angles_c_n_ca_loss_mean': + connection_violations['c_n_ca_loss_mean'], # () + 'connections_per_residue_loss_sum': + connection_violations['per_residue_loss_sum'], # (N) + 'connections_per_residue_violation_mask': + connection_violations['per_residue_violation_mask'], # (N) + 'clashes_mean_loss': + between_residue_clashes['mean_loss'], # () + 'clashes_per_atom_loss_sum': + between_residue_clashes['per_atom_loss_sum'], # (N, 14) + 'clashes_per_atom_clash_mask': + between_residue_clashes['per_atom_clash_mask'], # (N, 14) + }, + 'within_residues': { + 'per_atom_loss_sum': + within_residue_violations['per_atom_loss_sum'], # (N, 14) + 'per_atom_violations': + within_residue_violations['per_atom_violations'], # (N, 14), + }, + 'total_per_residue_violations_mask': + per_residue_violations_mask, # (N) + } + + +def compute_violation_metrics( + batch: Dict[str, jnp.ndarray], + atom14_pred_positions: jnp.ndarray, # (N, 14, 3) + violations: Dict[str, jnp.ndarray], + ) -> Dict[str, jnp.ndarray]: + """Compute several metrics to assess the structural violations.""" + + ret = {} + extreme_ca_ca_violations = all_atom.extreme_ca_ca_distance_violations( + pred_atom_positions=atom14_pred_positions, + pred_atom_mask=batch['atom14_atom_exists'].astype(jnp.float32), + residue_index=batch['residue_index'].astype(jnp.float32)) + ret['violations_extreme_ca_ca_distance'] = extreme_ca_ca_violations + ret['violations_between_residue_bond'] = utils.mask_mean( + mask=batch['seq_mask'], + value=violations['between_residues'][ + 'connections_per_residue_violation_mask']) + ret['violations_between_residue_clash'] = utils.mask_mean( + mask=batch['seq_mask'], + value=jnp.max( + violations['between_residues']['clashes_per_atom_clash_mask'], + axis=-1)) + ret['violations_within_residue'] = utils.mask_mean( + mask=batch['seq_mask'], + value=jnp.max( + violations['within_residues']['per_atom_violations'], axis=-1)) + ret['violations_per_residue'] = utils.mask_mean( + mask=batch['seq_mask'], + value=violations['total_per_residue_violations_mask']) + return ret + + +def supervised_chi_loss(ret, batch, value, config): + """Computes loss for direct chi angle supervision. + + Jumper et al. (2021) Suppl. Alg. 27 "torsionAngleLoss" + + Args: + ret: Dictionary to write outputs into, needs to contain 'loss'. + batch: Batch, needs to contain 'seq_mask', 'chi_mask', 'chi_angles'. + value: Dictionary containing structure module output, needs to contain + value['sidechains']['angles_sin_cos'] for angles and + value['sidechains']['unnormalized_angles_sin_cos'] for unnormalized + angles. + config: Configuration of loss, should contain 'chi_weight' and + 'angle_norm_weight', 'angle_norm_weight' scales angle norm term, + 'chi_weight' scales torsion term. + """ + eps = 1e-6 + + sequence_mask = batch['seq_mask'] + num_res = sequence_mask.shape[0] + chi_mask = batch['chi_mask'].astype(jnp.float32) + pred_angles = jnp.reshape( + value['sidechains']['angles_sin_cos'], [-1, num_res, 7, 2]) + pred_angles = pred_angles[:, :, 3:] + + residue_type_one_hot = jax.nn.one_hot( + batch['aatype'], residue_constants.restype_num + 1, + dtype=jnp.float32)[None] + chi_pi_periodic = jnp.einsum('ijk, kl->ijl', residue_type_one_hot, + jnp.asarray(residue_constants.chi_pi_periodic)) + + true_chi = batch['chi_angles'][None] + sin_true_chi = jnp.sin(true_chi) + cos_true_chi = jnp.cos(true_chi) + sin_cos_true_chi = jnp.stack([sin_true_chi, cos_true_chi], axis=-1) + + # This is -1 if chi is pi-periodic and +1 if it's 2pi-periodic + shifted_mask = (1 - 2 * chi_pi_periodic)[..., None] + sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi + + sq_chi_error = jnp.sum( + squared_difference(sin_cos_true_chi, pred_angles), -1) + sq_chi_error_shifted = jnp.sum( + squared_difference(sin_cos_true_chi_shifted, pred_angles), -1) + sq_chi_error = jnp.minimum(sq_chi_error, sq_chi_error_shifted) + + sq_chi_loss = utils.mask_mean(mask=chi_mask[None], value=sq_chi_error) + ret['chi_loss'] = sq_chi_loss + ret['loss'] += config.chi_weight * sq_chi_loss + unnormed_angles = jnp.reshape( + value['sidechains']['unnormalized_angles_sin_cos'], [-1, num_res, 7, 2]) + angle_norm = jnp.sqrt(jnp.sum(jnp.square(unnormed_angles), axis=-1) + eps) + norm_error = jnp.abs(angle_norm - 1.) + angle_norm_loss = utils.mask_mean(mask=sequence_mask[None, :, None], + value=norm_error) + + ret['angle_norm_loss'] = angle_norm_loss + ret['loss'] += config.angle_norm_weight * angle_norm_loss + + +def generate_new_affine(sequence_mask): + num_residues, _ = sequence_mask.shape + quaternion = jnp.tile( + jnp.reshape(jnp.asarray([1., 0., 0., 0.]), [1, 4]), + [num_residues, 1]) + + translation = jnp.zeros([num_residues, 3]) + return quat_affine.QuatAffine(quaternion, translation, unstack_inputs=True) + + +def l2_normalize(x, axis=-1, epsilon=1e-12): + return x / jnp.sqrt( + jnp.maximum(jnp.sum(x**2, axis=axis, keepdims=True), epsilon)) + + +class MultiRigidSidechain(hk.Module): + """Class to make side chain atoms.""" + + def __init__(self, config, global_config, name='rigid_sidechain'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, affine, representations_list, aatype): + """Predict side chains using multi-rigid representations. + + Args: + affine: The affines for each residue (translations in angstroms). + representations_list: A list of activations to predict side chains from. + aatype: Amino acid types. + + Returns: + Dict containing atom positions and frames (in angstroms). + """ + act = [ + common_modules.Linear( # pylint: disable=g-complex-comprehension + self.config.num_channel, + name='input_projection')(jax.nn.relu(x)) + for x in representations_list + ] + # Sum the activation list (equivalent to concat then Linear). + act = sum(act) + + final_init = 'zeros' if self.global_config.zero_init else 'linear' + + # Mapping with some residual blocks. + for _ in range(self.config.num_residual_block): + old_act = act + act = common_modules.Linear( + self.config.num_channel, + initializer='relu', + name='resblock1')( + jax.nn.relu(act)) + act = common_modules.Linear( + self.config.num_channel, + initializer=final_init, + name='resblock2')( + jax.nn.relu(act)) + act += old_act + + # Map activations to torsion angles. Shape: (num_res, 14). + num_res = act.shape[0] + unnormalized_angles = common_modules.Linear( + 14, name='unnormalized_angles')( + jax.nn.relu(act)) + unnormalized_angles = jnp.reshape( + unnormalized_angles, [num_res, 7, 2]) + angles = l2_normalize(unnormalized_angles, axis=-1) + + outputs = { + 'angles_sin_cos': angles, # jnp.ndarray (N, 7, 2) + 'unnormalized_angles_sin_cos': + unnormalized_angles, # jnp.ndarray (N, 7, 2) + } + + # Map torsion angles to frames. + backb_to_global = r3.rigids_from_quataffine(affine) + + # Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" + + # r3.Rigids with shape (N, 8). + all_frames_to_global = all_atom.torsion_angles_to_frames( + aatype, + backb_to_global, + angles) + + # Use frames and literature positions to create the final atom coordinates. + # r3.Vecs with shape (N, 14). + pred_positions = all_atom.frames_and_literature_positions_to_atom14_pos( + aatype, all_frames_to_global) + + outputs.update({ + 'atom_pos': pred_positions, # r3.Vecs (N, 14) + 'frames': all_frames_to_global, # r3.Rigids (N, 8) + }) + return outputs diff --git a/af_backprop/alphafold/model/layer_stack.py b/af_backprop/alphafold/model/layer_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..cbbb0dcb26445ec8ce57149f31aba9fc4de2863c --- /dev/null +++ b/af_backprop/alphafold/model/layer_stack.py @@ -0,0 +1,274 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Function to stack repeats of a layer function without shared parameters.""" + +import collections +import contextlib +import functools +import inspect +from typing import Any, Callable, Optional, Tuple, Union + +import haiku as hk +import jax +import jax.numpy as jnp + +LayerStackCarry = collections.namedtuple('LayerStackCarry', ['x', 'rng']) +LayerStackScanned = collections.namedtuple('LayerStackScanned', + ['i', 'args_ys']) + +# WrappedFn should take in arbitrarily nested `jnp.ndarray`, and return the +# exact same type. We cannot express this with `typing`. So we just use it +# to inform the user. In reality, the typing below will accept anything. +NestedArray = Any +WrappedFn = Callable[..., Union[NestedArray, Tuple[NestedArray]]] + + +def _check_no_varargs(f): + if list(inspect.signature( + f).parameters.values())[0].kind == inspect.Parameter.VAR_POSITIONAL: + raise ValueError( + 'The function `f` should not have any `varargs` (that is *args) ' + 'argument. Instead, it should only use explicit positional' + 'arguments.') + + +@contextlib.contextmanager +def nullcontext(): + yield + + +def maybe_with_rng(key): + if key is not None: + return hk.with_rng(key) + else: + return nullcontext() + + +def maybe_fold_in(key, data): + if key is not None: + return jax.random.fold_in(key, data) + else: + return None + + +class _LayerStack(hk.Module): + """Module to compose parameterized functions, implemented as a scan.""" + + def __init__(self, + count: int, + unroll: int, + name: Optional[str] = None): + """Iterate a function `f` `count` times, with non-shared parameters.""" + super().__init__(name=name) + self._count = count + self._unroll = unroll + + def __call__(self, x, *args_ys): + count = self._count + if hk.running_init(): + # At initialization time, we run just one layer but add an extra first + # dimension to every initialized tensor, making sure to use different + # random keys for different slices. + def creator(next_creator, shape, dtype, init, context): + del context + + def multi_init(shape, dtype): + assert shape[0] == count + key = hk.maybe_next_rng_key() + + def rng_context_init(slice_idx): + slice_key = maybe_fold_in(key, slice_idx) + with maybe_with_rng(slice_key): + return init(shape[1:], dtype) + + return jax.vmap(rng_context_init)(jnp.arange(count)) + + return next_creator((count,) + tuple(shape), dtype, multi_init) + + def getter(next_getter, value, context): + trailing_dims = len(context.original_shape) + 1 + sliced_value = jax.lax.index_in_dim( + value, index=0, axis=value.ndim - trailing_dims, keepdims=False) + return next_getter(sliced_value) + + with hk.experimental.custom_creator( + creator), hk.experimental.custom_getter(getter): + if len(args_ys) == 1 and args_ys[0] is None: + args0 = (None,) + else: + args0 = [ + jax.lax.dynamic_index_in_dim(ys, 0, keepdims=False) + for ys in args_ys + ] + x, z = self._call_wrapped(x, *args0) + if z is None: + return x, z + + # Broadcast state to hold each layer state. + def broadcast_state(layer_state): + return jnp.broadcast_to( + layer_state, [count,] + list(layer_state.shape)) + zs = jax.tree_util.tree_map(broadcast_state, z) + return x, zs + else: + # Use scan during apply, threading through random seed so that it's + # unique for each layer. + def layer(carry: LayerStackCarry, scanned: LayerStackScanned): + rng = carry.rng + + def getter(next_getter, value, context): + # Getter slices the full param at the current loop index. + trailing_dims = len(context.original_shape) + 1 + assert value.shape[value.ndim - trailing_dims] == count, ( + f'Attempting to use a parameter stack of size ' + f'{value.shape[value.ndim - trailing_dims]} for a LayerStack of ' + f'size {count}.') + + sliced_value = jax.lax.dynamic_index_in_dim( + value, scanned.i, axis=value.ndim - trailing_dims, keepdims=False) + return next_getter(sliced_value) + + with hk.experimental.custom_getter(getter): + if rng is None: + out_x, z = self._call_wrapped(carry.x, *scanned.args_ys) + else: + rng, rng_ = jax.random.split(rng) + with hk.with_rng(rng_): + out_x, z = self._call_wrapped(carry.x, *scanned.args_ys) + return LayerStackCarry(x=out_x, rng=rng), z + + carry = LayerStackCarry(x=x, rng=hk.maybe_next_rng_key()) + scanned = LayerStackScanned(i=jnp.arange(count, dtype=jnp.int32), + args_ys=args_ys) + + carry, zs = hk.scan( + layer, carry, scanned, length=count, unroll=self._unroll) + return carry.x, zs + + def _call_wrapped(self, + x: jnp.ndarray, + *args, + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + raise NotImplementedError() + + +class _LayerStackNoState(_LayerStack): + """_LayerStack impl with no per-layer state provided to the function.""" + + def __init__(self, + f: WrappedFn, + count: int, + unroll: int, + name: Optional[str] = None): + super().__init__(count=count, unroll=unroll, name=name) + _check_no_varargs(f) + self._f = f + + @hk.transparent + def _call_wrapped(self, args, y): + del y + ret = self._f(*args) + if len(args) == 1: + # If the function takes a single argument, the wrapped function receives + # a tuple of length 1, and therefore it must return a tuple of length 1. + ret = (ret,) + return ret, None + + +class _LayerStackWithState(_LayerStack): + """_LayerStack impl with per-layer state provided to the function.""" + + def __init__(self, + f: WrappedFn, + count: int, + unroll: int, + name: Optional[str] = None): + super().__init__(count=count, unroll=unroll, name=name) + self._f = f + + @hk.transparent + def _call_wrapped(self, x, *args): + return self._f(x, *args) + + +def layer_stack(num_layers: int, + with_state=False, + unroll: int = 1, + name: Optional[str] = None): + """Utility to wrap a Haiku function and recursively apply it to an input. + + A function is valid if it uses only explicit position parameters, and + its return type matches its input type. The position parameters can be + arbitrarily nested structures with `jnp.ndarray` at the leaf nodes. Note + that kwargs are not supported, neither are functions with variable number + of parameters (specified by `*args`). + + If `with_state=False` then the new, wrapped function can be understood as + performing the following: + ``` + for i in range(num_layers): + x = f(x) + return x + ``` + + And if `with_state=True`, assuming `f` takes two arguments on top of `x`: + ``` + for i in range(num_layers): + x, zs[i] = f(x, ys_0[i], ys_1[i]) + return x, zs + ``` + The code using `layer_stack` for the above function would be: + ``` + def f(x, y_0, y_1): + ... + return new_x, z + x, zs = layer_stack.layer_stack(num_layers, + with_state=True)(f)(x, ys_0, ys_1) + ``` + + Crucially, any parameters created inside `f` will not be shared across + iterations. + + Args: + num_layers: The number of times to iterate the wrapped function. + with_state: Whether or not to pass per-layer state to the wrapped function. + unroll: the unroll used by `scan`. + name: Name of the Haiku context. + + Returns: + Callable that will produce a layer stack when called with a valid function. + """ + def iterate(f): + if with_state: + @functools.wraps(f) + def wrapped(x, *args): + for ys in args: + assert ys.shape[0] == num_layers + return _LayerStackWithState( + f, num_layers, unroll=unroll, name=name)(x, *args) + else: + _check_no_varargs(f) + @functools.wraps(f) + def wrapped(*args): + ret = _LayerStackNoState( + f, num_layers, unroll=unroll, name=name)(args, None)[0] + if len(args) == 1: + # If the function takes a single argument, we must also return a + # single value, and not a tuple of length 1. + ret = ret[0] + return ret + + return wrapped + return iterate diff --git a/af_backprop/alphafold/model/lddt.py b/af_backprop/alphafold/model/lddt.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2a3f9c9c427b8da547cabf75ff48a7b9fc1844 --- /dev/null +++ b/af_backprop/alphafold/model/lddt.py @@ -0,0 +1,88 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""lDDT protein distance score.""" +import jax.numpy as jnp + + +def lddt(predicted_points, + true_points, + true_points_mask, + cutoff=15., + per_residue=False): + """Measure (approximate) lDDT for a batch of coordinates. + + lDDT reference: + Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local + superposition-free score for comparing protein structures and models using + distance difference tests. Bioinformatics 29, 2722–2728 (2013). + + lDDT is a measure of the difference between the true distance matrix and the + distance matrix of the predicted points. The difference is computed only on + points closer than cutoff *in the true structure*. + + This function does not compute the exact lDDT value that the original paper + describes because it does not include terms for physical feasibility + (e.g. bond length violations). Therefore this is only an approximate + lDDT score. + + Args: + predicted_points: (batch, length, 3) array of predicted 3D points + true_points: (batch, length, 3) array of true 3D points + true_points_mask: (batch, length, 1) binary-valued float array. This mask + should be 1 for points that exist in the true points. + cutoff: Maximum distance for a pair of points to be included + per_residue: If true, return score for each residue. Note that the overall + lDDT is not exactly the mean of the per_residue lDDT's because some + residues have more contacts than others. + + Returns: + An (approximate, see above) lDDT score in the range 0-1. + """ + + assert len(predicted_points.shape) == 3 + assert predicted_points.shape[-1] == 3 + assert true_points_mask.shape[-1] == 1 + assert len(true_points_mask.shape) == 3 + + # Compute true and predicted distance matrices. + dmat_true = jnp.sqrt(1e-10 + jnp.sum( + (true_points[:, :, None] - true_points[:, None, :])**2, axis=-1)) + + dmat_predicted = jnp.sqrt(1e-10 + jnp.sum( + (predicted_points[:, :, None] - + predicted_points[:, None, :])**2, axis=-1)) + + dists_to_score = ( + (dmat_true < cutoff).astype(jnp.float32) * true_points_mask * + jnp.transpose(true_points_mask, [0, 2, 1]) * + (1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction. + ) + + # Shift unscored distances to be far away. + dist_l1 = jnp.abs(dmat_true - dmat_predicted) + + # True lDDT uses a number of fixed bins. + # We ignore the physical plausibility correction to lDDT, though. + score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) + + (dist_l1 < 1.0).astype(jnp.float32) + + (dist_l1 < 2.0).astype(jnp.float32) + + (dist_l1 < 4.0).astype(jnp.float32)) + + # Normalize over the appropriate axes. + reduce_axes = (-1,) if per_residue else (-2, -1) + norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes)) + score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes)) + + return score diff --git a/af_backprop/alphafold/model/mapping.py b/af_backprop/alphafold/model/mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..6041f7c769d4e32f34af2e3528bfba9128cc521f --- /dev/null +++ b/af_backprop/alphafold/model/mapping.py @@ -0,0 +1,218 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Specialized mapping functions.""" + +import functools + +from typing import Any, Callable, Optional, Sequence, Union + +import haiku as hk +import jax +import jax.numpy as jnp + + +PYTREE = Any +PYTREE_JAX_ARRAY = Any + +partial = functools.partial +PROXY = object() + + +def _maybe_slice(array, i, slice_size, axis): + if axis is PROXY: + return array + else: + return jax.lax.dynamic_slice_in_dim( + array, i, slice_size=slice_size, axis=axis) + + +def _maybe_get_size(array, axis): + if axis == PROXY: + return -1 + else: + return array.shape[axis] + + +def _expand_axes(axes, values, name='sharded_apply'): + values_tree_def = jax.tree_flatten(values)[1] + flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes) + # Replace None's with PROXY + flat_axes = [PROXY if x is None else x for x in flat_axes] + return jax.tree_unflatten(values_tree_def, flat_axes) + + +def sharded_map( + fun: Callable[..., PYTREE_JAX_ARRAY], + shard_size: Union[int, None] = 1, + in_axes: Union[int, PYTREE] = 0, + out_axes: Union[int, PYTREE] = 0) -> Callable[..., PYTREE_JAX_ARRAY]: + """Sharded vmap. + + Maps `fun` over axes, in a way similar to vmap, but does so in shards of + `shard_size`. This allows a smooth trade-off between memory usage + (as in a plain map) vs higher throughput (as in a vmap). + + Args: + fun: Function to apply smap transform to. + shard_size: Integer denoting shard size. + in_axes: Either integer or pytree describing which axis to map over for each + input to `fun`, None denotes broadcasting. + out_axes: integer or pytree denoting to what axis in the output the mapped + over axis maps. + + Returns: + function with smap applied. + """ + vmapped_fun = hk.vmap(fun, in_axes, out_axes) + return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes) + + +def sharded_apply( + fun: Callable[..., PYTREE_JAX_ARRAY], # pylint: disable=g-bare-generic + shard_size: Union[int, None] = 1, + in_axes: Union[int, PYTREE] = 0, + out_axes: Union[int, PYTREE] = 0, + new_out_axes: bool = False) -> Callable[..., PYTREE_JAX_ARRAY]: + """Sharded apply. + + Applies `fun` over shards to axes, in a way similar to vmap, + but does so in shards of `shard_size`. Shards are stacked after. + This allows a smooth trade-off between + memory usage (as in a plain map) vs higher throughput (as in a vmap). + + Args: + fun: Function to apply smap transform to. + shard_size: Integer denoting shard size. + in_axes: Either integer or pytree describing which axis to map over for each + input to `fun`, None denotes broadcasting. + out_axes: integer or pytree denoting to what axis in the output the mapped + over axis maps. + new_out_axes: whether to stack outputs on new axes. This assumes that the + output sizes for each shard (including the possible remainder shard) are + the same. + + Returns: + function with smap applied. + """ + docstr = ('Mapped version of {fun}. Takes similar arguments to {fun} ' + 'but with additional array axes over which {fun} is mapped.') + if new_out_axes: + raise NotImplementedError('New output axes not yet implemented.') + + # shard size None denotes no sharding + if shard_size is None: + return fun + + @jax.util.wraps(fun, docstr=docstr) + def mapped_fn(*args): + # Expand in axes and Determine Loop range + in_axes_ = _expand_axes(in_axes, args) + + in_sizes = jax.tree_util.tree_map(_maybe_get_size, args, in_axes_) + flat_sizes = jax.tree_flatten(in_sizes)[0] + in_size = max(flat_sizes) + assert all(i in {in_size, -1} for i in flat_sizes) + + num_extra_shards = (in_size - 1) // shard_size + + # Fix Up if necessary + last_shard_size = in_size % shard_size + last_shard_size = shard_size if last_shard_size == 0 else last_shard_size + + def apply_fun_to_slice(slice_start, slice_size): + input_slice = jax.tree_util.tree_map( + lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis + ), args, in_axes_) + return fun(*input_slice) + + remainder_shape_dtype = hk.eval_shape( + partial(apply_fun_to_slice, 0, last_shard_size)) + out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype) + out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype) + out_axes_ = _expand_axes(out_axes, remainder_shape_dtype) + + if num_extra_shards > 0: + regular_shard_shape_dtype = hk.eval_shape( + partial(apply_fun_to_slice, 0, shard_size)) + shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype) + + def make_output_shape(axis, shard_shape, remainder_shape): + return shard_shape[:axis] + ( + shard_shape[axis] * num_extra_shards + + remainder_shape[axis],) + shard_shape[axis + 1:] + + out_shapes = jax.tree_util.tree_map(make_output_shape, out_axes_, shard_shapes, + out_shapes) + + # Calls dynamic Update slice with different argument order + # This is here since tree_multimap only works with positional arguments + def dynamic_update_slice_in_dim(full_array, update, axis, i): + return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis) + + def compute_shard(outputs, slice_start, slice_size): + slice_out = apply_fun_to_slice(slice_start, slice_size) + update_slice = partial( + dynamic_update_slice_in_dim, i=slice_start) + return jax.tree_util.tree_map(update_slice, outputs, slice_out, out_axes_) + + def scan_iteration(outputs, i): + new_outputs = compute_shard(outputs, i, shard_size) + return new_outputs, () + + slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size) + + def allocate_buffer(dtype, shape): + return jnp.zeros(shape, dtype=dtype) + + outputs = jax.tree_util.tree_map(allocate_buffer, out_dtypes, out_shapes) + + if slice_starts.shape[0] > 0: + outputs, _ = hk.scan(scan_iteration, outputs, slice_starts) + + if last_shard_size != shard_size: + remainder_start = in_size - last_shard_size + outputs = compute_shard(outputs, remainder_start, last_shard_size) + + return outputs + + return mapped_fn + + +def inference_subbatch( + module: Callable[..., PYTREE_JAX_ARRAY], + subbatch_size: int, + batched_args: Sequence[PYTREE_JAX_ARRAY], + nonbatched_args: Sequence[PYTREE_JAX_ARRAY], + low_memory: bool = True, + input_subbatch_dim: int = 0, + output_subbatch_dim: Optional[int] = None) -> PYTREE_JAX_ARRAY: + """Run through subbatches (like batch apply but with split and concat).""" + assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test + + if not low_memory: + args = list(batched_args) + list(nonbatched_args) + return module(*args) + + if output_subbatch_dim is None: + output_subbatch_dim = input_subbatch_dim + + def run_module(*batched_args): + args = list(batched_args) + list(nonbatched_args) + return module(*args) + sharded_module = sharded_apply(run_module, + shard_size=subbatch_size, + in_axes=input_subbatch_dim, + out_axes=output_subbatch_dim) + return sharded_module(*batched_args) diff --git a/af_backprop/alphafold/model/model.py b/af_backprop/alphafold/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fd72d2d17693c0a315d4af4ca7bba8f4bc42e992 --- /dev/null +++ b/af_backprop/alphafold/model/model.py @@ -0,0 +1,145 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Code for constructing the model.""" +from typing import Any, Mapping, Optional, Union + +from absl import logging +from alphafold.common import confidence +from alphafold.model import features +from alphafold.model import modules +import haiku as hk +import jax +import ml_collections +import numpy as np +import tensorflow.compat.v1 as tf +import tree + + +def get_confidence_metrics( + prediction_result: Mapping[str, Any]) -> Mapping[str, Any]: + """Post processes prediction_result to get confidence metrics.""" + + confidence_metrics = {} + confidence_metrics['plddt'] = confidence.compute_plddt( + prediction_result['predicted_lddt']['logits']) + if 'predicted_aligned_error' in prediction_result: + confidence_metrics.update(confidence.compute_predicted_aligned_error( + prediction_result['predicted_aligned_error']['logits'], + prediction_result['predicted_aligned_error']['breaks'])) + confidence_metrics['ptm'] = confidence.predicted_tm_score( + prediction_result['predicted_aligned_error']['logits'], + prediction_result['predicted_aligned_error']['breaks']) + + return confidence_metrics + + +class RunModel: + """Container for JAX model.""" + + def __init__(self, + config: ml_collections.ConfigDict, + params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None, + is_training=True, + return_representations=True): + self.config = config + self.params = params + + def _forward_fn(batch): + model = modules.AlphaFold(self.config.model) + return model( + batch, + is_training=is_training, + compute_loss=False, + ensemble_representations=False, + return_representations=return_representations) + + self.apply = jax.jit(hk.transform(_forward_fn).apply) + self.init = jax.jit(hk.transform(_forward_fn).init) + + def init_params(self, feat: features.FeatureDict, random_seed: int = 0): + """Initializes the model parameters. + + If none were provided when this class was instantiated then the parameters + are randomly initialized. + + Args: + feat: A dictionary of NumPy feature arrays as output by + RunModel.process_features. + random_seed: A random seed to use to initialize the parameters if none + were set when this class was initialized. + """ + if not self.params: + # Init params randomly. + rng = jax.random.PRNGKey(random_seed) + self.params = hk.data_structures.to_mutable_dict( + self.init(rng, feat)) + logging.warning('Initialized parameters randomly') + + def process_features( + self, + raw_features: Union[tf.train.Example, features.FeatureDict], + random_seed: int) -> features.FeatureDict: + """Processes features to prepare for feeding them into the model. + + Args: + raw_features: The output of the data pipeline either as a dict of NumPy + arrays or as a tf.train.Example. + random_seed: The random seed to use when processing the features. + + Returns: + A dict of NumPy feature arrays suitable for feeding into the model. + """ + if isinstance(raw_features, dict): + return features.np_example_to_features( + np_example=raw_features, + config=self.config, + random_seed=random_seed) + else: + return features.tf_example_to_features( + tf_example=raw_features, + config=self.config, + random_seed=random_seed) + + def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct: + self.init_params(feat) + logging.info('Running eval_shape with shape(feat) = %s', + tree.map_structure(lambda x: x.shape, feat)) + shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0), feat) + logging.info('Output shape was %s', shape) + return shape + + def predict(self, feat: features.FeatureDict) -> Mapping[str, Any]: + """Makes a prediction by inferencing the model on the provided features. + + Args: + feat: A dictionary of NumPy feature arrays as output by + RunModel.process_features. + + Returns: + A dictionary of model outputs. + """ + self.init_params(feat) + logging.info('Running predict with shape(feat) = %s', + tree.map_structure(lambda x: x.shape, feat)) + result = self.apply(self.params, jax.random.PRNGKey(0), feat) + # This block is to ensure benchmark timings are accurate. Some blocking is + # already happening when computing get_confidence_metrics, and this ensures + # all outputs are blocked on. + jax.tree_map(lambda x: x.block_until_ready(), result) + if self.config.use_struct: + result.update(get_confidence_metrics(result)) + logging.info('Output shape was %s', + tree.map_structure(lambda x: x.shape, result)) + return result diff --git a/af_backprop/alphafold/model/modules.py b/af_backprop/alphafold/model/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..f1ca2147d8e2170f338e2f9c6d0272f3cd4893bc --- /dev/null +++ b/af_backprop/alphafold/model/modules.py @@ -0,0 +1,2164 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modules and code used in the core part of AlphaFold. + +The structure generation code is in 'folding.py'. +""" +import functools +from alphafold.common import residue_constants +from alphafold.model import all_atom +from alphafold.model import common_modules +from alphafold.model import folding +from alphafold.model import layer_stack +from alphafold.model import lddt +from alphafold.model import mapping +from alphafold.model import prng +from alphafold.model import quat_affine +from alphafold.model import utils +import haiku as hk +import jax +import jax.numpy as jnp + +from alphafold.model.r3 import Rigids, Rots, Vecs + + +def softmax_cross_entropy(logits, labels): + """Computes softmax cross entropy given logits and one-hot class labels.""" + loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1) + return jnp.asarray(loss) + + +def sigmoid_cross_entropy(logits, labels): + """Computes sigmoid cross entropy given logits and multiple class labels.""" + log_p = jax.nn.log_sigmoid(logits) + # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter is more numerically stable + log_not_p = jax.nn.log_sigmoid(-logits) + loss = -labels * log_p - (1. - labels) * log_not_p + return jnp.asarray(loss) + + +def apply_dropout(*, tensor, safe_key, rate, is_training, broadcast_dim=None): + """Applies dropout to a tensor.""" + if is_training: # and rate != 0.0: + shape = list(tensor.shape) + if broadcast_dim is not None: + shape[broadcast_dim] = 1 + keep_rate = 1.0 - rate + keep = jax.random.bernoulli(safe_key.get(), keep_rate, shape=shape) + return keep * tensor / keep_rate + else: + return tensor + + +def dropout_wrapper(module, + input_act, + mask, + safe_key, + global_config, + output_act=None, + is_training=True, + scale_rate=1.0, + **kwargs): + """Applies module + dropout + residual update.""" + if output_act is None: + output_act = input_act + + gc = global_config + residual = module(input_act, mask, is_training=is_training, **kwargs) + dropout_rate = 0.0 if gc.deterministic else module.config.dropout_rate + + if module.config.shared_dropout: + if module.config.orientation == 'per_row': + broadcast_dim = 0 + else: + broadcast_dim = 1 + else: + broadcast_dim = None + + residual = apply_dropout(tensor=residual, + safe_key=safe_key, + rate=dropout_rate * scale_rate, + is_training=is_training, + broadcast_dim=broadcast_dim) + + new_act = output_act + residual + + return new_act + + +def create_extra_msa_feature(batch): + """Expand extra_msa into 1hot and concat with other extra msa features. + + We do this as late as possible as the one_hot extra msa can be very large. + + Arguments: + batch: a dictionary with the following keys: + * 'extra_msa': [N_extra_seq, N_res] MSA that wasn't selected as a cluster + centre. Note, that this is not one-hot encoded. + * 'extra_has_deletion': [N_extra_seq, N_res] Whether there is a deletion to + the left of each position in the extra MSA. + * 'extra_deletion_value': [N_extra_seq, N_res] The number of deletions to + the left of each position in the extra MSA. + + Returns: + Concatenated tensor of extra MSA features. + """ + # 23 = 20 amino acids + 'X' for unknown + gap + bert mask + msa_1hot = jax.nn.one_hot(batch['extra_msa'], 23) + msa_feat = [msa_1hot, + jnp.expand_dims(batch['extra_has_deletion'], axis=-1), + jnp.expand_dims(batch['extra_deletion_value'], axis=-1)] + return jnp.concatenate(msa_feat, axis=-1) + + +class AlphaFoldIteration(hk.Module): + """A single recycling iteration of AlphaFold architecture. + + Computes ensembled (averaged) representations from the provided features. + These representations are then passed to the various heads + that have been requested by the configuration file. Each head also returns a + loss which is combined as a weighted sum to produce the total loss. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 3-22 + """ + + def __init__(self, config, global_config, name='alphafold_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + ensembled_batch, + non_ensembled_batch, + is_training, + compute_loss=False, + ensemble_representations=False, + return_representations=False): + + num_ensemble = jnp.asarray(ensembled_batch['seq_length'].shape[0]) + + if not ensemble_representations: + assert ensembled_batch['seq_length'].shape[0] == 1 + + def slice_batch(i): + b = {k: v[i] for k, v in ensembled_batch.items()} + b.update(non_ensembled_batch) + return b + + # Compute representations for each batch element and average. + evoformer_module = EmbeddingsAndEvoformer( + self.config.embeddings_and_evoformer, self.global_config) + batch0 = slice_batch(0) + representations = evoformer_module(batch0, is_training) + + # MSA representations are not ensembled so + # we don't pass tensor into the loop. + msa_representation = representations['msa'] + del representations['msa'] + + # Average the representations (except MSA) over the batch dimension. + if ensemble_representations: + def body(x): + """Add one element to the representations ensemble.""" + i, current_representations = x + feats = slice_batch(i) + representations_update = evoformer_module( + feats, is_training) + + new_representations = {} + for k in current_representations: + new_representations[k] = ( + current_representations[k] + representations_update[k]) + return i+1, new_representations + + if hk.running_init(): + # When initializing the Haiku module, run one iteration of the + # while_loop to initialize the Haiku modules used in `body`. + _, representations = body((1, representations)) + else: + _, representations = hk.while_loop( + lambda x: x[0] < num_ensemble, + body, + (1, representations)) + + for k in representations: + if k != 'msa': + representations[k] /= num_ensemble.astype(representations[k].dtype) + + representations['msa'] = msa_representation + batch = batch0 # We are not ensembled from here on. + + if jnp.issubdtype(ensembled_batch['aatype'].dtype, jnp.integer): + _, num_residues = ensembled_batch['aatype'].shape + else: + _, num_residues, _ = ensembled_batch['aatype'].shape + + if self.config.use_struct: + struct_module = folding.StructureModule + else: + struct_module = folding.dummy + + heads = {} + for head_name, head_config in sorted(self.config.heads.items()): + if not head_config.weight: + continue # Do not instantiate zero-weight heads. + head_factory = { + 'masked_msa': MaskedMsaHead, + 'distogram': DistogramHead, + 'structure_module': functools.partial(struct_module, compute_loss=compute_loss), + 'predicted_lddt': PredictedLDDTHead, + 'predicted_aligned_error': PredictedAlignedErrorHead, + 'experimentally_resolved': ExperimentallyResolvedHead, + }[head_name] + heads[head_name] = (head_config, + head_factory(head_config, self.global_config)) + + total_loss = 0. + ret = {} + ret['representations'] = representations + + def loss(module, head_config, ret, name, filter_ret=True): + if filter_ret: + value = ret[name] + else: + value = ret + loss_output = module.loss(value, batch) + ret[name].update(loss_output) + loss = head_config.weight * ret[name]['loss'] + return loss + + for name, (head_config, module) in heads.items(): + # Skip PredictedLDDTHead and PredictedAlignedErrorHead until + # StructureModule is executed. + if name in ('predicted_lddt', 'predicted_aligned_error'): + continue + else: + ret[name] = module(representations, batch, is_training) + if 'representations' in ret[name]: + # Extra representations from the head. Used by the structure module + # to provide activations for the PredictedLDDTHead. + representations.update(ret[name].pop('representations')) + if compute_loss: + total_loss += loss(module, head_config, ret, name) + + if self.config.use_struct: + if self.config.heads.get('predicted_lddt.weight', 0.0): + # Add PredictedLDDTHead after StructureModule executes. + name = 'predicted_lddt' + # Feed all previous results to give access to structure_module result. + head_config, module = heads[name] + ret[name] = module(representations, batch, is_training) + if compute_loss: + total_loss += loss(module, head_config, ret, name, filter_ret=False) + + if ('predicted_aligned_error' in self.config.heads + and self.config.heads.get('predicted_aligned_error.weight', 0.0)): + # Add PredictedAlignedErrorHead after StructureModule executes. + name = 'predicted_aligned_error' + # Feed all previous results to give access to structure_module result. + head_config, module = heads[name] + ret[name] = module(representations, batch, is_training) + if compute_loss: + total_loss += loss(module, head_config, ret, name, filter_ret=False) + + if compute_loss: + return ret, total_loss + else: + return ret + +class AlphaFold(hk.Module): + """AlphaFold model with recycling. + + Jumper et al. (2021) Suppl. Alg. 2 "Inference" + """ + + def __init__(self, config, name='alphafold'): + super().__init__(name=name) + self.config = config + self.global_config = config.global_config + + def __call__( + self, + batch, + is_training, + compute_loss=False, + ensemble_representations=False, + return_representations=False): + """Run the AlphaFold model. + + Arguments: + batch: Dictionary with inputs to the AlphaFold model. + is_training: Whether the system is in training or inference mode. + compute_loss: Whether to compute losses (requires extra features + to be present in the batch and knowing the true structure). + ensemble_representations: Whether to use ensembling of representations. + return_representations: Whether to also return the intermediate + representations. + + Returns: + When compute_loss is True: + a tuple of loss and output of AlphaFoldIteration. + When compute_loss is False: + just output of AlphaFoldIteration. + + The output of AlphaFoldIteration is a nested dictionary containing + predictions from the various heads. + """ + if "scale_rate" not in batch: + batch["scale_rate"] = jnp.ones((1,)) + impl = AlphaFoldIteration(self.config, self.global_config) + if jnp.issubdtype(batch['aatype'].dtype, jnp.integer): + batch_size, num_residues = batch['aatype'].shape + else: + batch_size, num_residues, _ = batch['aatype'].shape + + def get_prev(ret): + new_prev = { + 'prev_msa_first_row': ret['representations']['msa_first_row'], + 'prev_pair': ret['representations']['pair'], + 'prev_dgram': ret["distogram"]["logits"], + } + if self.config.use_struct: + new_prev.update({'prev_pos': ret['structure_module']['final_atom_positions'], + 'prev_plddt': ret["predicted_lddt"]["logits"]}) + + if "predicted_aligned_error" in ret: + new_prev["prev_pae"] = ret["predicted_aligned_error"]["logits"] + + if not self.config.backprop_recycle: + for k in ["prev_pos","prev_msa_first_row","prev_pair"]: + if k in new_prev: + new_prev[k] = jax.lax.stop_gradient(new_prev[k]) + + return new_prev + + def do_call(prev, + recycle_idx, + compute_loss=compute_loss): + if self.config.resample_msa_in_recycling: + num_ensemble = batch_size // (self.config.num_recycle + 1) + def slice_recycle_idx(x): + start = recycle_idx * num_ensemble + size = num_ensemble + return jax.lax.dynamic_slice_in_dim(x, start, size, axis=0) + ensembled_batch = jax.tree_map(slice_recycle_idx, batch) + else: + num_ensemble = batch_size + ensembled_batch = batch + non_ensembled_batch = jax.tree_map(lambda x: x, prev) + + return impl(ensembled_batch=ensembled_batch, + non_ensembled_batch=non_ensembled_batch, + is_training=is_training, + compute_loss=compute_loss, + ensemble_representations=ensemble_representations) + + + emb_config = self.config.embeddings_and_evoformer + prev = { + 'prev_msa_first_row': jnp.zeros([num_residues, emb_config.msa_channel]), + 'prev_pair': jnp.zeros([num_residues, num_residues, emb_config.pair_channel]), + 'prev_dgram': jnp.zeros([num_residues, num_residues, 64]), + } + if self.config.use_struct: + prev.update({'prev_pos': jnp.zeros([num_residues, residue_constants.atom_type_num, 3]), + 'prev_plddt': jnp.zeros([num_residues, 50]), + 'prev_pae': jnp.zeros([num_residues, num_residues, 64])}) + + for k in ["pos","msa_first_row","pair","dgram"]: + if f"init_{k}" in batch: prev[f"prev_{k}"] = batch[f"init_{k}"][0] + + if self.config.num_recycle: + if 'num_iter_recycling' in batch: + # Training time: num_iter_recycling is in batch. + # The value for each ensemble batch is the same, so arbitrarily taking + # 0-th. + num_iter = batch['num_iter_recycling'][0] + + # Add insurance that we will not run more + # recyclings than the model is configured to run. + num_iter = jnp.minimum(num_iter, self.config.num_recycle) + else: + # Eval mode or tests: use the maximum number of iterations. + num_iter = self.config.num_recycle + + def add_prev(p,p_): + p_["prev_dgram"] += p["prev_dgram"] + if self.config.use_struct: + p_["prev_plddt"] += p["prev_plddt"] + p_["prev_pae"] += p["prev_pae"] + return p_ + + ############################################################## + def body(p, i): + p_ = get_prev(do_call(p, recycle_idx=i, compute_loss=False)) + if self.config.add_prev: + p_ = add_prev(p, p_) + return p_, None + if hk.running_init(): + prev,_ = body(prev, 0) + else: + prev,_ = hk.scan(body, prev, jnp.arange(num_iter)) + ############################################################## + + else: + num_iter = 0 + + ret = do_call(prev=prev, recycle_idx=num_iter) + if self.config.add_prev: + prev_ = get_prev(ret) + if compute_loss: + ret = ret[0], [ret[1]] + + if not return_representations: + del (ret[0] if compute_loss else ret)['representations'] # pytype: disable=unsupported-operands + + if self.config.add_prev and num_iter > 0: + prev_ = add_prev(prev, prev_) + ret["distogram"]["logits"] = prev_["prev_dgram"]/(num_iter+1) + if self.config.use_struct: + ret["predicted_lddt"]["logits"] = prev_["prev_plddt"]/(num_iter+1) + if "predicted_aligned_error" in ret: + ret["predicted_aligned_error"]["logits"] = prev_["prev_pae"]/(num_iter+1) + + return ret + +class TemplatePairStack(hk.Module): + """Pair stack for the templates. + + Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack" + """ + + def __init__(self, config, global_config, name='template_pair_stack'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, pair_act, pair_mask, is_training, safe_key=None, scale_rate=1.0): + """Builds TemplatePairStack module. + + Arguments: + pair_act: Pair activations for single template, shape [N_res, N_res, c_t]. + pair_mask: Pair mask, shape [N_res, N_res]. + is_training: Whether the module is in training mode. + safe_key: Safe key object encapsulating the random number generation key. + + Returns: + Updated pair_act, shape [N_res, N_res, c_t]. + """ + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + gc = self.global_config + c = self.config + + if not c.num_block: + return pair_act + + def block(x): + """One block of the template pair stack.""" + pair_act, safe_key = x + + dropout_wrapper_fn = functools.partial( + dropout_wrapper, is_training=is_training, global_config=gc, scale_rate=scale_rate) + + safe_key, *sub_keys = safe_key.split(6) + sub_keys = iter(sub_keys) + + pair_act = dropout_wrapper_fn( + TriangleAttention(c.triangle_attention_starting_node, gc, + name='triangle_attention_starting_node'), + pair_act, + pair_mask, + next(sub_keys)) + pair_act = dropout_wrapper_fn( + TriangleAttention(c.triangle_attention_ending_node, gc, + name='triangle_attention_ending_node'), + pair_act, + pair_mask, + next(sub_keys)) + pair_act = dropout_wrapper_fn( + TriangleMultiplication(c.triangle_multiplication_outgoing, gc, + name='triangle_multiplication_outgoing'), + pair_act, + pair_mask, + next(sub_keys)) + pair_act = dropout_wrapper_fn( + TriangleMultiplication(c.triangle_multiplication_incoming, gc, + name='triangle_multiplication_incoming'), + pair_act, + pair_mask, + next(sub_keys)) + pair_act = dropout_wrapper_fn( + Transition(c.pair_transition, gc, name='pair_transition'), + pair_act, + pair_mask, + next(sub_keys)) + + return pair_act, safe_key + + if gc.use_remat: + block = hk.remat(block) + + res_stack = layer_stack.layer_stack(c.num_block)(block) + pair_act, safe_key = res_stack((pair_act, safe_key)) + return pair_act + + +class Transition(hk.Module): + """Transition layer. + + Jumper et al. (2021) Suppl. Alg. 9 "MSATransition" + Jumper et al. (2021) Suppl. Alg. 15 "PairTransition" + """ + + def __init__(self, config, global_config, name='transition_block'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, act, mask, is_training=True): + """Builds Transition module. + + Arguments: + act: A tensor of queries of size [batch_size, N_res, N_channel]. + mask: A tensor denoting the mask of size [batch_size, N_res]. + is_training: Whether the module is in training mode. + + Returns: + A float32 tensor of size [batch_size, N_res, N_channel]. + """ + _, _, nc = act.shape + + num_intermediate = int(nc * self.config.num_intermediate_factor) + mask = jnp.expand_dims(mask, axis=-1) + + act = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='input_layer_norm')( + act) + + transition_module = hk.Sequential([ + common_modules.Linear( + num_intermediate, + initializer='relu', + name='transition1'), jax.nn.relu, + common_modules.Linear( + nc, + initializer=utils.final_init(self.global_config), + name='transition2') + ]) + + act = mapping.inference_subbatch( + transition_module, + self.global_config.subbatch_size, + batched_args=[act], + nonbatched_args=[], + low_memory=not is_training) + + return act + + +def glorot_uniform(): + return hk.initializers.VarianceScaling(scale=1.0, + mode='fan_avg', + distribution='uniform') + + +class Attention(hk.Module): + """Multihead attention.""" + + def __init__(self, config, global_config, output_dim, name='attention'): + super().__init__(name=name) + + self.config = config + self.global_config = global_config + self.output_dim = output_dim + + def __call__(self, q_data, m_data, bias, nonbatched_bias=None): + """Builds Attention module. + + Arguments: + q_data: A tensor of queries, shape [batch_size, N_queries, q_channels]. + m_data: A tensor of memories from which the keys and values are + projected, shape [batch_size, N_keys, m_channels]. + bias: A bias for the attention, shape [batch_size, N_queries, N_keys]. + nonbatched_bias: Shared bias, shape [N_queries, N_keys]. + + Returns: + A float32 tensor of shape [batch_size, N_queries, output_dim]. + """ + # Sensible default for when the config keys are missing + key_dim = self.config.get('key_dim', int(q_data.shape[-1])) + value_dim = self.config.get('value_dim', int(m_data.shape[-1])) + num_head = self.config.num_head + assert key_dim % num_head == 0 + assert value_dim % num_head == 0 + key_dim = key_dim // num_head + value_dim = value_dim // num_head + + q_weights = hk.get_parameter( + 'query_w', shape=(q_data.shape[-1], num_head, key_dim), + init=glorot_uniform()) + k_weights = hk.get_parameter( + 'key_w', shape=(m_data.shape[-1], num_head, key_dim), + init=glorot_uniform()) + v_weights = hk.get_parameter( + 'value_w', shape=(m_data.shape[-1], num_head, value_dim), + init=glorot_uniform()) + + q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5) + k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights) + v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights) + logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) + bias + if nonbatched_bias is not None: + logits += jnp.expand_dims(nonbatched_bias, axis=0) + weights = jax.nn.softmax(logits) + weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v) + + if self.global_config.zero_init: + init = hk.initializers.Constant(0.0) + else: + init = glorot_uniform() + + if self.config.gating: + gating_weights = hk.get_parameter( + 'gating_w', + shape=(q_data.shape[-1], num_head, value_dim), + init=hk.initializers.Constant(0.0)) + gating_bias = hk.get_parameter( + 'gating_b', + shape=(num_head, value_dim), + init=hk.initializers.Constant(1.0)) + + gate_values = jnp.einsum('bqc, chv->bqhv', q_data, + gating_weights) + gating_bias + + gate_values = jax.nn.sigmoid(gate_values) + + weighted_avg *= gate_values + + o_weights = hk.get_parameter( + 'output_w', shape=(num_head, value_dim, self.output_dim), + init=init) + o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), + init=hk.initializers.Constant(0.0)) + + output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias + + return output + + +class GlobalAttention(hk.Module): + """Global attention. + + Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" lines 2-7 + """ + + def __init__(self, config, global_config, output_dim, name='attention'): + super().__init__(name=name) + + self.config = config + self.global_config = global_config + self.output_dim = output_dim + + def __call__(self, q_data, m_data, q_mask, bias): + """Builds GlobalAttention module. + + Arguments: + q_data: A tensor of queries with size [batch_size, N_queries, + q_channels] + m_data: A tensor of memories from which the keys and values + projected. Size [batch_size, N_keys, m_channels] + q_mask: A binary mask for q_data with zeros in the padded sequence + elements and ones otherwise. Size [batch_size, N_queries, q_channels] + (or broadcastable to this shape). + bias: A bias for the attention. + + Returns: + A float32 tensor of size [batch_size, N_queries, output_dim]. + """ + # Sensible default for when the config keys are missing + key_dim = self.config.get('key_dim', int(q_data.shape[-1])) + value_dim = self.config.get('value_dim', int(m_data.shape[-1])) + num_head = self.config.num_head + assert key_dim % num_head == 0 + assert value_dim % num_head == 0 + key_dim = key_dim // num_head + value_dim = value_dim // num_head + + q_weights = hk.get_parameter( + 'query_w', shape=(q_data.shape[-1], num_head, key_dim), + init=glorot_uniform()) + k_weights = hk.get_parameter( + 'key_w', shape=(m_data.shape[-1], key_dim), + init=glorot_uniform()) + v_weights = hk.get_parameter( + 'value_w', shape=(m_data.shape[-1], value_dim), + init=glorot_uniform()) + + v = jnp.einsum('bka,ac->bkc', m_data, v_weights) + + q_avg = utils.mask_mean(q_mask, q_data, axis=1) + + q = jnp.einsum('ba,ahc->bhc', q_avg, q_weights) * key_dim**(-0.5) + k = jnp.einsum('bka,ac->bkc', m_data, k_weights) + bias = (1e9 * (q_mask[:, None, :, 0] - 1.)) + logits = jnp.einsum('bhc,bkc->bhk', q, k) + bias + weights = jax.nn.softmax(logits) + weighted_avg = jnp.einsum('bhk,bkc->bhc', weights, v) + + if self.global_config.zero_init: + init = hk.initializers.Constant(0.0) + else: + init = glorot_uniform() + + o_weights = hk.get_parameter( + 'output_w', shape=(num_head, value_dim, self.output_dim), + init=init) + o_bias = hk.get_parameter('output_b', shape=(self.output_dim,), + init=hk.initializers.Constant(0.0)) + + if self.config.gating: + gating_weights = hk.get_parameter( + 'gating_w', + shape=(q_data.shape[-1], num_head, value_dim), + init=hk.initializers.Constant(0.0)) + gating_bias = hk.get_parameter( + 'gating_b', + shape=(num_head, value_dim), + init=hk.initializers.Constant(1.0)) + + gate_values = jnp.einsum('bqc, chv->bqhv', q_data, gating_weights) + gate_values = jax.nn.sigmoid(gate_values + gating_bias) + weighted_avg = weighted_avg[:, None] * gate_values + output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias + else: + output = jnp.einsum('bhc,hco->bo', weighted_avg, o_weights) + o_bias + output = output[:, None] + return output + + +class MSARowAttentionWithPairBias(hk.Module): + """MSA per-row attention biased by the pair representation. + + Jumper et al. (2021) Suppl. Alg. 7 "MSARowAttentionWithPairBias" + """ + + def __init__(self, config, global_config, + name='msa_row_attention_with_pair_bias'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + msa_act, + msa_mask, + pair_act, + is_training=False): + """Builds MSARowAttentionWithPairBias module. + + Arguments: + msa_act: [N_seq, N_res, c_m] MSA representation. + msa_mask: [N_seq, N_res] mask of non-padded regions. + pair_act: [N_res, N_res, c_z] pair representation. + is_training: Whether the module is in training mode. + + Returns: + Update to msa_act, shape [N_seq, N_res, c_m]. + """ + c = self.config + + assert len(msa_act.shape) == 3 + assert len(msa_mask.shape) == 2 + assert c.orientation == 'per_row' + + bias = (1e9 * (msa_mask - 1.))[:, None, None, :] + assert len(bias.shape) == 4 + + msa_act = hk.LayerNorm( + axis=[-1], create_scale=True, create_offset=True, name='query_norm')( + msa_act) + + pair_act = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='feat_2d_norm')( + pair_act) + + init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1])) + weights = hk.get_parameter( + 'feat_2d_weights', + shape=(pair_act.shape[-1], c.num_head), + init=hk.initializers.RandomNormal(stddev=init_factor)) + nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) + + attn_mod = Attention( + c, self.global_config, msa_act.shape[-1]) + msa_act = mapping.inference_subbatch( + attn_mod, + self.global_config.subbatch_size, + batched_args=[msa_act, msa_act, bias], + nonbatched_args=[nonbatched_bias], + low_memory=not is_training) + + return msa_act + + +class MSAColumnAttention(hk.Module): + """MSA per-column attention. + + Jumper et al. (2021) Suppl. Alg. 8 "MSAColumnAttention" + """ + + def __init__(self, config, global_config, name='msa_column_attention'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + msa_act, + msa_mask, + is_training=False): + """Builds MSAColumnAttention module. + + Arguments: + msa_act: [N_seq, N_res, c_m] MSA representation. + msa_mask: [N_seq, N_res] mask of non-padded regions. + is_training: Whether the module is in training mode. + + Returns: + Update to msa_act, shape [N_seq, N_res, c_m] + """ + c = self.config + + assert len(msa_act.shape) == 3 + assert len(msa_mask.shape) == 2 + assert c.orientation == 'per_column' + + msa_act = jnp.swapaxes(msa_act, -2, -3) + msa_mask = jnp.swapaxes(msa_mask, -1, -2) + + bias = (1e9 * (msa_mask - 1.))[:, None, None, :] + assert len(bias.shape) == 4 + + msa_act = hk.LayerNorm( + axis=[-1], create_scale=True, create_offset=True, name='query_norm')( + msa_act) + + attn_mod = Attention( + c, self.global_config, msa_act.shape[-1]) + msa_act = mapping.inference_subbatch( + attn_mod, + self.global_config.subbatch_size, + batched_args=[msa_act, msa_act, bias], + nonbatched_args=[], + low_memory=not is_training) + + msa_act = jnp.swapaxes(msa_act, -2, -3) + + return msa_act + + +class MSAColumnGlobalAttention(hk.Module): + """MSA per-column global attention. + + Jumper et al. (2021) Suppl. Alg. 19 "MSAColumnGlobalAttention" + """ + + def __init__(self, config, global_config, name='msa_column_global_attention'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + msa_act, + msa_mask, + is_training=False): + """Builds MSAColumnGlobalAttention module. + + Arguments: + msa_act: [N_seq, N_res, c_m] MSA representation. + msa_mask: [N_seq, N_res] mask of non-padded regions. + is_training: Whether the module is in training mode. + + Returns: + Update to msa_act, shape [N_seq, N_res, c_m]. + """ + c = self.config + + assert len(msa_act.shape) == 3 + assert len(msa_mask.shape) == 2 + assert c.orientation == 'per_column' + + msa_act = jnp.swapaxes(msa_act, -2, -3) + msa_mask = jnp.swapaxes(msa_mask, -1, -2) + + bias = (1e9 * (msa_mask - 1.))[:, None, None, :] + assert len(bias.shape) == 4 + + msa_act = hk.LayerNorm( + axis=[-1], create_scale=True, create_offset=True, name='query_norm')( + msa_act) + + attn_mod = GlobalAttention( + c, self.global_config, msa_act.shape[-1], + name='attention') + # [N_seq, N_res, 1] + msa_mask = jnp.expand_dims(msa_mask, axis=-1) + msa_act = mapping.inference_subbatch( + attn_mod, + self.global_config.subbatch_size, + batched_args=[msa_act, msa_act, msa_mask, bias], + nonbatched_args=[], + low_memory=not is_training) + + msa_act = jnp.swapaxes(msa_act, -2, -3) + + return msa_act + + +class TriangleAttention(hk.Module): + """Triangle Attention. + + Jumper et al. (2021) Suppl. Alg. 13 "TriangleAttentionStartingNode" + Jumper et al. (2021) Suppl. Alg. 14 "TriangleAttentionEndingNode" + """ + + def __init__(self, config, global_config, name='triangle_attention'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, pair_act, pair_mask, is_training=False): + """Builds TriangleAttention module. + + Arguments: + pair_act: [N_res, N_res, c_z] pair activations tensor + pair_mask: [N_res, N_res] mask of non-padded regions in the tensor. + is_training: Whether the module is in training mode. + + Returns: + Update to pair_act, shape [N_res, N_res, c_z]. + """ + c = self.config + + assert len(pair_act.shape) == 3 + assert len(pair_mask.shape) == 2 + assert c.orientation in ['per_row', 'per_column'] + + if c.orientation == 'per_column': + pair_act = jnp.swapaxes(pair_act, -2, -3) + pair_mask = jnp.swapaxes(pair_mask, -1, -2) + + bias = (1e9 * (pair_mask - 1.))[:, None, None, :] + assert len(bias.shape) == 4 + + pair_act = hk.LayerNorm( + axis=[-1], create_scale=True, create_offset=True, name='query_norm')( + pair_act) + + init_factor = 1. / jnp.sqrt(int(pair_act.shape[-1])) + weights = hk.get_parameter( + 'feat_2d_weights', + shape=(pair_act.shape[-1], c.num_head), + init=hk.initializers.RandomNormal(stddev=init_factor)) + nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) + + attn_mod = Attention( + c, self.global_config, pair_act.shape[-1]) + pair_act = mapping.inference_subbatch( + attn_mod, + self.global_config.subbatch_size, + batched_args=[pair_act, pair_act, bias], + nonbatched_args=[nonbatched_bias], + low_memory=not is_training) + + if c.orientation == 'per_column': + pair_act = jnp.swapaxes(pair_act, -2, -3) + + return pair_act + + +class MaskedMsaHead(hk.Module): + """Head to predict MSA at the masked locations. + + The MaskedMsaHead employs a BERT-style objective to reconstruct a masked + version of the full MSA, based on a linear projection of + the MSA representation. + Jumper et al. (2021) Suppl. Sec. 1.9.9 "Masked MSA prediction" + """ + + def __init__(self, config, global_config, name='masked_msa_head'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, representations, batch, is_training): + """Builds MaskedMsaHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'msa': MSA representation, shape [N_seq, N_res, c_m]. + batch: Batch, unused. + is_training: Whether the module is in training mode. + + Returns: + Dictionary containing: + * 'logits': logits of shape [N_seq, N_res, N_aatype] with + (unnormalized) log probabilies of predicted aatype at position. + """ + del batch + logits = common_modules.Linear( + self.config.num_output, + initializer=utils.final_init(self.global_config), + name='logits')( + representations['msa']) + return dict(logits=logits) + + def loss(self, value, batch): + errors = softmax_cross_entropy( + labels=jax.nn.one_hot(batch['true_msa'], num_classes=23), + logits=value['logits']) + loss = (jnp.sum(errors * batch['bert_mask'], axis=(-2, -1)) / + (1e-8 + jnp.sum(batch['bert_mask'], axis=(-2, -1)))) + return {'loss': loss} + + +class PredictedLDDTHead(hk.Module): + """Head to predict the per-residue LDDT to be used as a confidence measure. + + Jumper et al. (2021) Suppl. Sec. 1.9.6 "Model confidence prediction (pLDDT)" + Jumper et al. (2021) Suppl. Alg. 29 "predictPerResidueLDDT_Ca" + """ + + def __init__(self, config, global_config, name='predicted_lddt_head'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, representations, batch, is_training): + """Builds ExperimentallyResolvedHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'structure_module': Single representation from the structure module, + shape [N_res, c_s]. + batch: Batch, unused. + is_training: Whether the module is in training mode. + + Returns: + Dictionary containing : + * 'logits': logits of shape [N_res, N_bins] with + (unnormalized) log probabilies of binned predicted lDDT. + """ + act = representations['structure_module'] + + act = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='input_layer_norm')( + act) + + act = common_modules.Linear( + self.config.num_channels, + initializer='relu', + name='act_0')( + act) + act = jax.nn.relu(act) + + act = common_modules.Linear( + self.config.num_channels, + initializer='relu', + name='act_1')( + act) + act = jax.nn.relu(act) + + logits = common_modules.Linear( + self.config.num_bins, + initializer=utils.final_init(self.global_config), + name='logits')( + act) + # Shape (batch_size, num_res, num_bins) + return dict(logits=logits) + + def loss(self, value, batch): + # Shape (num_res, 37, 3) + pred_all_atom_pos = value['structure_module']['final_atom_positions'] + # Shape (num_res, 37, 3) + true_all_atom_pos = batch['all_atom_positions'] + # Shape (num_res, 37) + all_atom_mask = batch['all_atom_mask'] + + # Shape (num_res,) + lddt_ca = lddt.lddt( + # Shape (batch_size, num_res, 3) + predicted_points=pred_all_atom_pos[None, :, 1, :], + # Shape (batch_size, num_res, 3) + true_points=true_all_atom_pos[None, :, 1, :], + # Shape (batch_size, num_res, 1) + true_points_mask=all_atom_mask[None, :, 1:2].astype(jnp.float32), + cutoff=15., + per_residue=True)[0] + lddt_ca = jax.lax.stop_gradient(lddt_ca) + + num_bins = self.config.num_bins + bin_index = jnp.floor(lddt_ca * num_bins).astype(jnp.int32) + + # protect against out of range for lddt_ca == 1 + bin_index = jnp.minimum(bin_index, num_bins - 1) + lddt_ca_one_hot = jax.nn.one_hot(bin_index, num_classes=num_bins) + + # Shape (num_res, num_channel) + logits = value['predicted_lddt']['logits'] + errors = softmax_cross_entropy(labels=lddt_ca_one_hot, logits=logits) + + # Shape (num_res,) + mask_ca = all_atom_mask[:, residue_constants.atom_order['CA']] + mask_ca = mask_ca.astype(jnp.float32) + loss = jnp.sum(errors * mask_ca) / (jnp.sum(mask_ca) + 1e-8) + + if self.config.filter_by_resolution: + # NMR & distillation have resolution = 0 + loss *= ((batch['resolution'] >= self.config.min_resolution) + & (batch['resolution'] <= self.config.max_resolution)).astype( + jnp.float32) + + output = {'loss': loss} + return output + + +class PredictedAlignedErrorHead(hk.Module): + """Head to predict the distance errors in the backbone alignment frames. + + Can be used to compute predicted TM-Score. + Jumper et al. (2021) Suppl. Sec. 1.9.7 "TM-score prediction" + """ + + def __init__(self, config, global_config, + name='predicted_aligned_error_head'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, representations, batch, is_training): + """Builds PredictedAlignedErrorHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'pair': pair representation, shape [N_res, N_res, c_z]. + batch: Batch, unused. + is_training: Whether the module is in training mode. + + Returns: + Dictionary containing: + * logits: logits for aligned error, shape [N_res, N_res, N_bins]. + * bin_breaks: array containing bin breaks, shape [N_bins - 1]. + """ + + act = representations['pair'] + + # Shape (num_res, num_res, num_bins) + logits = common_modules.Linear( + self.config.num_bins, + initializer=utils.final_init(self.global_config), + name='logits')(act) + # Shape (num_bins,) + breaks = jnp.linspace( + 0., self.config.max_error_bin, self.config.num_bins - 1) + return dict(logits=logits, breaks=breaks) + + def loss(self, value, batch): + # Shape (num_res, 7) + predicted_affine = quat_affine.QuatAffine.from_tensor( + value['structure_module']['final_affines']) + # Shape (num_res, 7) + true_affine = quat_affine.QuatAffine.from_tensor( + batch['backbone_affine_tensor']) + # Shape (num_res) + mask = batch['backbone_affine_mask'] + # Shape (num_res, num_res) + square_mask = mask[:, None] * mask[None, :] + num_bins = self.config.num_bins + # (1, num_bins - 1) + breaks = value['predicted_aligned_error']['breaks'] + # (1, num_bins) + logits = value['predicted_aligned_error']['logits'] + + # Compute the squared error for each alignment. + def _local_frame_points(affine): + points = [jnp.expand_dims(x, axis=-2) for x in affine.translation] + return affine.invert_point(points, extra_dims=1) + error_dist2_xyz = [ + jnp.square(a - b) + for a, b in zip(_local_frame_points(predicted_affine), + _local_frame_points(true_affine))] + error_dist2 = sum(error_dist2_xyz) + # Shape (num_res, num_res) + # First num_res are alignment frames, second num_res are the residues. + error_dist2 = jax.lax.stop_gradient(error_dist2) + + sq_breaks = jnp.square(breaks) + true_bins = jnp.sum(( + error_dist2[..., None] > sq_breaks).astype(jnp.int32), axis=-1) + + errors = softmax_cross_entropy( + labels=jax.nn.one_hot(true_bins, num_bins, axis=-1), logits=logits) + + loss = (jnp.sum(errors * square_mask, axis=(-2, -1)) / + (1e-8 + jnp.sum(square_mask, axis=(-2, -1)))) + + if self.config.filter_by_resolution: + # NMR & distillation have resolution = 0 + loss *= ((batch['resolution'] >= self.config.min_resolution) + & (batch['resolution'] <= self.config.max_resolution)).astype( + jnp.float32) + + output = {'loss': loss} + return output + + +class ExperimentallyResolvedHead(hk.Module): + """Predicts if an atom is experimentally resolved in a high-res structure. + + Only trained on high-resolution X-ray crystals & cryo-EM. + Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction' + """ + + def __init__(self, config, global_config, + name='experimentally_resolved_head'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, representations, batch, is_training): + """Builds ExperimentallyResolvedHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'single': Single representation, shape [N_res, c_s]. + batch: Batch, unused. + is_training: Whether the module is in training mode. + + Returns: + Dictionary containing: + * 'logits': logits of shape [N_res, 37], + log probability that an atom is resolved in atom37 representation, + can be converted to probability by applying sigmoid. + """ + logits = common_modules.Linear( + 37, # atom_exists.shape[-1] + initializer=utils.final_init(self.global_config), + name='logits')(representations['single']) + return dict(logits=logits) + + def loss(self, value, batch): + logits = value['logits'] + assert len(logits.shape) == 2 + + # Does the atom appear in the amino acid? + atom_exists = batch['atom37_atom_exists'] + # Is the atom resolved in the experiment? Subset of atom_exists, + # *except for OXT* + all_atom_mask = batch['all_atom_mask'].astype(jnp.float32) + + xent = sigmoid_cross_entropy(labels=all_atom_mask, logits=logits) + loss = jnp.sum(xent * atom_exists) / (1e-8 + jnp.sum(atom_exists)) + + if self.config.filter_by_resolution: + # NMR & distillation examples have resolution = 0. + loss *= ((batch['resolution'] >= self.config.min_resolution) + & (batch['resolution'] <= self.config.max_resolution)).astype( + jnp.float32) + + output = {'loss': loss} + return output + + +class TriangleMultiplication(hk.Module): + """Triangle multiplication layer ("outgoing" or "incoming"). + + Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing" + Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming" + """ + + def __init__(self, config, global_config, name='triangle_multiplication'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, act, mask, is_training=True): + """Builds TriangleMultiplication module. + + Arguments: + act: Pair activations, shape [N_res, N_res, c_z] + mask: Pair mask, shape [N_res, N_res]. + is_training: Whether the module is in training mode. + + Returns: + Outputs, same shape/type as act. + """ + del is_training + c = self.config + gc = self.global_config + + mask = mask[..., None] + + act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True, + name='layer_norm_input')(act) + input_act = act + + left_projection = common_modules.Linear( + c.num_intermediate_channel, + name='left_projection') + left_proj_act = mask * left_projection(act) + + right_projection = common_modules.Linear( + c.num_intermediate_channel, + name='right_projection') + right_proj_act = mask * right_projection(act) + + left_gate_values = jax.nn.sigmoid(common_modules.Linear( + c.num_intermediate_channel, + bias_init=1., + initializer=utils.final_init(gc), + name='left_gate')(act)) + + right_gate_values = jax.nn.sigmoid(common_modules.Linear( + c.num_intermediate_channel, + bias_init=1., + initializer=utils.final_init(gc), + name='right_gate')(act)) + + left_proj_act *= left_gate_values + right_proj_act *= right_gate_values + + # "Outgoing" edges equation: 'ikc,jkc->ijc' + # "Incoming" edges equation: 'kjc,kic->ijc' + # Note on the Suppl. Alg. 11 & 12 notation: + # For the "outgoing" edges, a = left_proj_act and b = right_proj_act + # For the "incoming" edges, it's swapped: + # b = left_proj_act and a = right_proj_act + act = jnp.einsum(c.equation, left_proj_act, right_proj_act) + + act = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='center_layer_norm')( + act) + + output_channel = int(input_act.shape[-1]) + + act = common_modules.Linear( + output_channel, + initializer=utils.final_init(gc), + name='output_projection')(act) + + gate_values = jax.nn.sigmoid(common_modules.Linear( + output_channel, + bias_init=1., + initializer=utils.final_init(gc), + name='gating_linear')(input_act)) + act *= gate_values + + return act + + +class DistogramHead(hk.Module): + """Head to predict a distogram. + + Jumper et al. (2021) Suppl. Sec. 1.9.8 "Distogram prediction" + """ + + def __init__(self, config, global_config, name='distogram_head'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, representations, batch, is_training): + """Builds DistogramHead module. + + Arguments: + representations: Dictionary of representations, must contain: + * 'pair': pair representation, shape [N_res, N_res, c_z]. + batch: Batch, unused. + is_training: Whether the module is in training mode. + + Returns: + Dictionary containing: + * logits: logits for distogram, shape [N_res, N_res, N_bins]. + * bin_breaks: array containing bin breaks, shape [N_bins - 1,]. + """ + half_logits = common_modules.Linear( + self.config.num_bins, + initializer=utils.final_init(self.global_config), + name='half_logits')( + representations['pair']) + + logits = half_logits + jnp.swapaxes(half_logits, -2, -3) + breaks = jnp.linspace(self.config.first_break, self.config.last_break, + self.config.num_bins - 1) + + return dict(logits=logits, bin_edges=breaks) + + def loss(self, value, batch): + return _distogram_log_loss(value['logits'], value['bin_edges'], + batch, self.config.num_bins) + + +def _distogram_log_loss(logits, bin_edges, batch, num_bins): + """Log loss of a distogram.""" + + assert len(logits.shape) == 3 + positions = batch['pseudo_beta'] + mask = batch['pseudo_beta_mask'] + + assert positions.shape[-1] == 3 + + sq_breaks = jnp.square(bin_edges) + + dist2 = jnp.sum( + jnp.square( + jnp.expand_dims(positions, axis=-2) - + jnp.expand_dims(positions, axis=-3)), + axis=-1, + keepdims=True) + + true_bins = jnp.sum(dist2 > sq_breaks, axis=-1) + + errors = softmax_cross_entropy( + labels=jax.nn.one_hot(true_bins, num_bins), logits=logits) + + square_mask = jnp.expand_dims(mask, axis=-2) * jnp.expand_dims(mask, axis=-1) + + avg_error = ( + jnp.sum(errors * square_mask, axis=(-2, -1)) / + (1e-6 + jnp.sum(square_mask, axis=(-2, -1)))) + dist2 = dist2[..., 0] + return dict(loss=avg_error, true_dist=jnp.sqrt(1e-6 + dist2)) + + +class OuterProductMean(hk.Module): + """Computes mean outer product. + + Jumper et al. (2021) Suppl. Alg. 10 "OuterProductMean" + """ + + def __init__(self, + config, + global_config, + num_output_channel, + name='outer_product_mean'): + super().__init__(name=name) + self.global_config = global_config + self.config = config + self.num_output_channel = num_output_channel + + def __call__(self, act, mask, is_training=True): + """Builds OuterProductMean module. + + Arguments: + act: MSA representation, shape [N_seq, N_res, c_m]. + mask: MSA mask, shape [N_seq, N_res]. + is_training: Whether the module is in training mode. + + Returns: + Update to pair representation, shape [N_res, N_res, c_z]. + """ + gc = self.global_config + c = self.config + + mask = mask[..., None] + act = hk.LayerNorm([-1], True, True, name='layer_norm_input')(act) + + left_act = mask * common_modules.Linear( + c.num_outer_channel, + initializer='linear', + name='left_projection')( + act) + + right_act = mask * common_modules.Linear( + c.num_outer_channel, + initializer='linear', + name='right_projection')( + act) + + if gc.zero_init: + init_w = hk.initializers.Constant(0.0) + else: + init_w = hk.initializers.VarianceScaling(scale=2., mode='fan_in') + + output_w = hk.get_parameter( + 'output_w', + shape=(c.num_outer_channel, c.num_outer_channel, + self.num_output_channel), + init=init_w) + output_b = hk.get_parameter( + 'output_b', shape=(self.num_output_channel,), + init=hk.initializers.Constant(0.0)) + + def compute_chunk(left_act): + # This is equivalent to + # + # act = jnp.einsum('abc,ade->dceb', left_act, right_act) + # act = jnp.einsum('dceb,cef->bdf', act, output_w) + output_b + # + # but faster. + left_act = jnp.transpose(left_act, [0, 2, 1]) + act = jnp.einsum('acb,ade->dceb', left_act, right_act) + act = jnp.einsum('dceb,cef->dbf', act, output_w) + output_b + return jnp.transpose(act, [1, 0, 2]) + + act = mapping.inference_subbatch( + compute_chunk, + c.chunk_size, + batched_args=[left_act], + nonbatched_args=[], + low_memory=True, + input_subbatch_dim=1, + output_subbatch_dim=0) + + epsilon = 1e-3 + norm = jnp.einsum('abc,adc->bdc', mask, mask) + act /= epsilon + norm + + return act + +def dgram_from_positions(positions, num_bins, min_bin, max_bin): + """Compute distogram from amino acid positions. + Arguments: + positions: [N_res, 3] Position coordinates. + num_bins: The number of bins in the distogram. + min_bin: The left edge of the first bin. + max_bin: The left edge of the final bin. The final bin catches + everything larger than `max_bin`. + Returns: + Distogram with the specified number of bins. + """ + def squared_difference(x, y): + return jnp.square(x - y) + + lower_breaks = jnp.linspace(min_bin, max_bin, num_bins) + lower_breaks = jnp.square(lower_breaks) + upper_breaks = jnp.concatenate([lower_breaks[1:],jnp.array([1e8], dtype=jnp.float32)], axis=-1) + dist2 = jnp.sum( + squared_difference( + jnp.expand_dims(positions, axis=-2), + jnp.expand_dims(positions, axis=-3)), + axis=-1, keepdims=True) + + return ((dist2 > lower_breaks).astype(jnp.float32) * (dist2 < upper_breaks).astype(jnp.float32)) + +def dgram_from_positions_soft(positions, num_bins, min_bin, max_bin, temp=2.0): + '''soft positions to dgram converter''' + lower_breaks = jnp.append(-1e8,jnp.linspace(min_bin, max_bin, num_bins)) + upper_breaks = jnp.append(lower_breaks[1:],1e8) + dist = jnp.sqrt(jnp.square(positions[...,:,None,:] - positions[...,None,:,:]).sum(-1,keepdims=True) + 1e-8) + o = jax.nn.sigmoid((dist - lower_breaks)/temp) * jax.nn.sigmoid((upper_breaks - dist)/temp) + o = o/(o.sum(-1,keepdims=True) + 1e-8) + return o[...,1:] + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): + """Create pseudo beta features.""" + + ca_idx = residue_constants.atom_order['CA'] + cb_idx = residue_constants.atom_order['CB'] + + if jnp.issubdtype(aatype.dtype, jnp.integer): + is_gly = jnp.equal(aatype, residue_constants.restype_order['G']) + is_gly_tile = jnp.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]) + pseudo_beta = jnp.where(is_gly_tile, all_atom_positions[..., ca_idx, :], all_atom_positions[..., cb_idx, :]) + + if all_atom_masks is not None: + pseudo_beta_mask = jnp.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx]) + pseudo_beta_mask = pseudo_beta_mask.astype(jnp.float32) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + else: + is_gly = aatype[...,residue_constants.restype_order['G']] + ca_pos = all_atom_positions[...,ca_idx,:] + cb_pos = all_atom_positions[...,cb_idx,:] + pseudo_beta = is_gly[...,None] * ca_pos + (1-is_gly[...,None]) * cb_pos + if all_atom_masks is not None: + ca_mask = all_atom_masks[...,ca_idx] + cb_mask = all_atom_masks[...,cb_idx] + pseudo_beta_mask = is_gly * ca_mask + (1-is_gly) * cb_mask + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + +class EvoformerIteration(hk.Module): + """Single iteration (block) of Evoformer stack. + Jumper et al. (2021) Suppl. Alg. 6 "EvoformerStack" lines 2-10 + """ + + def __init__(self, config, global_config, is_extra_msa, + name='evoformer_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + self.is_extra_msa = is_extra_msa + + def __call__(self, activations, masks, is_training=True, safe_key=None, scale_rate=1.0): + """Builds EvoformerIteration module. + + Arguments: + activations: Dictionary containing activations: + * 'msa': MSA activations, shape [N_seq, N_res, c_m]. + * 'pair': pair activations, shape [N_res, N_res, c_z]. + masks: Dictionary of masks: + * 'msa': MSA mask, shape [N_seq, N_res]. + * 'pair': pair mask, shape [N_res, N_res]. + is_training: Whether the module is in training mode. + safe_key: prng.SafeKey encapsulating rng key. + + Returns: + Outputs, same shape/type as act. + """ + c = self.config + gc = self.global_config + + msa_act, pair_act = activations['msa'], activations['pair'] + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + msa_mask, pair_mask = masks['msa'], masks['pair'] + + dropout_wrapper_fn = functools.partial( + dropout_wrapper, + is_training=is_training, + global_config=gc, + scale_rate=scale_rate) + + safe_key, *sub_keys = safe_key.split(10) + sub_keys = iter(sub_keys) + + msa_act = dropout_wrapper_fn( + MSARowAttentionWithPairBias( + c.msa_row_attention_with_pair_bias, gc, + name='msa_row_attention_with_pair_bias'), + msa_act, + msa_mask, + safe_key=next(sub_keys), + pair_act=pair_act) + + if not self.is_extra_msa: + attn_mod = MSAColumnAttention( + c.msa_column_attention, gc, name='msa_column_attention') + else: + attn_mod = MSAColumnGlobalAttention( + c.msa_column_attention, gc, name='msa_column_global_attention') + msa_act = dropout_wrapper_fn( + attn_mod, + msa_act, + msa_mask, + safe_key=next(sub_keys)) + + msa_act = dropout_wrapper_fn( + Transition(c.msa_transition, gc, name='msa_transition'), + msa_act, + msa_mask, + safe_key=next(sub_keys)) + + pair_act = dropout_wrapper_fn( + OuterProductMean( + config=c.outer_product_mean, + global_config=self.global_config, + num_output_channel=int(pair_act.shape[-1]), + name='outer_product_mean'), + msa_act, + msa_mask, + safe_key=next(sub_keys), + output_act=pair_act) + + pair_act = dropout_wrapper_fn( + TriangleMultiplication(c.triangle_multiplication_outgoing, gc, + name='triangle_multiplication_outgoing'), + pair_act, + pair_mask, + safe_key=next(sub_keys)) + pair_act = dropout_wrapper_fn( + TriangleMultiplication(c.triangle_multiplication_incoming, gc, + name='triangle_multiplication_incoming'), + pair_act, + pair_mask, + safe_key=next(sub_keys)) + + pair_act = dropout_wrapper_fn( + TriangleAttention(c.triangle_attention_starting_node, gc, + name='triangle_attention_starting_node'), + pair_act, + pair_mask, + safe_key=next(sub_keys)) + pair_act = dropout_wrapper_fn( + TriangleAttention(c.triangle_attention_ending_node, gc, + name='triangle_attention_ending_node'), + pair_act, + pair_mask, + safe_key=next(sub_keys)) + + pair_act = dropout_wrapper_fn( + Transition(c.pair_transition, gc, name='pair_transition'), + pair_act, + pair_mask, + safe_key=next(sub_keys)) + + return {'msa': msa_act, 'pair': pair_act} + + +class EmbeddingsAndEvoformer(hk.Module): + """Embeds the input data and runs Evoformer. + + Produces the MSA, single and pair representations. + Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5-18 + """ + + def __init__(self, config, global_config, name='evoformer'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, batch, is_training, safe_key=None): + + c = self.config + gc = self.global_config + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + # Embed clustered MSA. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 5 + # Jumper et al. (2021) Suppl. Alg. 3 "InputEmbedder" + preprocess_1d = common_modules.Linear( + c.msa_channel, name='preprocess_1d')( + batch['target_feat']) + + preprocess_msa = common_modules.Linear( + c.msa_channel, name='preprocess_msa')( + batch['msa_feat']) + + msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa + + left_single = common_modules.Linear( + c.pair_channel, name='left_single')( + batch['target_feat']) + right_single = common_modules.Linear( + c.pair_channel, name='right_single')( + batch['target_feat']) + pair_activations = left_single[:, None] + right_single[None] + mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] + + # Inject previous outputs for recycling. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 + # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" + + if "prev_pos" in batch: + # use predicted position input + prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) + if c.backprop_dgram: + dgram = dgram_from_positions_soft(prev_pseudo_beta, temp=c.backprop_dgram_temp, **c.prev_pos) + else: + dgram = dgram_from_positions(prev_pseudo_beta, **c.prev_pos) + + elif 'prev_dgram' in batch: + # use predicted distogram input (from Sergey) + dgram = jax.nn.softmax(batch["prev_dgram"]) + dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) + dgram = dgram @ dgram_map + + pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram) + + if c.recycle_features: + if 'prev_msa_first_row' in batch: + prev_msa_first_row = hk.LayerNorm([-1], + True, + True, + name='prev_msa_first_row_norm')( + batch['prev_msa_first_row']) + msa_activations = msa_activations.at[0].add(prev_msa_first_row) + + if 'prev_pair' in batch: + pair_activations += hk.LayerNorm([-1], + True, + True, + name='prev_pair_norm')( + batch['prev_pair']) + + # Relative position encoding. + # Jumper et al. (2021) Suppl. Alg. 4 "relpos" + # Jumper et al. (2021) Suppl. Alg. 5 "one_hot" + if c.max_relative_feature: + # Add one-hot-encoded clipped residue distances to the pair activations. + if "rel_pos" in batch: + rel_pos = batch['rel_pos'] + else: + if "offset" in batch: + offset = batch['offset'] + else: + pos = batch['residue_index'] + offset = pos[:, None] - pos[None, :] + rel_pos = jax.nn.one_hot( + jnp.clip( + offset + c.max_relative_feature, + a_min=0, + a_max=2 * c.max_relative_feature), + 2 * c.max_relative_feature + 1) + pair_activations += common_modules.Linear(c.pair_channel, name='pair_activiations')(rel_pos) + + # Embed templates into the pair activations. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-13 + + if c.template.enabled: + template_batch = {k: batch[k] for k in batch if k.startswith('template_')} + template_pair_representation = TemplateEmbedding(c.template, gc)( + pair_activations, + template_batch, + mask_2d, + is_training=is_training, + scale_rate=batch["scale_rate"]) + + pair_activations += template_pair_representation + + # Embed extra MSA features. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 14-16 + extra_msa_feat = create_extra_msa_feature(batch) + extra_msa_activations = common_modules.Linear( + c.extra_msa_channel, + name='extra_msa_activations')( + extra_msa_feat) + + # Extra MSA Stack. + # Jumper et al. (2021) Suppl. Alg. 18 "ExtraMsaStack" + extra_msa_stack_input = { + 'msa': extra_msa_activations, + 'pair': pair_activations, + } + + extra_msa_stack_iteration = EvoformerIteration( + c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') + + def extra_msa_stack_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + extra_evoformer_output = extra_msa_stack_iteration( + activations=act, + masks={ + 'msa': batch['extra_msa_mask'], + 'pair': mask_2d + }, + is_training=is_training, + safe_key=safe_subkey, scale_rate=batch["scale_rate"]) + return (extra_evoformer_output, safe_key) + + if gc.use_remat: + extra_msa_stack_fn = hk.remat(extra_msa_stack_fn) + + extra_msa_stack = layer_stack.layer_stack( + c.extra_msa_stack_num_block)( + extra_msa_stack_fn) + extra_msa_output, safe_key = extra_msa_stack( + (extra_msa_stack_input, safe_key)) + + pair_activations = extra_msa_output['pair'] + + evoformer_input = { + 'msa': msa_activations, + 'pair': pair_activations, + } + + evoformer_masks = {'msa': batch['msa_mask'], 'pair': mask_2d} + + #################################################################### + #################################################################### + + # Append num_templ rows to msa_activations with template embeddings. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 7-8 + if c.template.enabled and c.template.embed_torsion_angles: + if jnp.issubdtype(batch['template_aatype'].dtype, jnp.integer): + num_templ, num_res = batch['template_aatype'].shape + # Embed the templates aatypes. + aatype = batch['template_aatype'] + aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1) + else: + num_templ, num_res, _ = batch['template_aatype'].shape + aatype = batch['template_aatype'].argmax(-1) + aatype_one_hot = batch['template_aatype'] + + # Embed the templates aatype, torsion angles and masks. + # Shape (templates, residues, msa_channels) + ret = all_atom.atom37_to_torsion_angles( + aatype=aatype, + all_atom_pos=batch['template_all_atom_positions'], + all_atom_mask=batch['template_all_atom_masks'], + # Ensure consistent behaviour during testing: + placeholder_for_undefined=not gc.zero_init) + + template_features = jnp.concatenate([ + aatype_one_hot, + jnp.reshape(ret['torsion_angles_sin_cos'], [num_templ, num_res, 14]), + jnp.reshape(ret['alt_torsion_angles_sin_cos'], [num_templ, num_res, 14]), + ret['torsion_angles_mask']], axis=-1) + + template_activations = common_modules.Linear( + c.msa_channel, + initializer='relu', + name='template_single_embedding')(template_features) + template_activations = jax.nn.relu(template_activations) + template_activations = common_modules.Linear( + c.msa_channel, + initializer='relu', + name='template_projection')(template_activations) + + # Concatenate the templates to the msa. + evoformer_input['msa'] = jnp.concatenate([evoformer_input['msa'], template_activations], axis=0) + + # Concatenate templates masks to the msa masks. + # Use mask from the psi angle, as it only depends on the backbone atoms + # from a single residue. + torsion_angle_mask = ret['torsion_angles_mask'][:, :, 2] + torsion_angle_mask = torsion_angle_mask.astype(evoformer_masks['msa'].dtype) + evoformer_masks['msa'] = jnp.concatenate([evoformer_masks['msa'], torsion_angle_mask], axis=0) + + #################################################################### + #################################################################### + + # Main trunk of the network + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 17-18 + evoformer_iteration = EvoformerIteration( + c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') + + def evoformer_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + evoformer_output = evoformer_iteration( + activations=act, + masks=evoformer_masks, + is_training=is_training, + safe_key=safe_subkey, scale_rate=batch["scale_rate"]) + return (evoformer_output, safe_key) + + if gc.use_remat: + evoformer_fn = hk.remat(evoformer_fn) + + evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)(evoformer_fn) + evoformer_output, safe_key = evoformer_stack((evoformer_input, safe_key)) + + msa_activations = evoformer_output['msa'] + pair_activations = evoformer_output['pair'] + + single_activations = common_modules.Linear( + c.seq_channel, name='single_activations')(msa_activations[0]) + + num_sequences = batch['msa_feat'].shape[0] + output = { + 'single': single_activations, + 'pair': pair_activations, + # Crop away template rows such that they are not used in MaskedMsaHead. + 'msa': msa_activations[:num_sequences, :, :], + 'msa_first_row': msa_activations[0], + } + + return output + +#################################################################### +#################################################################### +class SingleTemplateEmbedding(hk.Module): + """Embeds a single template. + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9+11 + """ + + def __init__(self, config, global_config, name='single_template_embedding'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, query_embedding, batch, mask_2d, is_training, scale_rate=1.0): + """Build the single template embedding. + Arguments: + query_embedding: Query pair representation, shape [N_res, N_res, c_z]. + batch: A batch of template features (note the template dimension has been + stripped out as this module only runs over a single template). + mask_2d: Padding mask (Note: this doesn't care if a template exists, + unlike the template_pseudo_beta_mask). + is_training: Whether the module is in training mode. + Returns: + A template embedding [N_res, N_res, c_z]. + """ + assert mask_2d.dtype == query_embedding.dtype + dtype = query_embedding.dtype + num_res = batch['template_aatype'].shape[0] + num_channels = (self.config.template_pair_stack + .triangle_attention_ending_node.value_dim) + template_mask = batch['template_pseudo_beta_mask'] + template_mask_2d = template_mask[:, None] * template_mask[None, :] + template_mask_2d = template_mask_2d.astype(dtype) + + if "template_dgram" in batch: + template_dgram = batch["template_dgram"] + else: + if self.config.backprop_dgram: + template_dgram = dgram_from_positions_soft(batch['template_pseudo_beta'], + temp=self.config.backprop_dgram_temp, + **self.config.dgram_features) + else: + template_dgram = dgram_from_positions(batch['template_pseudo_beta'], + **self.config.dgram_features) + template_dgram = template_dgram.astype(dtype) + + to_concat = [template_dgram, template_mask_2d[:, :, None]] + + if jnp.issubdtype(batch['template_aatype'].dtype, jnp.integer): + aatype = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1, dtype=dtype) + else: + aatype = batch['template_aatype'] + + to_concat.append(jnp.tile(aatype[None, :, :], [num_res, 1, 1])) + to_concat.append(jnp.tile(aatype[:, None, :], [1, num_res, 1])) + + # Backbone affine mask: whether the residue has C, CA, N + # (the template mask defined above only considers pseudo CB). + n, ca, c = [residue_constants.atom_order[a] for a in ('N', 'CA', 'C')] + template_mask = ( + batch['template_all_atom_masks'][..., n] * + batch['template_all_atom_masks'][..., ca] * + batch['template_all_atom_masks'][..., c]) + template_mask_2d = template_mask[:, None] * template_mask[None, :] + + # compute unit_vector (not used by default) + if self.config.use_template_unit_vector: + rot, trans = quat_affine.make_transform_from_reference( + n_xyz=batch['template_all_atom_positions'][:, n], + ca_xyz=batch['template_all_atom_positions'][:, ca], + c_xyz=batch['template_all_atom_positions'][:, c]) + affines = quat_affine.QuatAffine( + quaternion=quat_affine.rot_to_quat(rot, unstack_inputs=True), + translation=trans, + rotation=rot, + unstack_inputs=True) + points = [jnp.expand_dims(x, axis=-2) for x in affines.translation] + affine_vec = affines.invert_point(points, extra_dims=1) + inv_distance_scalar = jax.lax.rsqrt(1e-6 + sum([jnp.square(x) for x in affine_vec])) + inv_distance_scalar *= template_mask_2d.astype(inv_distance_scalar.dtype) + unit_vector = [(x * inv_distance_scalar)[..., None] for x in affine_vec] + else: + unit_vector = [jnp.zeros((num_res,num_res,1))] * 3 + + unit_vector = [x.astype(dtype) for x in unit_vector] + to_concat.extend(unit_vector) + + template_mask_2d = template_mask_2d.astype(dtype) + to_concat.append(template_mask_2d[..., None]) + + act = jnp.concatenate(to_concat, axis=-1) + + # Mask out non-template regions so we don't get arbitrary values in the + # distogram for these regions. + act *= template_mask_2d[..., None] + + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 9 + act = common_modules.Linear( + num_channels, + initializer='relu', + name='embedding2d')(act) + + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 11 + act = TemplatePairStack( + self.config.template_pair_stack, self.global_config)(act, mask_2d, is_training, scale_rate=scale_rate) + + act = hk.LayerNorm([-1], True, True, name='output_layer_norm')(act) + return act + + +class TemplateEmbedding(hk.Module): + """Embeds a set of templates. + Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12 + Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention" + """ + + def __init__(self, config, global_config, name='template_embedding'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, query_embedding, template_batch, mask_2d, is_training, scale_rate=1.0): + """Build TemplateEmbedding module. + Arguments: + query_embedding: Query pair representation, shape [N_res, N_res, c_z]. + template_batch: A batch of template features. + mask_2d: Padding mask (Note: this doesn't care if a template exists, + unlike the template_pseudo_beta_mask). + is_training: Whether the module is in training mode. + Returns: + A template embedding [N_res, N_res, c_z]. + """ + + num_templates = template_batch['template_mask'].shape[0] + num_channels = (self.config.template_pair_stack + .triangle_attention_ending_node.value_dim) + num_res = query_embedding.shape[0] + + dtype = query_embedding.dtype + template_mask = template_batch['template_mask'] + template_mask = template_mask.astype(dtype) + + query_num_channels = query_embedding.shape[-1] + + # Make sure the weights are shared across templates by constructing the + # embedder here. + # Jumper et al. (2021) Suppl. Alg. 2 "Inference" lines 9-12 + template_embedder = SingleTemplateEmbedding(self.config, self.global_config) + + def map_fn(batch): + return template_embedder(query_embedding, batch, mask_2d, is_training, scale_rate=scale_rate) + + template_pair_representation = mapping.sharded_map(map_fn, in_axes=0)(template_batch) + + # Cross attend from the query to the templates along the residue + # dimension by flattening everything else into the batch dimension. + # Jumper et al. (2021) Suppl. Alg. 17 "TemplatePointwiseAttention" + flat_query = jnp.reshape(query_embedding,[num_res * num_res, 1, query_num_channels]) + + flat_templates = jnp.reshape( + jnp.transpose(template_pair_representation, [1, 2, 0, 3]), + [num_res * num_res, num_templates, num_channels]) + + bias = (1e9 * (template_mask[None, None, None, :] - 1.)) + + template_pointwise_attention_module = Attention( + self.config.attention, self.global_config, query_num_channels) + nonbatched_args = [bias] + batched_args = [flat_query, flat_templates] + + embedding = mapping.inference_subbatch( + template_pointwise_attention_module, + self.config.subbatch_size, + batched_args=batched_args, + nonbatched_args=nonbatched_args, + low_memory=not is_training) + embedding = jnp.reshape(embedding,[num_res, num_res, query_num_channels]) + + # No gradients if no templates. + embedding *= (jnp.sum(template_mask) > 0.).astype(embedding.dtype) + + return embedding +#################################################################### diff --git a/af_backprop/alphafold/model/prng.py b/af_backprop/alphafold/model/prng.py new file mode 100644 index 0000000000000000000000000000000000000000..9f50d4c0e3f186817f04c7eb3a7850f4dbad256f --- /dev/null +++ b/af_backprop/alphafold/model/prng.py @@ -0,0 +1,70 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collection of utilities surrounding PRNG usage in protein folding.""" + +import haiku as hk +import jax + +def safe_dropout(*, tensor, safe_key, rate, is_deterministic, is_training): + """Applies dropout to a tensor.""" + if is_training and not is_deterministic: + keep_rate = 1.0 - rate + keep = jax.random.bernoulli(safe_key.get(), keep_rate, shape=tensor.shape) + return keep * tensor / keep_rate + else: + return tensor + +class SafeKey: + """Safety wrapper for PRNG keys.""" + + def __init__(self, key): + self._key = key + self._used = False + + def _assert_not_used(self): + if self._used: + raise RuntimeError('Random key has been used previously.') + + def get(self): + self._assert_not_used() + self._used = True + return self._key + + def split(self, num_keys=2): + self._assert_not_used() + self._used = True + new_keys = jax.random.split(self._key, num_keys) + return jax.tree_map(SafeKey, tuple(new_keys)) + + def duplicate(self, num_keys=2): + self._assert_not_used() + self._used = True + return tuple(SafeKey(self._key) for _ in range(num_keys)) + + +def _safe_key_flatten(safe_key): + # Flatten transfers "ownership" to the tree + return (safe_key._key,), safe_key._used # pylint: disable=protected-access + + +def _safe_key_unflatten(aux_data, children): + ret = SafeKey(children[0]) + ret._used = aux_data # pylint: disable=protected-access + return ret + + +jax.tree_util.register_pytree_node( + SafeKey, _safe_key_flatten, _safe_key_unflatten) + diff --git a/af_backprop/alphafold/model/quat_affine.py b/af_backprop/alphafold/model/quat_affine.py new file mode 100644 index 0000000000000000000000000000000000000000..9ebcd20f3e2948c905242dc3e09df6684b99ace7 --- /dev/null +++ b/af_backprop/alphafold/model/quat_affine.py @@ -0,0 +1,459 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quaternion geometry modules. + +This introduces a representation of coordinate frames that is based around a +‘QuatAffine’ object. This object describes an array of coordinate frames. +It consists of vectors corresponding to the +origin of the frames as well as orientations which are stored in two +ways, as unit quaternions as well as a rotation matrices. +The rotation matrices are derived from the unit quaternions and the two are kept +in sync. +For an explanation of the relation between unit quaternions and rotations see +https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation + +This representation is used in the model for the backbone frames. + +One important thing to note here, is that while we update both representations +the jit compiler is going to ensure that only the parts that are +actually used are executed. +""" + + +import functools +from typing import Tuple + +import jax +import jax.numpy as jnp +import numpy as np + +# pylint: disable=bad-whitespace +QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32) + +QUAT_TO_ROT[0, 0] = [[ 1, 0, 0], [ 0, 1, 0], [ 0, 0, 1]] # rr +QUAT_TO_ROT[1, 1] = [[ 1, 0, 0], [ 0,-1, 0], [ 0, 0,-1]] # ii +QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [ 0, 1, 0], [ 0, 0,-1]] # jj +QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [ 0,-1, 0], [ 0, 0, 1]] # kk + +QUAT_TO_ROT[1, 2] = [[ 0, 2, 0], [ 2, 0, 0], [ 0, 0, 0]] # ij +QUAT_TO_ROT[1, 3] = [[ 0, 0, 2], [ 0, 0, 0], [ 2, 0, 0]] # ik +QUAT_TO_ROT[2, 3] = [[ 0, 0, 0], [ 0, 0, 2], [ 0, 2, 0]] # jk + +QUAT_TO_ROT[0, 1] = [[ 0, 0, 0], [ 0, 0,-2], [ 0, 2, 0]] # ir +QUAT_TO_ROT[0, 2] = [[ 0, 0, 2], [ 0, 0, 0], [-2, 0, 0]] # jr +QUAT_TO_ROT[0, 3] = [[ 0,-2, 0], [ 2, 0, 0], [ 0, 0, 0]] # kr + +QUAT_MULTIPLY = np.zeros((4, 4, 4), dtype=np.float32) +QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0], + [ 0,-1, 0, 0], + [ 0, 0,-1, 0], + [ 0, 0, 0,-1]] + +QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0], + [ 1, 0, 0, 0], + [ 0, 0, 0, 1], + [ 0, 0,-1, 0]] + +QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0], + [ 0, 0, 0,-1], + [ 1, 0, 0, 0], + [ 0, 1, 0, 0]] + +QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1], + [ 0, 0, 1, 0], + [ 0,-1, 0, 0], + [ 1, 0, 0, 0]] + +QUAT_MULTIPLY_BY_VEC = QUAT_MULTIPLY[:, 1:, :] +# pylint: enable=bad-whitespace + + +def rot_to_quat(rot, unstack_inputs=False): + """Convert rotation matrix to quaternion. + + Note that this function calls self_adjoint_eig which is extremely expensive on + the GPU. If at all possible, this function should run on the CPU. + + Args: + rot: rotation matrix (see below for format). + unstack_inputs: If true, rotation matrix should be shape (..., 3, 3) + otherwise the rotation matrix should be a list of lists of tensors. + + Returns: + Quaternion as (..., 4) tensor. + """ + if unstack_inputs: + rot = [jnp.moveaxis(x, -1, 0) for x in jnp.moveaxis(rot, -2, 0)] + + [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot + + # pylint: disable=bad-whitespace + k = [[ xx + yy + zz, zy - yz, xz - zx, yx - xy,], + [ zy - yz, xx - yy - zz, xy + yx, xz + zx,], + [ xz - zx, xy + yx, yy - xx - zz, yz + zy,], + [ yx - xy, xz + zx, yz + zy, zz - xx - yy,]] + # pylint: enable=bad-whitespace + + k = (1./3.) * jnp.stack([jnp.stack(x, axis=-1) for x in k], + axis=-2) + + # Get eigenvalues in non-decreasing order and associated. + _, qs = jnp.linalg.eigh(k) + return qs[..., -1] + + +def rot_list_to_tensor(rot_list): + """Convert list of lists to rotation tensor.""" + return jnp.stack( + [jnp.stack(rot_list[0], axis=-1), + jnp.stack(rot_list[1], axis=-1), + jnp.stack(rot_list[2], axis=-1)], + axis=-2) + + +def vec_list_to_tensor(vec_list): + """Convert list to vector tensor.""" + return jnp.stack(vec_list, axis=-1) + + +def quat_to_rot(normalized_quat): + """Convert a normalized quaternion to a rotation matrix.""" + rot_tensor = jnp.sum( + np.reshape(QUAT_TO_ROT, (4, 4, 9)) * + normalized_quat[..., :, None, None] * + normalized_quat[..., None, :, None], + axis=(-3, -2)) + rot = jnp.moveaxis(rot_tensor, -1, 0) # Unstack. + return [[rot[0], rot[1], rot[2]], + [rot[3], rot[4], rot[5]], + [rot[6], rot[7], rot[8]]] + + +def quat_multiply_by_vec(quat, vec): + """Multiply a quaternion by a pure-vector quaternion.""" + return jnp.sum( + QUAT_MULTIPLY_BY_VEC * + quat[..., :, None, None] * + vec[..., None, :, None], + axis=(-3, -2)) + + +def quat_multiply(quat1, quat2): + """Multiply a quaternion by another quaternion.""" + return jnp.sum( + QUAT_MULTIPLY * + quat1[..., :, None, None] * + quat2[..., None, :, None], + axis=(-3, -2)) + + +def apply_rot_to_vec(rot, vec, unstack=False): + """Multiply rotation matrix by a vector.""" + if unstack: + x, y, z = [vec[:, i] for i in range(3)] + else: + x, y, z = vec + return [rot[0][0] * x + rot[0][1] * y + rot[0][2] * z, + rot[1][0] * x + rot[1][1] * y + rot[1][2] * z, + rot[2][0] * x + rot[2][1] * y + rot[2][2] * z] + + +def apply_inverse_rot_to_vec(rot, vec): + """Multiply the inverse of a rotation matrix by a vector.""" + # Inverse rotation is just transpose + return [rot[0][0] * vec[0] + rot[1][0] * vec[1] + rot[2][0] * vec[2], + rot[0][1] * vec[0] + rot[1][1] * vec[1] + rot[2][1] * vec[2], + rot[0][2] * vec[0] + rot[1][2] * vec[1] + rot[2][2] * vec[2]] + + +class QuatAffine(object): + """Affine transformation represented by quaternion and vector.""" + + def __init__(self, quaternion, translation, rotation=None, normalize=True, + unstack_inputs=False): + """Initialize from quaternion and translation. + + Args: + quaternion: Rotation represented by a quaternion, to be applied + before translation. Must be a unit quaternion unless normalize==True. + translation: Translation represented as a vector. + rotation: Same rotation as the quaternion, represented as a (..., 3, 3) + tensor. If None, rotation will be calculated from the quaternion. + normalize: If True, l2 normalize the quaternion on input. + unstack_inputs: If True, translation is a vector with last component 3 + """ + + if quaternion is not None: + assert quaternion.shape[-1] == 4 + + if unstack_inputs: + if rotation is not None: + rotation = [jnp.moveaxis(x, -1, 0) # Unstack. + for x in jnp.moveaxis(rotation, -2, 0)] # Unstack. + translation = jnp.moveaxis(translation, -1, 0) # Unstack. + + if normalize and quaternion is not None: + quaternion = quaternion / jnp.linalg.norm(quaternion, axis=-1, + keepdims=True) + + if rotation is None: + rotation = quat_to_rot(quaternion) + + self.quaternion = quaternion + self.rotation = [list(row) for row in rotation] + self.translation = list(translation) + + assert all(len(row) == 3 for row in self.rotation) + assert len(self.translation) == 3 + + def to_tensor(self): + return jnp.concatenate( + [self.quaternion] + + [jnp.expand_dims(x, axis=-1) for x in self.translation], + axis=-1) + + def apply_tensor_fn(self, tensor_fn): + """Return a new QuatAffine with tensor_fn applied (e.g. stop_gradient).""" + return QuatAffine( + tensor_fn(self.quaternion), + [tensor_fn(x) for x in self.translation], + rotation=[[tensor_fn(x) for x in row] for row in self.rotation], + normalize=False) + + def apply_rotation_tensor_fn(self, tensor_fn): + """Return a new QuatAffine with tensor_fn applied to the rotation part.""" + return QuatAffine( + tensor_fn(self.quaternion), + [x for x in self.translation], + rotation=[[tensor_fn(x) for x in row] for row in self.rotation], + normalize=False) + + def scale_translation(self, position_scale): + """Return a new quat affine with a different scale for translation.""" + + return QuatAffine( + self.quaternion, + [x * position_scale for x in self.translation], + rotation=[[x for x in row] for row in self.rotation], + normalize=False) + + @classmethod + def from_tensor(cls, tensor, normalize=False): + quaternion, tx, ty, tz = jnp.split(tensor, [4, 5, 6], axis=-1) + return cls(quaternion, + [tx[..., 0], ty[..., 0], tz[..., 0]], + normalize=normalize) + + def pre_compose(self, update): + """Return a new QuatAffine which applies the transformation update first. + + Args: + update: Length-6 vector. 3-vector of x, y, and z such that the quaternion + update is (1, x, y, z) and zero for the 3-vector is the identity + quaternion. 3-vector for translation concatenated. + + Returns: + New QuatAffine object. + """ + vector_quaternion_update, x, y, z = jnp.split(update, [3, 4, 5], axis=-1) + trans_update = [jnp.squeeze(x, axis=-1), + jnp.squeeze(y, axis=-1), + jnp.squeeze(z, axis=-1)] + + new_quaternion = (self.quaternion + + quat_multiply_by_vec(self.quaternion, + vector_quaternion_update)) + + trans_update = apply_rot_to_vec(self.rotation, trans_update) + new_translation = [ + self.translation[0] + trans_update[0], + self.translation[1] + trans_update[1], + self.translation[2] + trans_update[2]] + + return QuatAffine(new_quaternion, new_translation) + + def apply_to_point(self, point, extra_dims=0): + """Apply affine to a point. + + Args: + point: List of 3 tensors to apply affine. + extra_dims: Number of dimensions at the end of the transformed_point + shape that are not present in the rotation and translation. The most + common use is rotation N points at once with extra_dims=1 for use in a + network. + + Returns: + Transformed point after applying affine. + """ + rotation = self.rotation + translation = self.translation + for _ in range(extra_dims): + expand_fn = functools.partial(jnp.expand_dims, axis=-1) + rotation = jax.tree_map(expand_fn, rotation) + translation = jax.tree_map(expand_fn, translation) + + rot_point = apply_rot_to_vec(rotation, point) + return [ + rot_point[0] + translation[0], + rot_point[1] + translation[1], + rot_point[2] + translation[2]] + + def invert_point(self, transformed_point, extra_dims=0): + """Apply inverse of transformation to a point. + + Args: + transformed_point: List of 3 tensors to apply affine + extra_dims: Number of dimensions at the end of the transformed_point + shape that are not present in the rotation and translation. The most + common use is rotation N points at once with extra_dims=1 for use in a + network. + + Returns: + Transformed point after applying affine. + """ + rotation = self.rotation + translation = self.translation + for _ in range(extra_dims): + expand_fn = functools.partial(jnp.expand_dims, axis=-1) + rotation = jax.tree_map(expand_fn, rotation) + translation = jax.tree_map(expand_fn, translation) + + rot_point = [ + transformed_point[0] - translation[0], + transformed_point[1] - translation[1], + transformed_point[2] - translation[2]] + + return apply_inverse_rot_to_vec(rotation, rot_point) + + def __repr__(self): + return 'QuatAffine(%r, %r)' % (self.quaternion, self.translation) + + +def _multiply(a, b): + return jnp.stack([ + jnp.array([a[0][0]*b[0][0] + a[0][1]*b[1][0] + a[0][2]*b[2][0], + a[0][0]*b[0][1] + a[0][1]*b[1][1] + a[0][2]*b[2][1], + a[0][0]*b[0][2] + a[0][1]*b[1][2] + a[0][2]*b[2][2]]), + + jnp.array([a[1][0]*b[0][0] + a[1][1]*b[1][0] + a[1][2]*b[2][0], + a[1][0]*b[0][1] + a[1][1]*b[1][1] + a[1][2]*b[2][1], + a[1][0]*b[0][2] + a[1][1]*b[1][2] + a[1][2]*b[2][2]]), + + jnp.array([a[2][0]*b[0][0] + a[2][1]*b[1][0] + a[2][2]*b[2][0], + a[2][0]*b[0][1] + a[2][1]*b[1][1] + a[2][2]*b[2][1], + a[2][0]*b[0][2] + a[2][1]*b[1][2] + a[2][2]*b[2][2]])]) + + +def make_canonical_transform( + n_xyz: jnp.ndarray, + ca_xyz: jnp.ndarray, + c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Returns translation and rotation matrices to canonicalize residue atoms. + + Note that this method does not take care of symmetries. If you provide the + atom positions in the non-standard way, the N atom will end up not at + [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You + need to take care of such cases in your code. + + Args: + n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. + ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. + c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. + + Returns: + A tuple (translation, rotation) where: + translation is an array of shape [batch, 3] defining the translation. + rotation is an array of shape [batch, 3, 3] defining the rotation. + After applying the translation and rotation to all atoms in a residue: + * All atoms will be shifted so that CA is at the origin, + * All atoms will be rotated so that C is at the x-axis, + * All atoms will be shifted so that N is in the xy plane. + """ + assert len(n_xyz.shape) == 2, n_xyz.shape + assert n_xyz.shape[-1] == 3, n_xyz.shape + assert n_xyz.shape == ca_xyz.shape == c_xyz.shape, ( + n_xyz.shape, ca_xyz.shape, c_xyz.shape) + + # Place CA at the origin. + translation = -ca_xyz + n_xyz = n_xyz + translation + c_xyz = c_xyz + translation + + # Place C on the x-axis. + c_x, c_y, c_z = [c_xyz[:, i] for i in range(3)] + # Rotate by angle c1 in the x-y plane (around the z-axis). + sin_c1 = -c_y / jnp.sqrt(1e-20 + c_x**2 + c_y**2) + cos_c1 = c_x / jnp.sqrt(1e-20 + c_x**2 + c_y**2) + zeros = jnp.zeros_like(sin_c1) + ones = jnp.ones_like(sin_c1) + # pylint: disable=bad-whitespace + c1_rot_matrix = jnp.stack([jnp.array([cos_c1, -sin_c1, zeros]), + jnp.array([sin_c1, cos_c1, zeros]), + jnp.array([zeros, zeros, ones])]) + + # Rotate by angle c2 in the x-z plane (around the y-axis). + sin_c2 = c_z / jnp.sqrt(1e-20 + c_x**2 + c_y**2 + c_z**2) + cos_c2 = jnp.sqrt(c_x**2 + c_y**2) / jnp.sqrt( + 1e-20 + c_x**2 + c_y**2 + c_z**2) + c2_rot_matrix = jnp.stack([jnp.array([cos_c2, zeros, sin_c2]), + jnp.array([zeros, ones, zeros]), + jnp.array([-sin_c2, zeros, cos_c2])]) + + c_rot_matrix = _multiply(c2_rot_matrix, c1_rot_matrix) + n_xyz = jnp.stack(apply_rot_to_vec(c_rot_matrix, n_xyz, unstack=True)).T + + # Place N in the x-y plane. + _, n_y, n_z = [n_xyz[:, i] for i in range(3)] + # Rotate by angle alpha in the y-z plane (around the x-axis). + sin_n = -n_z / jnp.sqrt(1e-20 + n_y**2 + n_z**2) + cos_n = n_y / jnp.sqrt(1e-20 + n_y**2 + n_z**2) + n_rot_matrix = jnp.stack([jnp.array([ones, zeros, zeros]), + jnp.array([zeros, cos_n, -sin_n]), + jnp.array([zeros, sin_n, cos_n])]) + # pylint: enable=bad-whitespace + + return (translation, + jnp.transpose(_multiply(n_rot_matrix, c_rot_matrix), [2, 0, 1])) + + +def make_transform_from_reference( + n_xyz: jnp.ndarray, + ca_xyz: jnp.ndarray, + c_xyz: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Returns rotation and translation matrices to convert from reference. + + Note that this method does not take care of symmetries. If you provide the + atom positions in the non-standard way, the N atom will end up not at + [-0.527250, 1.359329, 0.0] but instead at [-0.527250, -1.359329, 0.0]. You + need to take care of such cases in your code. + + Args: + n_xyz: An array of shape [batch, 3] of nitrogen xyz coordinates. + ca_xyz: An array of shape [batch, 3] of carbon alpha xyz coordinates. + c_xyz: An array of shape [batch, 3] of carbon xyz coordinates. + + Returns: + A tuple (rotation, translation) where: + rotation is an array of shape [batch, 3, 3] defining the rotation. + translation is an array of shape [batch, 3] defining the translation. + After applying the translation and rotation to the reference backbone, + the coordinates will approximately equal to the input coordinates. + + The order of translation and rotation differs from make_canonical_transform + because the rotation from this function should be applied before the + translation, unlike make_canonical_transform. + """ + translation, rotation = make_canonical_transform(n_xyz, ca_xyz, c_xyz) + return np.transpose(rotation, (0, 2, 1)), -translation diff --git a/af_backprop/alphafold/model/r3.py b/af_backprop/alphafold/model/r3.py new file mode 100644 index 0000000000000000000000000000000000000000..1e775ab39e529c6086938adbb1d6c2cd3fb6cc8e --- /dev/null +++ b/af_backprop/alphafold/model/r3.py @@ -0,0 +1,320 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformations for 3D coordinates. + +This Module contains objects for representing Vectors (Vecs), Rotation Matrices +(Rots) and proper Rigid transformation (Rigids). These are represented as +named tuples with arrays for each entry, for example a set of +[N, M] points would be represented as a Vecs object with arrays of shape [N, M] +for x, y and z. + +This is being done to improve readability by making it very clear what objects +are geometric objects rather than relying on comments and array shapes. +Another reason for this is to avoid using matrix +multiplication primitives like matmul or einsum, on modern accelerator hardware +these can end up on specialized cores such as tensor cores on GPU or the MXU on +cloud TPUs, this often involves lower computational precision which can be +problematic for coordinate geometry. Also these cores are typically optimized +for larger matrices than 3 dimensional, this code is written to avoid any +unintended use of these cores on both GPUs and TPUs. +""" + +import collections +from typing import List +from alphafold.model import quat_affine +import jax.numpy as jnp +import tree + +# Array of 3-component vectors, stored as individual array for +# each component. +Vecs = collections.namedtuple('Vecs', ['x', 'y', 'z']) + +# Array of 3x3 rotation matrices, stored as individual array for +# each component. +Rots = collections.namedtuple('Rots', ['xx', 'xy', 'xz', + 'yx', 'yy', 'yz', + 'zx', 'zy', 'zz']) +# Array of rigid 3D transformations, stored as array of rotations and +# array of translations. +Rigids = collections.namedtuple('Rigids', ['rot', 'trans']) + + +def squared_difference(x, y): + return jnp.square(x - y) + + +def invert_rigids(r: Rigids) -> Rigids: + """Computes group inverse of rigid transformations 'r'.""" + inv_rots = invert_rots(r.rot) + t = rots_mul_vecs(inv_rots, r.trans) + inv_trans = Vecs(-t.x, -t.y, -t.z) + return Rigids(inv_rots, inv_trans) + + +def invert_rots(m: Rots) -> Rots: + """Computes inverse of rotations 'm'.""" + return Rots(m.xx, m.yx, m.zx, + m.xy, m.yy, m.zy, + m.xz, m.yz, m.zz) + + +def rigids_from_3_points( + point_on_neg_x_axis: Vecs, # shape (...) + origin: Vecs, # shape (...) + point_on_xy_plane: Vecs, # shape (...) +) -> Rigids: # shape (...) + """Create Rigids from 3 points. + + Jumper et al. (2021) Suppl. Alg. 21 "rigidFrom3Points" + This creates a set of rigid transformations from 3 points by Gram Schmidt + orthogonalization. + + Args: + point_on_neg_x_axis: Vecs corresponding to points on the negative x axis + origin: Origin of resulting rigid transformations + point_on_xy_plane: Vecs corresponding to points in the xy plane + Returns: + Rigid transformations from global frame to local frames derived from + the input points. + """ + m = rots_from_two_vecs( + e0_unnormalized=vecs_sub(origin, point_on_neg_x_axis), + e1_unnormalized=vecs_sub(point_on_xy_plane, origin)) + + return Rigids(rot=m, trans=origin) + + +def rigids_from_list(l: List[jnp.ndarray]) -> Rigids: + """Converts flat list of arrays to rigid transformations.""" + assert len(l) == 12 + return Rigids(Rots(*(l[:9])), Vecs(*(l[9:]))) + + +def rigids_from_quataffine(a: quat_affine.QuatAffine) -> Rigids: + """Converts QuatAffine object to the corresponding Rigids object.""" + return Rigids(Rots(*tree.flatten(a.rotation)), + Vecs(*a.translation)) + + +def rigids_from_tensor4x4( + m: jnp.ndarray # shape (..., 4, 4) +) -> Rigids: # shape (...) + """Construct Rigids object from an 4x4 array. + + Here the 4x4 is representing the transformation in homogeneous coordinates. + + Args: + m: Array representing transformations in homogeneous coordinates. + Returns: + Rigids object corresponding to transformations m + """ + assert m.shape[-1] == 4 + assert m.shape[-2] == 4 + return Rigids( + Rots(m[..., 0, 0], m[..., 0, 1], m[..., 0, 2], + m[..., 1, 0], m[..., 1, 1], m[..., 1, 2], + m[..., 2, 0], m[..., 2, 1], m[..., 2, 2]), + Vecs(m[..., 0, 3], m[..., 1, 3], m[..., 2, 3])) + + +def rigids_from_tensor_flat9( + m: jnp.ndarray # shape (..., 9) +) -> Rigids: # shape (...) + """Flat9 encoding: first two columns of rotation matrix + translation.""" + assert m.shape[-1] == 9 + e0 = Vecs(m[..., 0], m[..., 1], m[..., 2]) + e1 = Vecs(m[..., 3], m[..., 4], m[..., 5]) + trans = Vecs(m[..., 6], m[..., 7], m[..., 8]) + return Rigids(rot=rots_from_two_vecs(e0, e1), + trans=trans) + + +def rigids_from_tensor_flat12( + m: jnp.ndarray # shape (..., 12) +) -> Rigids: # shape (...) + """Flat12 encoding: rotation matrix (9 floats) + translation (3 floats).""" + assert m.shape[-1] == 12 + x = jnp.moveaxis(m, -1, 0) # Unstack + return Rigids(Rots(*x[:9]), Vecs(*x[9:])) + + +def rigids_mul_rigids(a: Rigids, b: Rigids) -> Rigids: + """Group composition of Rigids 'a' and 'b'.""" + return Rigids( + rots_mul_rots(a.rot, b.rot), + vecs_add(a.trans, rots_mul_vecs(a.rot, b.trans))) + + +def rigids_mul_rots(r: Rigids, m: Rots) -> Rigids: + """Compose rigid transformations 'r' with rotations 'm'.""" + return Rigids(rots_mul_rots(r.rot, m), r.trans) + + +def rigids_mul_vecs(r: Rigids, v: Vecs) -> Vecs: + """Apply rigid transforms 'r' to points 'v'.""" + return vecs_add(rots_mul_vecs(r.rot, v), r.trans) + + +def rigids_to_list(r: Rigids) -> List[jnp.ndarray]: + """Turn Rigids into flat list, inverse of 'rigids_from_list'.""" + return list(r.rot) + list(r.trans) + + +def rigids_to_quataffine(r: Rigids) -> quat_affine.QuatAffine: + """Convert Rigids r into QuatAffine, inverse of 'rigids_from_quataffine'.""" + return quat_affine.QuatAffine( + quaternion=None, + rotation=[[r.rot.xx, r.rot.xy, r.rot.xz], + [r.rot.yx, r.rot.yy, r.rot.yz], + [r.rot.zx, r.rot.zy, r.rot.zz]], + translation=[r.trans.x, r.trans.y, r.trans.z]) + + +def rigids_to_tensor_flat9( + r: Rigids # shape (...) +) -> jnp.ndarray: # shape (..., 9) + """Flat9 encoding: first two columns of rotation matrix + translation.""" + return jnp.stack( + [r.rot.xx, r.rot.yx, r.rot.zx, r.rot.xy, r.rot.yy, r.rot.zy] + + list(r.trans), axis=-1) + + +def rigids_to_tensor_flat12( + r: Rigids # shape (...) +) -> jnp.ndarray: # shape (..., 12) + """Flat12 encoding: rotation matrix (9 floats) + translation (3 floats).""" + return jnp.stack(list(r.rot) + list(r.trans), axis=-1) + + +def rots_from_tensor3x3( + m: jnp.ndarray, # shape (..., 3, 3) +) -> Rots: # shape (...) + """Convert rotations represented as (3, 3) array to Rots.""" + assert m.shape[-1] == 3 + assert m.shape[-2] == 3 + return Rots(m[..., 0, 0], m[..., 0, 1], m[..., 0, 2], + m[..., 1, 0], m[..., 1, 1], m[..., 1, 2], + m[..., 2, 0], m[..., 2, 1], m[..., 2, 2]) + + +def rots_from_two_vecs(e0_unnormalized: Vecs, e1_unnormalized: Vecs) -> Rots: + """Create rotation matrices from unnormalized vectors for the x and y-axes. + + This creates a rotation matrix from two vectors using Gram-Schmidt + orthogonalization. + + Args: + e0_unnormalized: vectors lying along x-axis of resulting rotation + e1_unnormalized: vectors lying in xy-plane of resulting rotation + Returns: + Rotations resulting from Gram-Schmidt procedure. + """ + # Normalize the unit vector for the x-axis, e0. + e0 = vecs_robust_normalize(e0_unnormalized) + + # make e1 perpendicular to e0. + c = vecs_dot_vecs(e1_unnormalized, e0) + e1 = Vecs(e1_unnormalized.x - c * e0.x, + e1_unnormalized.y - c * e0.y, + e1_unnormalized.z - c * e0.z) + e1 = vecs_robust_normalize(e1) + + # Compute e2 as cross product of e0 and e1. + e2 = vecs_cross_vecs(e0, e1) + + return Rots(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) + + +def rots_mul_rots(a: Rots, b: Rots) -> Rots: + """Composition of rotations 'a' and 'b'.""" + c0 = rots_mul_vecs(a, Vecs(b.xx, b.yx, b.zx)) + c1 = rots_mul_vecs(a, Vecs(b.xy, b.yy, b.zy)) + c2 = rots_mul_vecs(a, Vecs(b.xz, b.yz, b.zz)) + return Rots(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) + + +def rots_mul_vecs(m: Rots, v: Vecs) -> Vecs: + """Apply rotations 'm' to vectors 'v'.""" + return Vecs(m.xx * v.x + m.xy * v.y + m.xz * v.z, + m.yx * v.x + m.yy * v.y + m.yz * v.z, + m.zx * v.x + m.zy * v.y + m.zz * v.z) + + +def vecs_add(v1: Vecs, v2: Vecs) -> Vecs: + """Add two vectors 'v1' and 'v2'.""" + return Vecs(v1.x + v2.x, v1.y + v2.y, v1.z + v2.z) + + +def vecs_dot_vecs(v1: Vecs, v2: Vecs) -> jnp.ndarray: + """Dot product of vectors 'v1' and 'v2'.""" + return v1.x * v2.x + v1.y * v2.y + v1.z * v2.z + + +def vecs_cross_vecs(v1: Vecs, v2: Vecs) -> Vecs: + """Cross product of vectors 'v1' and 'v2'.""" + return Vecs(v1.y * v2.z - v1.z * v2.y, + v1.z * v2.x - v1.x * v2.z, + v1.x * v2.y - v1.y * v2.x) + + +def vecs_from_tensor(x: jnp.ndarray # shape (..., 3) + ) -> Vecs: # shape (...) + """Converts from tensor of shape (3,) to Vecs.""" + num_components = x.shape[-1] + assert num_components == 3 + return Vecs(x[..., 0], x[..., 1], x[..., 2]) + + +def vecs_robust_normalize(v: Vecs, epsilon: float = 1e-8) -> Vecs: + """Normalizes vectors 'v'. + + Args: + v: vectors to be normalized. + epsilon: small regularizer added to squared norm before taking square root. + Returns: + normalized vectors + """ + norms = vecs_robust_norm(v, epsilon) + return Vecs(v.x / norms, v.y / norms, v.z / norms) + + +def vecs_robust_norm(v: Vecs, epsilon: float = 1e-8) -> jnp.ndarray: + """Computes norm of vectors 'v'. + + Args: + v: vectors to be normalized. + epsilon: small regularizer added to squared norm before taking square root. + Returns: + norm of 'v' + """ + return jnp.sqrt(jnp.square(v.x) + jnp.square(v.y) + jnp.square(v.z) + epsilon) + + +def vecs_sub(v1: Vecs, v2: Vecs) -> Vecs: + """Computes v1 - v2.""" + return Vecs(v1.x - v2.x, v1.y - v2.y, v1.z - v2.z) + + +def vecs_squared_distance(v1: Vecs, v2: Vecs) -> jnp.ndarray: + """Computes squared euclidean difference between 'v1' and 'v2'.""" + return (squared_difference(v1.x, v2.x) + + squared_difference(v1.y, v2.y) + + squared_difference(v1.z, v2.z)) + + +def vecs_to_tensor(v: Vecs # shape (...) + ) -> jnp.ndarray: # shape(..., 3) + """Converts 'v' to tensor with shape 3, inverse of 'vecs_from_tensor'.""" + return jnp.stack([v.x, v.y, v.z], axis=-1) diff --git a/af_backprop/alphafold/model/tf/__init__.py b/af_backprop/alphafold/model/tf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c520687f67754b0488690287f941854c8cf6133 --- /dev/null +++ b/af_backprop/alphafold/model/tf/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Alphafold model TensorFlow code.""" diff --git a/af_backprop/alphafold/model/tf/data_transforms.py b/af_backprop/alphafold/model/tf/data_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..7af966ef4b7cb469f2b817a16ad42eea50f31e18 --- /dev/null +++ b/af_backprop/alphafold/model/tf/data_transforms.py @@ -0,0 +1,625 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data for AlphaFold.""" + +from alphafold.common import residue_constants +from alphafold.model.tf import shape_helpers +from alphafold.model.tf import shape_placeholders +from alphafold.model.tf import utils +import numpy as np +import tensorflow.compat.v1 as tf + +# Pylint gets confused by the curry1 decorator because it changes the number +# of arguments to the function. +# pylint:disable=no-value-for-parameter + + +NUM_RES = shape_placeholders.NUM_RES +NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ +NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ +NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES + + +def cast_64bit_ints(protein): + + for k, v in protein.items(): + if v.dtype == tf.int64: + protein[k] = tf.cast(v, tf.int32) + return protein + + +_MSA_FEATURE_NAMES = [ + 'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', + 'true_msa' +] + + +def make_seq_mask(protein): + protein['seq_mask'] = tf.ones( + shape_helpers.shape_list(protein['aatype']), dtype=tf.float32) + return protein + + +def make_template_mask(protein): + protein['template_mask'] = tf.ones( + shape_helpers.shape_list(protein['template_domain_names']), + dtype=tf.float32) + return protein + + +def curry1(f): + """Supply all arguments but the first.""" + + def fc(*args, **kwargs): + return lambda x: f(x, *args, **kwargs) + + return fc + + +@curry1 +def add_distillation_flag(protein, distillation): + protein['is_distillation'] = tf.constant(float(distillation), + shape=[], + dtype=tf.float32) + return protein + + +def make_all_atom_aatype(protein): + protein['all_atom_aatype'] = protein['aatype'] + return protein + + +def fix_templates_aatype(protein): + """Fixes aatype encoding of templates.""" + # Map one-hot to indices. + protein['template_aatype'] = tf.argmax( + protein['template_aatype'], output_type=tf.int32, axis=-1) + # Map hhsearch-aatype to our aatype. + new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = tf.constant(new_order_list, dtype=tf.int32) + protein['template_aatype'] = tf.gather(params=new_order, + indices=protein['template_aatype']) + return protein + + +def correct_msa_restypes(protein): + """Correct MSA restype to have the same order as residue_constants.""" + new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = tf.constant(new_order_list, dtype=protein['msa'].dtype) + protein['msa'] = tf.gather(new_order, protein['msa'], axis=0) + + perm_matrix = np.zeros((22, 22), dtype=np.float32) + perm_matrix[range(len(new_order_list)), new_order_list] = 1. + + for k in protein: + if 'profile' in k: # Include both hhblits and psiblast profiles + num_dim = protein[k].shape.as_list()[-1] + assert num_dim in [20, 21, 22], ( + 'num_dim for %s out of expected range: %s' % (k, num_dim)) + protein[k] = tf.tensordot(protein[k], perm_matrix[:num_dim, :num_dim], 1) + return protein + + +def squeeze_features(protein): + """Remove singleton and repeated dimensions in protein features.""" + protein['aatype'] = tf.argmax( + protein['aatype'], axis=-1, output_type=tf.int32) + for k in [ + 'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence', + 'superfamily', 'deletion_matrix', 'resolution', + 'between_segment_residues', 'residue_index', 'template_all_atom_masks']: + if k in protein: + final_dim = shape_helpers.shape_list(protein[k])[-1] + if isinstance(final_dim, int) and final_dim == 1: + protein[k] = tf.squeeze(protein[k], axis=-1) + + for k in ['seq_length', 'num_alignments']: + if k in protein: + protein[k] = protein[k][0] # Remove fake sequence dimension + return protein + + +def make_random_crop_to_size_seed(protein): + """Random seed for cropping residues and templates.""" + protein['random_crop_to_size_seed'] = utils.make_random_seed() + return protein + + +@curry1 +def randomly_replace_msa_with_unknown(protein, replace_proportion): + """Replace a proportion of the MSA with 'X'.""" + msa_mask = (tf.random.uniform(shape_helpers.shape_list(protein['msa'])) < + replace_proportion) + x_idx = 20 + gap_idx = 21 + msa_mask = tf.logical_and(msa_mask, protein['msa'] != gap_idx) + protein['msa'] = tf.where(msa_mask, + tf.ones_like(protein['msa']) * x_idx, + protein['msa']) + aatype_mask = ( + tf.random.uniform(shape_helpers.shape_list(protein['aatype'])) < + replace_proportion) + + protein['aatype'] = tf.where(aatype_mask, + tf.ones_like(protein['aatype']) * x_idx, + protein['aatype']) + return protein + + +@curry1 +def sample_msa(protein, max_seq, keep_extra): + """Sample MSA randomly, remaining sequences are stored as `extra_*`. + + Args: + protein: batch to sample msa from. + max_seq: number of sequences to sample. + keep_extra: When True sequences not sampled are put into fields starting + with 'extra_*'. + + Returns: + Protein with sampled msa. + """ + num_seq = tf.shape(protein['msa'])[0] + shuffled = tf.random_shuffle(tf.range(1, num_seq)) + index_order = tf.concat([[0], shuffled], axis=0) + num_sel = tf.minimum(max_seq, num_seq) + + sel_seq, not_sel_seq = tf.split(index_order, [num_sel, num_seq - num_sel]) + + for k in _MSA_FEATURE_NAMES: + if k in protein: + if keep_extra: + protein['extra_' + k] = tf.gather(protein[k], not_sel_seq) + protein[k] = tf.gather(protein[k], sel_seq) + + return protein + + +@curry1 +def crop_extra_msa(protein, max_extra_msa): + """MSA features are cropped so only `max_extra_msa` sequences are kept.""" + num_seq = tf.shape(protein['extra_msa'])[0] + num_sel = tf.minimum(max_extra_msa, num_seq) + select_indices = tf.random_shuffle(tf.range(0, num_seq))[:num_sel] + for k in _MSA_FEATURE_NAMES: + if 'extra_' + k in protein: + protein['extra_' + k] = tf.gather(protein['extra_' + k], select_indices) + + return protein + + +def delete_extra_msa(protein): + for k in _MSA_FEATURE_NAMES: + if 'extra_' + k in protein: + del protein['extra_' + k] + return protein + + +@curry1 +def block_delete_msa(protein, config): + """Sample MSA by deleting contiguous blocks. + + Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion" + + Arguments: + protein: batch dict containing the msa + config: ConfigDict with parameters + + Returns: + updated protein + """ + num_seq = shape_helpers.shape_list(protein['msa'])[0] + block_num_seq = tf.cast( + tf.floor(tf.cast(num_seq, tf.float32) * config.msa_fraction_per_block), + tf.int32) + + if config.randomize_num_blocks: + nb = tf.random.uniform([], 0, config.num_blocks + 1, dtype=tf.int32) + else: + nb = config.num_blocks + + del_block_starts = tf.random.uniform([nb], 0, num_seq, dtype=tf.int32) + del_blocks = del_block_starts[:, None] + tf.range(block_num_seq) + del_blocks = tf.clip_by_value(del_blocks, 0, num_seq - 1) + del_indices = tf.unique(tf.sort(tf.reshape(del_blocks, [-1])))[0] + + # Make sure we keep the original sequence + sparse_diff = tf.sets.difference(tf.range(1, num_seq)[None], + del_indices[None]) + keep_indices = tf.squeeze(tf.sparse.to_dense(sparse_diff), 0) + keep_indices = tf.concat([[0], keep_indices], axis=0) + + for k in _MSA_FEATURE_NAMES: + if k in protein: + protein[k] = tf.gather(protein[k], keep_indices) + + return protein + + +@curry1 +def nearest_neighbor_clusters(protein, gap_agreement_weight=0.): + """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" + + # Determine how much weight we assign to each agreement. In theory, we could + # use a full blosum matrix here, but right now let's just down-weight gap + # agreement because it could be spurious. + # Never put weight on agreeing on BERT mask + weights = tf.concat([ + tf.ones(21), + gap_agreement_weight * tf.ones(1), + np.zeros(1)], 0) + + # Make agreement score as weighted Hamming distance + sample_one_hot = (protein['msa_mask'][:, :, None] * + tf.one_hot(protein['msa'], 23)) + extra_one_hot = (protein['extra_msa_mask'][:, :, None] * + tf.one_hot(protein['extra_msa'], 23)) + + num_seq, num_res, _ = shape_helpers.shape_list(sample_one_hot) + extra_num_seq, _, _ = shape_helpers.shape_list(extra_one_hot) + + # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights) + # in an optimized fashion to avoid possible memory or computation blowup. + agreement = tf.matmul( + tf.reshape(extra_one_hot, [extra_num_seq, num_res * 23]), + tf.reshape(sample_one_hot * weights, [num_seq, num_res * 23]), + transpose_b=True) + + # Assign each sequence in the extra sequences to the closest MSA sample + protein['extra_cluster_assignment'] = tf.argmax( + agreement, axis=1, output_type=tf.int32) + + return protein + + +@curry1 +def summarize_clusters(protein): + """Produce profile and deletion_matrix_mean within each cluster.""" + num_seq = shape_helpers.shape_list(protein['msa'])[0] + def csum(x): + return tf.math.unsorted_segment_sum( + x, protein['extra_cluster_assignment'], num_seq) + + mask = protein['extra_msa_mask'] + mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center + + msa_sum = csum(mask[:, :, None] * tf.one_hot(protein['extra_msa'], 23)) + msa_sum += tf.one_hot(protein['msa'], 23) # Original sequence + protein['cluster_profile'] = msa_sum / mask_counts[:, :, None] + + del msa_sum + + del_sum = csum(mask * protein['extra_deletion_matrix']) + del_sum += protein['deletion_matrix'] # Original sequence + protein['cluster_deletion_mean'] = del_sum / mask_counts + del del_sum + + return protein + + +def make_msa_mask(protein): + """Mask features are all ones, but will later be zero-padded.""" + protein['msa_mask'] = tf.ones( + shape_helpers.shape_list(protein['msa']), dtype=tf.float32) + protein['msa_row_mask'] = tf.ones( + shape_helpers.shape_list(protein['msa'])[0], dtype=tf.float32) + return protein + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): + """Create pseudo beta features.""" + is_gly = tf.equal(aatype, residue_constants.restype_order['G']) + ca_idx = residue_constants.atom_order['CA'] + cb_idx = residue_constants.atom_order['CB'] + pseudo_beta = tf.where( + tf.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :]) + + if all_atom_masks is not None: + pseudo_beta_mask = tf.where( + is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx]) + pseudo_beta_mask = tf.cast(pseudo_beta_mask, tf.float32) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +@curry1 +def make_pseudo_beta(protein, prefix=''): + """Create pseudo-beta (alpha for glycine) position and mask.""" + assert prefix in ['', 'template_'] + protein[prefix + 'pseudo_beta'], protein[prefix + 'pseudo_beta_mask'] = ( + pseudo_beta_fn( + protein['template_aatype' if prefix else 'all_atom_aatype'], + protein[prefix + 'all_atom_positions'], + protein['template_all_atom_masks' if prefix else 'all_atom_mask'])) + return protein + + +@curry1 +def add_constant_field(protein, key, value): + protein[key] = tf.convert_to_tensor(value) + return protein + + +def shaped_categorical(probs, epsilon=1e-10): + ds = shape_helpers.shape_list(probs) + num_classes = ds[-1] + counts = tf.random.categorical( + tf.reshape(tf.log(probs + epsilon), [-1, num_classes]), + 1, + dtype=tf.int32) + return tf.reshape(counts, ds[:-1]) + + +def make_hhblits_profile(protein): + """Compute the HHblits MSA profile if not already present.""" + if 'hhblits_profile' in protein: + return protein + + # Compute the profile for every residue (over all MSA sequences). + protein['hhblits_profile'] = tf.reduce_mean( + tf.one_hot(protein['msa'], 22), axis=0) + return protein + + +@curry1 +def make_masked_msa(protein, config, replace_fraction): + """Create data for BERT on raw MSA.""" + # Add a random amino acid uniformly + random_aa = tf.constant([0.05] * 20 + [0., 0.], dtype=tf.float32) + + categorical_probs = ( + config.uniform_prob * random_aa + + config.profile_prob * protein['hhblits_profile'] + + config.same_prob * tf.one_hot(protein['msa'], 22)) + + # Put all remaining probability on [MASK] which is a new column + pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))] + pad_shapes[-1][1] = 1 + mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob + assert mask_prob >= 0. + categorical_probs = tf.pad( + categorical_probs, pad_shapes, constant_values=mask_prob) + + sh = shape_helpers.shape_list(protein['msa']) + mask_position = tf.random.uniform(sh) < replace_fraction + + bert_msa = shaped_categorical(categorical_probs) + bert_msa = tf.where(mask_position, bert_msa, protein['msa']) + + # Mix real and masked MSA + protein['bert_mask'] = tf.cast(mask_position, tf.float32) + protein['true_msa'] = protein['msa'] + protein['msa'] = bert_msa + + return protein + + +@curry1 +def make_fixed_size(protein, shape_schema, msa_cluster_size, extra_msa_size, + num_res, num_templates=0): + """Guess at the MSA and sequence dimensions to make fixed size.""" + + pad_size_map = { + NUM_RES: num_res, + NUM_MSA_SEQ: msa_cluster_size, + NUM_EXTRA_SEQ: extra_msa_size, + NUM_TEMPLATES: num_templates, + } + + for k, v in protein.items(): + # Don't transfer this to the accelerator. + if k == 'extra_cluster_assignment': + continue + shape = v.shape.as_list() + schema = shape_schema[k] + assert len(shape) == len(schema), ( + f'Rank mismatch between shape and shape schema for {k}: ' + f'{shape} vs {schema}') + pad_size = [ + pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema) + ] + padding = [(0, p - tf.shape(v)[i]) for i, p in enumerate(pad_size)] + if padding: + protein[k] = tf.pad( + v, padding, name=f'pad_to_fixed_{k}') + protein[k].set_shape(pad_size) + + return protein + + +@curry1 +def make_msa_feat(protein): + """Create and concatenate MSA features.""" + # Whether there is a domain break. Always zero for chains, but keeping + # for compatibility with domain datasets. + has_break = tf.clip_by_value( + tf.cast(protein['between_segment_residues'], tf.float32), + 0, 1) + aatype_1hot = tf.one_hot(protein['aatype'], 21, axis=-1) + + target_feat = [ + tf.expand_dims(has_break, axis=-1), + aatype_1hot, # Everyone gets the original sequence. + ] + + msa_1hot = tf.one_hot(protein['msa'], 23, axis=-1) + has_deletion = tf.clip_by_value(protein['deletion_matrix'], 0., 1.) + deletion_value = tf.atan(protein['deletion_matrix'] / 3.) * (2. / np.pi) + + msa_feat = [ + msa_1hot, + tf.expand_dims(has_deletion, axis=-1), + tf.expand_dims(deletion_value, axis=-1), + ] + + if 'cluster_profile' in protein: + deletion_mean_value = ( + tf.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi)) + msa_feat.extend([ + protein['cluster_profile'], + tf.expand_dims(deletion_mean_value, axis=-1), + ]) + + if 'extra_deletion_matrix' in protein: + protein['extra_has_deletion'] = tf.clip_by_value( + protein['extra_deletion_matrix'], 0., 1.) + protein['extra_deletion_value'] = tf.atan( + protein['extra_deletion_matrix'] / 3.) * (2. / np.pi) + + protein['msa_feat'] = tf.concat(msa_feat, axis=-1) + protein['target_feat'] = tf.concat(target_feat, axis=-1) + return protein + + +@curry1 +def select_feat(protein, feature_list): + return {k: v for k, v in protein.items() if k in feature_list} + + +@curry1 +def crop_templates(protein, max_templates): + for k, v in protein.items(): + if k.startswith('template_'): + protein[k] = v[:max_templates] + return protein + + +@curry1 +def random_crop_to_size(protein, crop_size, max_templates, shape_schema, + subsample_templates=False): + """Crop randomly to `crop_size`, or keep as is if shorter than that.""" + seq_length = protein['seq_length'] + if 'template_mask' in protein: + num_templates = tf.cast( + shape_helpers.shape_list(protein['template_mask'])[0], tf.int32) + else: + num_templates = tf.constant(0, dtype=tf.int32) + num_res_crop_size = tf.math.minimum(seq_length, crop_size) + + # Ensures that the cropping of residues and templates happens in the same way + # across ensembling iterations. + # Do not use for randomness that should vary in ensembling. + seed_maker = utils.SeedMaker(initial_seed=protein['random_crop_to_size_seed']) + + if subsample_templates: + templates_crop_start = tf.random.stateless_uniform( + shape=(), minval=0, maxval=num_templates + 1, dtype=tf.int32, + seed=seed_maker()) + else: + templates_crop_start = 0 + + num_templates_crop_size = tf.math.minimum( + num_templates - templates_crop_start, max_templates) + + num_res_crop_start = tf.random.stateless_uniform( + shape=(), minval=0, maxval=seq_length - num_res_crop_size + 1, + dtype=tf.int32, seed=seed_maker()) + + templates_select_indices = tf.argsort(tf.random.stateless_uniform( + [num_templates], seed=seed_maker())) + + for k, v in protein.items(): + if k not in shape_schema or ( + 'template' not in k and NUM_RES not in shape_schema[k]): + continue + + # randomly permute the templates before cropping them. + if k.startswith('template') and subsample_templates: + v = tf.gather(v, templates_select_indices) + + crop_sizes = [] + crop_starts = [] + for i, (dim_size, dim) in enumerate(zip(shape_schema[k], + shape_helpers.shape_list(v))): + is_num_res = (dim_size == NUM_RES) + if i == 0 and k.startswith('template'): + crop_size = num_templates_crop_size + crop_start = templates_crop_start + else: + crop_start = num_res_crop_start if is_num_res else 0 + crop_size = (num_res_crop_size if is_num_res else + (-1 if dim is None else dim)) + crop_sizes.append(crop_size) + crop_starts.append(crop_start) + protein[k] = tf.slice(v, crop_starts, crop_sizes) + + protein['seq_length'] = num_res_crop_size + return protein + + +def make_atom14_masks(protein): + """Construct denser atom positions (14 dimensions instead of 37).""" + restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 + restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 + restype_atom14_mask = [] + + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + + restype_atom14_to_atom37.append([ + (residue_constants.atom_order[name] if name else 0) + for name in atom_names + ]) + + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in residue_constants.atom_types + ]) + + restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) + + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom37_to_atom14.append([0] * 37) + restype_atom14_mask.append([0.] * 14) + + restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) + restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + + # create the mapping for (residx, atom14) --> atom37, i.e. an array + # with shape (num_res, 14) containing the atom37 indices for this protein + residx_atom14_to_atom37 = tf.gather(restype_atom14_to_atom37, + protein['aatype']) + residx_atom14_mask = tf.gather(restype_atom14_mask, + protein['aatype']) + + protein['atom14_atom_exists'] = residx_atom14_mask + protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37 + + # create the gather indices for mapping back + residx_atom37_to_atom14 = tf.gather(restype_atom37_to_atom14, + protein['aatype']) + protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14 + + # create the corresponding mask + restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) + for restype, restype_letter in enumerate(residue_constants.restypes): + restype_name = residue_constants.restype_1to3[restype_letter] + atom_names = residue_constants.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = residue_constants.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + + residx_atom37_mask = tf.gather(restype_atom37_mask, + protein['aatype']) + protein['atom37_atom_exists'] = residx_atom37_mask + + return protein diff --git a/af_backprop/alphafold/model/tf/input_pipeline.py b/af_backprop/alphafold/model/tf/input_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e9a9bc3a8aa15316aa88c3947120883be869331e --- /dev/null +++ b/af_backprop/alphafold/model/tf/input_pipeline.py @@ -0,0 +1,166 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Feature pre-processing input pipeline for AlphaFold.""" + +from alphafold.model.tf import data_transforms +from alphafold.model.tf import shape_placeholders +import tensorflow.compat.v1 as tf +import tree + +# Pylint gets confused by the curry1 decorator because it changes the number +# of arguments to the function. +# pylint:disable=no-value-for-parameter + + +NUM_RES = shape_placeholders.NUM_RES +NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ +NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ +NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES + + +def nonensembled_map_fns(data_config): + """Input pipeline functions which are not ensembled.""" + common_cfg = data_config.common + + map_fns = [ + data_transforms.correct_msa_restypes, + data_transforms.add_distillation_flag(False), + data_transforms.cast_64bit_ints, + data_transforms.squeeze_features, + # Keep to not disrupt RNG. + data_transforms.randomly_replace_msa_with_unknown(0.0), + data_transforms.make_seq_mask, + data_transforms.make_msa_mask, + # Compute the HHblits profile if it's not set. This has to be run before + # sampling the MSA. + data_transforms.make_hhblits_profile, + data_transforms.make_random_crop_to_size_seed, + ] + if common_cfg.use_templates: + map_fns.extend([ + data_transforms.fix_templates_aatype, + data_transforms.make_template_mask, + data_transforms.make_pseudo_beta('template_') + ]) + map_fns.extend([ + data_transforms.make_atom14_masks, + ]) + + return map_fns + + +def ensembled_map_fns(data_config): + """Input pipeline functions that can be ensembled and averaged.""" + common_cfg = data_config.common + eval_cfg = data_config.eval + + map_fns = [] + + if common_cfg.reduce_msa_clusters_by_max_templates: + pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates + else: + pad_msa_clusters = eval_cfg.max_msa_clusters + + max_msa_clusters = pad_msa_clusters + max_extra_msa = common_cfg.max_extra_msa + + map_fns.append( + data_transforms.sample_msa( + max_msa_clusters, + keep_extra=True)) + + if 'masked_msa' in common_cfg: + # Masked MSA should come *before* MSA clustering so that + # the clustering and full MSA profile do not leak information about + # the masked locations and secret corrupted locations. + map_fns.append( + data_transforms.make_masked_msa(common_cfg.masked_msa, + eval_cfg.masked_msa_replace_fraction)) + + if common_cfg.msa_cluster_features: + map_fns.append(data_transforms.nearest_neighbor_clusters()) + map_fns.append(data_transforms.summarize_clusters()) + + # Crop after creating the cluster profiles. + if max_extra_msa: + map_fns.append(data_transforms.crop_extra_msa(max_extra_msa)) + else: + map_fns.append(data_transforms.delete_extra_msa) + + map_fns.append(data_transforms.make_msa_feat()) + + crop_feats = dict(eval_cfg.feat) + + if eval_cfg.fixed_size: + map_fns.append(data_transforms.select_feat(list(crop_feats))) + map_fns.append(data_transforms.random_crop_to_size( + eval_cfg.crop_size, + eval_cfg.max_templates, + crop_feats, + eval_cfg.subsample_templates)) + map_fns.append(data_transforms.make_fixed_size( + crop_feats, + pad_msa_clusters, + common_cfg.max_extra_msa, + eval_cfg.crop_size, + eval_cfg.max_templates)) + else: + map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates)) + + return map_fns + + +def process_tensors_from_config(tensors, data_config): + """Apply filters and maps to an existing dataset, based on the config.""" + + def wrap_ensemble_fn(data, i): + """Function to be mapped over the ensemble dimension.""" + d = data.copy() + fns = ensembled_map_fns(data_config) + fn = compose(fns) + d['ensemble_index'] = i + return fn(d) + + eval_cfg = data_config.eval + tensors = compose( + nonensembled_map_fns( + data_config))( + tensors) + + tensors_0 = wrap_ensemble_fn(tensors, tf.constant(0)) + num_ensemble = eval_cfg.num_ensemble + if data_config.common.resample_msa_in_recycling: + # Separate batch per ensembling & recycling step. + num_ensemble *= data_config.common.num_recycle + 1 + + if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1: + fn_output_signature = tree.map_structure( + tf.TensorSpec.from_tensor, tensors_0) + tensors = tf.map_fn( + lambda x: wrap_ensemble_fn(tensors, x), + tf.range(num_ensemble), + parallel_iterations=1, + fn_output_signature=fn_output_signature) + else: + tensors = tree.map_structure(lambda x: x[None], + tensors_0) + return tensors + + +@data_transforms.curry1 +def compose(x, fs): + for f in fs: + x = f(x) + return x diff --git a/af_backprop/alphafold/model/tf/protein_features.py b/af_backprop/alphafold/model/tf/protein_features.py new file mode 100644 index 0000000000000000000000000000000000000000..c78cfa5ea50baa63f9fabcb3bc7c7b66d10a1fa0 --- /dev/null +++ b/af_backprop/alphafold/model/tf/protein_features.py @@ -0,0 +1,129 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains descriptions of various protein features.""" +import enum +from typing import Dict, Optional, Sequence, Tuple, Union +from alphafold.common import residue_constants +import tensorflow.compat.v1 as tf + +# Type aliases. +FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]] + + +class FeatureType(enum.Enum): + ZERO_DIM = 0 # Shape [x] + ONE_DIM = 1 # Shape [num_res, x] + TWO_DIM = 2 # Shape [num_res, num_res, x] + MSA = 3 # Shape [msa_length, num_res, x] + + +# Placeholder values that will be replaced with their true value at runtime. +NUM_RES = "num residues placeholder" +NUM_SEQ = "length msa placeholder" +NUM_TEMPLATES = "num templates placeholder" +# Sizes of the protein features, NUM_RES and NUM_SEQ are allowed as placeholders +# to be replaced with the number of residues and the number of sequences in the +# multiple sequence alignment, respectively. + + +FEATURES = { + #### Static features of a protein sequence #### + "aatype": (tf.float32, [NUM_RES, 21]), + "between_segment_residues": (tf.int64, [NUM_RES, 1]), + "deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]), + "domain_name": (tf.string, [1]), + "msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]), + "num_alignments": (tf.int64, [NUM_RES, 1]), + "residue_index": (tf.int64, [NUM_RES, 1]), + "seq_length": (tf.int64, [NUM_RES, 1]), + "sequence": (tf.string, [1]), + "all_atom_positions": (tf.float32, + [NUM_RES, residue_constants.atom_type_num, 3]), + "all_atom_mask": (tf.int64, [NUM_RES, residue_constants.atom_type_num]), + "resolution": (tf.float32, [1]), + "template_domain_names": (tf.string, [NUM_TEMPLATES]), + "template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]), + "template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]), + "template_all_atom_positions": (tf.float32, [ + NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3 + ]), + "template_all_atom_masks": (tf.float32, [ + NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1 + ]), +} + +FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()} +FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()} + + +def register_feature(name: str, + type_: tf.dtypes.DType, + shape_: Tuple[Union[str, int]]): + """Register extra features used in custom datasets.""" + FEATURES[name] = (type_, shape_) + FEATURE_TYPES[name] = type_ + FEATURE_SIZES[name] = shape_ + + +def shape(feature_name: str, + num_residues: int, + msa_length: int, + num_templates: Optional[int] = None, + features: Optional[FeaturesMetadata] = None): + """Get the shape for the given feature name. + + This is near identical to _get_tf_shape_no_placeholders() but with 2 + differences: + * This method does not calculate a single placeholder from the total number of + elements (eg given and size := 12, this won't deduce NUM_RES + must be 4) + * This method will work with tensors + + Args: + feature_name: String identifier for the feature. If the feature name ends + with "_unnormalized", this suffix is stripped off. + num_residues: The number of residues in the current domain - some elements + of the shape can be dynamic and will be replaced by this value. + msa_length: The number of sequences in the multiple sequence alignment, some + elements of the shape can be dynamic and will be replaced by this value. + If the number of alignments is unknown / not read, please pass None for + msa_length. + num_templates (optional): The number of templates in this tfexample. + features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES. + + Returns: + List of ints representation the tensor size. + + Raises: + ValueError: If a feature is requested but no concrete placeholder value is + given. + """ + features = features or FEATURES + if feature_name.endswith("_unnormalized"): + feature_name = feature_name[:-13] + + unused_dtype, raw_sizes = features[feature_name] + replacements = {NUM_RES: num_residues, + NUM_SEQ: msa_length} + + if num_templates is not None: + replacements[NUM_TEMPLATES] = num_templates + + sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes] + for dimension in sizes: + if isinstance(dimension, str): + raise ValueError("Could not parse %s (shape: %s) with values: %s" % ( + feature_name, raw_sizes, replacements)) + return sizes diff --git a/af_backprop/alphafold/model/tf/proteins_dataset.py b/af_backprop/alphafold/model/tf/proteins_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b1c038a41c6e276275a7904e748ea9e31e6083 --- /dev/null +++ b/af_backprop/alphafold/model/tf/proteins_dataset.py @@ -0,0 +1,166 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Datasets consisting of proteins.""" +from typing import Dict, Mapping, Optional, Sequence +from alphafold.model.tf import protein_features +import numpy as np +import tensorflow.compat.v1 as tf + +TensorDict = Dict[str, tf.Tensor] + + +def parse_tfexample( + raw_data: bytes, + features: protein_features.FeaturesMetadata, + key: Optional[str] = None) -> Dict[str, tf.train.Feature]: + """Read a single TF Example proto and return a subset of its features. + + Args: + raw_data: A serialized tf.Example proto. + features: A dictionary of features, mapping string feature names to a tuple + (dtype, shape). This dictionary should be a subset of + protein_features.FEATURES (or the dictionary itself for all features). + key: Optional string with the SSTable key of that tf.Example. This will be + added into features as a 'key' but only if requested in features. + + Returns: + A dictionary of features mapping feature names to features. Only the given + features are returned, all other ones are filtered out. + """ + feature_map = { + k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True) + for k, v in features.items() + } + parsed_features = tf.io.parse_single_example(raw_data, feature_map) + reshaped_features = parse_reshape_logic(parsed_features, features, key=key) + + return reshaped_features + + +def _first(tensor: tf.Tensor) -> tf.Tensor: + """Returns the 1st element - the input can be a tensor or a scalar.""" + return tf.reshape(tensor, shape=(-1,))[0] + + +def parse_reshape_logic( + parsed_features: TensorDict, + features: protein_features.FeaturesMetadata, + key: Optional[str] = None) -> TensorDict: + """Transforms parsed serial features to the correct shape.""" + # Find out what is the number of sequences and the number of alignments. + num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32) + + if "num_alignments" in parsed_features: + num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32) + else: + num_msa = 0 + + if "template_domain_names" in parsed_features: + num_templates = tf.cast( + tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32) + else: + num_templates = 0 + + if key is not None and "key" in features: + parsed_features["key"] = [key] # Expand dims from () to (1,). + + # Reshape the tensors according to the sequence length and num alignments. + for k, v in parsed_features.items(): + new_shape = protein_features.shape( + feature_name=k, + num_residues=num_residues, + msa_length=num_msa, + num_templates=num_templates, + features=features) + new_shape_size = tf.constant(1, dtype=tf.int32) + for dim in new_shape: + new_shape_size *= tf.cast(dim, tf.int32) + + assert_equal = tf.assert_equal( + tf.size(v), new_shape_size, + name="assert_%s_shape_correct" % k, + message="The size of feature %s (%s) could not be reshaped " + "into %s" % (k, tf.size(v), new_shape)) + if "template" not in k: + # Make sure the feature we are reshaping is not empty. + assert_non_empty = tf.assert_greater( + tf.size(v), 0, name="assert_%s_non_empty" % k, + message="The feature %s is not set in the tf.Example. Either do not " + "request the feature or use a tf.Example that has the " + "feature set." % k) + with tf.control_dependencies([assert_non_empty, assert_equal]): + parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) + else: + with tf.control_dependencies([assert_equal]): + parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) + + return parsed_features + + +def _make_features_metadata( + feature_names: Sequence[str]) -> protein_features.FeaturesMetadata: + """Makes a feature name to type and shape mapping from a list of names.""" + # Make sure these features are always read. + required_features = ["aatype", "sequence", "seq_length"] + feature_names = list(set(feature_names) | set(required_features)) + + features_metadata = {name: protein_features.FEATURES[name] + for name in feature_names} + return features_metadata + + +def create_tensor_dict( + raw_data: bytes, + features: Sequence[str], + key: Optional[str] = None, + ) -> TensorDict: + """Creates a dictionary of tensor features. + + Args: + raw_data: A serialized tf.Example proto. + features: A list of strings of feature names to be returned in the dataset. + key: Optional string with the SSTable key of that tf.Example. This will be + added into features as a 'key' but only if requested in features. + + Returns: + A dictionary of features mapping feature names to features. Only the given + features are returned, all other ones are filtered out. + """ + features_metadata = _make_features_metadata(features) + return parse_tfexample(raw_data, features_metadata, key) + + +def np_to_tensor_dict( + np_example: Mapping[str, np.ndarray], + features: Sequence[str], + ) -> TensorDict: + """Creates dict of tensors from a dict of NumPy arrays. + + Args: + np_example: A dict of NumPy feature arrays. + features: A list of strings of feature names to be returned in the dataset. + + Returns: + A dictionary of features mapping feature names to features. Only the given + features are returned, all other ones are filtered out. + """ + features_metadata = _make_features_metadata(features) + tensor_dict = {k: tf.constant(v) for k, v in np_example.items() + if k in features_metadata} + + # Ensures shapes are as expected. Needed for setting size of empty features + # e.g. when no template hits were found. + tensor_dict = parse_reshape_logic(tensor_dict, features_metadata) + return tensor_dict diff --git a/af_backprop/alphafold/model/tf/shape_helpers.py b/af_backprop/alphafold/model/tf/shape_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..be2926a63bce7ca5db3effe63d5264620aa1dcf8 --- /dev/null +++ b/af_backprop/alphafold/model/tf/shape_helpers.py @@ -0,0 +1,47 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for dealing with shapes of TensorFlow tensors.""" +import tensorflow.compat.v1 as tf + + +def shape_list(x): + """Return list of dimensions of a tensor, statically where possible. + + Like `x.shape.as_list()` but with tensors instead of `None`s. + + Args: + x: A tensor. + Returns: + A list with length equal to the rank of the tensor. The n-th element of the + list is an integer when that dimension is statically known otherwise it is + the n-th element of `tf.shape(x)`. + """ + x = tf.convert_to_tensor(x) + + # If unknown rank, return dynamic shape + if x.get_shape().dims is None: + return tf.shape(x) + + static = x.get_shape().as_list() + shape = tf.shape(x) + + ret = [] + for i in range(len(static)): + dim = static[i] + if dim is None: + dim = shape[i] + ret.append(dim) + return ret + diff --git a/af_backprop/alphafold/model/tf/shape_placeholders.py b/af_backprop/alphafold/model/tf/shape_placeholders.py new file mode 100644 index 0000000000000000000000000000000000000000..cffdeb5e1fa9691eb74680b8c9aeb8bab6123fa8 --- /dev/null +++ b/af_backprop/alphafold/model/tf/shape_placeholders.py @@ -0,0 +1,20 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Placeholder values for run-time varying dimension sizes.""" + +NUM_RES = 'num residues placeholder' +NUM_MSA_SEQ = 'msa placeholder' +NUM_EXTRA_SEQ = 'extra msa placeholder' +NUM_TEMPLATES = 'num templates placeholder' diff --git a/af_backprop/alphafold/model/tf/utils.py b/af_backprop/alphafold/model/tf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fc40a2ceb2de1c2d56c17697393713804d7da350 --- /dev/null +++ b/af_backprop/alphafold/model/tf/utils.py @@ -0,0 +1,47 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities for various components.""" +import tensorflow.compat.v1 as tf + + +def tf_combine_mask(*masks): + """Take the intersection of float-valued masks.""" + ret = 1 + for m in masks: + ret *= m + return ret + + +class SeedMaker(object): + """Return unique seeds.""" + + def __init__(self, initial_seed=0): + self.next_seed = initial_seed + + def __call__(self): + i = self.next_seed + self.next_seed += 1 + return i + +seed_maker = SeedMaker() + + +def make_random_seed(): + return tf.random.uniform([2], + tf.int32.min, + tf.int32.max, + tf.int32, + seed=seed_maker()) + diff --git a/af_backprop/alphafold/model/utils.py b/af_backprop/alphafold/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed5361e867a3f566afaa5aafeaac10863c4af8d --- /dev/null +++ b/af_backprop/alphafold/model/utils.py @@ -0,0 +1,81 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collection of JAX utility functions for use in protein folding.""" + +import collections +import numbers +from typing import Mapping + +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + + +def final_init(config): + if config.zero_init: + return 'zeros' + else: + return 'linear' + + +def batched_gather(params, indices, axis=0, batch_dims=0): + """Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`.""" + take_fn = lambda p, i: jnp.take(p, i, axis=axis) + for _ in range(batch_dims): + take_fn = jax.vmap(take_fn) + return take_fn(params, indices) + + +def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10): + """Masked mean.""" + if drop_mask_channel: + mask = mask[..., 0] + + mask_shape = mask.shape + value_shape = value.shape + + assert len(mask_shape) == len(value_shape) + + if isinstance(axis, numbers.Integral): + axis = [axis] + elif axis is None: + axis = list(range(len(mask_shape))) + assert isinstance(axis, collections.Iterable), ( + 'axis needs to be either an iterable, integer or "None"') + + broadcast_factor = 1. + for axis_ in axis: + value_size = value_shape[axis_] + mask_size = mask_shape[axis_] + if mask_size == 1: + broadcast_factor *= value_size + else: + assert mask_size == value_size + + return (jnp.sum(mask * value, axis=axis) / + (jnp.sum(mask, axis=axis) * broadcast_factor + eps)) + + +def flat_params_to_haiku(params: Mapping[str, np.ndarray]) -> hk.Params: + """Convert a dictionary of NumPy arrays to Haiku parameters.""" + hk_params = {} + for path, array in params.items(): + scope, name = path.split('//') + if scope not in hk_params: + hk_params[scope] = {} + hk_params[scope][name] = jnp.array(array) + + return hk_params diff --git a/af_backprop/examples/AlphaFold_single.ipynb b/af_backprop/examples/AlphaFold_single.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..29bae829e9bc9be8a9cd27bed04dff0a8fcd718f --- /dev/null +++ b/af_backprop/examples/AlphaFold_single.ipynb @@ -0,0 +1,311 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "AlphaFold_single.ipynb", + "provenance": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "#AlphaFold - single sequence input\n", + "- WARNING - For DEMO and educational purposes only. \n", + "- For natural proteins you often need more than a single sequence to accurately predict the structure. See [ColabFold](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb) notebook if you want to predict the protein structure from a multiple-sequence-alignment. That being said, this notebook could potentially be useful for evaluating *de novo* designed proteins.\n" + ], + "metadata": { + "id": "VpfCw7IzVHXv" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Setup\n", + "from IPython.utils import io\n", + "import os,sys,re\n", + "import tensorflow as tf\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "\n", + "with io.capture_output() as captured:\n", + " if not os.path.isdir(\"af_backprop\"):\n", + " %shell git clone -b beta https://github.com/sokrypton/af_backprop.git\n", + " %shell pip -q install biopython dm-haiku ml-collections py3Dmol\n", + " %shell wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py\n", + " if not os.path.isdir(\"params\"):\n", + " %shell mkdir params\n", + " %shell curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", + "\n", + "try:\n", + " # check if TPU is available\n", + " import jax.tools.colab_tpu\n", + " jax.tools.colab_tpu.setup_tpu()\n", + " print('Running on TPU')\n", + " DEVICE = \"tpu\"\n", + "except:\n", + " if jax.local_devices()[0].platform == 'cpu':\n", + " print(\"WARNING: no GPU detected, will be using CPU\")\n", + " DEVICE = \"cpu\"\n", + " else:\n", + " print('Running on GPU')\n", + " DEVICE = \"gpu\"\n", + " # disable GPU on tensorflow\n", + " tf.config.set_visible_devices([], 'GPU')\n", + "\n", + "sys.path.append('/content/af_backprop')\n", + "# import libraries\n", + "from utils import update_seq, update_aatype, get_plddt, get_pae\n", + "import colabfold as cf\n", + "from alphafold.common import protein\n", + "from alphafold.data import pipeline\n", + "from alphafold.model import data, config, model\n", + "from alphafold.common import residue_constants\n", + "\n", + "def clear_mem():\n", + " backend = jax.lib.xla_bridge.get_backend()\n", + " for buf in backend.live_buffers(): buf.delete()\n", + "\n", + "def setup_model(max_len, model_name=\"model_2_ptm\"):\n", + "\n", + " clear_mem()\n", + "\n", + " # setup model\n", + " cfg = config.model_config(\"model_5_ptm\")\n", + " cfg.model.num_recycle = 0\n", + " cfg.data.common.num_recycle = 0\n", + " cfg.data.eval.max_msa_clusters = 1\n", + " cfg.data.common.max_extra_msa = 1\n", + " cfg.data.eval.masked_msa_replace_fraction = 0\n", + " cfg.model.global_config.subbatch_size = None\n", + " model_params = data.get_model_haiku_params(model_name=model_name, data_dir=\".\")\n", + " model_runner = model.RunModel(cfg, model_params, is_training=False)\n", + "\n", + " seq = \"A\" * max_len\n", + " length = len(seq)\n", + " feature_dict = {\n", + " **pipeline.make_sequence_features(sequence=seq, description=\"none\", num_res=length),\n", + " **pipeline.make_msa_features(msas=[[seq]], deletion_matrices=[[[0]*length]])\n", + " }\n", + " inputs = model_runner.process_features(feature_dict,random_seed=0)\n", + "\n", + " def runner(seq, opt):\n", + " # update sequence\n", + " inputs = opt[\"inputs\"]\n", + " inputs.update(opt[\"prev\"])\n", + " update_seq(seq, inputs)\n", + " update_aatype(inputs[\"target_feat\"][...,1:], inputs)\n", + "\n", + " # mask prediction\n", + " mask = seq.sum(-1)\n", + " inputs[\"seq_mask\"] = inputs[\"seq_mask\"].at[:].set(mask)\n", + " inputs[\"msa_mask\"] = inputs[\"msa_mask\"].at[:].set(mask)\n", + " inputs[\"residue_index\"] = jnp.where(mask==1,inputs[\"residue_index\"],0)\n", + "\n", + " # get prediction\n", + " key = jax.random.PRNGKey(0)\n", + " outputs = model_runner.apply(opt[\"params\"], key, inputs)\n", + "\n", + " prev = {\"init_msa_first_row\":outputs['representations']['msa_first_row'][None],\n", + " \"init_pair\":outputs['representations']['pair'][None],\n", + " \"init_pos\":outputs['structure_module']['final_atom_positions'][None]}\n", + " \n", + " aux = {\"final_atom_positions\":outputs[\"structure_module\"][\"final_atom_positions\"],\n", + " \"final_atom_mask\":outputs[\"structure_module\"][\"final_atom_mask\"],\n", + " \"plddt\":get_plddt(outputs),\"pae\":get_pae(outputs),\n", + " \"inputs\":inputs, \"prev\":prev}\n", + " return aux\n", + "\n", + " return jax.jit(runner), {\"inputs\":inputs,\"params\":model_params}\n", + "\n", + "MAX_LEN = 50\n", + "RUNNER, OPT = setup_model(MAX_LEN)" + ], + "metadata": { + "cellView": "form", + "id": "24ybo88aBiSU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%%time\n", + "#@title Enter the amino acid sequence to fold ⬇️\n", + "\n", + "sequence = 'GGGGGGGGGGGGGGGGGGGG' #@param {type:\"string\"}\n", + "recycles = 0 #@param [\"0\", \"1\", \"2\", \"3\", \"6\", \"12\", \"24\"] {type:\"raw\"}\n", + "SEQ = re.sub(\"[^A-Z]\", \"\", sequence.upper())\n", + "LEN = len(SEQ)\n", + "if LEN > MAX_LEN:\n", + " print(\"recompiling...\")\n", + " MAX_LEN = LEN\n", + " RUNNER, OPT = setup_model(MAX_LEN)\n", + "\n", + "x = np.array([residue_constants.restype_order.get(aa,0) for aa in SEQ])\n", + "x = np.pad(x,[0,MAX_LEN-LEN],constant_values=-1)\n", + "x = jax.nn.one_hot(x,20)\n", + "\n", + "OPT[\"prev\"] = {'init_msa_first_row': np.zeros([1, MAX_LEN, 256]),\n", + " 'init_pair': np.zeros([1, MAX_LEN, MAX_LEN, 128]),\n", + " 'init_pos': np.zeros([1, MAX_LEN, 37, 3])}\n", + "\n", + "positions = []\n", + "plddts = []\n", + "for r in range(recycles+1):\n", + " outs = RUNNER(x, OPT)\n", + " outs = jax.tree_map(lambda x:np.asarray(x), outs)\n", + " positions.append(outs[\"prev\"][\"init_pos\"][0,:LEN])\n", + " plddts.append(outs[\"plddt\"][:LEN])\n", + " OPT[\"prev\"] = outs[\"prev\"]\n", + " if recycles > 0:\n", + " print(r, plddts[-1].mean())" + ], + "metadata": { + "cellView": "form", + "id": "cAoC4ar8G7ZH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Display 3D structure {run: \"auto\"}\n", + "color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n", + "show_sidechains = True #@param {type:\"boolean\"}\n", + "show_mainchains = False #@param {type:\"boolean\"}\n", + "#@markdown - TIP - hold mouse over aminoacid to get name and position number\n", + "\n", + "def save_pdb(outs, filename):\n", + " '''save pdb coordinates'''\n", + " p = {\"residue_index\":outs[\"inputs\"][\"residue_index\"][0][:LEN] + 1,\n", + " \"aatype\":outs[\"inputs\"][\"aatype\"].argmax(-1)[0][:LEN],\n", + " \"atom_positions\":outs[\"final_atom_positions\"][:LEN],\n", + " \"atom_mask\":outs[\"final_atom_mask\"][:LEN]}\n", + " b_factors = 100.0 * outs[\"plddt\"][:LEN,None] * p[\"atom_mask\"]\n", + " p = protein.Protein(**p,b_factors=b_factors)\n", + " pdb_lines = protein.to_pdb(p)\n", + " with open(filename, 'w') as f:\n", + " f.write(pdb_lines)\n", + "\n", + "save_pdb(outs,\"out.pdb\")\n", + "num_res = int(outs[\"inputs\"][\"aatype\"][0].sum())\n", + "\n", + "v = cf.show_pdb(\"out.pdb\", show_sidechains, show_mainchains, color,\n", + " color_HP=True, size=(800,480)) \n", + "v.setHoverable({},\n", + " True,\n", + " '''function(atom,viewer,event,container){if(!atom.label){atom.label=viewer.addLabel(\" \"+atom.resn+\":\"+atom.resi,{position:atom,backgroundColor:'mintcream',fontColor:'black'});}}''',\n", + " '''function(atom,viewer){if(atom.label){viewer.removeLabel(atom.label);delete atom.label;}}''')\n", + "v.show() \n", + "\n", + "if color == \"lDDT\":\n", + " cf.plot_plddt_legend().show() \n", + "if \"pae\" in outs:\n", + " cf.plot_confidence(outs[\"plddt\"][:LEN]*100, outs[\"pae\"][:LEN,:LEN]).show()\n", + "else:\n", + " cf.plot_confidence(outs[\"plddt\"][:LEN]*100).show()" + ], + "metadata": { + "cellView": "form", + "id": "-KbUGG4ZOp0J" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Animate\n", + "#@markdown - Animate trajectory if more than 0 recycle(s)\n", + "import matplotlib\n", + "from matplotlib import animation\n", + "import matplotlib.pyplot as plt\n", + "from IPython.display import HTML\n", + "\n", + "def make_animation(positions, plddts=None, line_w=2.0):\n", + "\n", + " def ca_align_to_last(positions):\n", + " def align(P, Q):\n", + " p = P - P.mean(0,keepdims=True)\n", + " q = Q - Q.mean(0,keepdims=True)\n", + " return p @ cf.kabsch(p,q)\n", + " \n", + " pos = positions[-1,:,1,:] - positions[-1,:,1,:].mean(0,keepdims=True)\n", + " best_2D_view = pos @ cf.kabsch(pos,pos,return_v=True)\n", + "\n", + " new_positions = []\n", + " for i in range(len(positions)):\n", + " new_positions.append(align(positions[i,:,1,:],best_2D_view))\n", + " return np.asarray(new_positions)\n", + "\n", + " # align all to last recycle\n", + " pos = ca_align_to_last(positions)\n", + "\n", + " fig, (ax1, ax2, ax3) = plt.subplots(1,3)\n", + " fig.subplots_adjust(top = 0.90, bottom = 0.10, right = 1, left = 0, hspace = 0, wspace = 0)\n", + " fig.set_figwidth(13)\n", + " fig.set_figheight(5)\n", + " fig.set_dpi(100)\n", + "\n", + " xy_min = pos[...,:2].min() - 1\n", + " xy_max = pos[...,:2].max() + 1\n", + "\n", + " for ax in [ax1,ax3]:\n", + " ax.set_xlim(xy_min, xy_max)\n", + " ax.set_ylim(xy_min, xy_max)\n", + " ax.axis(False)\n", + "\n", + " ims=[]\n", + " for k,(xyz,plddt) in enumerate(zip(pos,plddts)):\n", + " ims.append([])\n", + " im2 = ax2.plot(plddt, animated=True, color=\"black\")\n", + " tt1 = cf.add_text(\"colored by N->C\", ax1)\n", + " tt2 = cf.add_text(f\"recycle={k}\", ax2)\n", + " tt3 = cf.add_text(f\"pLDDT={plddt.mean():.3f}\", ax3)\n", + " ax2.set_xlabel(\"positions\")\n", + " ax2.set_ylabel(\"pLDDT\")\n", + " ax2.set_ylim(0,100)\n", + " ims[-1] += [cf.plot_pseudo_3D(xyz, ax=ax1, line_w=line_w)]\n", + " ims[-1] += [im2[0],tt1,tt2,tt3]\n", + " ims[-1] += [cf.plot_pseudo_3D(xyz, c=plddt, cmin=50, cmax=90, ax=ax3, line_w=line_w)]\n", + " \n", + " ani = animation.ArtistAnimation(fig, ims, blit=True, interval=120)\n", + " plt.close()\n", + " return ani.to_html5_video()\n", + "\n", + "HTML(make_animation(np.asarray(positions),\n", + " np.asarray(plddts) * 100.0))" + ], + "metadata": { + "cellView": "form", + "id": "tdjdC0KFPjWw" + }, + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/af_backprop/examples/af_design.ipynb b/af_backprop/examples/af_design.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d98761563074bc02f10ab9f784d8d106e1b9e08f --- /dev/null +++ b/af_backprop/examples/af_design.ipynb @@ -0,0 +1,41 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "af_design.ipynb", + "provenance": [], + "include_colab_link": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OA2k3sAYuiXe" + }, + "source": [ + "#AF Design\n", + "NOTE, updated version of this notebook has moved to: [ColabDesign](https://github.com/sokrypton/ColabDesign/tree/main/af)" + ] + } + ] +} \ No newline at end of file diff --git a/af_backprop/examples/fixbb_design.ipynb b/af_backprop/examples/fixbb_design.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..da07f36c6976fddfd2fb06f753d86ed3ac2ea2ed --- /dev/null +++ b/af_backprop/examples/fixbb_design.ipynb @@ -0,0 +1,29 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "fixbb_design.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "uLHIgB5QydoL" + }, + "source": [ + "This notebook has moved here: https://colab.research.google.com/github/sokrypton/af_backprop/blob/main/examples/af_design.ipynb" + ] + } + ] +} diff --git a/af_backprop/examples/sc_hall/1QJG.pdb b/af_backprop/examples/sc_hall/1QJG.pdb new file mode 100644 index 0000000000000000000000000000000000000000..6916c215e13c49d9510eddddf4b22ab7e56c913a --- /dev/null +++ b/af_backprop/examples/sc_hall/1QJG.pdb @@ -0,0 +1,1156 @@ +ATOM 1 N MET A 1 10.694 86.076 -17.884 1.00 84.30 A N +ATOM 2 CA MET A 1 9.356 86.709 -18.062 1.00 81.23 A C +ATOM 3 C MET A 1 8.548 85.831 -19.012 1.00 74.08 A C +ATOM 4 O MET A 1 8.958 85.627 -20.159 1.00 74.53 A O +ATOM 5 CB MET A 1 9.510 88.112 -18.660 1.00 88.90 A C +ATOM 6 CG MET A 1 8.245 88.955 -18.591 1.00 98.51 A C +ATOM 7 SD MET A 1 7.719 89.207 -16.879 1.00111.10 A S +ATOM 8 CE MET A 1 6.310 88.074 -16.761 1.00106.45 A C +ATOM 9 H1 MET A 1 11.109 85.938 -18.826 1.00 0.00 A H +ATOM 10 H2 MET A 1 11.317 86.619 -17.265 1.00 0.00 A H +ATOM 11 H3 MET A 1 10.557 85.138 -17.448 1.00 0.00 A H +ATOM 12 N ASN A 2 7.410 85.316 -18.548 1.00 62.40 A N +ATOM 13 CA ASN A 2 6.596 84.449 -19.389 1.00 47.87 A C +ATOM 14 C ASN A 2 5.205 84.954 -19.711 1.00 41.68 A C +ATOM 15 O ASN A 2 4.231 84.652 -19.021 1.00 41.06 A O +ATOM 16 CB ASN A 2 6.506 83.052 -18.794 1.00 41.73 A C +ATOM 17 CG ASN A 2 7.857 82.460 -18.510 1.00 34.33 A C +ATOM 18 ND2 ASN A 2 7.874 81.403 -17.729 1.00 31.37 A N +ATOM 19 OD1 ASN A 2 8.883 82.969 -18.959 1.00 37.21 A O +ATOM 20 H ASN A 2 7.083 85.483 -17.636 1.00 0.00 A H +ATOM 21 HD21 ASN A 2 7.016 81.088 -17.384 1.00 0.00 A H +ATOM 22 HD22 ASN A 2 8.733 80.991 -17.536 1.00 0.00 A H +ATOM 23 N THR A 3 5.127 85.753 -20.761 1.00 36.66 A N +ATOM 24 CA THR A 3 3.865 86.272 -21.226 1.00 34.05 A C +ATOM 25 C THR A 3 3.507 85.320 -22.359 1.00 34.07 A C +ATOM 26 O THR A 3 4.363 84.575 -22.833 1.00 36.58 A O +ATOM 27 CB THR A 3 4.029 87.713 -21.764 1.00 33.23 A C +ATOM 28 CG2 THR A 3 4.446 88.643 -20.655 1.00 34.82 A C +ATOM 29 OG1 THR A 3 5.029 87.746 -22.789 1.00 32.02 A O +ATOM 30 H THR A 3 5.922 86.005 -21.270 1.00 0.00 A H +ATOM 31 HG1 THR A 3 4.879 88.607 -23.212 1.00 0.00 A H +ATOM 32 N PRO A 4 2.236 85.282 -22.773 1.00 33.54 A N +ATOM 33 CA PRO A 4 1.860 84.384 -23.865 1.00 30.39 A C +ATOM 34 C PRO A 4 2.707 84.727 -25.085 1.00 29.97 A C +ATOM 35 O PRO A 4 3.253 83.850 -25.760 1.00 31.13 A O +ATOM 36 CB PRO A 4 0.404 84.759 -24.119 1.00 30.46 A C +ATOM 37 CG PRO A 4 -0.066 85.173 -22.778 1.00 33.24 A C +ATOM 38 CD PRO A 4 1.061 86.017 -22.276 1.00 34.34 A C +ATOM 39 N GLU A 5 2.848 86.023 -25.319 1.00 27.73 A N +ATOM 40 CA GLU A 5 3.607 86.541 -26.437 1.00 27.01 A C +ATOM 41 C GLU A 5 5.012 85.965 -26.481 1.00 24.21 A C +ATOM 42 O GLU A 5 5.488 85.561 -27.532 1.00 27.56 A O +ATOM 43 CB GLU A 5 3.690 88.070 -26.353 1.00 31.01 A C +ATOM 44 CG GLU A 5 2.342 88.828 -26.453 1.00 43.67 A C +ATOM 45 CD GLU A 5 1.448 88.719 -25.210 1.00 49.66 A C +ATOM 46 OE1 GLU A 5 1.955 88.464 -24.095 1.00 57.37 A O +ATOM 47 OE2 GLU A 5 0.222 88.898 -25.353 1.00 51.34 A O1- +ATOM 48 H GLU A 5 2.432 86.661 -24.705 1.00 0.00 A H +ATOM 49 N HIS A 6 5.662 85.907 -25.328 1.00 20.87 A N +ATOM 50 CA HIS A 6 7.030 85.407 -25.240 1.00 18.21 A C +ATOM 51 C HIS A 6 7.128 83.913 -25.504 1.00 16.76 A C +ATOM 52 O HIS A 6 8.016 83.462 -26.216 1.00 20.94 A O +ATOM 53 CB HIS A 6 7.644 85.758 -23.879 1.00 17.74 A C +ATOM 54 CG HIS A 6 8.966 85.097 -23.617 1.00 23.16 A C +ATOM 55 CD2 HIS A 6 9.288 84.014 -22.867 1.00 23.51 A C +ATOM 56 ND1 HIS A 6 10.148 85.543 -24.171 1.00 26.10 A N +ATOM 57 CE1 HIS A 6 11.138 84.764 -23.775 1.00 25.94 A C +ATOM 58 NE2 HIS A 6 10.643 83.829 -22.985 1.00 24.12 A N +ATOM 59 H HIS A 6 5.203 86.224 -24.518 1.00 0.00 A H +ATOM 60 HD1 HIS A 6 10.261 86.309 -24.778 1.00 0.00 A H +ATOM 61 HE2 HIS A 6 11.146 83.092 -22.596 1.00 0.00 A H +ATOM 62 N MET A 7 6.222 83.145 -24.924 1.00 14.80 A N +ATOM 63 CA MET A 7 6.251 81.718 -25.128 1.00 13.22 A C +ATOM 64 C MET A 7 5.959 81.421 -26.587 1.00 11.59 A C +ATOM 65 O MET A 7 6.500 80.493 -27.161 1.00 12.50 A O +ATOM 66 CB MET A 7 5.245 81.037 -24.208 1.00 13.28 A C +ATOM 67 CG MET A 7 5.517 81.314 -22.727 1.00 9.93 A C +ATOM 68 SD MET A 7 4.532 80.310 -21.631 1.00 15.44 A S +ATOM 69 CE MET A 7 3.047 81.335 -21.585 1.00 7.92 A C +ATOM 70 H MET A 7 5.517 83.549 -24.369 1.00 0.00 A H +ATOM 71 N THR A 8 5.132 82.245 -27.204 1.00 13.23 A N +ATOM 72 CA THR A 8 4.793 82.053 -28.600 1.00 14.43 A C +ATOM 73 C THR A 8 5.973 82.395 -29.514 1.00 18.05 A C +ATOM 74 O THR A 8 6.181 81.732 -30.534 1.00 19.37 A O +ATOM 75 CB THR A 8 3.577 82.897 -28.978 1.00 19.26 A C +ATOM 76 CG2 THR A 8 3.136 82.608 -30.403 1.00 20.22 A C +ATOM 77 OG1 THR A 8 2.500 82.579 -28.086 1.00 14.83 A O +ATOM 78 H THR A 8 4.729 82.990 -26.718 1.00 0.00 A H +ATOM 79 HG1 THR A 8 2.746 82.797 -27.187 1.00 0.00 A H +ATOM 80 N ALA A 9 6.751 83.416 -29.152 1.00 14.07 A N +ATOM 81 CA ALA A 9 7.911 83.806 -29.947 1.00 13.33 A C +ATOM 82 C ALA A 9 9.000 82.739 -29.838 1.00 12.50 A C +ATOM 83 O ALA A 9 9.697 82.462 -30.807 1.00 15.18 A O +ATOM 84 CB ALA A 9 8.423 85.131 -29.500 1.00 13.58 A C +ATOM 85 H ALA A 9 6.526 83.927 -28.344 1.00 0.00 A H +ATOM 86 N VAL A 10 9.118 82.116 -28.671 1.00 11.73 A N +ATOM 87 CA VAL A 10 10.090 81.055 -28.464 1.00 12.11 A C +ATOM 88 C VAL A 10 9.731 79.823 -29.295 1.00 13.25 A C +ATOM 89 O VAL A 10 10.614 79.217 -29.890 1.00 19.98 A O +ATOM 90 CB VAL A 10 10.191 80.674 -26.995 1.00 9.17 A C +ATOM 91 CG1 VAL A 10 10.931 79.370 -26.838 1.00 6.34 A C +ATOM 92 CG2 VAL A 10 10.917 81.772 -26.238 1.00 8.83 A C +ATOM 93 H VAL A 10 8.524 82.378 -27.934 1.00 0.00 A H +ATOM 94 N VAL A 11 8.451 79.453 -29.354 1.00 14.16 A N +ATOM 95 CA VAL A 11 8.036 78.296 -30.153 1.00 8.99 A C +ATOM 96 C VAL A 11 8.424 78.566 -31.597 1.00 13.66 A C +ATOM 97 O VAL A 11 8.974 77.707 -32.272 1.00 20.39 A O +ATOM 98 CB VAL A 11 6.512 78.064 -30.071 1.00 8.71 A C +ATOM 99 CG1 VAL A 11 6.051 77.051 -31.104 1.00 6.03 A C +ATOM 100 CG2 VAL A 11 6.137 77.589 -28.690 1.00 10.59 A C +ATOM 101 H VAL A 11 7.783 79.946 -28.829 1.00 0.00 A H +ATOM 102 N GLN A 12 8.171 79.789 -32.048 1.00 17.26 A N +ATOM 103 CA GLN A 12 8.478 80.205 -33.405 1.00 16.35 A C +ATOM 104 C GLN A 12 9.972 80.219 -33.722 1.00 14.80 A C +ATOM 105 O GLN A 12 10.369 79.891 -34.824 1.00 13.57 A O +ATOM 106 CB GLN A 12 7.856 81.566 -33.678 1.00 18.90 A C +ATOM 107 CG GLN A 12 6.344 81.525 -33.685 1.00 23.55 A C +ATOM 108 CD GLN A 12 5.732 82.912 -33.828 1.00 34.39 A C +ATOM 109 NE2 GLN A 12 6.071 83.814 -32.905 1.00 35.29 A N +ATOM 110 OE1 GLN A 12 4.959 83.169 -34.748 1.00 41.94 A O +ATOM 111 H GLN A 12 7.747 80.428 -31.439 1.00 0.00 A H +ATOM 112 HE21 GLN A 12 6.672 83.536 -32.191 1.00 0.00 A H +ATOM 113 HE22 GLN A 12 5.691 84.710 -33.003 1.00 0.00 A H +ATOM 114 N ARG A 13 10.804 80.604 -32.764 1.00 16.01 A N +ATOM 115 CA ARG A 13 12.240 80.609 -32.997 1.00 15.24 A C +ATOM 116 C ARG A 13 12.729 79.171 -33.079 1.00 15.94 A C +ATOM 117 O ARG A 13 13.676 78.860 -33.802 1.00 18.78 A O +ATOM 118 CB ARG A 13 12.970 81.328 -31.868 1.00 14.71 A C +ATOM 119 CG ARG A 13 12.781 82.814 -31.870 1.00 24.21 A C +ATOM 120 CD ARG A 13 13.525 83.449 -30.719 1.00 30.45 A C +ATOM 121 NE ARG A 13 13.429 84.899 -30.810 1.00 42.18 A N +ATOM 122 CZ ARG A 13 12.766 85.668 -29.951 1.00 46.62 A C +ATOM 123 NH1 ARG A 13 12.138 85.130 -28.912 1.00 52.17 A N1+ +ATOM 124 NH2 ARG A 13 12.700 86.979 -30.163 1.00 53.25 A N +ATOM 125 H ARG A 13 10.448 80.892 -31.898 1.00 0.00 A H +ATOM 126 HE ARG A 13 13.905 85.284 -31.580 1.00 0.00 A H +ATOM 127 HH11 ARG A 13 12.163 84.134 -28.780 1.00 0.00 A H +ATOM 128 HH12 ARG A 13 11.626 85.671 -28.239 1.00 0.00 A H +ATOM 129 HH21 ARG A 13 13.140 87.388 -30.967 1.00 0.00 A H +ATOM 130 HH22 ARG A 13 12.212 87.593 -29.540 1.00 0.00 A H +ATOM 131 N TYR A 14 12.069 78.299 -32.327 1.00 13.30 A N +ATOM 132 CA TYR A 14 12.402 76.890 -32.286 1.00 10.90 A C +ATOM 133 C TYR A 14 12.147 76.240 -33.647 1.00 13.50 A C +ATOM 134 O TYR A 14 12.987 75.533 -34.188 1.00 16.92 A O +ATOM 135 CB TYR A 14 11.538 76.227 -31.229 1.00 6.03 A C +ATOM 136 CG TYR A 14 11.674 74.726 -31.157 1.00 9.19 A C +ATOM 137 CD1 TYR A 14 12.809 74.151 -30.607 1.00 11.32 A C +ATOM 138 CD2 TYR A 14 10.662 73.885 -31.594 1.00 6.00 A C +ATOM 139 CE1 TYR A 14 12.937 72.787 -30.486 1.00 13.36 A C +ATOM 140 CE2 TYR A 14 10.787 72.508 -31.473 1.00 7.87 A C +ATOM 141 CZ TYR A 14 11.936 71.975 -30.912 1.00 7.44 A C +ATOM 142 OH TYR A 14 12.106 70.618 -30.754 1.00 8.52 A O +ATOM 143 H TYR A 14 11.336 78.619 -31.760 1.00 0.00 A H +ATOM 144 HH TYR A 14 11.279 70.202 -31.029 1.00 0.00 A H +ATOM 145 N VAL A 15 10.951 76.458 -34.176 1.00 17.30 A N +ATOM 146 CA VAL A 15 10.573 75.905 -35.460 1.00 16.50 A C +ATOM 147 C VAL A 15 11.522 76.437 -36.513 1.00 17.14 A C +ATOM 148 O VAL A 15 11.978 75.699 -37.366 1.00 20.10 A O +ATOM 149 CB VAL A 15 9.135 76.312 -35.818 1.00 14.14 A C +ATOM 150 CG1 VAL A 15 8.871 76.099 -37.278 1.00 14.20 A C +ATOM 151 CG2 VAL A 15 8.169 75.506 -35.018 1.00 11.49 A C +ATOM 152 H VAL A 15 10.306 76.993 -33.666 1.00 0.00 A H +ATOM 153 N ALA A 16 11.835 77.723 -36.422 1.00 15.98 A N +ATOM 154 CA ALA A 16 12.718 78.372 -37.381 1.00 16.15 A C +ATOM 155 C ALA A 16 14.137 77.859 -37.330 1.00 18.88 A C +ATOM 156 O ALA A 16 14.744 77.656 -38.380 1.00 26.26 A O +ATOM 157 CB ALA A 16 12.708 79.867 -37.179 1.00 17.59 A C +ATOM 158 H ALA A 16 11.451 78.264 -35.698 1.00 0.00 A H +ATOM 159 N ALA A 17 14.670 77.679 -36.119 1.00 12.81 A N +ATOM 160 CA ALA A 17 16.036 77.180 -35.919 1.00 12.41 A C +ATOM 161 C ALA A 17 16.144 75.741 -36.425 1.00 15.44 A C +ATOM 162 O ALA A 17 17.136 75.357 -37.021 1.00 19.19 A O +ATOM 163 CB ALA A 17 16.415 77.268 -34.457 1.00 9.24 A C +ATOM 164 H ALA A 17 14.132 77.888 -35.332 1.00 0.00 A H +ATOM 165 N LEU A 18 15.088 74.967 -36.223 1.00 17.31 A N +ATOM 166 CA LEU A 18 15.009 73.586 -36.680 1.00 15.47 A C +ATOM 167 C LEU A 18 15.056 73.590 -38.218 1.00 19.85 A C +ATOM 168 O LEU A 18 15.794 72.814 -38.837 1.00 20.17 A O +ATOM 169 CB LEU A 18 13.694 72.985 -36.174 1.00 15.33 A C +ATOM 170 CG LEU A 18 13.622 71.635 -35.445 1.00 15.11 A C +ATOM 171 CD1 LEU A 18 14.784 71.406 -34.532 1.00 9.98 A C +ATOM 172 CD2 LEU A 18 12.329 71.585 -34.666 1.00 11.37 A C +ATOM 173 H LEU A 18 14.330 75.333 -35.719 1.00 0.00 A H +ATOM 174 N ASN A 19 14.319 74.519 -38.822 1.00 19.25 A N +ATOM 175 CA ASN A 19 14.279 74.678 -40.282 1.00 18.67 A C +ATOM 176 C ASN A 19 15.583 75.183 -40.890 1.00 19.48 A C +ATOM 177 O ASN A 19 15.759 75.149 -42.112 1.00 26.00 A O +ATOM 178 CB ASN A 19 13.227 75.700 -40.684 1.00 16.40 A C +ATOM 179 CG ASN A 19 11.885 75.095 -40.902 1.00 19.81 A C +ATOM 180 ND2 ASN A 19 10.846 75.801 -40.448 1.00 19.07 A N +ATOM 181 OD1 ASN A 19 11.761 73.998 -41.471 1.00 21.76 A O +ATOM 182 H ASN A 19 13.762 75.104 -38.267 1.00 0.00 A H +ATOM 183 HD21 ASN A 19 11.004 76.649 -39.995 1.00 0.00 A H +ATOM 184 HD22 ASN A 19 9.959 75.414 -40.603 1.00 0.00 A H +ATOM 185 N ALA A 20 16.467 75.724 -40.070 1.00 15.92 A N +ATOM 186 CA ALA A 20 17.708 76.257 -40.599 1.00 13.83 A C +ATOM 187 C ALA A 20 18.952 75.465 -40.200 1.00 15.30 A C +ATOM 188 O ALA A 20 20.075 75.893 -40.484 1.00 20.93 A O +ATOM 189 CB ALA A 20 17.850 77.723 -40.180 1.00 9.03 A C +ATOM 190 H ALA A 20 16.279 75.791 -39.113 1.00 0.00 A H +ATOM 191 N GLY A 21 18.765 74.314 -39.559 1.00 14.60 A N +ATOM 192 CA GLY A 21 19.901 73.530 -39.112 1.00 6.97 A C +ATOM 193 C GLY A 21 20.786 74.355 -38.196 1.00 11.91 A C +ATOM 194 O GLY A 21 21.993 74.153 -38.146 1.00 19.27 A O +ATOM 195 H GLY A 21 17.855 73.991 -39.412 1.00 0.00 A H +ATOM 196 N ASP A 22 20.176 75.272 -37.455 1.00 10.80 A N +ATOM 197 CA ASP A 22 20.876 76.156 -36.540 1.00 13.90 A C +ATOM 198 C ASP A 22 20.938 75.487 -35.162 1.00 18.34 A C +ATOM 199 O ASP A 22 20.055 75.645 -34.319 1.00 21.73 A O +ATOM 200 CB ASP A 22 20.107 77.467 -36.460 1.00 15.62 A C +ATOM 201 CG ASP A 22 20.716 78.448 -35.481 1.00 21.13 A C +ATOM 202 OD1 ASP A 22 21.758 78.127 -34.874 1.00 24.96 A O +ATOM 203 OD2 ASP A 22 20.155 79.559 -35.333 1.00 22.13 A O1- +ATOM 204 H ASP A 22 19.204 75.373 -37.497 1.00 0.00 A H +ATOM 205 N LEU A 23 22.029 74.791 -34.912 1.00 16.13 A N +ATOM 206 CA LEU A 23 22.200 74.048 -33.687 1.00 12.88 A C +ATOM 207 C LEU A 23 22.297 74.892 -32.455 1.00 13.55 A C +ATOM 208 O LEU A 23 21.770 74.531 -31.414 1.00 16.43 A O +ATOM 209 CB LEU A 23 23.431 73.159 -33.812 1.00 22.08 A C +ATOM 210 CG LEU A 23 23.645 72.107 -32.746 1.00 22.16 A C +ATOM 211 CD1 LEU A 23 22.340 71.415 -32.440 1.00 23.03 A C +ATOM 212 CD2 LEU A 23 24.679 71.134 -33.248 1.00 22.78 A C +ATOM 213 H LEU A 23 22.746 74.826 -35.582 1.00 0.00 A H +ATOM 214 N ASP A 24 22.994 76.002 -32.553 1.00 14.90 A N +ATOM 215 CA ASP A 24 23.149 76.860 -31.407 1.00 19.66 A C +ATOM 216 C ASP A 24 21.920 77.692 -31.129 1.00 21.39 A C +ATOM 217 O ASP A 24 21.687 78.078 -29.993 1.00 24.19 A O +ATOM 218 CB ASP A 24 24.421 77.697 -31.539 1.00 21.55 A C +ATOM 219 CG ASP A 24 25.681 76.846 -31.380 1.00 26.43 A C +ATOM 220 OD1 ASP A 24 25.590 75.603 -31.526 1.00 33.14 A O +ATOM 221 OD2 ASP A 24 26.765 77.402 -31.111 1.00 34.14 A O1- +ATOM 222 H ASP A 24 23.414 76.281 -33.395 1.00 0.00 A H +ATOM 223 N GLY A 25 21.113 77.936 -32.157 1.00 24.38 A N +ATOM 224 CA GLY A 25 19.888 78.690 -31.980 1.00 19.16 A C +ATOM 225 C GLY A 25 18.894 77.848 -31.203 1.00 17.24 A C +ATOM 226 O GLY A 25 18.222 78.340 -30.311 1.00 19.60 A O +ATOM 227 H GLY A 25 21.353 77.634 -33.059 1.00 0.00 A H +ATOM 228 N ILE A 26 18.821 76.566 -31.518 1.00 10.29 A N +ATOM 229 CA ILE A 26 17.930 75.669 -30.819 1.00 12.55 A C +ATOM 230 C ILE A 26 18.297 75.585 -29.339 1.00 15.29 A C +ATOM 231 O ILE A 26 17.474 75.828 -28.461 1.00 17.57 A O +ATOM 232 CB ILE A 26 18.004 74.262 -31.406 1.00 9.32 A C +ATOM 233 CG1 ILE A 26 17.508 74.273 -32.836 1.00 8.47 A C +ATOM 234 CG2 ILE A 26 17.156 73.321 -30.597 1.00 13.37 A C +ATOM 235 CD1 ILE A 26 17.655 72.942 -33.496 1.00 12.60 A C +ATOM 236 H ILE A 26 19.377 76.251 -32.269 1.00 0.00 A H +ATOM 237 N VAL A 27 19.550 75.250 -29.072 1.00 17.20 A N +ATOM 238 CA VAL A 27 20.046 75.090 -27.710 1.00 14.33 A C +ATOM 239 C VAL A 27 20.013 76.328 -26.818 1.00 13.10 A C +ATOM 240 O VAL A 27 19.883 76.202 -25.606 1.00 18.39 A O +ATOM 241 CB VAL A 27 21.461 74.464 -27.718 1.00 12.93 A C +ATOM 242 CG1 VAL A 27 21.924 74.194 -26.304 1.00 17.80 A C +ATOM 243 CG2 VAL A 27 21.436 73.131 -28.520 1.00 14.81 A C +ATOM 244 H VAL A 27 20.169 75.108 -29.823 1.00 0.00 A H +ATOM 245 N ALA A 28 20.075 77.515 -27.406 1.00 13.76 A N +ATOM 246 CA ALA A 28 20.052 78.764 -26.637 1.00 12.16 A C +ATOM 247 C ALA A 28 18.720 78.974 -25.931 1.00 16.53 A C +ATOM 248 O ALA A 28 18.628 79.725 -24.963 1.00 18.18 A O +ATOM 249 CB ALA A 28 20.344 79.958 -27.556 1.00 10.61 A C +ATOM 250 H ALA A 28 20.146 77.563 -28.383 1.00 0.00 A H +ATOM 251 N LEU A 29 17.693 78.306 -26.443 1.00 16.50 A N +ATOM 252 CA LEU A 29 16.332 78.376 -25.940 1.00 12.57 A C +ATOM 253 C LEU A 29 16.119 77.660 -24.611 1.00 13.08 A C +ATOM 254 O LEU A 29 15.228 78.028 -23.840 1.00 18.83 A O +ATOM 255 CB LEU A 29 15.392 77.758 -26.981 1.00 10.87 A C +ATOM 256 CG LEU A 29 14.615 78.653 -27.953 1.00 16.58 A C +ATOM 257 CD1 LEU A 29 15.336 79.965 -28.233 1.00 17.44 A C +ATOM 258 CD2 LEU A 29 14.320 77.880 -29.244 1.00 12.46 A C +ATOM 259 H LEU A 29 17.855 77.733 -27.220 1.00 0.00 A H +ATOM 260 N PHE A 30 16.905 76.618 -24.363 1.00 15.03 A N +ATOM 261 CA PHE A 30 16.769 75.816 -23.153 1.00 13.36 A C +ATOM 262 C PHE A 30 17.547 76.305 -21.927 1.00 18.06 A C +ATOM 263 O PHE A 30 18.538 77.027 -22.044 1.00 20.40 A O +ATOM 264 CB PHE A 30 17.128 74.354 -23.463 1.00 13.54 A C +ATOM 265 CG PHE A 30 16.189 73.679 -24.468 1.00 13.97 A C +ATOM 266 CD1 PHE A 30 16.391 73.811 -25.833 1.00 12.41 A C +ATOM 267 CD2 PHE A 30 15.121 72.890 -24.032 1.00 11.09 A C +ATOM 268 CE1 PHE A 30 15.563 73.176 -26.733 1.00 13.85 A C +ATOM 269 CE2 PHE A 30 14.289 72.251 -24.930 1.00 8.95 A C +ATOM 270 CZ PHE A 30 14.509 72.392 -26.278 1.00 13.73 A C +ATOM 271 H PHE A 30 17.631 76.398 -24.981 1.00 0.00 A H +ATOM 272 N ALA A 31 17.020 75.999 -20.745 1.00 18.68 A N +ATOM 273 CA ALA A 31 17.674 76.353 -19.487 1.00 20.42 A C +ATOM 274 C ALA A 31 18.795 75.322 -19.367 1.00 20.67 A C +ATOM 275 O ALA A 31 18.672 74.215 -19.864 1.00 20.86 A O +ATOM 276 CB ALA A 31 16.705 76.222 -18.325 1.00 13.63 A C +ATOM 277 H ALA A 31 16.197 75.481 -20.726 1.00 0.00 A H +ATOM 278 N ASP A 32 19.868 75.649 -18.680 1.00 25.24 A N +ATOM 279 CA ASP A 32 20.975 74.718 -18.584 1.00 29.66 A C +ATOM 280 C ASP A 32 20.645 73.374 -17.954 1.00 26.34 A C +ATOM 281 O ASP A 32 21.213 72.337 -18.317 1.00 27.92 A O +ATOM 282 CB ASP A 32 22.141 75.424 -17.903 1.00 41.07 A C +ATOM 283 CG ASP A 32 22.543 76.697 -18.649 1.00 56.13 A C +ATOM 284 OD1 ASP A 32 21.696 77.629 -18.750 1.00 65.00 A O +ATOM 285 OD2 ASP A 32 23.672 76.741 -19.194 1.00 63.02 A O1- +ATOM 286 H ASP A 32 19.988 76.501 -18.211 1.00 0.00 A H +ATOM 287 N ASP A 33 19.654 73.378 -17.082 1.00 24.28 A N +ATOM 288 CA ASP A 33 19.230 72.174 -16.393 1.00 25.07 A C +ATOM 289 C ASP A 33 17.986 71.549 -17.017 1.00 25.39 A C +ATOM 290 O ASP A 33 17.418 70.592 -16.464 1.00 27.75 A O +ATOM 291 CB ASP A 33 18.975 72.494 -14.915 1.00 34.12 A C +ATOM 292 CG ASP A 33 18.076 73.713 -14.717 1.00 38.66 A C +ATOM 293 OD1 ASP A 33 18.398 74.816 -15.213 1.00 41.05 A O +ATOM 294 OD2 ASP A 33 17.043 73.574 -14.037 1.00 46.96 A O1- +ATOM 295 H ASP A 33 19.193 74.214 -16.864 1.00 0.00 A H +ATOM 296 N ALA A 34 17.600 72.044 -18.190 1.00 16.51 A N +ATOM 297 CA ALA A 34 16.417 71.557 -18.898 1.00 16.20 A C +ATOM 298 C ALA A 34 16.446 70.087 -19.302 1.00 18.11 A C +ATOM 299 O ALA A 34 17.509 69.450 -19.332 1.00 22.20 A O +ATOM 300 CB ALA A 34 16.163 72.412 -20.117 1.00 13.10 A C +ATOM 301 H ALA A 34 18.140 72.752 -18.594 1.00 0.00 A H +ATOM 302 N THR A 35 15.267 69.534 -19.566 1.00 17.22 A N +ATOM 303 CA THR A 35 15.164 68.143 -19.997 1.00 19.54 A C +ATOM 304 C THR A 35 14.404 68.107 -21.314 1.00 20.23 A C +ATOM 305 O THR A 35 13.650 69.036 -21.617 1.00 19.48 A O +ATOM 306 CB THR A 35 14.393 67.270 -18.995 1.00 19.13 A C +ATOM 307 CG2 THR A 35 15.303 66.727 -17.930 1.00 22.28 A C +ATOM 308 OG1 THR A 35 13.364 68.047 -18.385 1.00 33.43 A O +ATOM 309 H THR A 35 14.440 70.059 -19.505 1.00 0.00 A H +ATOM 310 HG1 THR A 35 13.777 68.790 -17.927 1.00 0.00 A H +ATOM 311 N VAL A 36 14.589 67.023 -22.066 1.00 17.96 A N +ATOM 312 CA VAL A 36 13.938 66.785 -23.362 1.00 13.42 A C +ATOM 313 C VAL A 36 13.458 65.335 -23.329 1.00 17.93 A C +ATOM 314 O VAL A 36 14.226 64.446 -22.964 1.00 21.00 A O +ATOM 315 CB VAL A 36 14.960 66.940 -24.551 1.00 15.38 A C +ATOM 316 CG1 VAL A 36 14.490 66.201 -25.779 1.00 9.13 A C +ATOM 317 CG2 VAL A 36 15.164 68.416 -24.891 1.00 6.36 A C +ATOM 318 H VAL A 36 15.199 66.329 -21.737 1.00 0.00 A H +ATOM 319 N GLU A 37 12.180 65.098 -23.604 1.00 17.48 A N +ATOM 320 CA GLU A 37 11.667 63.739 -23.638 1.00 16.27 A C +ATOM 321 C GLU A 37 11.080 63.616 -25.024 1.00 14.95 A C +ATOM 322 O GLU A 37 9.975 64.057 -25.277 1.00 14.81 A O +ATOM 323 CB GLU A 37 10.601 63.529 -22.583 1.00 21.55 A C +ATOM 324 CG GLU A 37 10.328 62.078 -22.314 1.00 25.91 A C +ATOM 325 CD GLU A 37 9.276 61.879 -21.261 1.00 27.26 A C +ATOM 326 OE1 GLU A 37 9.249 62.659 -20.294 1.00 36.29 A O +ATOM 327 OE2 GLU A 37 8.458 60.950 -21.407 1.00 32.31 A O1- +ATOM 328 H GLU A 37 11.555 65.842 -23.761 1.00 0.00 A H +ATOM 329 N ASN A 38 11.870 63.089 -25.948 1.00 19.33 A N +ATOM 330 CA ASN A 38 11.453 62.956 -27.337 1.00 18.26 A C +ATOM 331 C ASN A 38 11.866 61.577 -27.816 1.00 18.79 A C +ATOM 332 O ASN A 38 13.049 61.275 -27.911 1.00 23.11 A O +ATOM 333 CB ASN A 38 12.159 64.031 -28.170 1.00 15.03 A C +ATOM 334 CG ASN A 38 11.539 64.233 -29.541 1.00 13.27 A C +ATOM 335 ND2 ASN A 38 11.759 65.409 -30.106 1.00 20.00 A N +ATOM 336 OD1 ASN A 38 10.881 63.345 -30.090 1.00 14.68 A O +ATOM 337 H ASN A 38 12.755 62.739 -25.703 1.00 0.00 A H +ATOM 338 HD21 ASN A 38 12.289 66.102 -29.664 1.00 0.00 A H +ATOM 339 HD22 ASN A 38 11.345 65.573 -30.975 1.00 0.00 A H +ATOM 340 N PRO A 39 10.893 60.686 -28.029 1.00 19.61 A N +ATOM 341 CA PRO A 39 9.463 60.934 -27.834 1.00 20.71 A C +ATOM 342 C PRO A 39 9.023 60.632 -26.403 1.00 17.56 A C +ATOM 343 O PRO A 39 9.786 60.056 -25.612 1.00 15.92 A O +ATOM 344 CB PRO A 39 8.835 59.932 -28.790 1.00 24.55 A C +ATOM 345 CG PRO A 39 9.747 58.741 -28.617 1.00 19.39 A C +ATOM 346 CD PRO A 39 11.114 59.384 -28.689 1.00 20.94 A C +ATOM 347 N VAL A 40 7.792 61.022 -26.084 1.00 17.69 A N +ATOM 348 CA VAL A 40 7.211 60.751 -24.770 1.00 20.37 A C +ATOM 349 C VAL A 40 7.255 59.234 -24.574 1.00 21.71 A C +ATOM 350 O VAL A 40 6.793 58.456 -25.421 1.00 24.84 A O +ATOM 351 CB VAL A 40 5.745 61.274 -24.666 1.00 13.69 A C +ATOM 352 CG1 VAL A 40 4.933 60.444 -23.697 1.00 21.73 A C +ATOM 353 CG2 VAL A 40 5.755 62.678 -24.193 1.00 13.64 A C +ATOM 354 H VAL A 40 7.263 61.507 -26.755 1.00 0.00 A H +ATOM 355 N GLY A 41 7.833 58.831 -23.452 1.00 25.12 A N +ATOM 356 CA GLY A 41 7.974 57.423 -23.140 1.00 25.44 A C +ATOM 357 C GLY A 41 9.425 56.959 -23.121 1.00 23.88 A C +ATOM 358 O GLY A 41 9.692 55.840 -22.692 1.00 27.74 A O +ATOM 359 H GLY A 41 8.186 59.492 -22.810 1.00 0.00 A H +ATOM 360 N SER A 42 10.348 57.782 -23.626 1.00 26.93 A N +ATOM 361 CA SER A 42 11.775 57.450 -23.641 1.00 27.71 A C +ATOM 362 C SER A 42 12.519 58.104 -22.456 1.00 28.62 A C +ATOM 363 O SER A 42 11.940 58.880 -21.683 1.00 28.06 A O +ATOM 364 CB SER A 42 12.403 57.858 -24.979 1.00 28.05 A C +ATOM 365 OG SER A 42 12.277 59.247 -25.211 1.00 26.25 A O +ATOM 366 H SER A 42 10.065 58.650 -23.975 1.00 0.00 A H +ATOM 367 HG SER A 42 11.405 59.561 -24.962 1.00 0.00 A H +ATOM 368 N GLU A 43 13.798 57.785 -22.304 1.00 28.59 A N +ATOM 369 CA GLU A 43 14.583 58.337 -21.210 1.00 30.92 A C +ATOM 370 C GLU A 43 14.857 59.797 -21.502 1.00 28.51 A C +ATOM 371 O GLU A 43 15.324 60.152 -22.588 1.00 30.42 A O +ATOM 372 CB GLU A 43 15.922 57.600 -21.045 1.00 44.55 A C +ATOM 373 CG GLU A 43 16.019 56.223 -21.722 1.00 62.47 A C +ATOM 374 CD GLU A 43 16.565 56.284 -23.155 1.00 70.64 A C +ATOM 375 OE1 GLU A 43 15.895 56.851 -24.052 1.00 75.18 A O +ATOM 376 OE2 GLU A 43 17.675 55.758 -23.385 1.00 74.27 A O1- +ATOM 377 H GLU A 43 14.229 57.192 -22.953 1.00 0.00 A H +ATOM 378 N PRO A 44 14.533 60.675 -20.552 1.00 25.20 A N +ATOM 379 CA PRO A 44 14.754 62.113 -20.709 1.00 25.10 A C +ATOM 380 C PRO A 44 16.243 62.473 -20.872 1.00 26.61 A C +ATOM 381 O PRO A 44 17.105 61.850 -20.267 1.00 29.02 A O +ATOM 382 CB PRO A 44 14.195 62.679 -19.395 1.00 19.62 A C +ATOM 383 CG PRO A 44 13.110 61.749 -19.058 1.00 21.92 A C +ATOM 384 CD PRO A 44 13.750 60.402 -19.335 1.00 18.83 A C +ATOM 385 N ARG A 45 16.528 63.439 -21.742 1.00 27.43 A N +ATOM 386 CA ARG A 45 17.875 63.933 -21.964 1.00 25.28 A C +ATOM 387 C ARG A 45 17.876 65.090 -20.981 1.00 22.66 A C +ATOM 388 O ARG A 45 16.906 65.850 -20.926 1.00 25.12 A O +ATOM 389 CB ARG A 45 18.028 64.488 -23.386 1.00 32.63 A C +ATOM 390 CG ARG A 45 17.525 63.562 -24.488 1.00 42.81 A C +ATOM 391 CD ARG A 45 18.355 62.295 -24.560 1.00 50.48 A C +ATOM 392 NE ARG A 45 17.537 61.091 -24.696 1.00 57.00 A N +ATOM 393 CZ ARG A 45 17.594 60.248 -25.728 1.00 60.21 A C +ATOM 394 NH1 ARG A 45 18.433 60.483 -26.743 1.00 56.34 A N1+ +ATOM 395 NH2 ARG A 45 16.865 59.127 -25.706 1.00 64.07 A N +ATOM 396 H ARG A 45 15.796 63.843 -22.236 1.00 0.00 A H +ATOM 397 HE ARG A 45 16.892 60.897 -23.964 1.00 0.00 A H +ATOM 398 HH11 ARG A 45 19.032 61.291 -26.729 1.00 0.00 A H +ATOM 399 HH12 ARG A 45 18.503 59.869 -27.529 1.00 0.00 A H +ATOM 400 HH21 ARG A 45 16.294 58.876 -24.902 1.00 0.00 A H +ATOM 401 HH22 ARG A 45 16.850 58.425 -26.417 1.00 0.00 A H +ATOM 402 N SER A 46 18.932 65.219 -20.187 1.00 21.54 A N +ATOM 403 CA SER A 46 19.001 66.291 -19.192 1.00 19.49 A C +ATOM 404 C SER A 46 20.289 67.084 -19.323 1.00 15.15 A C +ATOM 405 O SER A 46 21.369 66.505 -19.426 1.00 21.08 A O +ATOM 406 CB SER A 46 18.885 65.686 -17.785 1.00 22.78 A C +ATOM 407 OG SER A 46 18.817 66.685 -16.770 1.00 30.92 A O +ATOM 408 H SER A 46 19.691 64.606 -20.248 1.00 0.00 A H +ATOM 409 HG SER A 46 19.658 67.169 -16.829 1.00 0.00 A H +ATOM 410 N GLY A 47 20.175 68.403 -19.383 1.00 9.93 A N +ATOM 411 CA GLY A 47 21.360 69.224 -19.488 1.00 15.35 A C +ATOM 412 C GLY A 47 21.740 69.608 -20.899 1.00 20.79 A C +ATOM 413 O GLY A 47 21.493 68.864 -21.845 1.00 22.57 A O +ATOM 414 H GLY A 47 19.297 68.836 -19.380 1.00 0.00 A H +ATOM 415 N THR A 48 22.368 70.774 -21.013 1.00 22.30 A N +ATOM 416 CA THR A 48 22.815 71.356 -22.272 1.00 27.44 A C +ATOM 417 C THR A 48 23.584 70.425 -23.210 1.00 26.51 A C +ATOM 418 O THR A 48 23.317 70.391 -24.411 1.00 28.00 A O +ATOM 419 CB THR A 48 23.674 72.587 -21.977 1.00 29.56 A C +ATOM 420 CG2 THR A 48 23.921 73.382 -23.235 1.00 36.98 A C +ATOM 421 OG1 THR A 48 22.983 73.417 -21.035 1.00 38.91 A O +ATOM 422 H THR A 48 22.518 71.316 -20.206 1.00 0.00 A H +ATOM 423 HG1 THR A 48 23.493 74.188 -20.725 1.00 0.00 A H +ATOM 424 N ALA A 49 24.546 69.686 -22.670 1.00 26.42 A N +ATOM 425 CA ALA A 49 25.350 68.768 -23.472 1.00 25.21 A C +ATOM 426 C ALA A 49 24.525 67.668 -24.103 1.00 22.40 A C +ATOM 427 O ALA A 49 24.779 67.304 -25.245 1.00 24.67 A O +ATOM 428 CB ALA A 49 26.454 68.160 -22.642 1.00 24.80 A C +ATOM 429 H ALA A 49 24.722 69.753 -21.705 1.00 0.00 A H +ATOM 430 N ALA A 50 23.552 67.136 -23.361 1.00 18.24 A N +ATOM 431 CA ALA A 50 22.692 66.059 -23.862 1.00 16.13 A C +ATOM 432 C ALA A 50 21.644 66.566 -24.832 1.00 15.85 A C +ATOM 433 O ALA A 50 21.205 65.829 -25.709 1.00 21.53 A O +ATOM 434 CB ALA A 50 22.018 65.338 -22.721 1.00 18.85 A C +ATOM 435 H ALA A 50 23.408 67.471 -22.450 1.00 0.00 A H +ATOM 436 N ILE A 51 21.209 67.807 -24.641 1.00 14.98 A N +ATOM 437 CA ILE A 51 20.213 68.425 -25.517 1.00 16.38 A C +ATOM 438 C ILE A 51 20.848 68.716 -26.892 1.00 17.29 A C +ATOM 439 O ILE A 51 20.271 68.383 -27.932 1.00 17.52 A O +ATOM 440 CB ILE A 51 19.589 69.701 -24.851 1.00 17.04 A C +ATOM 441 CG1 ILE A 51 18.731 69.286 -23.637 1.00 18.33 A C +ATOM 442 CG2 ILE A 51 18.750 70.483 -25.859 1.00 17.15 A C +ATOM 443 CD1 ILE A 51 18.279 70.417 -22.747 1.00 14.09 A C +ATOM 444 H ILE A 51 21.558 68.311 -23.875 1.00 0.00 A H +ATOM 445 N ARG A 52 22.065 69.256 -26.885 1.00 18.63 A N +ATOM 446 CA ARG A 52 22.802 69.549 -28.119 1.00 19.29 A C +ATOM 447 C ARG A 52 23.059 68.258 -28.919 1.00 19.10 A C +ATOM 448 O ARG A 52 22.758 68.180 -30.114 1.00 24.51 A O +ATOM 449 CB ARG A 52 24.129 70.235 -27.788 1.00 21.72 A C +ATOM 450 CG ARG A 52 24.843 70.737 -29.003 1.00 24.00 A C +ATOM 451 CD ARG A 52 26.227 71.218 -28.716 1.00 21.97 A C +ATOM 452 NE ARG A 52 26.898 71.544 -29.971 1.00 29.26 A N +ATOM 453 CZ ARG A 52 26.958 72.767 -30.486 1.00 30.33 A C +ATOM 454 NH1 ARG A 52 26.384 73.771 -29.843 1.00 28.11 A N1+ +ATOM 455 NH2 ARG A 52 27.608 72.995 -31.626 1.00 22.98 A N +ATOM 456 H ARG A 52 22.480 69.475 -26.023 1.00 0.00 A H +ATOM 457 HE ARG A 52 27.290 70.783 -30.464 1.00 0.00 A H +ATOM 458 HH11 ARG A 52 25.877 73.672 -28.996 1.00 0.00 A H +ATOM 459 HH12 ARG A 52 26.387 74.709 -30.252 1.00 0.00 A H +ATOM 460 HH21 ARG A 52 28.079 72.275 -32.140 1.00 0.00 A H +ATOM 461 HH22 ARG A 52 27.637 73.938 -31.999 1.00 0.00 A H +ATOM 462 N GLU A 53 23.544 67.223 -28.242 1.00 19.57 A N +ATOM 463 CA GLU A 53 23.818 65.938 -28.880 1.00 22.65 A C +ATOM 464 C GLU A 53 22.549 65.369 -29.537 1.00 21.36 A C +ATOM 465 O GLU A 53 22.580 64.901 -30.692 1.00 20.72 A O +ATOM 466 CB GLU A 53 24.465 64.968 -27.854 1.00 29.62 A C +ATOM 467 CG GLU A 53 23.751 63.618 -27.551 1.00 43.16 A C +ATOM 468 CD GLU A 53 24.079 62.493 -28.534 1.00 48.92 A C +ATOM 469 OE1 GLU A 53 25.097 62.583 -29.251 1.00 49.44 A O +ATOM 470 OE2 GLU A 53 23.311 61.508 -28.602 1.00 48.51 A O1- +ATOM 471 H GLU A 53 23.760 67.315 -27.289 1.00 0.00 A H +ATOM 472 N PHE A 54 21.426 65.473 -28.831 1.00 15.74 A N +ATOM 473 CA PHE A 54 20.163 64.981 -29.346 1.00 14.67 A C +ATOM 474 C PHE A 54 19.739 65.757 -30.596 1.00 14.61 A C +ATOM 475 O PHE A 54 19.359 65.181 -31.611 1.00 16.88 A O +ATOM 476 CB PHE A 54 19.076 65.093 -28.284 1.00 16.51 A C +ATOM 477 CG PHE A 54 17.700 64.832 -28.818 1.00 15.40 A C +ATOM 478 CD1 PHE A 54 17.260 63.526 -29.028 1.00 19.05 A C +ATOM 479 CD2 PHE A 54 16.869 65.889 -29.176 1.00 14.87 A C +ATOM 480 CE1 PHE A 54 16.018 63.272 -29.591 1.00 20.66 A C +ATOM 481 CE2 PHE A 54 15.625 65.655 -29.741 1.00 22.45 A C +ATOM 482 CZ PHE A 54 15.196 64.334 -29.952 1.00 21.58 A C +ATOM 483 H PHE A 54 21.466 65.860 -27.931 1.00 0.00 A H +ATOM 484 N TYR A 55 19.812 67.073 -30.538 1.00 15.57 A N +ATOM 485 CA TYR A 55 19.419 67.835 -31.699 1.00 15.97 A C +ATOM 486 C TYR A 55 20.386 67.726 -32.851 1.00 15.54 A C +ATOM 487 O TYR A 55 19.967 67.836 -34.002 1.00 17.97 A O +ATOM 488 CB TYR A 55 19.126 69.271 -31.335 1.00 16.17 A C +ATOM 489 CG TYR A 55 17.743 69.413 -30.737 1.00 19.24 A C +ATOM 490 CD1 TYR A 55 16.595 69.253 -31.523 1.00 16.59 A C +ATOM 491 CD2 TYR A 55 17.580 69.717 -29.384 1.00 21.43 A C +ATOM 492 CE1 TYR A 55 15.323 69.397 -30.969 1.00 17.69 A C +ATOM 493 CE2 TYR A 55 16.319 69.862 -28.830 1.00 24.02 A C +ATOM 494 CZ TYR A 55 15.200 69.702 -29.623 1.00 18.58 A C +ATOM 495 OH TYR A 55 13.971 69.854 -29.031 1.00 16.62 A O +ATOM 496 H TYR A 55 20.122 67.516 -29.720 1.00 0.00 A H +ATOM 497 HH TYR A 55 13.258 69.756 -29.685 1.00 0.00 A H +ATOM 498 N ALA A 56 21.664 67.470 -32.552 1.00 13.11 A N +ATOM 499 CA ALA A 56 22.676 67.316 -33.597 1.00 13.94 A C +ATOM 500 C ALA A 56 22.334 66.072 -34.401 1.00 14.20 A C +ATOM 501 O ALA A 56 22.427 66.070 -35.624 1.00 19.71 A O +ATOM 502 CB ALA A 56 24.075 67.195 -32.993 1.00 17.34 A C +ATOM 503 H ALA A 56 21.936 67.397 -31.617 1.00 0.00 A H +ATOM 504 N ASN A 57 21.911 65.022 -33.706 1.00 18.33 A N +ATOM 505 CA ASN A 57 21.521 63.772 -34.353 1.00 18.10 A C +ATOM 506 C ASN A 57 20.281 63.960 -35.190 1.00 22.69 A C +ATOM 507 O ASN A 57 20.120 63.306 -36.214 1.00 25.44 A O +ATOM 508 CB ASN A 57 21.177 62.712 -33.319 1.00 21.21 A C +ATOM 509 CG ASN A 57 22.358 61.920 -32.885 1.00 29.38 A C +ATOM 510 ND2 ASN A 57 22.662 61.977 -31.588 1.00 38.74 A N +ATOM 511 OD1 ASN A 57 23.002 61.241 -33.697 1.00 37.38 A O +ATOM 512 H ASN A 57 21.873 65.098 -32.727 1.00 0.00 A H +ATOM 513 HD21 ASN A 57 22.139 62.582 -31.013 1.00 0.00 A H +ATOM 514 HD22 ASN A 57 23.344 61.421 -31.155 1.00 0.00 A H +ATOM 515 N SER A 58 19.359 64.777 -34.691 1.00 20.80 A N +ATOM 516 CA SER A 58 18.097 65.041 -35.371 1.00 17.38 A C +ATOM 517 C SER A 58 18.311 65.791 -36.668 1.00 19.47 A C +ATOM 518 O SER A 58 17.671 65.508 -37.674 1.00 21.88 A O +ATOM 519 CB SER A 58 17.151 65.827 -34.449 1.00 17.00 A C +ATOM 520 OG SER A 58 16.738 65.021 -33.356 1.00 20.95 A O +ATOM 521 H SER A 58 19.507 65.199 -33.818 1.00 0.00 A H +ATOM 522 HG SER A 58 16.358 64.202 -33.683 1.00 0.00 A H +ATOM 523 N LEU A 59 19.245 66.726 -36.652 1.00 18.71 A N +ATOM 524 CA LEU A 59 19.536 67.509 -37.835 1.00 16.24 A C +ATOM 525 C LEU A 59 20.444 66.809 -38.853 1.00 13.82 A C +ATOM 526 O LEU A 59 20.937 67.438 -39.788 1.00 14.80 A O +ATOM 527 CB LEU A 59 20.108 68.852 -37.425 1.00 14.05 A C +ATOM 528 CG LEU A 59 19.145 69.698 -36.596 1.00 12.88 A C +ATOM 529 CD1 LEU A 59 19.927 70.882 -36.106 1.00 17.02 A C +ATOM 530 CD2 LEU A 59 17.959 70.167 -37.409 1.00 11.78 A C +ATOM 531 H LEU A 59 19.753 66.901 -35.830 1.00 0.00 A H +ATOM 532 N LYS A 60 20.658 65.514 -38.657 1.00 12.12 A N +ATOM 533 CA LYS A 60 21.456 64.686 -39.551 1.00 13.05 A C +ATOM 534 C LYS A 60 20.789 64.780 -40.927 1.00 16.30 A C +ATOM 535 O LYS A 60 21.464 64.851 -41.956 1.00 17.91 A O +ATOM 536 CB LYS A 60 21.391 63.235 -39.078 1.00 17.73 A C +ATOM 537 CG LYS A 60 22.692 62.486 -39.059 1.00 23.84 A C +ATOM 538 CD LYS A 60 23.522 62.797 -37.839 1.00 27.57 A C +ATOM 539 CE LYS A 60 23.596 61.594 -36.905 1.00 21.21 A C +ATOM 540 NZ LYS A 60 24.341 60.433 -37.453 1.00 23.59 A N1+ +ATOM 541 H LYS A 60 20.266 65.084 -37.873 1.00 0.00 A H +ATOM 542 HZ1 LYS A 60 23.931 60.114 -38.358 1.00 0.00 A H +ATOM 543 HZ2 LYS A 60 25.314 60.730 -37.672 1.00 0.00 A H +ATOM 544 HZ3 LYS A 60 24.354 59.650 -36.780 1.00 0.00 A H +ATOM 545 N LEU A 61 19.455 64.758 -40.925 1.00 19.54 A N +ATOM 546 CA LEU A 61 18.634 64.842 -42.131 1.00 19.62 A C +ATOM 547 C LEU A 61 18.053 66.240 -42.191 1.00 22.00 A C +ATOM 548 O LEU A 61 17.749 66.824 -41.156 1.00 23.73 A O +ATOM 549 CB LEU A 61 17.455 63.864 -42.056 1.00 21.85 A C +ATOM 550 CG LEU A 61 17.726 62.376 -41.884 1.00 26.24 A C +ATOM 551 CD1 LEU A 61 16.446 61.604 -42.034 1.00 22.46 A C +ATOM 552 CD2 LEU A 61 18.709 61.932 -42.935 1.00 32.42 A C +ATOM 553 H LEU A 61 18.984 64.713 -40.066 1.00 0.00 A H +ATOM 554 N PRO A 62 17.943 66.824 -43.396 1.00 22.42 A N +ATOM 555 CA PRO A 62 17.375 68.170 -43.485 1.00 22.72 A C +ATOM 556 C PRO A 62 15.894 68.092 -43.131 1.00 23.49 A C +ATOM 557 O PRO A 62 15.174 67.226 -43.638 1.00 22.02 A O +ATOM 558 CB PRO A 62 17.614 68.543 -44.945 1.00 26.38 A C +ATOM 559 CG PRO A 62 17.666 67.216 -45.647 1.00 24.89 A C +ATOM 560 CD PRO A 62 18.425 66.358 -44.707 1.00 21.44 A C +ATOM 561 N LEU A 63 15.442 68.975 -42.250 1.00 22.90 A N +ATOM 562 CA LEU A 63 14.058 68.935 -41.817 1.00 16.78 A C +ATOM 563 C LEU A 63 13.300 70.164 -42.231 1.00 17.38 A C +ATOM 564 O LEU A 63 13.857 71.258 -42.293 1.00 18.87 A O +ATOM 565 CB LEU A 63 13.998 68.836 -40.297 1.00 11.97 A C +ATOM 566 CG LEU A 63 14.823 67.776 -39.574 1.00 10.10 A C +ATOM 567 CD1 LEU A 63 14.699 68.081 -38.117 1.00 9.09 A C +ATOM 568 CD2 LEU A 63 14.360 66.334 -39.876 1.00 9.67 A C +ATOM 569 H LEU A 63 16.025 69.675 -41.895 1.00 0.00 A H +ATOM 570 N ALA A 64 12.027 69.964 -42.544 1.00 19.85 A N +ATOM 571 CA ALA A 64 11.115 71.045 -42.911 1.00 18.67 A C +ATOM 572 C ALA A 64 10.016 70.917 -41.862 1.00 14.97 A C +ATOM 573 O ALA A 64 9.274 69.926 -41.837 1.00 15.53 A O +ATOM 574 CB ALA A 64 10.548 70.830 -44.307 1.00 18.29 A C +ATOM 575 H ALA A 64 11.673 69.048 -42.532 1.00 0.00 A H +ATOM 576 N VAL A 65 9.992 71.857 -40.928 1.00 18.71 A N +ATOM 577 CA VAL A 65 9.012 71.827 -39.863 1.00 18.12 A C +ATOM 578 C VAL A 65 8.152 73.064 -39.940 1.00 14.97 A C +ATOM 579 O VAL A 65 8.641 74.129 -40.261 1.00 18.06 A O +ATOM 580 CB VAL A 65 9.699 71.646 -38.470 1.00 17.54 A C +ATOM 581 CG1 VAL A 65 11.135 72.002 -38.542 1.00 18.32 A C +ATOM 582 CG2 VAL A 65 9.028 72.462 -37.429 1.00 17.88 A C +ATOM 583 H VAL A 65 10.612 72.612 -40.980 1.00 0.00 A H +ATOM 584 N GLU A 66 6.855 72.918 -39.696 1.00 16.34 A N +ATOM 585 CA GLU A 66 5.952 74.061 -39.752 1.00 19.15 A C +ATOM 586 C GLU A 66 4.853 73.956 -38.722 1.00 15.51 A C +ATOM 587 O GLU A 66 4.336 72.874 -38.480 1.00 17.34 A O +ATOM 588 CB GLU A 66 5.283 74.137 -41.128 1.00 26.88 A C +ATOM 589 CG GLU A 66 4.292 73.007 -41.373 1.00 38.82 A C +ATOM 590 CD GLU A 66 3.409 73.196 -42.604 1.00 49.05 A C +ATOM 591 OE1 GLU A 66 3.507 74.262 -43.261 1.00 55.93 A O +ATOM 592 OE2 GLU A 66 2.612 72.266 -42.909 1.00 52.01 A O1- +ATOM 593 H GLU A 66 6.490 72.032 -39.486 1.00 0.00 A H +ATOM 594 N LEU A 67 4.489 75.078 -38.118 1.00 14.88 A N +ATOM 595 CA LEU A 67 3.388 75.087 -37.152 1.00 16.04 A C +ATOM 596 C LEU A 67 2.134 74.870 -37.982 1.00 13.87 A C +ATOM 597 O LEU A 67 2.047 75.397 -39.083 1.00 23.45 A O +ATOM 598 CB LEU A 67 3.296 76.443 -36.451 1.00 13.88 A C +ATOM 599 CG LEU A 67 4.275 76.657 -35.301 1.00 20.67 A C +ATOM 600 CD1 LEU A 67 4.330 78.143 -34.900 1.00 25.36 A C +ATOM 601 CD2 LEU A 67 3.868 75.767 -34.142 1.00 23.30 A C +ATOM 602 H LEU A 67 4.930 75.912 -38.366 1.00 0.00 A H +ATOM 603 N THR A 68 1.176 74.089 -37.495 1.00 15.39 A N +ATOM 604 CA THR A 68 -0.032 73.866 -38.280 1.00 13.51 A C +ATOM 605 C THR A 68 -1.275 74.480 -37.689 1.00 13.48 A C +ATOM 606 O THR A 68 -2.317 74.459 -38.310 1.00 14.79 A O +ATOM 607 CB THR A 68 -0.297 72.392 -38.512 1.00 9.45 A C +ATOM 608 CG2 THR A 68 0.755 71.809 -39.466 1.00 10.27 A C +ATOM 609 OG1 THR A 68 -0.247 71.711 -37.257 1.00 15.66 A O +ATOM 610 H THR A 68 1.276 73.696 -36.603 1.00 0.00 A H +ATOM 611 HG1 THR A 68 0.629 71.871 -36.880 1.00 0.00 A H +ATOM 612 N GLN A 69 -1.158 75.032 -36.489 1.00 12.31 A N +ATOM 613 CA GLN A 69 -2.273 75.656 -35.783 1.00 11.58 A C +ATOM 614 C GLN A 69 -1.700 76.685 -34.837 1.00 15.38 A C +ATOM 615 O GLN A 69 -0.498 76.728 -34.605 1.00 15.03 A O +ATOM 616 CB GLN A 69 -3.032 74.625 -34.946 1.00 9.10 A C +ATOM 617 CG GLN A 69 -4.205 73.982 -35.640 1.00 8.88 A C +ATOM 618 CD GLN A 69 -4.936 73.018 -34.760 1.00 14.81 A C +ATOM 619 NE2 GLN A 69 -5.731 72.163 -35.371 1.00 14.48 A N +ATOM 620 OE1 GLN A 69 -4.797 73.041 -33.527 1.00 24.74 A O +ATOM 621 H GLN A 69 -0.290 75.078 -36.036 1.00 0.00 A H +ATOM 622 HE21 GLN A 69 -5.800 72.199 -36.347 1.00 0.00 A H +ATOM 623 HE22 GLN A 69 -6.234 71.521 -34.825 1.00 0.00 A H +ATOM 624 N GLU A 70 -2.575 77.494 -34.267 1.00 17.67 A N +ATOM 625 CA GLU A 70 -2.175 78.510 -33.316 1.00 15.37 A C +ATOM 626 C GLU A 70 -1.456 77.860 -32.126 1.00 12.95 A C +ATOM 627 O GLU A 70 -1.678 76.684 -31.812 1.00 15.83 A O +ATOM 628 CB GLU A 70 -3.433 79.251 -32.846 1.00 15.22 A C +ATOM 629 CG GLU A 70 -4.542 78.332 -32.266 1.00 13.80 A C +ATOM 630 CD GLU A 70 -5.831 79.077 -31.979 1.00 11.85 A C +ATOM 631 OE1 GLU A 70 -5.932 79.770 -30.956 1.00 19.78 A O +ATOM 632 OE2 GLU A 70 -6.766 78.964 -32.783 1.00 15.86 A O1- +ATOM 633 H GLU A 70 -3.520 77.395 -34.476 1.00 0.00 A H +ATOM 634 N VAL A 71 -0.545 78.599 -31.514 1.00 13.61 A N +ATOM 635 CA VAL A 71 0.179 78.113 -30.348 1.00 12.54 A C +ATOM 636 C VAL A 71 -0.741 78.284 -29.115 1.00 13.85 A C +ATOM 637 O VAL A 71 -1.479 79.275 -29.003 1.00 11.75 A O +ATOM 638 CB VAL A 71 1.510 78.921 -30.124 1.00 10.04 A C +ATOM 639 CG1 VAL A 71 2.231 78.466 -28.840 1.00 9.76 A C +ATOM 640 CG2 VAL A 71 2.448 78.743 -31.301 1.00 6.51 A C +ATOM 641 H VAL A 71 -0.376 79.504 -31.849 1.00 0.00 A H +ATOM 642 N ARG A 72 -0.777 77.282 -28.245 1.00 11.83 A N +ATOM 643 CA ARG A 72 -1.580 77.367 -27.037 1.00 11.86 A C +ATOM 644 C ARG A 72 -0.613 77.779 -25.955 1.00 13.28 A C +ATOM 645 O ARG A 72 0.319 77.054 -25.665 1.00 20.35 A O +ATOM 646 CB ARG A 72 -2.187 76.014 -26.706 1.00 6.01 A C +ATOM 647 CG ARG A 72 -3.535 75.777 -27.333 1.00 9.72 A C +ATOM 648 CD ARG A 72 -3.509 75.793 -28.857 1.00 6.03 A C +ATOM 649 NE ARG A 72 -4.710 75.143 -29.389 1.00 7.77 A N +ATOM 650 CZ ARG A 72 -4.841 74.651 -30.617 1.00 6.68 A C +ATOM 651 NH1 ARG A 72 -3.855 74.736 -31.503 1.00 12.81 A N1+ +ATOM 652 NH2 ARG A 72 -5.928 73.971 -30.929 1.00 6.00 A N +ATOM 653 H ARG A 72 -0.231 76.481 -28.394 1.00 0.00 A H +ATOM 654 HE ARG A 72 -5.477 75.136 -28.778 1.00 0.00 A H +ATOM 655 HH11 ARG A 72 -2.995 75.220 -31.296 1.00 0.00 A H +ATOM 656 HH12 ARG A 72 -3.972 74.350 -32.426 1.00 0.00 A H +ATOM 657 HH21 ARG A 72 -6.650 73.797 -30.256 1.00 0.00 A H +ATOM 658 HH22 ARG A 72 -6.066 73.585 -31.852 1.00 0.00 A H +ATOM 659 N ALA A 73 -0.803 78.947 -25.375 1.00 11.78 A N +ATOM 660 CA ALA A 73 0.105 79.402 -24.342 1.00 14.51 A C +ATOM 661 C ALA A 73 -0.669 79.861 -23.090 1.00 18.31 A C +ATOM 662 O ALA A 73 -1.619 80.658 -23.182 1.00 23.63 A O +ATOM 663 CB ALA A 73 0.982 80.523 -24.896 1.00 8.71 A C +ATOM 664 H ALA A 73 -1.566 79.514 -25.615 1.00 0.00 A H +ATOM 665 N VAL A 74 -0.263 79.356 -21.924 1.00 18.62 A N +ATOM 666 CA VAL A 74 -0.920 79.698 -20.666 1.00 18.32 A C +ATOM 667 C VAL A 74 -0.087 79.243 -19.471 1.00 20.08 A C +ATOM 668 O VAL A 74 0.538 78.184 -19.518 1.00 22.78 A O +ATOM 669 CB VAL A 74 -2.351 79.043 -20.571 1.00 12.55 A C +ATOM 670 CG1 VAL A 74 -2.258 77.535 -20.280 1.00 8.98 A C +ATOM 671 CG2 VAL A 74 -3.185 79.729 -19.505 1.00 10.59 A C +ATOM 672 H VAL A 74 0.495 78.730 -21.882 1.00 0.00 A H +ATOM 673 N ALA A 75 -0.092 80.054 -18.411 1.00 22.66 A N +ATOM 674 CA ALA A 75 0.615 79.784 -17.147 1.00 22.52 A C +ATOM 675 C ALA A 75 1.994 79.147 -17.259 1.00 22.63 A C +ATOM 676 O ALA A 75 2.236 78.075 -16.699 1.00 33.25 A O +ATOM 677 CB ALA A 75 -0.264 78.951 -16.220 1.00 21.28 A C +ATOM 678 H ALA A 75 -0.602 80.881 -18.484 1.00 0.00 A H +ATOM 679 N ASN A 76 2.903 79.823 -17.952 1.00 20.58 A N +ATOM 680 CA ASN A 76 4.268 79.325 -18.133 1.00 22.59 A C +ATOM 681 C ASN A 76 4.377 78.008 -18.900 1.00 20.82 A C +ATOM 682 O ASN A 76 5.376 77.305 -18.773 1.00 21.38 A O +ATOM 683 CB ASN A 76 4.978 79.168 -16.784 1.00 22.74 A C +ATOM 684 CG ASN A 76 4.982 80.450 -15.981 1.00 27.62 A C +ATOM 685 ND2 ASN A 76 4.488 80.373 -14.765 1.00 27.46 A N +ATOM 686 OD1 ASN A 76 5.386 81.504 -16.462 1.00 32.04 A O +ATOM 687 H ASN A 76 2.639 80.681 -18.349 1.00 0.00 A H +ATOM 688 HD21 ASN A 76 4.146 79.496 -14.486 1.00 0.00 A H +ATOM 689 HD22 ASN A 76 4.489 81.171 -14.201 1.00 0.00 A H +ATOM 690 N GLU A 77 3.355 77.661 -19.671 1.00 15.14 A N +ATOM 691 CA GLU A 77 3.375 76.442 -20.455 1.00 11.27 A C +ATOM 692 C GLU A 77 2.881 76.771 -21.852 1.00 12.76 A C +ATOM 693 O GLU A 77 2.173 77.770 -22.038 1.00 11.79 A O +ATOM 694 CB GLU A 77 2.519 75.364 -19.797 1.00 6.01 A C +ATOM 695 CG GLU A 77 3.034 74.986 -18.430 1.00 11.87 A C +ATOM 696 CD GLU A 77 2.296 73.836 -17.819 1.00 12.81 A C +ATOM 697 OE1 GLU A 77 2.053 72.814 -18.511 1.00 17.80 A O +ATOM 698 OE2 GLU A 77 1.971 73.918 -16.625 1.00 23.63 A O1- +ATOM 699 H GLU A 77 2.539 78.194 -19.717 1.00 0.00 A H +ATOM 700 N ALA A 78 3.348 76.002 -22.838 1.00 7.70 A N +ATOM 701 CA ALA A 78 2.955 76.190 -24.235 1.00 6.02 A C +ATOM 702 C ALA A 78 2.866 74.833 -24.881 1.00 7.74 A C +ATOM 703 O ALA A 78 3.611 73.927 -24.526 1.00 10.37 A O +ATOM 704 CB ALA A 78 3.958 77.040 -24.964 1.00 6.00 A C +ATOM 705 H ALA A 78 3.971 75.276 -22.627 1.00 0.00 A H +ATOM 706 N ALA A 79 1.893 74.651 -25.762 1.00 11.83 A N +ATOM 707 CA ALA A 79 1.727 73.386 -26.475 1.00 10.41 A C +ATOM 708 C ALA A 79 1.440 73.791 -27.912 1.00 14.15 A C +ATOM 709 O ALA A 79 0.759 74.788 -28.157 1.00 13.66 A O +ATOM 710 CB ALA A 79 0.569 72.601 -25.901 1.00 6.06 A C +ATOM 711 H ALA A 79 1.275 75.385 -25.969 1.00 0.00 A H +ATOM 712 N PHE A 80 1.972 73.048 -28.872 1.00 11.10 A N +ATOM 713 CA PHE A 80 1.749 73.380 -30.271 1.00 9.14 A C +ATOM 714 C PHE A 80 1.715 72.121 -31.120 1.00 9.38 A C +ATOM 715 O PHE A 80 2.266 71.090 -30.733 1.00 10.43 A O +ATOM 716 CB PHE A 80 2.826 74.360 -30.796 1.00 8.83 A C +ATOM 717 CG PHE A 80 4.252 73.914 -30.536 1.00 11.32 A C +ATOM 718 CD1 PHE A 80 4.785 73.926 -29.238 1.00 8.20 A C +ATOM 719 CD2 PHE A 80 5.069 73.502 -31.575 1.00 13.44 A C +ATOM 720 CE1 PHE A 80 6.102 73.534 -28.993 1.00 9.03 A C +ATOM 721 CE2 PHE A 80 6.389 73.114 -31.334 1.00 12.24 A C +ATOM 722 CZ PHE A 80 6.905 73.129 -30.041 1.00 11.05 A C +ATOM 723 H PHE A 80 2.509 72.255 -28.638 1.00 0.00 A H +ATOM 724 N ALA A 81 1.016 72.214 -32.251 1.00 10.99 A N +ATOM 725 CA ALA A 81 0.884 71.124 -33.206 1.00 8.80 A C +ATOM 726 C ALA A 81 1.687 71.571 -34.428 1.00 7.95 A C +ATOM 727 O ALA A 81 1.658 72.740 -34.816 1.00 7.94 A O +ATOM 728 CB ALA A 81 -0.578 70.908 -33.551 1.00 10.70 A C +ATOM 729 H ALA A 81 0.592 73.066 -32.479 1.00 0.00 A H +ATOM 730 N PHE A 82 2.454 70.658 -34.999 1.00 10.87 A N +ATOM 731 CA PHE A 82 3.289 70.998 -36.139 1.00 10.92 A C +ATOM 732 C PHE A 82 3.673 69.736 -36.902 1.00 9.53 A C +ATOM 733 O PHE A 82 3.261 68.641 -36.537 1.00 15.36 A O +ATOM 734 CB PHE A 82 4.553 71.730 -35.642 1.00 9.06 A C +ATOM 735 CG PHE A 82 5.463 70.889 -34.772 1.00 16.63 A C +ATOM 736 CD1 PHE A 82 5.068 70.480 -33.495 1.00 13.23 A C +ATOM 737 CD2 PHE A 82 6.724 70.506 -35.225 1.00 13.14 A C +ATOM 738 CE1 PHE A 82 5.917 69.703 -32.700 1.00 9.87 A C +ATOM 739 CE2 PHE A 82 7.587 69.724 -34.427 1.00 10.48 A C +ATOM 740 CZ PHE A 82 7.184 69.326 -33.175 1.00 9.88 A C +ATOM 741 H PHE A 82 2.475 69.730 -34.669 1.00 0.00 A H +ATOM 742 N ILE A 83 4.364 69.879 -38.017 1.00 10.92 A N +ATOM 743 CA ILE A 83 4.817 68.690 -38.730 1.00 10.71 A C +ATOM 744 C ILE A 83 6.314 68.819 -38.947 1.00 13.47 A C +ATOM 745 O ILE A 83 6.840 69.938 -38.968 1.00 13.07 A O +ATOM 746 CB ILE A 83 4.110 68.484 -40.105 1.00 12.41 A C +ATOM 747 CG1 ILE A 83 4.151 69.749 -40.958 1.00 15.47 A C +ATOM 748 CG2 ILE A 83 2.702 67.968 -39.913 1.00 19.29 A C +ATOM 749 CD1 ILE A 83 5.425 69.913 -41.777 1.00 32.58 A C +ATOM 750 H ILE A 83 4.593 70.768 -38.367 1.00 0.00 A H +ATOM 751 N VAL A 84 6.995 67.678 -39.010 1.00 16.93 A N +ATOM 752 CA VAL A 84 8.429 67.609 -39.284 1.00 15.26 A C +ATOM 753 C VAL A 84 8.482 66.681 -40.487 1.00 12.30 A C +ATOM 754 O VAL A 84 8.094 65.516 -40.389 1.00 10.20 A O +ATOM 755 CB VAL A 84 9.225 66.946 -38.156 1.00 20.95 A C +ATOM 756 CG1 VAL A 84 10.699 66.907 -38.516 1.00 22.12 A C +ATOM 757 CG2 VAL A 84 9.028 67.695 -36.859 1.00 22.87 A C +ATOM 758 H VAL A 84 6.512 66.832 -38.875 1.00 0.00 A H +ATOM 759 N SER A 85 8.877 67.216 -41.637 1.00 17.81 A N +ATOM 760 CA SER A 85 8.956 66.428 -42.865 1.00 18.91 A C +ATOM 761 C SER A 85 10.412 66.233 -43.220 1.00 21.30 A C +ATOM 762 O SER A 85 11.228 67.122 -43.003 1.00 21.25 A O +ATOM 763 CB SER A 85 8.298 67.176 -44.023 1.00 17.72 A C +ATOM 764 OG SER A 85 7.306 68.083 -43.570 1.00 36.30 A O +ATOM 765 H SER A 85 9.138 68.145 -41.668 1.00 0.00 A H +ATOM 766 HG SER A 85 7.749 68.805 -43.105 1.00 0.00 A H +ATOM 767 N PHE A 86 10.722 65.086 -43.807 1.00 24.08 A N +ATOM 768 CA PHE A 86 12.080 64.760 -44.234 1.00 25.09 A C +ATOM 769 C PHE A 86 11.997 63.587 -45.209 1.00 29.80 A C +ATOM 770 O PHE A 86 10.937 62.971 -45.368 1.00 31.28 A O +ATOM 771 CB PHE A 86 12.968 64.396 -43.028 1.00 25.83 A C +ATOM 772 CG PHE A 86 12.464 63.227 -42.234 1.00 23.95 A C +ATOM 773 CD1 PHE A 86 11.511 63.411 -41.228 1.00 25.22 A C +ATOM 774 CD2 PHE A 86 12.887 61.940 -42.528 1.00 25.79 A C +ATOM 775 CE1 PHE A 86 10.975 62.322 -40.527 1.00 28.59 A C +ATOM 776 CE2 PHE A 86 12.362 60.842 -41.837 1.00 25.89 A C +ATOM 777 CZ PHE A 86 11.397 61.036 -40.832 1.00 26.91 A C +ATOM 778 H PHE A 86 10.021 64.413 -43.964 1.00 0.00 A H +ATOM 779 N GLU A 87 13.100 63.296 -45.885 1.00 37.39 A N +ATOM 780 CA GLU A 87 13.135 62.186 -46.819 1.00 43.62 A C +ATOM 781 C GLU A 87 14.269 61.285 -46.409 1.00 45.60 A C +ATOM 782 O GLU A 87 15.385 61.763 -46.198 1.00 45.45 A O +ATOM 783 CB GLU A 87 13.404 62.673 -48.239 1.00 49.85 A C +ATOM 784 CG GLU A 87 12.398 63.669 -48.769 1.00 60.90 A C +ATOM 785 CD GLU A 87 12.611 63.979 -50.240 1.00 64.35 A C +ATOM 786 OE1 GLU A 87 12.669 63.021 -51.050 1.00 68.83 A O +ATOM 787 OE2 GLU A 87 12.721 65.180 -50.583 1.00 65.74 A O1- +ATOM 788 H GLU A 87 13.916 63.827 -45.774 1.00 0.00 A H +ATOM 789 N TYR A 88 13.985 60.008 -46.191 1.00 51.49 A N +ATOM 790 CA TYR A 88 15.065 59.097 -45.859 1.00 60.92 A C +ATOM 791 C TYR A 88 15.253 58.240 -47.088 1.00 66.96 A C +ATOM 792 O TYR A 88 14.530 57.261 -47.297 1.00 68.37 A O +ATOM 793 CB TYR A 88 14.781 58.224 -44.639 1.00 62.44 A C +ATOM 794 CG TYR A 88 15.987 57.378 -44.266 1.00 68.24 A C +ATOM 795 CD1 TYR A 88 17.243 57.967 -44.093 1.00 67.81 A C +ATOM 796 CD2 TYR A 88 15.891 55.990 -44.133 1.00 71.27 A C +ATOM 797 CE1 TYR A 88 18.372 57.205 -43.803 1.00 70.22 A C +ATOM 798 CE2 TYR A 88 17.023 55.212 -43.841 1.00 71.36 A C +ATOM 799 CZ TYR A 88 18.259 55.832 -43.678 1.00 71.55 A C +ATOM 800 OH TYR A 88 19.386 55.092 -43.397 1.00 73.19 A O +ATOM 801 H TYR A 88 13.069 59.692 -46.317 1.00 0.00 A H +ATOM 802 HH TYR A 88 20.138 55.683 -43.299 1.00 0.00 A H +ATOM 803 N GLN A 89 16.189 58.671 -47.926 1.00 71.77 A N +ATOM 804 CA GLN A 89 16.519 57.997 -49.175 1.00 75.65 A C +ATOM 805 C GLN A 89 15.393 58.122 -50.206 1.00 75.76 A C +ATOM 806 O GLN A 89 14.710 57.141 -50.523 1.00 77.08 A O +ATOM 807 CB GLN A 89 16.888 56.524 -48.921 1.00 78.19 A C +ATOM 808 CG GLN A 89 17.898 56.313 -47.781 1.00 84.01 A C +ATOM 809 CD GLN A 89 18.912 57.452 -47.658 1.00 88.89 A C +ATOM 810 NE2 GLN A 89 20.064 57.298 -48.302 1.00 91.91 A N +ATOM 811 OE1 GLN A 89 18.651 58.462 -46.995 1.00 92.08 A O +ATOM 812 H GLN A 89 16.710 59.458 -47.651 1.00 0.00 A H +ATOM 813 HE21 GLN A 89 20.213 56.483 -48.823 1.00 0.00 A H +ATOM 814 HE22 GLN A 89 20.726 58.016 -48.216 1.00 0.00 A H +ATOM 815 N GLY A 90 15.184 59.355 -50.678 1.00 74.54 A N +ATOM 816 CA GLY A 90 14.160 59.650 -51.678 1.00 70.86 A C +ATOM 817 C GLY A 90 12.725 59.353 -51.272 1.00 68.55 A C +ATOM 818 O GLY A 90 11.787 59.539 -52.058 1.00 67.92 A O +ATOM 819 H GLY A 90 15.715 60.101 -50.336 1.00 0.00 A H +ATOM 820 N ARG A 91 12.564 58.894 -50.033 1.00 67.07 A N +ATOM 821 CA ARG A 91 11.269 58.543 -49.470 1.00 64.55 A C +ATOM 822 C ARG A 91 10.844 59.601 -48.444 1.00 56.77 A C +ATOM 823 O ARG A 91 11.498 59.765 -47.412 1.00 53.37 A O +ATOM 824 CB ARG A 91 11.381 57.157 -48.821 1.00 70.87 A C +ATOM 825 CG ARG A 91 10.072 56.536 -48.353 1.00 82.61 A C +ATOM 826 CD ARG A 91 9.843 56.764 -46.866 1.00 91.33 A C +ATOM 827 NE ARG A 91 10.904 56.179 -46.047 1.00 99.54 A N +ATOM 828 CZ ARG A 91 11.706 56.874 -45.244 1.00104.71 A C +ATOM 829 NH1 ARG A 91 11.584 58.194 -45.143 1.00106.46 A N1+ +ATOM 830 NH2 ARG A 91 12.619 56.239 -44.520 1.00108.62 A N +ATOM 831 H ARG A 91 13.335 58.777 -49.455 1.00 0.00 A H +ATOM 832 HE ARG A 91 11.009 55.209 -46.140 1.00 0.00 A H +ATOM 833 HH11 ARG A 91 10.861 58.681 -45.653 1.00 0.00 A H +ATOM 834 HH12 ARG A 91 12.138 58.762 -44.541 1.00 0.00 A H +ATOM 835 HH21 ARG A 91 12.708 55.244 -44.590 1.00 0.00 A H +ATOM 836 HH22 ARG A 91 13.232 56.706 -43.887 1.00 0.00 A H +ATOM 837 N LYS A 92 9.773 60.332 -48.745 1.00 49.39 A N +ATOM 838 CA LYS A 92 9.280 61.362 -47.838 1.00 44.20 A C +ATOM 839 C LYS A 92 8.488 60.782 -46.665 1.00 37.95 A C +ATOM 840 O LYS A 92 7.862 59.730 -46.768 1.00 36.08 A O +ATOM 841 CB LYS A 92 8.438 62.404 -48.580 1.00 47.71 A C +ATOM 842 CG LYS A 92 7.989 63.562 -47.681 1.00 54.34 A C +ATOM 843 CD LYS A 92 7.121 64.568 -48.414 1.00 61.58 A C +ATOM 844 CE LYS A 92 6.096 65.191 -47.477 1.00 66.71 A C +ATOM 845 NZ LYS A 92 5.102 64.179 -46.991 1.00 72.99 A N1+ +ATOM 846 H LYS A 92 9.328 60.189 -49.603 1.00 0.00 A H +ATOM 847 HZ1 LYS A 92 4.604 63.670 -47.752 1.00 0.00 A H +ATOM 848 HZ2 LYS A 92 5.431 63.405 -46.381 1.00 0.00 A H +ATOM 849 HZ3 LYS A 92 4.294 64.558 -46.437 1.00 0.00 A H +ATOM 850 N THR A 93 8.506 61.483 -45.548 1.00 31.62 A N +ATOM 851 CA THR A 93 7.819 61.026 -44.368 1.00 26.75 A C +ATOM 852 C THR A 93 7.462 62.279 -43.596 1.00 24.78 A C +ATOM 853 O THR A 93 8.219 63.251 -43.612 1.00 23.01 A O +ATOM 854 CB THR A 93 8.769 60.117 -43.588 1.00 30.50 A C +ATOM 855 CG2 THR A 93 8.131 59.574 -42.325 1.00 28.73 A C +ATOM 856 OG1 THR A 93 9.162 59.034 -44.440 1.00 27.75 A O +ATOM 857 H THR A 93 9.028 62.313 -45.474 1.00 0.00 A H +ATOM 858 HG1 THR A 93 8.374 58.813 -44.959 1.00 0.00 A H +ATOM 859 N VAL A 94 6.275 62.288 -43.000 1.00 21.14 A N +ATOM 860 CA VAL A 94 5.775 63.426 -42.232 1.00 21.95 A C +ATOM 861 C VAL A 94 5.327 62.965 -40.853 1.00 23.50 A C +ATOM 862 O VAL A 94 4.408 62.148 -40.745 1.00 25.33 A O +ATOM 863 CB VAL A 94 4.537 64.069 -42.904 1.00 19.68 A C +ATOM 864 CG1 VAL A 94 3.883 65.037 -41.942 1.00 16.72 A C +ATOM 865 CG2 VAL A 94 4.927 64.795 -44.163 1.00 15.77 A C +ATOM 866 H VAL A 94 5.699 61.526 -43.010 1.00 0.00 A H +ATOM 867 N VAL A 95 5.963 63.478 -39.801 1.00 20.80 A N +ATOM 868 CA VAL A 95 5.589 63.108 -38.432 1.00 14.02 A C +ATOM 869 C VAL A 95 4.834 64.298 -37.872 1.00 14.94 A C +ATOM 870 O VAL A 95 5.287 65.442 -38.003 1.00 16.50 A O +ATOM 871 CB VAL A 95 6.818 62.832 -37.558 1.00 12.89 A C +ATOM 872 CG1 VAL A 95 6.403 62.428 -36.191 1.00 13.40 A C +ATOM 873 CG2 VAL A 95 7.661 61.729 -38.181 1.00 18.43 A C +ATOM 874 H VAL A 95 6.682 64.134 -39.938 1.00 0.00 A H +ATOM 875 N ALA A 96 3.673 64.034 -37.281 1.00 13.39 A N +ATOM 876 CA ALA A 96 2.839 65.089 -36.704 1.00 13.84 A C +ATOM 877 C ALA A 96 2.736 64.955 -35.178 1.00 13.73 A C +ATOM 878 O ALA A 96 1.908 64.210 -34.647 1.00 14.82 A O +ATOM 879 CB ALA A 96 1.446 65.064 -37.339 1.00 10.54 A C +ATOM 880 H ALA A 96 3.360 63.116 -37.200 1.00 0.00 A H +ATOM 881 N PRO A 97 3.606 65.653 -34.449 1.00 13.05 A N +ATOM 882 CA PRO A 97 3.549 65.558 -32.998 1.00 10.55 A C +ATOM 883 C PRO A 97 2.922 66.810 -32.410 1.00 13.35 A C +ATOM 884 O PRO A 97 2.604 67.760 -33.140 1.00 11.58 A O +ATOM 885 CB PRO A 97 5.033 65.520 -32.609 1.00 6.00 A C +ATOM 886 CG PRO A 97 5.820 65.918 -33.910 1.00 7.23 A C +ATOM 887 CD PRO A 97 4.788 66.416 -34.870 1.00 10.87 A C +ATOM 888 N ILE A 98 2.685 66.767 -31.107 1.00 11.92 A N +ATOM 889 CA ILE A 98 2.211 67.929 -30.374 1.00 12.48 A C +ATOM 890 C ILE A 98 3.323 68.053 -29.320 1.00 10.15 A C +ATOM 891 O ILE A 98 3.655 67.071 -28.655 1.00 11.57 A O +ATOM 892 CB ILE A 98 0.796 67.718 -29.684 1.00 15.20 A C +ATOM 893 CG1 ILE A 98 -0.334 67.826 -30.724 1.00 8.52 A C +ATOM 894 CG2 ILE A 98 0.540 68.823 -28.631 1.00 9.28 A C +ATOM 895 CD1 ILE A 98 -1.721 67.689 -30.134 1.00 6.04 A C +ATOM 896 H ILE A 98 2.859 65.943 -30.595 1.00 0.00 A H +ATOM 897 N ASP A 99 3.986 69.198 -29.264 1.00 6.01 A N +ATOM 898 CA ASP A 99 5.028 69.386 -28.276 1.00 7.15 A C +ATOM 899 C ASP A 99 4.491 70.205 -27.127 1.00 10.60 A C +ATOM 900 O ASP A 99 3.639 71.081 -27.326 1.00 10.01 A O +ATOM 901 CB ASP A 99 6.205 70.150 -28.841 1.00 10.22 A C +ATOM 902 CG ASP A 99 7.175 69.277 -29.617 1.00 6.01 A C +ATOM 903 OD1 ASP A 99 6.924 68.080 -29.890 1.00 11.95 A O +ATOM 904 OD2 ASP A 99 8.219 69.846 -29.961 1.00 8.57 A O1- +ATOM 905 H ASP A 99 3.765 69.922 -29.882 1.00 0.00 A H +ATOM 906 HD2 ASP A 99 8.781 69.224 -30.451 1.00 0.00 A H +ATOM 907 N HIS A 100 5.043 69.955 -25.940 1.00 11.43 A N +ATOM 908 CA HIS A 100 4.678 70.666 -24.725 1.00 8.74 A C +ATOM 909 C HIS A 100 5.947 71.261 -24.104 1.00 12.40 A C +ATOM 910 O HIS A 100 6.892 70.547 -23.794 1.00 10.22 A O +ATOM 911 CB HIS A 100 3.998 69.713 -23.733 1.00 12.71 A C +ATOM 912 CG HIS A 100 3.659 70.350 -22.414 1.00 20.40 A C +ATOM 913 CD2 HIS A 100 3.150 71.568 -22.119 1.00 13.89 A C +ATOM 914 ND1 HIS A 100 3.837 69.711 -21.205 1.00 18.95 A N +ATOM 915 CE1 HIS A 100 3.447 70.507 -20.226 1.00 22.32 A C +ATOM 916 NE2 HIS A 100 3.028 71.641 -20.755 1.00 20.66 A N +ATOM 917 H HIS A 100 5.734 69.262 -25.919 1.00 0.00 A H +ATOM 918 HD1 HIS A 100 4.288 68.856 -21.040 1.00 0.00 A H +ATOM 919 HE2 HIS A 100 2.731 72.447 -20.266 1.00 0.00 A H +ATOM 920 N PHE A 101 5.970 72.579 -23.994 1.00 10.81 A N +ATOM 921 CA PHE A 101 7.086 73.315 -23.422 1.00 10.19 A C +ATOM 922 C PHE A 101 6.664 73.821 -22.050 1.00 12.64 A C +ATOM 923 O PHE A 101 5.489 74.119 -21.824 1.00 11.76 A O +ATOM 924 CB PHE A 101 7.395 74.560 -24.265 1.00 11.58 A C +ATOM 925 CG PHE A 101 8.232 74.303 -25.503 1.00 9.82 A C +ATOM 926 CD1 PHE A 101 8.433 73.019 -25.996 1.00 15.10 A C +ATOM 927 CD2 PHE A 101 8.847 75.374 -26.159 1.00 13.96 A C +ATOM 928 CE1 PHE A 101 9.241 72.813 -27.122 1.00 18.99 A C +ATOM 929 CE2 PHE A 101 9.655 75.174 -27.285 1.00 13.55 A C +ATOM 930 CZ PHE A 101 9.853 73.895 -27.763 1.00 8.83 A C +ATOM 931 H PHE A 101 5.212 73.095 -24.316 1.00 0.00 A H +ATOM 932 N ARG A 102 7.628 73.976 -21.154 1.00 13.07 A N +ATOM 933 CA ARG A 102 7.343 74.491 -19.835 1.00 13.20 A C +ATOM 934 C ARG A 102 8.472 75.491 -19.618 1.00 11.79 A C +ATOM 935 O ARG A 102 9.641 75.171 -19.825 1.00 13.69 A O +ATOM 936 CB ARG A 102 7.353 73.355 -18.820 1.00 20.65 A C +ATOM 937 CG ARG A 102 6.599 73.668 -17.560 1.00 31.29 A C +ATOM 938 CD ARG A 102 6.617 72.499 -16.595 1.00 43.38 A C +ATOM 939 NE ARG A 102 5.788 71.387 -17.057 1.00 46.48 A N +ATOM 940 CZ ARG A 102 4.617 71.049 -16.514 1.00 50.74 A C +ATOM 941 NH1 ARG A 102 4.133 71.744 -15.484 1.00 50.96 A N1+ +ATOM 942 NH2 ARG A 102 3.940 70.000 -16.978 1.00 53.28 A N +ATOM 943 H ARG A 102 8.549 73.733 -21.383 1.00 0.00 A H +ATOM 944 HE ARG A 102 6.158 70.882 -17.819 1.00 0.00 A H +ATOM 945 HH11 ARG A 102 4.637 72.549 -15.155 1.00 0.00 A H +ATOM 946 HH12 ARG A 102 3.251 71.575 -15.046 1.00 0.00 A H +ATOM 947 HH21 ARG A 102 4.295 69.403 -17.709 1.00 0.00 A H +ATOM 948 HH22 ARG A 102 3.074 69.714 -16.579 1.00 0.00 A H +ATOM 949 N PHE A 103 8.116 76.726 -19.302 1.00 9.98 A N +ATOM 950 CA PHE A 103 9.104 77.772 -19.129 1.00 15.39 A C +ATOM 951 C PHE A 103 9.374 78.094 -17.667 1.00 20.19 A C +ATOM 952 O PHE A 103 8.618 77.699 -16.771 1.00 20.52 A O +ATOM 953 CB PHE A 103 8.662 79.056 -19.847 1.00 12.56 A C +ATOM 954 CG PHE A 103 8.441 78.896 -21.334 1.00 11.89 A C +ATOM 955 CD1 PHE A 103 7.328 78.209 -21.827 1.00 16.99 A C +ATOM 956 CD2 PHE A 103 9.297 79.485 -22.239 1.00 14.06 A C +ATOM 957 CE1 PHE A 103 7.068 78.118 -23.204 1.00 12.90 A C +ATOM 958 CE2 PHE A 103 9.051 79.401 -23.605 1.00 16.35 A C +ATOM 959 CZ PHE A 103 7.923 78.711 -24.086 1.00 16.18 A C +ATOM 960 H PHE A 103 7.174 76.928 -19.137 1.00 0.00 A H +ATOM 961 N ASN A 104 10.470 78.806 -17.432 1.00 18.03 A N +ATOM 962 CA ASN A 104 10.837 79.196 -16.083 1.00 18.86 A C +ATOM 963 C ASN A 104 10.821 80.696 -15.968 1.00 21.20 A C +ATOM 964 O ASN A 104 10.576 81.383 -16.966 1.00 26.75 A O +ATOM 965 CB ASN A 104 12.205 78.639 -15.693 1.00 18.97 A C +ATOM 966 CG ASN A 104 13.339 79.145 -16.565 1.00 15.78 A C +ATOM 967 ND2 ASN A 104 13.099 80.160 -17.373 1.00 21.10 A N +ATOM 968 OD1 ASN A 104 14.432 78.598 -16.511 1.00 23.62 A O +ATOM 969 H ASN A 104 11.032 79.032 -18.204 1.00 0.00 A H +ATOM 970 HD21 ASN A 104 12.267 80.624 -17.550 1.00 0.00 A H +ATOM 971 HD22 ASN A 104 13.944 80.392 -17.810 1.00 0.00 A H +ATOM 972 N GLY A 105 11.125 81.204 -14.775 1.00 25.04 A N +ATOM 973 CA GLY A 105 11.128 82.641 -14.549 1.00 26.04 A C +ATOM 974 C GLY A 105 11.921 83.455 -15.559 1.00 26.75 A C +ATOM 975 O GLY A 105 11.496 84.537 -15.977 1.00 32.50 A O +ATOM 976 H GLY A 105 11.324 80.606 -14.024 1.00 0.00 A H +ATOM 977 N ALA A 106 13.060 82.922 -15.982 1.00 26.82 A N +ATOM 978 CA ALA A 106 13.918 83.599 -16.947 1.00 25.44 A C +ATOM 979 C ALA A 106 13.375 83.589 -18.381 1.00 26.75 A C +ATOM 980 O ALA A 106 13.932 84.242 -19.274 1.00 31.90 A O +ATOM 981 CB ALA A 106 15.297 82.985 -16.912 1.00 27.30 A C +ATOM 982 H ALA A 106 13.345 82.059 -15.621 1.00 0.00 A H +ATOM 983 N GLY A 107 12.278 82.876 -18.600 1.00 21.30 A N +ATOM 984 CA GLY A 107 11.710 82.811 -19.928 1.00 19.15 A C +ATOM 985 C GLY A 107 12.295 81.715 -20.793 1.00 16.57 A C +ATOM 986 O GLY A 107 11.977 81.637 -21.979 1.00 19.15 A O +ATOM 987 H GLY A 107 11.817 82.417 -17.881 1.00 0.00 A H +ATOM 988 N LYS A 108 13.149 80.880 -20.221 1.00 15.74 A N +ATOM 989 CA LYS A 108 13.755 79.785 -20.953 1.00 15.79 A C +ATOM 990 C LYS A 108 12.959 78.509 -20.722 1.00 16.00 A C +ATOM 991 O LYS A 108 12.232 78.380 -19.728 1.00 19.98 A O +ATOM 992 CB LYS A 108 15.221 79.576 -20.536 1.00 21.57 A C +ATOM 993 CG LYS A 108 16.163 80.753 -20.884 1.00 30.19 A C +ATOM 994 CD LYS A 108 17.600 80.313 -21.253 1.00 41.27 A C +ATOM 995 CE LYS A 108 18.449 79.827 -20.044 1.00 48.96 A C +ATOM 996 NZ LYS A 108 19.775 79.201 -20.448 1.00 62.20 A N1+ +ATOM 997 H LYS A 108 13.372 80.997 -19.288 1.00 0.00 A H +ATOM 998 HZ1 LYS A 108 19.585 78.387 -21.083 1.00 0.00 A H +ATOM 999 HZ2 LYS A 108 20.373 79.857 -20.978 1.00 0.00 A H +ATOM 1000 HZ3 LYS A 108 20.367 78.823 -19.657 1.00 0.00 A H +ATOM 1001 N VAL A 109 13.094 77.581 -21.661 1.00 12.86 A N +ATOM 1002 CA VAL A 109 12.422 76.295 -21.622 1.00 10.69 A C +ATOM 1003 C VAL A 109 13.176 75.384 -20.647 1.00 16.12 A C +ATOM 1004 O VAL A 109 14.391 75.199 -20.754 1.00 18.17 A O +ATOM 1005 CB VAL A 109 12.387 75.663 -23.067 1.00 6.00 A C +ATOM 1006 CG1 VAL A 109 11.695 74.319 -23.068 1.00 6.01 A C +ATOM 1007 CG2 VAL A 109 11.663 76.584 -24.017 1.00 7.37 A C +ATOM 1008 H VAL A 109 13.697 77.765 -22.416 1.00 0.00 A H +ATOM 1009 N VAL A 110 12.474 74.846 -19.661 1.00 14.02 A N +ATOM 1010 CA VAL A 110 13.125 73.959 -18.716 1.00 11.14 A C +ATOM 1011 C VAL A 110 12.697 72.532 -18.976 1.00 15.43 A C +ATOM 1012 O VAL A 110 13.341 71.588 -18.504 1.00 19.88 A O +ATOM 1013 CB VAL A 110 12.800 74.338 -17.246 1.00 14.86 A C +ATOM 1014 CG1 VAL A 110 13.421 75.672 -16.892 1.00 17.88 A C +ATOM 1015 CG2 VAL A 110 11.296 74.409 -17.032 1.00 18.60 A C +ATOM 1016 H VAL A 110 11.530 75.078 -19.560 1.00 0.00 A H +ATOM 1017 N SER A 111 11.650 72.372 -19.785 1.00 21.33 A N +ATOM 1018 CA SER A 111 11.111 71.050 -20.101 1.00 18.49 A C +ATOM 1019 C SER A 111 10.358 70.982 -21.431 1.00 19.05 A C +ATOM 1020 O SER A 111 9.522 71.837 -21.713 1.00 17.55 A O +ATOM 1021 CB SER A 111 10.160 70.630 -18.972 1.00 15.96 A C +ATOM 1022 OG SER A 111 9.363 69.523 -19.337 1.00 23.59 A O +ATOM 1023 H SER A 111 11.210 73.148 -20.188 1.00 0.00 A H +ATOM 1024 HG SER A 111 9.937 68.749 -19.487 1.00 0.00 A H +ATOM 1025 N MET A 112 10.665 69.991 -22.264 1.00 17.04 A N +ATOM 1026 CA MET A 112 9.929 69.842 -23.510 1.00 12.13 A C +ATOM 1027 C MET A 112 9.613 68.375 -23.688 1.00 9.89 A C +ATOM 1028 O MET A 112 10.415 67.518 -23.360 1.00 9.32 A O +ATOM 1029 CB MET A 112 10.668 70.447 -24.722 1.00 16.06 A C +ATOM 1030 CG MET A 112 11.581 69.539 -25.551 1.00 16.45 A C +ATOM 1031 SD MET A 112 10.783 68.225 -26.550 1.00 16.21 A S +ATOM 1032 CE MET A 112 11.040 68.840 -28.146 1.00 26.96 A C +ATOM 1033 H MET A 112 11.372 69.342 -22.047 1.00 0.00 A H +ATOM 1034 N ARG A 113 8.396 68.084 -24.121 1.00 8.79 A N +ATOM 1035 CA ARG A 113 7.982 66.701 -24.340 1.00 7.97 A C +ATOM 1036 C ARG A 113 7.285 66.645 -25.684 1.00 6.41 A C +ATOM 1037 O ARG A 113 6.378 67.434 -25.935 1.00 13.99 A O +ATOM 1038 CB ARG A 113 7.030 66.251 -23.233 1.00 9.02 A C +ATOM 1039 CG ARG A 113 7.625 66.370 -21.867 1.00 15.14 A C +ATOM 1040 CD ARG A 113 6.696 65.766 -20.873 1.00 16.20 A C +ATOM 1041 NE ARG A 113 6.778 64.314 -20.880 1.00 14.51 A N +ATOM 1042 CZ ARG A 113 5.727 63.509 -20.833 1.00 10.39 A C +ATOM 1043 NH1 ARG A 113 4.507 64.006 -20.784 1.00 13.94 A N1+ +ATOM 1044 NH2 ARG A 113 5.907 62.204 -20.788 1.00 13.69 A N +ATOM 1045 H ARG A 113 7.768 68.807 -24.323 1.00 0.00 A H +ATOM 1046 HE ARG A 113 7.691 63.913 -20.902 1.00 0.00 A H +ATOM 1047 HH11 ARG A 113 4.320 64.981 -20.767 1.00 0.00 A H +ATOM 1048 HH12 ARG A 113 3.746 63.347 -20.805 1.00 0.00 A H +ATOM 1049 HH21 ARG A 113 6.856 61.827 -20.804 1.00 0.00 A H +ATOM 1050 HH22 ARG A 113 5.170 61.521 -20.786 1.00 0.00 A H +ATOM 1051 N ALA A 114 7.729 65.724 -26.538 1.00 12.66 A N +ATOM 1052 CA ALA A 114 7.198 65.566 -27.880 1.00 6.59 A C +ATOM 1053 C ALA A 114 6.288 64.349 -27.911 1.00 6.02 A C +ATOM 1054 O ALA A 114 6.739 63.219 -27.744 1.00 11.50 A O +ATOM 1055 CB ALA A 114 8.349 65.421 -28.876 1.00 7.36 A C +ATOM 1056 H ALA A 114 8.438 65.113 -26.244 1.00 0.00 A H +ATOM 1057 N LEU A 115 5.000 64.590 -28.126 1.00 7.62 A N +ATOM 1058 CA LEU A 115 3.997 63.535 -28.149 1.00 10.84 A C +ATOM 1059 C LEU A 115 3.587 63.101 -29.541 1.00 13.62 A C +ATOM 1060 O LEU A 115 3.045 63.908 -30.288 1.00 17.93 A O +ATOM 1061 CB LEU A 115 2.762 64.036 -27.412 1.00 11.12 A C +ATOM 1062 CG LEU A 115 2.448 63.462 -26.042 1.00 13.33 A C +ATOM 1063 CD1 LEU A 115 1.773 64.522 -25.185 1.00 21.56 A C +ATOM 1064 CD2 LEU A 115 1.549 62.271 -26.233 1.00 19.09 A C +ATOM 1065 H LEU A 115 4.719 65.518 -28.285 1.00 0.00 A H +ATOM 1066 N PHE A 116 3.887 61.854 -29.901 1.00 17.39 A N +ATOM 1067 CA PHE A 116 3.512 61.279 -31.205 1.00 15.62 A C +ATOM 1068 C PHE A 116 3.773 59.782 -31.214 1.00 14.98 A C +ATOM 1069 O PHE A 116 4.750 59.330 -30.638 1.00 21.35 A O +ATOM 1070 CB PHE A 116 4.255 61.931 -32.388 1.00 12.89 A C +ATOM 1071 CG PHE A 116 5.742 61.674 -32.402 1.00 15.41 A C +ATOM 1072 CD1 PHE A 116 6.260 60.540 -33.006 1.00 18.05 A C +ATOM 1073 CD2 PHE A 116 6.627 62.574 -31.794 1.00 16.24 A C +ATOM 1074 CE1 PHE A 116 7.644 60.294 -33.007 1.00 19.50 A C +ATOM 1075 CE2 PHE A 116 8.000 62.335 -31.790 1.00 17.91 A C +ATOM 1076 CZ PHE A 116 8.509 61.189 -32.400 1.00 17.62 A C +ATOM 1077 H PHE A 116 4.395 61.271 -29.298 1.00 0.00 A H +ATOM 1078 N GLY A 117 2.874 59.021 -31.831 1.00 17.65 A N +ATOM 1079 CA GLY A 117 3.034 57.579 -31.914 1.00 20.71 A C +ATOM 1080 C GLY A 117 3.137 57.144 -33.370 1.00 23.09 A C +ATOM 1081 O GLY A 117 3.117 57.988 -34.263 1.00 26.53 A O +ATOM 1082 H GLY A 117 2.104 59.433 -32.250 1.00 0.00 A H +ATOM 1083 N GLU A 118 3.195 55.833 -33.617 1.00 29.70 A N +ATOM 1084 CA GLU A 118 3.312 55.274 -34.973 1.00 34.79 A C +ATOM 1085 C GLU A 118 2.274 55.834 -35.940 1.00 33.08 A C +ATOM 1086 O GLU A 118 2.573 56.087 -37.101 1.00 32.88 A O +ATOM 1087 CB GLU A 118 3.207 53.745 -34.939 1.00 46.74 A C +ATOM 1088 CG GLU A 118 1.967 53.236 -34.174 1.00 64.19 A C +ATOM 1089 CD GLU A 118 1.551 51.794 -34.510 1.00 73.18 A C +ATOM 1090 OE1 GLU A 118 2.285 51.093 -35.248 1.00 75.99 A O +ATOM 1091 OE2 GLU A 118 0.460 51.381 -34.038 1.00 78.00 A O1- +ATOM 1092 H GLU A 118 3.159 55.215 -32.856 1.00 0.00 A H +ATOM 1093 N LYS A 119 1.059 56.056 -35.451 1.00 31.94 A N +ATOM 1094 CA LYS A 119 -0.012 56.595 -36.285 1.00 28.74 A C +ATOM 1095 C LYS A 119 0.134 58.092 -36.630 1.00 21.50 A C +ATOM 1096 O LYS A 119 -0.639 58.627 -37.418 1.00 22.81 A O +ATOM 1097 CB LYS A 119 -1.382 56.269 -35.658 1.00 36.51 A C +ATOM 1098 CG LYS A 119 -1.695 54.759 -35.695 1.00 47.40 A C +ATOM 1099 CD LYS A 119 -3.014 54.351 -35.029 1.00 55.43 A C +ATOM 1100 CE LYS A 119 -3.265 52.816 -35.130 1.00 60.03 A C +ATOM 1101 NZ LYS A 119 -2.308 51.931 -34.351 1.00 67.90 A N1+ +ATOM 1102 H LYS A 119 0.889 55.839 -34.514 1.00 0.00 A H +ATOM 1103 HZ1 LYS A 119 -1.283 51.987 -34.576 1.00 0.00 A H +ATOM 1104 HZ2 LYS A 119 -2.367 52.066 -33.316 1.00 0.00 A H +ATOM 1105 HZ3 LYS A 119 -2.545 50.920 -34.392 1.00 0.00 A H +ATOM 1106 N ASN A 120 1.146 58.748 -36.069 1.00 18.19 A N +ATOM 1107 CA ASN A 120 1.410 60.162 -36.331 1.00 14.95 A C +ATOM 1108 C ASN A 120 2.621 60.253 -37.240 1.00 14.05 A C +ATOM 1109 O ASN A 120 3.229 61.300 -37.399 1.00 16.23 A O +ATOM 1110 CB ASN A 120 1.649 60.927 -35.022 1.00 16.83 A C +ATOM 1111 CG ASN A 120 0.500 60.768 -34.058 1.00 11.86 A C +ATOM 1112 ND2 ASN A 120 -0.704 61.061 -34.534 1.00 13.11 A N +ATOM 1113 OD1 ASN A 120 0.670 60.299 -32.930 1.00 15.82 A O +ATOM 1114 H ASN A 120 1.752 58.275 -35.475 1.00 0.00 A H +ATOM 1115 HD21 ASN A 120 -0.770 61.365 -35.461 1.00 0.00 A H +ATOM 1116 HD22 ASN A 120 -1.492 60.971 -33.959 1.00 0.00 A H +ATOM 1117 N ILE A 121 2.983 59.130 -37.829 1.00 15.46 A N +ATOM 1118 CA ILE A 121 4.110 59.084 -38.731 1.00 19.96 A C +ATOM 1119 C ILE A 121 3.507 58.777 -40.088 1.00 24.14 A C +ATOM 1120 O ILE A 121 2.828 57.778 -40.251 1.00 23.40 A O +ATOM 1121 CB ILE A 121 5.075 57.989 -38.307 1.00 17.39 A C +ATOM 1122 CG1 ILE A 121 5.643 58.351 -36.940 1.00 12.27 A C +ATOM 1123 CG2 ILE A 121 6.173 57.820 -39.330 1.00 17.35 A C +ATOM 1124 CD1 ILE A 121 6.604 57.357 -36.398 1.00 23.38 A C +ATOM 1125 H ILE A 121 2.500 58.289 -37.685 1.00 0.00 A H +ATOM 1126 N HIS A 122 3.705 59.662 -41.048 1.00 23.88 A N +ATOM 1127 CA HIS A 122 3.128 59.445 -42.357 1.00 36.19 A C +ATOM 1128 C HIS A 122 4.148 59.369 -43.482 1.00 46.30 A C +ATOM 1129 O HIS A 122 4.765 60.358 -43.840 1.00 47.68 A O +ATOM 1130 CB HIS A 122 2.085 60.519 -42.630 1.00 29.55 A C +ATOM 1131 CG HIS A 122 1.052 60.620 -41.556 1.00 29.07 A C +ATOM 1132 CD2 HIS A 122 0.190 59.698 -41.060 1.00 25.53 A C +ATOM 1133 ND1 HIS A 122 0.885 61.751 -40.788 1.00 26.56 A N +ATOM 1134 CE1 HIS A 122 -0.029 61.522 -39.862 1.00 25.43 A C +ATOM 1135 NE2 HIS A 122 -0.464 60.282 -40.005 1.00 23.16 A N +ATOM 1136 H HIS A 122 4.299 60.404 -40.884 1.00 0.00 A H +ATOM 1137 HD1 HIS A 122 1.376 62.596 -40.895 1.00 0.00 A H +ATOM 1138 HE2 HIS A 122 -1.100 59.859 -39.377 1.00 0.00 A H +ATOM 1139 N ALA A 123 4.395 58.165 -43.976 1.00 59.66 A N +ATOM 1140 CA ALA A 123 5.331 58.000 -45.074 1.00 73.75 A C +ATOM 1141 C ALA A 123 4.609 58.586 -46.276 1.00 82.21 A C +ATOM 1142 O ALA A 123 3.377 58.581 -46.320 1.00 85.95 A O +ATOM 1143 CB ALA A 123 5.643 56.528 -45.297 1.00 73.39 A C +ATOM 1144 H ALA A 123 3.918 57.391 -43.620 1.00 0.00 A H +ATOM 1145 N GLY A 124 5.368 59.086 -47.242 1.00 90.51 A N +ATOM 1146 CA GLY A 124 4.773 59.683 -48.421 1.00101.54 A C +ATOM 1147 C GLY A 124 4.023 60.951 -48.070 1.00108.75 A C +ATOM 1148 O GLY A 124 4.498 62.055 -48.351 1.00109.98 A O +ATOM 1149 H GLY A 124 6.323 59.002 -47.173 1.00 0.00 A H +ATOM 1150 N ALA A 125 2.855 60.788 -47.454 1.00114.55 A N +ATOM 1151 CA ALA A 125 2.012 61.902 -47.043 1.00119.60 A C +ATOM 1152 C ALA A 125 2.779 62.840 -46.123 1.00121.61 A C +ATOM 1153 O ALA A 125 2.439 64.040 -46.086 1.00122.70 A O +ATOM 1154 CB ALA A 125 0.768 61.380 -46.343 1.00121.44 A C +ATOM 1155 OXT ALA A 125 3.759 62.383 -45.497 1.00123.42 A O1- +ATOM 1156 H ALA A 125 2.534 59.892 -47.260 1.00 0.00 A H \ No newline at end of file diff --git a/af_backprop/examples/sc_hall/1QJS_starting.pdb b/af_backprop/examples/sc_hall/1QJS_starting.pdb new file mode 100644 index 0000000000000000000000000000000000000000..d1bf8bfe257eb87fbc27cd33b9f8769f116c0b28 --- /dev/null +++ b/af_backprop/examples/sc_hall/1QJS_starting.pdb @@ -0,0 +1,880 @@ +MODEL 1 +ATOM 1 N HIS A 1 -11.161 5.339 22.224 1.00 0.00 N +ATOM 2 CA HIS A 1 -9.750 5.488 21.883 1.00 0.00 C +ATOM 3 C HIS A 1 -9.362 4.571 20.728 1.00 0.00 C +ATOM 4 CB HIS A 1 -8.871 5.198 23.102 1.00 0.00 C +ATOM 5 O HIS A 1 -9.646 3.372 20.760 1.00 0.00 O +ATOM 6 CG HIS A 1 -9.124 6.115 24.256 1.00 0.00 C +ATOM 7 CD2 HIS A 1 -9.874 5.954 25.372 1.00 0.00 C +ATOM 8 ND1 HIS A 1 -8.571 7.375 24.340 1.00 0.00 N +ATOM 9 CE1 HIS A 1 -8.971 7.950 25.462 1.00 0.00 C +ATOM 10 NE2 HIS A 1 -9.762 7.109 26.106 1.00 0.00 N +ATOM 11 N CYS A 2 -8.947 5.187 19.589 1.00 0.00 N +ATOM 12 CA CYS A 2 -8.553 4.393 18.430 1.00 0.00 C +ATOM 13 C CYS A 2 -7.058 4.097 18.455 1.00 0.00 C +ATOM 14 CB CYS A 2 -8.916 5.119 17.135 1.00 0.00 C +ATOM 15 O CYS A 2 -6.262 4.934 18.882 1.00 0.00 O +ATOM 16 SG CYS A 2 -10.678 5.483 16.975 1.00 0.00 S +ATOM 17 N TYR A 3 -6.659 2.830 18.088 1.00 0.00 N +ATOM 18 CA TYR A 3 -5.257 2.439 18.175 1.00 0.00 C +ATOM 19 C TYR A 3 -4.640 2.308 16.787 1.00 0.00 C +ATOM 20 CB TYR A 3 -5.114 1.117 18.936 1.00 0.00 C +ATOM 21 O TYR A 3 -5.356 2.162 15.794 1.00 0.00 O +ATOM 22 CG TYR A 3 -6.148 0.084 18.559 1.00 0.00 C +ATOM 23 CD1 TYR A 3 -7.368 0.015 19.228 1.00 0.00 C +ATOM 24 CD2 TYR A 3 -5.907 -0.825 17.534 1.00 0.00 C +ATOM 25 CE1 TYR A 3 -8.323 -0.936 18.885 1.00 0.00 C +ATOM 26 CE2 TYR A 3 -6.855 -1.780 17.183 1.00 0.00 C +ATOM 27 OH TYR A 3 -9.000 -2.772 17.519 1.00 0.00 O +ATOM 28 CZ TYR A 3 -8.058 -1.828 17.863 1.00 0.00 C +ATOM 29 N ASN A 4 -3.262 2.608 16.714 1.00 0.00 N +ATOM 30 CA ASN A 4 -2.461 2.337 15.525 1.00 0.00 C +ATOM 31 C ASN A 4 -2.395 0.842 15.223 1.00 0.00 C +ATOM 32 CB ASN A 4 -1.051 2.909 15.684 1.00 0.00 C +ATOM 33 O ASN A 4 -2.342 0.021 16.140 1.00 0.00 O +ATOM 34 CG ASN A 4 -1.049 4.414 15.872 1.00 0.00 C +ATOM 35 ND2 ASN A 4 -0.034 4.925 16.557 1.00 0.00 N +ATOM 36 OD1 ASN A 4 -1.954 5.111 15.405 1.00 0.00 O +ATOM 37 N THR A 5 -2.544 0.524 13.919 1.00 0.00 N +ATOM 38 CA THR A 5 -2.446 -0.879 13.530 1.00 0.00 C +ATOM 39 C THR A 5 -1.345 -1.076 12.492 1.00 0.00 C +ATOM 40 CB THR A 5 -3.783 -1.399 12.970 1.00 0.00 C +ATOM 41 O THR A 5 -0.994 -0.143 11.767 1.00 0.00 O +ATOM 42 CG2 THR A 5 -4.888 -1.312 14.018 1.00 0.00 C +ATOM 43 OG1 THR A 5 -4.156 -0.612 11.832 1.00 0.00 O +ATOM 44 N HIS A 6 -0.770 -2.218 12.577 1.00 0.00 N +ATOM 45 CA HIS A 6 0.214 -2.702 11.616 1.00 0.00 C +ATOM 46 C HIS A 6 -0.116 -4.117 11.155 1.00 0.00 C +ATOM 47 CB HIS A 6 1.619 -2.661 12.220 1.00 0.00 C +ATOM 48 O HIS A 6 -0.128 -5.050 11.961 1.00 0.00 O +ATOM 49 CG HIS A 6 2.697 -3.055 11.261 1.00 0.00 C +ATOM 50 CD2 HIS A 6 3.413 -4.199 11.150 1.00 0.00 C +ATOM 51 ND1 HIS A 6 3.149 -2.218 10.264 1.00 0.00 N +ATOM 52 CE1 HIS A 6 4.099 -2.832 9.579 1.00 0.00 C +ATOM 53 NE2 HIS A 6 4.278 -4.036 10.097 1.00 0.00 N +ATOM 54 N GLU A 7 -0.382 -4.286 9.867 1.00 0.00 N +ATOM 55 CA GLU A 7 -0.811 -5.584 9.356 1.00 0.00 C +ATOM 56 C GLU A 7 0.045 -6.021 8.171 1.00 0.00 C +ATOM 57 CB GLU A 7 -2.287 -5.541 8.952 1.00 0.00 C +ATOM 58 O GLU A 7 0.361 -5.213 7.296 1.00 0.00 O +ATOM 59 CG GLU A 7 -3.240 -5.342 10.121 1.00 0.00 C +ATOM 60 CD GLU A 7 -4.701 -5.297 9.704 1.00 0.00 C +ATOM 61 OE1 GLU A 7 -5.587 -5.386 10.584 1.00 0.00 O +ATOM 62 OE2 GLU A 7 -4.963 -5.174 8.486 1.00 0.00 O +ATOM 63 N HIS A 8 0.419 -7.223 8.285 1.00 0.00 N +ATOM 64 CA HIS A 8 1.167 -7.903 7.234 1.00 0.00 C +ATOM 65 C HIS A 8 0.236 -8.691 6.317 1.00 0.00 C +ATOM 66 CB HIS A 8 2.218 -8.835 7.840 1.00 0.00 C +ATOM 67 O HIS A 8 -0.668 -9.384 6.790 1.00 0.00 O +ATOM 68 CG HIS A 8 3.038 -9.558 6.819 1.00 0.00 C +ATOM 69 CD2 HIS A 8 3.959 -9.106 5.936 1.00 0.00 C +ATOM 70 ND1 HIS A 8 2.951 -10.920 6.625 1.00 0.00 N +ATOM 71 CE1 HIS A 8 3.787 -11.275 5.663 1.00 0.00 C +ATOM 72 NE2 HIS A 8 4.410 -10.193 5.228 1.00 0.00 N +ATOM 73 N PHE A 9 0.399 -8.517 4.937 1.00 0.00 N +ATOM 74 CA PHE A 9 -0.393 -9.342 4.032 1.00 0.00 C +ATOM 75 C PHE A 9 0.429 -9.757 2.818 1.00 0.00 C +ATOM 76 CB PHE A 9 -1.652 -8.593 3.583 1.00 0.00 C +ATOM 77 O PHE A 9 1.455 -9.142 2.517 1.00 0.00 O +ATOM 78 CG PHE A 9 -1.366 -7.296 2.875 1.00 0.00 C +ATOM 79 CD1 PHE A 9 -1.153 -6.127 3.596 1.00 0.00 C +ATOM 80 CD2 PHE A 9 -1.311 -7.246 1.488 1.00 0.00 C +ATOM 81 CE1 PHE A 9 -0.889 -4.925 2.943 1.00 0.00 C +ATOM 82 CE2 PHE A 9 -1.047 -6.049 0.829 1.00 0.00 C +ATOM 83 CZ PHE A 9 -0.837 -4.889 1.558 1.00 0.00 C +ATOM 84 N ARG A 10 0.178 -10.978 2.316 1.00 0.00 N +ATOM 85 CA ARG A 10 0.973 -11.510 1.213 1.00 0.00 C +ATOM 86 C ARG A 10 0.149 -11.587 -0.068 1.00 0.00 C +ATOM 87 CB ARG A 10 1.525 -12.892 1.566 1.00 0.00 C +ATOM 88 O ARG A 10 -1.047 -11.885 -0.026 1.00 0.00 O +ATOM 89 CG ARG A 10 2.623 -12.868 2.618 1.00 0.00 C +ATOM 90 CD ARG A 10 3.229 -14.247 2.833 1.00 0.00 C +ATOM 91 NE ARG A 10 4.333 -14.501 1.913 1.00 0.00 N +ATOM 92 NH1 ARG A 10 5.449 -15.950 3.324 1.00 0.00 N +ATOM 93 NH2 ARG A 10 6.314 -15.462 1.257 1.00 0.00 N +ATOM 94 CZ ARG A 10 5.363 -15.304 2.167 1.00 0.00 C +ATOM 95 N LEU A 11 0.804 -11.094 -1.103 1.00 0.00 N +ATOM 96 CA LEU A 11 0.320 -11.410 -2.443 1.00 0.00 C +ATOM 97 C LEU A 11 0.991 -12.670 -2.980 1.00 0.00 C +ATOM 98 CB LEU A 11 0.572 -10.238 -3.395 1.00 0.00 C +ATOM 99 O LEU A 11 1.837 -13.263 -2.307 1.00 0.00 O +ATOM 100 CG LEU A 11 -0.077 -8.906 -3.016 1.00 0.00 C +ATOM 101 CD1 LEU A 11 0.400 -7.799 -3.952 1.00 0.00 C +ATOM 102 CD2 LEU A 11 -1.597 -9.023 -3.049 1.00 0.00 C +ATOM 103 N ASP A 12 0.341 -13.343 -4.077 1.00 0.00 N +ATOM 104 CA ASP A 12 0.855 -14.552 -4.714 1.00 0.00 C +ATOM 105 C ASP A 12 2.306 -14.367 -5.151 1.00 0.00 C +ATOM 106 CB ASP A 12 -0.011 -14.936 -5.915 1.00 0.00 C +ATOM 107 O ASP A 12 2.690 -14.792 -6.243 1.00 0.00 O +ATOM 108 CG ASP A 12 -1.397 -15.415 -5.519 1.00 0.00 C +ATOM 109 OD1 ASP A 12 -1.546 -16.017 -4.434 1.00 0.00 O +ATOM 110 OD2 ASP A 12 -2.347 -15.191 -6.300 1.00 0.00 O +ATOM 111 N ASP A 13 3.110 -13.435 -4.576 1.00 0.00 N +ATOM 112 CA ASP A 13 4.545 -13.254 -4.773 1.00 0.00 C +ATOM 113 C ASP A 13 5.321 -13.603 -3.505 1.00 0.00 C +ATOM 114 CB ASP A 13 4.851 -11.817 -5.199 1.00 0.00 C +ATOM 115 O ASP A 13 5.243 -12.885 -2.506 1.00 0.00 O +ATOM 116 CG ASP A 13 6.323 -11.586 -5.491 1.00 0.00 C +ATOM 117 OD1 ASP A 13 7.094 -12.568 -5.553 1.00 0.00 O +ATOM 118 OD2 ASP A 13 6.716 -10.412 -5.658 1.00 0.00 O +ATOM 119 N PRO A 14 6.029 -14.766 -3.543 1.00 0.00 N +ATOM 120 CA PRO A 14 6.658 -15.283 -2.325 1.00 0.00 C +ATOM 121 C PRO A 14 7.790 -14.392 -1.821 1.00 0.00 C +ATOM 122 CB PRO A 14 7.191 -16.652 -2.757 1.00 0.00 C +ATOM 123 O PRO A 14 8.147 -14.448 -0.641 1.00 0.00 O +ATOM 124 CG PRO A 14 7.284 -16.572 -4.246 1.00 0.00 C +ATOM 125 CD PRO A 14 6.262 -15.588 -4.736 1.00 0.00 C +ATOM 126 N TRP A 15 8.193 -13.454 -2.636 1.00 0.00 N +ATOM 127 CA TRP A 15 9.394 -12.710 -2.274 1.00 0.00 C +ATOM 128 C TRP A 15 9.039 -11.318 -1.760 1.00 0.00 C +ATOM 129 CB TRP A 15 10.341 -12.598 -3.472 1.00 0.00 C +ATOM 130 O TRP A 15 9.853 -10.666 -1.102 1.00 0.00 O +ATOM 131 CG TRP A 15 10.788 -13.920 -4.020 1.00 0.00 C +ATOM 132 CD1 TRP A 15 11.486 -14.892 -3.358 1.00 0.00 C +ATOM 133 CD2 TRP A 15 10.573 -14.413 -5.346 1.00 0.00 C +ATOM 134 CE2 TRP A 15 11.167 -15.692 -5.418 1.00 0.00 C +ATOM 135 CE3 TRP A 15 9.934 -13.896 -6.481 1.00 0.00 C +ATOM 136 NE1 TRP A 15 11.716 -15.961 -4.193 1.00 0.00 N +ATOM 137 CH2 TRP A 15 10.509 -15.934 -7.678 1.00 0.00 C +ATOM 138 CZ2 TRP A 15 11.141 -16.463 -6.583 1.00 0.00 C +ATOM 139 CZ3 TRP A 15 9.909 -14.665 -7.639 1.00 0.00 C +ATOM 140 N THR A 16 7.868 -10.913 -2.013 1.00 0.00 N +ATOM 141 CA THR A 16 7.495 -9.538 -1.702 1.00 0.00 C +ATOM 142 C THR A 16 6.582 -9.488 -0.480 1.00 0.00 C +ATOM 143 CB THR A 16 6.794 -8.866 -2.898 1.00 0.00 C +ATOM 144 O THR A 16 5.600 -10.229 -0.402 1.00 0.00 O +ATOM 145 CG2 THR A 16 6.483 -7.402 -2.602 1.00 0.00 C +ATOM 146 OG1 THR A 16 7.646 -8.939 -4.047 1.00 0.00 O +ATOM 147 N GLU A 17 6.987 -8.731 0.402 1.00 0.00 N +ATOM 148 CA GLU A 17 6.175 -8.520 1.596 1.00 0.00 C +ATOM 149 C GLU A 17 5.456 -7.175 1.545 1.00 0.00 C +ATOM 150 CB GLU A 17 7.039 -8.607 2.857 1.00 0.00 C +ATOM 151 O GLU A 17 6.060 -6.155 1.205 1.00 0.00 O +ATOM 152 CG GLU A 17 7.672 -9.973 3.076 1.00 0.00 C +ATOM 153 CD GLU A 17 8.498 -10.055 4.349 1.00 0.00 C +ATOM 154 OE1 GLU A 17 8.478 -11.113 5.019 1.00 0.00 O +ATOM 155 OE2 GLU A 17 9.170 -9.053 4.681 1.00 0.00 O +ATOM 156 N PHE A 18 4.157 -7.211 1.917 1.00 0.00 N +ATOM 157 CA PHE A 18 3.352 -5.996 1.948 1.00 0.00 C +ATOM 158 C PHE A 18 2.861 -5.707 3.362 1.00 0.00 C +ATOM 159 CB PHE A 18 2.160 -6.115 0.993 1.00 0.00 C +ATOM 160 O PHE A 18 2.413 -6.614 4.066 1.00 0.00 O +ATOM 161 CG PHE A 18 2.553 -6.268 -0.452 1.00 0.00 C +ATOM 162 CD1 PHE A 18 2.624 -5.160 -1.288 1.00 0.00 C +ATOM 163 CD2 PHE A 18 2.851 -7.520 -0.974 1.00 0.00 C +ATOM 164 CE1 PHE A 18 2.988 -5.299 -2.626 1.00 0.00 C +ATOM 165 CE2 PHE A 18 3.216 -7.666 -2.310 1.00 0.00 C +ATOM 166 CZ PHE A 18 3.282 -6.554 -3.134 1.00 0.00 C +ATOM 167 N TYR A 19 3.014 -4.483 3.756 1.00 0.00 N +ATOM 168 CA TYR A 19 2.556 -4.040 5.069 1.00 0.00 C +ATOM 169 C TYR A 19 1.598 -2.861 4.943 1.00 0.00 C +ATOM 170 CB TYR A 19 3.746 -3.653 5.952 1.00 0.00 C +ATOM 171 O TYR A 19 1.757 -2.016 4.059 1.00 0.00 O +ATOM 172 CG TYR A 19 4.807 -4.723 6.042 1.00 0.00 C +ATOM 173 CD1 TYR A 19 4.770 -5.685 7.049 1.00 0.00 C +ATOM 174 CD2 TYR A 19 5.849 -4.773 5.122 1.00 0.00 C +ATOM 175 CE1 TYR A 19 5.747 -6.671 7.138 1.00 0.00 C +ATOM 176 CE2 TYR A 19 6.831 -5.755 5.201 1.00 0.00 C +ATOM 177 OH TYR A 19 7.742 -7.672 6.294 1.00 0.00 O +ATOM 178 CZ TYR A 19 6.772 -6.698 6.211 1.00 0.00 C +ATOM 179 N ARG A 20 0.661 -2.887 5.726 1.00 0.00 N +ATOM 180 CA ARG A 20 -0.194 -1.722 5.928 1.00 0.00 C +ATOM 181 C ARG A 20 -0.080 -1.200 7.357 1.00 0.00 C +ATOM 182 CB ARG A 20 -1.652 -2.062 5.610 1.00 0.00 C +ATOM 183 O ARG A 20 -0.319 -1.939 8.314 1.00 0.00 O +ATOM 184 CG ARG A 20 -2.611 -0.897 5.793 1.00 0.00 C +ATOM 185 CD ARG A 20 -3.943 -1.349 6.375 1.00 0.00 C +ATOM 186 NE ARG A 20 -3.783 -1.914 7.712 1.00 0.00 N +ATOM 187 NH1 ARG A 20 -6.044 -1.973 8.182 1.00 0.00 N +ATOM 188 NH2 ARG A 20 -4.527 -2.709 9.734 1.00 0.00 N +ATOM 189 CZ ARG A 20 -4.785 -2.198 8.540 1.00 0.00 C +ATOM 190 N THR A 21 0.265 0.071 7.504 1.00 0.00 N +ATOM 191 CA THR A 21 0.341 0.733 8.802 1.00 0.00 C +ATOM 192 C THR A 21 -0.671 1.872 8.889 1.00 0.00 C +ATOM 193 CB THR A 21 1.756 1.278 9.069 1.00 0.00 C +ATOM 194 O THR A 21 -0.741 2.716 7.993 1.00 0.00 O +ATOM 195 CG2 THR A 21 1.849 1.911 10.454 1.00 0.00 C +ATOM 196 OG1 THR A 21 2.700 0.204 8.982 1.00 0.00 O +ATOM 197 N LEU A 22 -1.459 1.839 9.957 1.00 0.00 N +ATOM 198 CA LEU A 22 -2.458 2.875 10.198 1.00 0.00 C +ATOM 199 C LEU A 22 -2.082 3.722 11.409 1.00 0.00 C +ATOM 200 CB LEU A 22 -3.840 2.250 10.406 1.00 0.00 C +ATOM 201 O LEU A 22 -1.766 3.184 12.473 1.00 0.00 O +ATOM 202 CG LEU A 22 -5.013 3.224 10.530 1.00 0.00 C +ATOM 203 CD1 LEU A 22 -6.236 2.675 9.804 1.00 0.00 C +ATOM 204 CD2 LEU A 22 -5.331 3.493 11.997 1.00 0.00 C +ATOM 205 N ASN A 23 -2.010 5.058 11.207 1.00 0.00 N +ATOM 206 CA ASN A 23 -1.916 6.025 12.296 1.00 0.00 C +ATOM 207 C ASN A 23 -3.262 6.686 12.578 1.00 0.00 C +ATOM 208 CB ASN A 23 -0.859 7.086 11.982 1.00 0.00 C +ATOM 209 O ASN A 23 -3.700 7.558 11.826 1.00 0.00 O +ATOM 210 CG ASN A 23 -0.544 7.968 13.174 1.00 0.00 C +ATOM 211 ND2 ASN A 23 0.738 8.240 13.389 1.00 0.00 N +ATOM 212 OD1 ASN A 23 -1.445 8.400 13.897 1.00 0.00 O +ATOM 213 N ALA A 24 -3.964 6.163 13.657 1.00 0.00 N +ATOM 214 CA ALA A 24 -5.343 6.534 13.965 1.00 0.00 C +ATOM 215 C ALA A 24 -5.445 8.012 14.331 1.00 0.00 C +ATOM 216 CB ALA A 24 -5.885 5.669 15.101 1.00 0.00 C +ATOM 217 O ALA A 24 -6.448 8.664 14.033 1.00 0.00 O +ATOM 218 N ARG A 25 -4.362 8.538 14.892 1.00 0.00 N +ATOM 219 CA ARG A 25 -4.384 9.932 15.323 1.00 0.00 C +ATOM 220 C ARG A 25 -4.331 10.876 14.126 1.00 0.00 C +ATOM 221 CB ARG A 25 -3.219 10.220 16.271 1.00 0.00 C +ATOM 222 O ARG A 25 -5.124 11.815 14.034 1.00 0.00 O +ATOM 223 CG ARG A 25 -3.211 11.637 16.824 1.00 0.00 C +ATOM 224 CD ARG A 25 -2.139 11.820 17.889 1.00 0.00 C +ATOM 225 NE ARG A 25 -2.145 13.176 18.432 1.00 0.00 N +ATOM 226 NH1 ARG A 25 -0.364 12.832 19.863 1.00 0.00 N +ATOM 227 NH2 ARG A 25 -1.397 14.877 19.781 1.00 0.00 N +ATOM 228 CZ ARG A 25 -1.302 13.625 19.358 1.00 0.00 C +ATOM 229 N SER A 26 -3.376 10.637 13.145 1.00 0.00 N +ATOM 230 CA SER A 26 -3.171 11.513 11.996 1.00 0.00 C +ATOM 231 C SER A 26 -4.052 11.100 10.822 1.00 0.00 C +ATOM 232 CB SER A 26 -1.702 11.504 11.569 1.00 0.00 C +ATOM 233 O SER A 26 -4.116 11.801 9.810 1.00 0.00 O +ATOM 234 OG SER A 26 -1.312 10.211 11.138 1.00 0.00 O +ATOM 235 N LYS A 27 -4.836 10.033 11.095 1.00 0.00 N +ATOM 236 CA LYS A 27 -5.656 9.457 10.033 1.00 0.00 C +ATOM 237 C LYS A 27 -4.835 9.227 8.768 1.00 0.00 C +ATOM 238 CB LYS A 27 -6.850 10.362 9.726 1.00 0.00 C +ATOM 239 O LYS A 27 -5.263 9.587 7.669 1.00 0.00 O +ATOM 240 CG LYS A 27 -7.799 10.556 10.899 1.00 0.00 C +ATOM 241 CD LYS A 27 -8.515 9.260 11.258 1.00 0.00 C +ATOM 242 CE LYS A 27 -9.598 9.490 12.304 1.00 0.00 C +ATOM 243 NZ LYS A 27 -10.257 8.213 12.709 1.00 0.00 N +ATOM 244 N THR A 28 -3.684 8.620 8.927 1.00 0.00 N +ATOM 245 CA THR A 28 -2.783 8.326 7.817 1.00 0.00 C +ATOM 246 C THR A 28 -2.604 6.820 7.650 1.00 0.00 C +ATOM 247 CB THR A 28 -1.409 8.990 8.023 1.00 0.00 C +ATOM 248 O THR A 28 -2.497 6.089 8.637 1.00 0.00 O +ATOM 249 CG2 THR A 28 -0.480 8.706 6.848 1.00 0.00 C +ATOM 250 OG1 THR A 28 -1.584 10.407 8.148 1.00 0.00 O +ATOM 251 N CYS A 29 -2.718 6.348 6.470 1.00 0.00 N +ATOM 252 CA CYS A 29 -2.452 4.960 6.109 1.00 0.00 C +ATOM 253 C CYS A 29 -1.190 4.850 5.262 1.00 0.00 C +ATOM 254 CB CYS A 29 -3.639 4.366 5.352 1.00 0.00 C +ATOM 255 O CYS A 29 -1.018 5.597 4.297 1.00 0.00 O +ATOM 256 SG CYS A 29 -3.561 2.571 5.170 1.00 0.00 S +ATOM 257 N ILE A 30 -0.309 3.925 5.603 1.00 0.00 N +ATOM 258 CA ILE A 30 0.942 3.715 4.882 1.00 0.00 C +ATOM 259 C ILE A 30 0.989 2.290 4.335 1.00 0.00 C +ATOM 260 CB ILE A 30 2.166 3.984 5.786 1.00 0.00 C +ATOM 261 O ILE A 30 0.818 1.325 5.084 1.00 0.00 O +ATOM 262 CG1 ILE A 30 2.083 5.388 6.394 1.00 0.00 C +ATOM 263 CG2 ILE A 30 3.468 3.803 4.999 1.00 0.00 C +ATOM 264 CD1 ILE A 30 3.072 5.632 7.526 1.00 0.00 C +ATOM 265 N VAL A 31 1.098 2.123 3.077 1.00 0.00 N +ATOM 266 CA VAL A 31 1.317 0.825 2.449 1.00 0.00 C +ATOM 267 C VAL A 31 2.785 0.684 2.053 1.00 0.00 C +ATOM 268 CB VAL A 31 0.411 0.632 1.212 1.00 0.00 C +ATOM 269 O VAL A 31 3.310 1.498 1.289 1.00 0.00 O +ATOM 270 CG1 VAL A 31 0.662 -0.729 0.566 1.00 0.00 C +ATOM 271 CG2 VAL A 31 -1.059 0.779 1.601 1.00 0.00 C +ATOM 272 N THR A 32 3.455 -0.320 2.592 1.00 0.00 N +ATOM 273 CA THR A 32 4.880 -0.565 2.391 1.00 0.00 C +ATOM 274 C THR A 32 5.102 -1.860 1.616 1.00 0.00 C +ATOM 275 CB THR A 32 5.630 -0.631 3.734 1.00 0.00 C +ATOM 276 O THR A 32 4.464 -2.876 1.899 1.00 0.00 O +ATOM 277 CG2 THR A 32 7.129 -0.812 3.519 1.00 0.00 C +ATOM 278 OG1 THR A 32 5.405 0.583 4.461 1.00 0.00 O +ATOM 279 N VAL A 33 5.927 -1.787 0.613 1.00 0.00 N +ATOM 280 CA VAL A 33 6.436 -2.966 -0.081 1.00 0.00 C +ATOM 281 C VAL A 33 7.879 -3.231 0.342 1.00 0.00 C +ATOM 282 CB VAL A 33 6.351 -2.802 -1.615 1.00 0.00 C +ATOM 283 O VAL A 33 8.737 -2.352 0.230 1.00 0.00 O +ATOM 284 CG1 VAL A 33 6.835 -4.067 -2.321 1.00 0.00 C +ATOM 285 CG2 VAL A 33 4.922 -2.464 -2.036 1.00 0.00 C +ATOM 286 N ASP A 34 8.169 -4.392 0.805 1.00 0.00 N +ATOM 287 CA ASP A 34 9.479 -4.825 1.284 1.00 0.00 C +ATOM 288 C ASP A 34 10.014 -5.983 0.446 1.00 0.00 C +ATOM 289 CB ASP A 34 9.404 -5.231 2.757 1.00 0.00 C +ATOM 290 O ASP A 34 9.513 -7.106 0.537 1.00 0.00 O +ATOM 291 CG ASP A 34 10.762 -5.570 3.348 1.00 0.00 C +ATOM 292 OD1 ASP A 34 11.727 -5.774 2.581 1.00 0.00 O +ATOM 293 OD2 ASP A 34 10.866 -5.636 4.592 1.00 0.00 O +ATOM 294 N GLN A 35 11.073 -5.726 -0.286 1.00 0.00 N +ATOM 295 CA GLN A 35 11.675 -6.742 -1.144 1.00 0.00 C +ATOM 296 C GLN A 35 13.024 -7.197 -0.595 1.00 0.00 C +ATOM 297 CB GLN A 35 11.840 -6.212 -2.569 1.00 0.00 C +ATOM 298 O GLN A 35 13.897 -7.623 -1.353 1.00 0.00 O +ATOM 299 CG GLN A 35 10.522 -5.921 -3.274 1.00 0.00 C +ATOM 300 CD GLN A 35 9.870 -7.169 -3.838 1.00 0.00 C +ATOM 301 NE2 GLN A 35 8.547 -7.240 -3.747 1.00 0.00 N +ATOM 302 OE1 GLN A 35 10.551 -8.062 -4.353 1.00 0.00 O +ATOM 303 N THR A 36 13.298 -6.855 0.668 1.00 0.00 N +ATOM 304 CA THR A 36 14.573 -7.196 1.289 1.00 0.00 C +ATOM 305 C THR A 36 14.903 -8.670 1.071 1.00 0.00 C +ATOM 306 CB THR A 36 14.560 -6.888 2.797 1.00 0.00 C +ATOM 307 O THR A 36 16.070 -9.034 0.913 1.00 0.00 O +ATOM 308 CG2 THR A 36 15.919 -7.172 3.428 1.00 0.00 C +ATOM 309 OG1 THR A 36 14.233 -5.506 2.994 1.00 0.00 O +ATOM 310 N ASN A 37 13.791 -9.502 0.940 1.00 0.00 N +ATOM 311 CA ASN A 37 14.006 -10.939 0.811 1.00 0.00 C +ATOM 312 C ASN A 37 13.990 -11.380 -0.650 1.00 0.00 C +ATOM 313 CB ASN A 37 12.955 -11.713 1.610 1.00 0.00 C +ATOM 314 O ASN A 37 13.982 -12.577 -0.942 1.00 0.00 O +ATOM 315 CG ASN A 37 13.047 -11.453 3.101 1.00 0.00 C +ATOM 316 ND2 ASN A 37 11.898 -11.333 3.754 1.00 0.00 N +ATOM 317 OD1 ASN A 37 14.143 -11.360 3.661 1.00 0.00 O +ATOM 318 N ASN A 38 13.915 -10.430 -1.517 1.00 0.00 N +ATOM 319 CA ASN A 38 13.930 -10.725 -2.946 1.00 0.00 C +ATOM 320 C ASN A 38 15.353 -10.908 -3.466 1.00 0.00 C +ATOM 321 CB ASN A 38 13.216 -9.621 -3.730 1.00 0.00 C +ATOM 322 O ASN A 38 16.154 -9.972 -3.437 1.00 0.00 O +ATOM 323 CG ASN A 38 12.940 -10.011 -5.168 1.00 0.00 C +ATOM 324 ND2 ASN A 38 12.018 -9.303 -5.808 1.00 0.00 N +ATOM 325 OD1 ASN A 38 13.551 -10.942 -5.700 1.00 0.00 O +ATOM 326 N PRO A 39 15.718 -12.109 -3.794 1.00 0.00 N +ATOM 327 CA PRO A 39 17.079 -12.430 -4.229 1.00 0.00 C +ATOM 328 C PRO A 39 17.430 -11.806 -5.578 1.00 0.00 C +ATOM 329 CB PRO A 39 17.068 -13.958 -4.322 1.00 0.00 C +ATOM 330 O PRO A 39 18.598 -11.804 -5.975 1.00 0.00 O +ATOM 331 CG PRO A 39 15.632 -14.318 -4.529 1.00 0.00 C +ATOM 332 CD PRO A 39 14.778 -13.272 -3.872 1.00 0.00 C +ATOM 333 N GLN A 40 16.394 -11.251 -6.150 1.00 0.00 N +ATOM 334 CA GLN A 40 16.646 -10.746 -7.496 1.00 0.00 C +ATOM 335 C GLN A 40 17.358 -9.397 -7.452 1.00 0.00 C +ATOM 336 CB GLN A 40 15.338 -10.625 -8.278 1.00 0.00 C +ATOM 337 O GLN A 40 16.883 -8.459 -6.810 1.00 0.00 O +ATOM 338 CG GLN A 40 14.617 -11.951 -8.481 1.00 0.00 C +ATOM 339 CD GLN A 40 13.302 -11.796 -9.221 1.00 0.00 C +ATOM 340 NE2 GLN A 40 12.511 -12.864 -9.254 1.00 0.00 N +ATOM 341 OE1 GLN A 40 12.999 -10.727 -9.759 1.00 0.00 O +ATOM 342 N GLU A 41 18.658 -9.506 -7.750 1.00 0.00 N +ATOM 343 CA GLU A 41 19.475 -8.300 -7.847 1.00 0.00 C +ATOM 344 C GLU A 41 19.146 -7.510 -9.110 1.00 0.00 C +ATOM 345 CB GLU A 41 20.964 -8.656 -7.822 1.00 0.00 C +ATOM 346 O GLU A 41 18.572 -8.053 -10.056 1.00 0.00 O +ATOM 347 CG GLU A 41 21.428 -9.272 -6.510 1.00 0.00 C +ATOM 348 CD GLU A 41 22.928 -9.511 -6.457 1.00 0.00 C +ATOM 349 OE1 GLU A 41 23.442 -9.906 -5.385 1.00 0.00 O +ATOM 350 OE2 GLU A 41 23.595 -9.301 -7.494 1.00 0.00 O +ATOM 351 N ASN A 42 18.825 -6.245 -8.943 1.00 0.00 N +ATOM 352 CA ASN A 42 18.748 -5.279 -10.034 1.00 0.00 C +ATOM 353 C ASN A 42 17.322 -5.141 -10.561 1.00 0.00 C +ATOM 354 CB ASN A 42 19.697 -5.672 -11.168 1.00 0.00 C +ATOM 355 O ASN A 42 17.113 -4.986 -11.765 1.00 0.00 O +ATOM 356 CG ASN A 42 21.153 -5.643 -10.747 1.00 0.00 C +ATOM 357 ND2 ASN A 42 21.930 -6.604 -11.232 1.00 0.00 N +ATOM 358 OD1 ASN A 42 21.576 -4.764 -9.991 1.00 0.00 O +ATOM 359 N MET A 43 16.406 -5.433 -9.710 1.00 0.00 N +ATOM 360 CA MET A 43 15.006 -5.237 -10.074 1.00 0.00 C +ATOM 361 C MET A 43 14.398 -4.075 -9.295 1.00 0.00 C +ATOM 362 CB MET A 43 14.201 -6.513 -9.824 1.00 0.00 C +ATOM 363 O MET A 43 14.758 -3.842 -8.140 1.00 0.00 O +ATOM 364 CG MET A 43 14.599 -7.674 -10.720 1.00 0.00 C +ATOM 365 SD MET A 43 13.455 -9.102 -10.570 1.00 0.00 S +ATOM 366 CE MET A 43 13.902 -9.697 -8.916 1.00 0.00 C +ATOM 367 N GLY A 44 13.752 -3.201 -10.052 1.00 0.00 N +ATOM 368 CA GLY A 44 12.963 -2.172 -9.393 1.00 0.00 C +ATOM 369 C GLY A 44 11.483 -2.499 -9.336 1.00 0.00 C +ATOM 370 O GLY A 44 11.043 -3.506 -9.895 1.00 0.00 O +ATOM 371 N PHE A 45 10.800 -1.846 -8.459 1.00 0.00 N +ATOM 372 CA PHE A 45 9.356 -2.022 -8.356 1.00 0.00 C +ATOM 373 C PHE A 45 8.664 -0.687 -8.106 1.00 0.00 C +ATOM 374 CB PHE A 45 9.013 -3.009 -7.235 1.00 0.00 C +ATOM 375 O PHE A 45 9.316 0.302 -7.767 1.00 0.00 O +ATOM 376 CG PHE A 45 9.419 -2.537 -5.865 1.00 0.00 C +ATOM 377 CD1 PHE A 45 10.703 -2.769 -5.387 1.00 0.00 C +ATOM 378 CD2 PHE A 45 8.516 -1.860 -5.055 1.00 0.00 C +ATOM 379 CE1 PHE A 45 11.082 -2.334 -4.119 1.00 0.00 C +ATOM 380 CE2 PHE A 45 8.888 -1.422 -3.787 1.00 0.00 C +ATOM 381 CZ PHE A 45 10.170 -1.661 -3.321 1.00 0.00 C +ATOM 382 N ALA A 46 7.363 -0.706 -8.425 1.00 0.00 N +ATOM 383 CA ALA A 46 6.515 0.457 -8.177 1.00 0.00 C +ATOM 384 C ALA A 46 5.186 0.044 -7.550 1.00 0.00 C +ATOM 385 CB ALA A 46 6.271 1.225 -9.474 1.00 0.00 C +ATOM 386 O ALA A 46 4.642 -1.014 -7.875 1.00 0.00 O +ATOM 387 N ILE A 47 4.724 0.753 -6.545 1.00 0.00 N +ATOM 388 CA ILE A 47 3.375 0.616 -6.007 1.00 0.00 C +ATOM 389 C ILE A 47 2.622 1.935 -6.157 1.00 0.00 C +ATOM 390 CB ILE A 47 3.400 0.179 -4.525 1.00 0.00 C +ATOM 391 O ILE A 47 3.171 3.005 -5.882 1.00 0.00 O +ATOM 392 CG1 ILE A 47 4.194 1.185 -3.684 1.00 0.00 C +ATOM 393 CG2 ILE A 47 3.983 -1.231 -4.387 1.00 0.00 C +ATOM 394 CD1 ILE A 47 4.036 0.998 -2.182 1.00 0.00 C +ATOM 395 N MET A 48 1.411 1.770 -6.657 1.00 0.00 N +ATOM 396 CA MET A 48 0.618 2.953 -6.979 1.00 0.00 C +ATOM 397 C MET A 48 -0.783 2.847 -6.387 1.00 0.00 C +ATOM 398 CB MET A 48 0.533 3.148 -8.494 1.00 0.00 C +ATOM 399 O MET A 48 -1.425 1.800 -6.483 1.00 0.00 O +ATOM 400 CG MET A 48 -0.272 4.368 -8.911 1.00 0.00 C +ATOM 401 SD MET A 48 -0.347 4.571 -10.733 1.00 0.00 S +ATOM 402 CE MET A 48 1.088 3.577 -11.226 1.00 0.00 C +ATOM 403 N LEU A 49 -1.241 3.937 -5.762 1.00 0.00 N +ATOM 404 CA LEU A 49 -2.646 4.072 -5.392 1.00 0.00 C +ATOM 405 C LEU A 49 -3.494 4.445 -6.605 1.00 0.00 C +ATOM 406 CB LEU A 49 -2.813 5.127 -4.295 1.00 0.00 C +ATOM 407 O LEU A 49 -3.363 5.546 -7.144 1.00 0.00 O +ATOM 408 CG LEU A 49 -4.234 5.341 -3.771 1.00 0.00 C +ATOM 409 CD1 LEU A 49 -4.792 4.039 -3.206 1.00 0.00 C +ATOM 410 CD2 LEU A 49 -4.255 6.441 -2.715 1.00 0.00 C +ATOM 411 N ILE A 50 -4.427 3.567 -7.073 1.00 0.00 N +ATOM 412 CA ILE A 50 -5.183 3.724 -8.311 1.00 0.00 C +ATOM 413 C ILE A 50 -6.056 4.974 -8.227 1.00 0.00 C +ATOM 414 CB ILE A 50 -6.053 2.481 -8.603 1.00 0.00 C +ATOM 415 O ILE A 50 -6.663 5.247 -7.188 1.00 0.00 O +ATOM 416 CG1 ILE A 50 -5.166 1.261 -8.875 1.00 0.00 C +ATOM 417 CG2 ILE A 50 -6.995 2.746 -9.780 1.00 0.00 C +ATOM 418 CD1 ILE A 50 -5.936 -0.045 -9.017 1.00 0.00 C +ATOM 419 N ASP A 51 -6.069 5.759 -9.226 1.00 0.00 N +ATOM 420 CA ASP A 51 -6.921 6.925 -9.438 1.00 0.00 C +ATOM 421 C ASP A 51 -6.398 8.135 -8.667 1.00 0.00 C +ATOM 422 CB ASP A 51 -8.362 6.620 -9.025 1.00 0.00 C +ATOM 423 O ASP A 51 -7.171 9.018 -8.290 1.00 0.00 O +ATOM 424 CG ASP A 51 -9.006 5.537 -9.873 1.00 0.00 C +ATOM 425 OD1 ASP A 51 -8.719 5.462 -11.087 1.00 0.00 O +ATOM 426 OD2 ASP A 51 -9.809 4.753 -9.323 1.00 0.00 O +ATOM 427 N THR A 52 -5.116 8.064 -8.253 1.00 0.00 N +ATOM 428 CA THR A 52 -4.446 9.180 -7.595 1.00 0.00 C +ATOM 429 C THR A 52 -3.068 9.419 -8.204 1.00 0.00 C +ATOM 430 CB THR A 52 -4.305 8.933 -6.081 1.00 0.00 C +ATOM 431 O THR A 52 -2.650 8.695 -9.111 1.00 0.00 O +ATOM 432 CG2 THR A 52 -5.582 8.334 -5.500 1.00 0.00 C +ATOM 433 OG1 THR A 52 -3.218 8.029 -5.850 1.00 0.00 O +ATOM 434 N ASP A 53 -2.398 10.489 -7.806 1.00 0.00 N +ATOM 435 CA ASP A 53 -1.026 10.761 -8.225 1.00 0.00 C +ATOM 436 C ASP A 53 -0.028 10.312 -7.160 1.00 0.00 C +ATOM 437 CB ASP A 53 -0.842 12.251 -8.524 1.00 0.00 C +ATOM 438 O ASP A 53 1.125 10.750 -7.159 1.00 0.00 O +ATOM 439 CG ASP A 53 -1.641 12.717 -9.728 1.00 0.00 C +ATOM 440 OD1 ASP A 53 -1.728 11.972 -10.728 1.00 0.00 O +ATOM 441 OD2 ASP A 53 -2.186 13.841 -9.678 1.00 0.00 O +ATOM 442 N ILE A 54 -0.442 9.451 -6.327 1.00 0.00 N +ATOM 443 CA ILE A 54 0.406 8.993 -5.232 1.00 0.00 C +ATOM 444 C ILE A 54 1.056 7.662 -5.603 1.00 0.00 C +ATOM 445 CB ILE A 54 -0.395 8.850 -3.918 1.00 0.00 C +ATOM 446 O ILE A 54 0.362 6.682 -5.881 1.00 0.00 O +ATOM 447 CG1 ILE A 54 -1.083 10.174 -3.566 1.00 0.00 C +ATOM 448 CG2 ILE A 54 0.516 8.386 -2.778 1.00 0.00 C +ATOM 449 CD1 ILE A 54 -2.074 10.069 -2.414 1.00 0.00 C +ATOM 450 N TRP A 55 2.279 7.611 -5.677 1.00 0.00 N +ATOM 451 CA TRP A 55 2.996 6.392 -6.034 1.00 0.00 C +ATOM 452 C TRP A 55 4.426 6.427 -5.505 1.00 0.00 C +ATOM 453 CB TRP A 55 3.004 6.197 -7.552 1.00 0.00 C +ATOM 454 O TRP A 55 4.907 7.474 -5.065 1.00 0.00 O +ATOM 455 CG TRP A 55 3.649 7.321 -8.307 1.00 0.00 C +ATOM 456 CD1 TRP A 55 3.095 8.533 -8.612 1.00 0.00 C +ATOM 457 CD2 TRP A 55 4.970 7.333 -8.857 1.00 0.00 C +ATOM 458 CE2 TRP A 55 5.150 8.587 -9.482 1.00 0.00 C +ATOM 459 CE3 TRP A 55 6.020 6.405 -8.881 1.00 0.00 C +ATOM 460 NE1 TRP A 55 3.993 9.299 -9.318 1.00 0.00 N +ATOM 461 CH2 TRP A 55 7.351 8.010 -10.134 1.00 0.00 C +ATOM 462 CZ2 TRP A 55 6.341 8.936 -10.125 1.00 0.00 C +ATOM 463 CZ3 TRP A 55 7.203 6.755 -9.522 1.00 0.00 C +ATOM 464 N CYS A 56 5.000 5.260 -5.391 1.00 0.00 N +ATOM 465 CA CYS A 56 6.389 5.078 -4.984 1.00 0.00 C +ATOM 466 C CYS A 56 7.106 4.104 -5.911 1.00 0.00 C +ATOM 467 CB CYS A 56 6.464 4.574 -3.543 1.00 0.00 C +ATOM 468 O CYS A 56 6.576 3.038 -6.228 1.00 0.00 O +ATOM 469 SG CYS A 56 8.149 4.302 -2.954 1.00 0.00 S +ATOM 470 N MET A 57 8.257 4.518 -6.421 1.00 0.00 N +ATOM 471 CA MET A 57 9.125 3.657 -7.220 1.00 0.00 C +ATOM 472 C MET A 57 10.511 3.550 -6.593 1.00 0.00 C +ATOM 473 CB MET A 57 9.237 4.186 -8.651 1.00 0.00 C +ATOM 474 O MET A 57 11.049 4.539 -6.094 1.00 0.00 O +ATOM 475 CG MET A 57 7.924 4.169 -9.416 1.00 0.00 C +ATOM 476 SD MET A 57 8.115 4.724 -11.154 1.00 0.00 S +ATOM 477 CE MET A 57 9.877 4.372 -11.410 1.00 0.00 C +ATOM 478 N SER A 58 10.989 2.288 -6.576 1.00 0.00 N +ATOM 479 CA SER A 58 12.264 2.156 -5.879 1.00 0.00 C +ATOM 480 C SER A 58 13.082 0.996 -6.438 1.00 0.00 C +ATOM 481 CB SER A 58 12.038 1.955 -4.380 1.00 0.00 C +ATOM 482 O SER A 58 12.521 -0.003 -6.894 1.00 0.00 O +ATOM 483 OG SER A 58 13.273 1.808 -3.701 1.00 0.00 O +ATOM 484 N PHE A 59 14.446 1.264 -6.479 1.00 0.00 N +ATOM 485 CA PHE A 59 15.391 0.171 -6.677 1.00 0.00 C +ATOM 486 C PHE A 59 15.963 -0.298 -5.345 1.00 0.00 C +ATOM 487 CB PHE A 59 16.524 0.601 -7.614 1.00 0.00 C +ATOM 488 O PHE A 59 16.693 -1.291 -5.292 1.00 0.00 O +ATOM 489 CG PHE A 59 16.135 0.625 -9.067 1.00 0.00 C +ATOM 490 CD1 PHE A 59 16.399 -0.463 -9.890 1.00 0.00 C +ATOM 491 CD2 PHE A 59 15.504 1.736 -9.611 1.00 0.00 C +ATOM 492 CE1 PHE A 59 16.040 -0.444 -11.235 1.00 0.00 C +ATOM 493 CE2 PHE A 59 15.143 1.763 -10.955 1.00 0.00 C +ATOM 494 CZ PHE A 59 15.412 0.672 -11.766 1.00 0.00 C +ATOM 495 N ALA A 60 15.648 0.435 -4.357 1.00 0.00 N +ATOM 496 CA ALA A 60 15.954 0.014 -2.992 1.00 0.00 C +ATOM 497 C ALA A 60 14.951 -1.025 -2.500 1.00 0.00 C +ATOM 498 CB ALA A 60 15.969 1.219 -2.054 1.00 0.00 C +ATOM 499 O ALA A 60 13.911 -1.239 -3.128 1.00 0.00 O +ATOM 500 N PRO A 61 15.294 -1.798 -1.450 1.00 0.00 N +ATOM 501 CA PRO A 61 14.452 -2.918 -1.021 1.00 0.00 C +ATOM 502 C PRO A 61 13.119 -2.462 -0.431 1.00 0.00 C +ATOM 503 CB PRO A 61 15.308 -3.616 0.038 1.00 0.00 C +ATOM 504 O PRO A 61 12.211 -3.276 -0.247 1.00 0.00 O +ATOM 505 CG PRO A 61 16.273 -2.571 0.500 1.00 0.00 C +ATOM 506 CD PRO A 61 16.512 -1.612 -0.630 1.00 0.00 C +ATOM 507 N LEU A 62 12.946 -1.118 -0.268 1.00 0.00 N +ATOM 508 CA LEU A 62 11.732 -0.669 0.405 1.00 0.00 C +ATOM 509 C LEU A 62 11.131 0.538 -0.307 1.00 0.00 C +ATOM 510 CB LEU A 62 12.027 -0.321 1.866 1.00 0.00 C +ATOM 511 O LEU A 62 11.859 1.434 -0.741 1.00 0.00 O +ATOM 512 CG LEU A 62 12.417 -1.486 2.776 1.00 0.00 C +ATOM 513 CD1 LEU A 62 12.964 -0.965 4.101 1.00 0.00 C +ATOM 514 CD2 LEU A 62 11.223 -2.406 3.010 1.00 0.00 C +ATOM 515 N CYS A 63 9.795 0.579 -0.388 1.00 0.00 N +ATOM 516 CA CYS A 63 9.049 1.716 -0.916 1.00 0.00 C +ATOM 517 C CYS A 63 7.730 1.894 -0.174 1.00 0.00 C +ATOM 518 CB CYS A 63 8.783 1.534 -2.411 1.00 0.00 C +ATOM 519 O CYS A 63 7.066 0.914 0.166 1.00 0.00 O +ATOM 520 SG CYS A 63 8.129 3.012 -3.218 1.00 0.00 S +ATOM 521 N GLU A 64 7.273 3.137 0.060 1.00 0.00 N +ATOM 522 CA GLU A 64 6.045 3.385 0.811 1.00 0.00 C +ATOM 523 C GLU A 64 5.144 4.378 0.083 1.00 0.00 C +ATOM 524 CB GLU A 64 6.369 3.900 2.216 1.00 0.00 C +ATOM 525 O GLU A 64 5.631 5.312 -0.557 1.00 0.00 O +ATOM 526 CG GLU A 64 7.158 2.914 3.066 1.00 0.00 C +ATOM 527 CD GLU A 64 7.576 3.483 4.412 1.00 0.00 C +ATOM 528 OE1 GLU A 64 8.127 2.728 5.245 1.00 0.00 O +ATOM 529 OE2 GLU A 64 7.351 4.693 4.636 1.00 0.00 O +ATOM 530 N VAL A 65 3.891 4.163 0.164 1.00 0.00 N +ATOM 531 CA VAL A 65 2.869 5.126 -0.231 1.00 0.00 C +ATOM 532 C VAL A 65 2.061 5.553 0.992 1.00 0.00 C +ATOM 533 CB VAL A 65 1.931 4.547 -1.314 1.00 0.00 C +ATOM 534 O VAL A 65 1.446 4.719 1.660 1.00 0.00 O +ATOM 535 CG1 VAL A 65 0.787 5.515 -1.612 1.00 0.00 C +ATOM 536 CG2 VAL A 65 2.716 4.234 -2.587 1.00 0.00 C +ATOM 537 N LYS A 66 2.125 6.823 1.301 1.00 0.00 N +ATOM 538 CA LYS A 66 1.412 7.397 2.439 1.00 0.00 C +ATOM 539 C LYS A 66 0.230 8.244 1.976 1.00 0.00 C +ATOM 540 CB LYS A 66 2.357 8.240 3.295 1.00 0.00 C +ATOM 541 O LYS A 66 0.372 9.083 1.084 1.00 0.00 O +ATOM 542 CG LYS A 66 1.740 8.734 4.595 1.00 0.00 C +ATOM 543 CD LYS A 66 2.762 9.464 5.458 1.00 0.00 C +ATOM 544 CE LYS A 66 2.137 9.991 6.742 1.00 0.00 C +ATOM 545 NZ LYS A 66 3.131 10.723 7.582 1.00 0.00 N +ATOM 546 N PHE A 67 -0.965 8.044 2.601 1.00 0.00 N +ATOM 547 CA PHE A 67 -2.137 8.813 2.199 1.00 0.00 C +ATOM 548 C PHE A 67 -3.119 8.946 3.357 1.00 0.00 C +ATOM 549 CB PHE A 67 -2.826 8.158 0.999 1.00 0.00 C +ATOM 550 O PHE A 67 -3.043 8.196 4.332 1.00 0.00 O +ATOM 551 CG PHE A 67 -3.210 6.721 1.229 1.00 0.00 C +ATOM 552 CD1 PHE A 67 -2.288 5.700 1.029 1.00 0.00 C +ATOM 553 CD2 PHE A 67 -4.493 6.391 1.645 1.00 0.00 C +ATOM 554 CE1 PHE A 67 -2.641 4.369 1.241 1.00 0.00 C +ATOM 555 CE2 PHE A 67 -4.852 5.063 1.859 1.00 0.00 C +ATOM 556 CZ PHE A 67 -3.925 4.054 1.655 1.00 0.00 C +ATOM 557 N SER A 68 -3.986 9.931 3.229 1.00 0.00 N +ATOM 558 CA SER A 68 -5.005 10.179 4.243 1.00 0.00 C +ATOM 559 C SER A 68 -6.269 9.372 3.966 1.00 0.00 C +ATOM 560 CB SER A 68 -5.346 11.668 4.308 1.00 0.00 C +ATOM 561 O SER A 68 -6.585 9.080 2.811 1.00 0.00 O +ATOM 562 OG SER A 68 -4.210 12.428 4.685 1.00 0.00 O +ATOM 563 N TYR A 69 -6.917 8.883 5.050 1.00 0.00 N +ATOM 564 CA TYR A 69 -8.184 8.178 4.893 1.00 0.00 C +ATOM 565 C TYR A 69 -9.270 8.807 5.758 1.00 0.00 C +ATOM 566 CB TYR A 69 -8.024 6.697 5.251 1.00 0.00 C +ATOM 567 O TYR A 69 -8.972 9.565 6.684 1.00 0.00 O +ATOM 568 CG TYR A 69 -7.710 6.456 6.707 1.00 0.00 C +ATOM 569 CD1 TYR A 69 -6.406 6.554 7.185 1.00 0.00 C +ATOM 570 CD2 TYR A 69 -8.718 6.128 7.608 1.00 0.00 C +ATOM 571 CE1 TYR A 69 -6.112 6.331 8.526 1.00 0.00 C +ATOM 572 CE2 TYR A 69 -8.436 5.902 8.952 1.00 0.00 C +ATOM 573 OH TYR A 69 -6.847 5.784 10.729 1.00 0.00 O +ATOM 574 CZ TYR A 69 -7.132 6.006 9.401 1.00 0.00 C +ATOM 575 N ARG A 70 -10.583 8.696 5.262 1.00 0.00 N +ATOM 576 CA ARG A 70 -11.721 9.219 6.010 1.00 0.00 C +ATOM 577 C ARG A 70 -12.611 8.089 6.514 1.00 0.00 C +ATOM 578 CB ARG A 70 -12.536 10.183 5.146 1.00 0.00 C +ATOM 579 O ARG A 70 -12.681 7.024 5.897 1.00 0.00 O +ATOM 580 CG ARG A 70 -11.777 11.434 4.734 1.00 0.00 C +ATOM 581 CD ARG A 70 -12.666 12.410 3.977 1.00 0.00 C +ATOM 582 NE ARG A 70 -11.945 13.627 3.614 1.00 0.00 N +ATOM 583 NH1 ARG A 70 -13.751 14.639 2.588 1.00 0.00 N +ATOM 584 NH2 ARG A 70 -11.730 15.716 2.684 1.00 0.00 N +ATOM 585 CZ ARG A 70 -12.477 14.658 2.963 1.00 0.00 C +ATOM 586 N GLY A 71 -13.241 8.307 7.725 1.00 0.00 N +ATOM 587 CA GLY A 71 -14.135 7.313 8.296 1.00 0.00 C +ATOM 588 C GLY A 71 -13.408 6.233 9.075 1.00 0.00 C +ATOM 589 O GLY A 71 -12.192 6.309 9.263 1.00 0.00 O +ATOM 590 N MET A 72 -14.158 5.323 9.659 1.00 0.00 N +ATOM 591 CA MET A 72 -13.613 4.253 10.491 1.00 0.00 C +ATOM 592 C MET A 72 -13.139 3.085 9.632 1.00 0.00 C +ATOM 593 CB MET A 72 -14.656 3.771 11.500 1.00 0.00 C +ATOM 594 O MET A 72 -12.324 2.274 10.076 1.00 0.00 O +ATOM 595 CG MET A 72 -15.002 4.801 12.563 1.00 0.00 C +ATOM 596 SD MET A 72 -16.056 4.109 13.896 1.00 0.00 S +ATOM 597 CE MET A 72 -15.027 2.711 14.424 1.00 0.00 C +ATOM 598 N LYS A 73 -13.499 3.031 8.394 1.00 0.00 N +ATOM 599 CA LYS A 73 -13.158 1.998 7.420 1.00 0.00 C +ATOM 600 C LYS A 73 -13.157 2.558 6.001 1.00 0.00 C +ATOM 601 CB LYS A 73 -14.134 0.824 7.519 1.00 0.00 C +ATOM 602 O LYS A 73 -14.116 3.212 5.585 1.00 0.00 O +ATOM 603 CG LYS A 73 -13.805 -0.335 6.589 1.00 0.00 C +ATOM 604 CD LYS A 73 -14.814 -1.467 6.729 1.00 0.00 C +ATOM 605 CE LYS A 73 -14.536 -2.590 5.738 1.00 0.00 C +ATOM 606 NZ LYS A 73 -15.551 -3.681 5.839 1.00 0.00 N +ATOM 607 N ALA A 74 -12.007 2.391 5.220 1.00 0.00 N +ATOM 608 CA ALA A 74 -11.877 2.842 3.837 1.00 0.00 C +ATOM 609 C ALA A 74 -11.034 1.868 3.019 1.00 0.00 C +ATOM 610 CB ALA A 74 -11.266 4.241 3.788 1.00 0.00 C +ATOM 611 O ALA A 74 -10.053 1.315 3.520 1.00 0.00 O +ATOM 612 N MET A 75 -11.483 1.658 1.812 1.00 0.00 N +ATOM 613 CA MET A 75 -10.803 0.722 0.923 1.00 0.00 C +ATOM 614 C MET A 75 -10.086 1.463 -0.202 1.00 0.00 C +ATOM 615 CB MET A 75 -11.796 -0.283 0.337 1.00 0.00 C +ATOM 616 O MET A 75 -10.672 2.331 -0.851 1.00 0.00 O +ATOM 617 CG MET A 75 -12.428 -1.196 1.375 1.00 0.00 C +ATOM 618 SD MET A 75 -13.576 -2.419 0.630 1.00 0.00 S +ATOM 619 CE MET A 75 -15.015 -1.358 0.318 1.00 0.00 C +ATOM 620 N PHE A 76 -8.828 0.988 -0.456 1.00 0.00 N +ATOM 621 CA PHE A 76 -8.002 1.608 -1.486 1.00 0.00 C +ATOM 622 C PHE A 76 -7.388 0.550 -2.396 1.00 0.00 C +ATOM 623 CB PHE A 76 -6.897 2.459 -0.851 1.00 0.00 C +ATOM 624 O PHE A 76 -6.937 -0.495 -1.925 1.00 0.00 O +ATOM 625 CG PHE A 76 -7.413 3.553 0.044 1.00 0.00 C +ATOM 626 CD1 PHE A 76 -7.706 4.810 -0.471 1.00 0.00 C +ATOM 627 CD2 PHE A 76 -7.604 3.325 1.400 1.00 0.00 C +ATOM 628 CE1 PHE A 76 -8.184 5.825 0.355 1.00 0.00 C +ATOM 629 CE2 PHE A 76 -8.081 4.334 2.232 1.00 0.00 C +ATOM 630 CZ PHE A 76 -8.369 5.584 1.708 1.00 0.00 C +ATOM 631 N SER A 77 -7.380 0.793 -3.671 1.00 0.00 N +ATOM 632 CA SER A 77 -6.849 -0.138 -4.661 1.00 0.00 C +ATOM 633 C SER A 77 -5.443 0.262 -5.098 1.00 0.00 C +ATOM 634 CB SER A 77 -7.770 -0.206 -5.880 1.00 0.00 C +ATOM 635 O SER A 77 -5.182 1.436 -5.368 1.00 0.00 O +ATOM 636 OG SER A 77 -9.051 -0.690 -5.516 1.00 0.00 O +ATOM 637 N PHE A 78 -4.547 -0.700 -5.147 1.00 0.00 N +ATOM 638 CA PHE A 78 -3.150 -0.484 -5.504 1.00 0.00 C +ATOM 639 C PHE A 78 -2.766 -1.319 -6.720 1.00 0.00 C +ATOM 640 CB PHE A 78 -2.234 -0.822 -4.324 1.00 0.00 C +ATOM 641 O PHE A 78 -3.356 -2.373 -6.967 1.00 0.00 O +ATOM 642 CG PHE A 78 -2.257 0.204 -3.224 1.00 0.00 C +ATOM 643 CD1 PHE A 78 -1.255 1.162 -3.127 1.00 0.00 C +ATOM 644 CD2 PHE A 78 -3.281 0.210 -2.286 1.00 0.00 C +ATOM 645 CE1 PHE A 78 -1.274 2.113 -2.109 1.00 0.00 C +ATOM 646 CE2 PHE A 78 -3.307 1.158 -1.266 1.00 0.00 C +ATOM 647 CZ PHE A 78 -2.302 2.108 -1.179 1.00 0.00 C +ATOM 648 N ARG A 79 -1.852 -0.697 -7.408 1.00 0.00 N +ATOM 649 CA ARG A 79 -1.158 -1.385 -8.491 1.00 0.00 C +ATOM 650 C ARG A 79 0.318 -1.580 -8.159 1.00 0.00 C +ATOM 651 CB ARG A 79 -1.302 -0.609 -9.802 1.00 0.00 C +ATOM 652 O ARG A 79 1.017 -0.619 -7.830 1.00 0.00 O +ATOM 653 CG ARG A 79 -0.640 -1.283 -10.993 1.00 0.00 C +ATOM 654 CD ARG A 79 -0.783 -0.453 -12.262 1.00 0.00 C +ATOM 655 NE ARG A 79 -0.132 -1.094 -13.401 1.00 0.00 N +ATOM 656 NH1 ARG A 79 -0.537 0.641 -14.872 1.00 0.00 N +ATOM 657 NH2 ARG A 79 0.583 -1.230 -15.579 1.00 0.00 N +ATOM 658 CZ ARG A 79 -0.030 -0.560 -14.615 1.00 0.00 C +ATOM 659 N TYR A 80 0.823 -2.845 -8.184 1.00 0.00 N +ATOM 660 CA TYR A 80 2.220 -3.211 -7.979 1.00 0.00 C +ATOM 661 C TYR A 80 2.861 -3.665 -9.284 1.00 0.00 C +ATOM 662 CB TYR A 80 2.337 -4.317 -6.927 1.00 0.00 C +ATOM 663 O TYR A 80 2.344 -4.559 -9.959 1.00 0.00 O +ATOM 664 CG TYR A 80 3.724 -4.901 -6.812 1.00 0.00 C +ATOM 665 CD1 TYR A 80 4.021 -6.153 -7.347 1.00 0.00 C +ATOM 666 CD2 TYR A 80 4.741 -4.204 -6.168 1.00 0.00 C +ATOM 667 CE1 TYR A 80 5.297 -6.696 -7.242 1.00 0.00 C +ATOM 668 CE2 TYR A 80 6.020 -4.737 -6.057 1.00 0.00 C +ATOM 669 OH TYR A 80 7.554 -6.514 -6.490 1.00 0.00 O +ATOM 670 CZ TYR A 80 6.289 -5.982 -6.597 1.00 0.00 C +ATOM 671 N ILE A 81 3.977 -3.045 -9.726 1.00 0.00 N +ATOM 672 CA ILE A 81 4.663 -3.362 -10.974 1.00 0.00 C +ATOM 673 C ILE A 81 6.141 -3.626 -10.697 1.00 0.00 C +ATOM 674 CB ILE A 81 4.506 -2.226 -12.010 1.00 0.00 C +ATOM 675 O ILE A 81 6.795 -2.855 -9.991 1.00 0.00 O +ATOM 676 CG1 ILE A 81 3.023 -1.919 -12.245 1.00 0.00 C +ATOM 677 CG2 ILE A 81 5.205 -2.593 -13.323 1.00 0.00 C +ATOM 678 CD1 ILE A 81 2.772 -0.622 -13.001 1.00 0.00 C +ATOM 679 N MET A 82 6.624 -4.650 -11.272 1.00 0.00 N +ATOM 680 CA MET A 82 8.056 -4.932 -11.247 1.00 0.00 C +ATOM 681 C MET A 82 8.708 -4.553 -12.572 1.00 0.00 C +ATOM 682 CB MET A 82 8.309 -6.410 -10.943 1.00 0.00 C +ATOM 683 O MET A 82 8.136 -4.782 -13.639 1.00 0.00 O +ATOM 684 CG MET A 82 7.919 -6.821 -9.533 1.00 0.00 C +ATOM 685 SD MET A 82 8.465 -8.524 -9.121 1.00 0.00 S +ATOM 686 CE MET A 82 7.498 -9.472 -10.328 1.00 0.00 C +ATOM 687 N TYR A 83 9.885 -3.839 -12.450 1.00 0.00 N +ATOM 688 CA TYR A 83 10.575 -3.496 -13.688 1.00 0.00 C +ATOM 689 C TYR A 83 12.059 -3.831 -13.597 1.00 0.00 C +ATOM 690 CB TYR A 83 10.394 -2.010 -14.011 1.00 0.00 C +ATOM 691 O TYR A 83 12.607 -3.957 -12.499 1.00 0.00 O +ATOM 692 CG TYR A 83 10.659 -1.097 -12.838 1.00 0.00 C +ATOM 693 CD1 TYR A 83 9.644 -0.769 -11.943 1.00 0.00 C +ATOM 694 CD2 TYR A 83 11.924 -0.561 -12.624 1.00 0.00 C +ATOM 695 CE1 TYR A 83 9.883 0.073 -10.861 1.00 0.00 C +ATOM 696 CE2 TYR A 83 12.175 0.282 -11.546 1.00 0.00 C +ATOM 697 OH TYR A 83 11.392 1.426 -9.603 1.00 0.00 O +ATOM 698 CZ TYR A 83 11.150 0.592 -10.671 1.00 0.00 C +ATOM 699 N ASP A 84 12.579 -4.184 -14.769 1.00 0.00 N +ATOM 700 CA ASP A 84 14.004 -4.497 -14.807 1.00 0.00 C +ATOM 701 C ASP A 84 14.846 -3.224 -14.847 1.00 0.00 C +ATOM 702 CB ASP A 84 14.328 -5.379 -16.015 1.00 0.00 C +ATOM 703 O ASP A 84 14.307 -2.116 -14.805 1.00 0.00 O +ATOM 704 CG ASP A 84 14.207 -4.641 -17.337 1.00 0.00 C +ATOM 705 OD1 ASP A 84 14.151 -3.393 -17.336 1.00 0.00 O +ATOM 706 OD2 ASP A 84 14.166 -5.315 -18.390 1.00 0.00 O +ATOM 707 N GLN A 85 16.104 -3.107 -14.825 1.00 0.00 N +ATOM 708 CA GLN A 85 17.029 -1.983 -14.722 1.00 0.00 C +ATOM 709 C GLN A 85 16.939 -1.083 -15.951 1.00 0.00 C +ATOM 710 CB GLN A 85 18.463 -2.481 -14.540 1.00 0.00 C +ATOM 711 O GLN A 85 17.373 0.070 -15.916 1.00 0.00 O +ATOM 712 CG GLN A 85 18.972 -3.327 -15.700 1.00 0.00 C +ATOM 713 CD GLN A 85 20.353 -3.901 -15.446 1.00 0.00 C +ATOM 714 NE2 GLN A 85 20.790 -4.803 -16.318 1.00 0.00 N +ATOM 715 OE1 GLN A 85 21.023 -3.536 -14.474 1.00 0.00 O +ATOM 716 N ASN A 86 16.347 -1.719 -16.985 1.00 0.00 N +ATOM 717 CA ASN A 86 16.216 -0.932 -18.206 1.00 0.00 C +ATOM 718 C ASN A 86 14.876 -0.203 -18.262 1.00 0.00 C +ATOM 719 CB ASN A 86 16.389 -1.821 -19.439 1.00 0.00 C +ATOM 720 O ASN A 86 14.590 0.504 -19.229 1.00 0.00 O +ATOM 721 CG ASN A 86 17.770 -2.440 -19.524 1.00 0.00 C +ATOM 722 ND2 ASN A 86 17.829 -3.711 -19.903 1.00 0.00 N +ATOM 723 OD1 ASN A 86 18.777 -1.781 -19.250 1.00 0.00 O +ATOM 724 N GLY A 87 14.047 -0.451 -17.157 1.00 0.00 N +ATOM 725 CA GLY A 87 12.774 0.245 -17.074 1.00 0.00 C +ATOM 726 C GLY A 87 11.645 -0.491 -17.770 1.00 0.00 C +ATOM 727 O GLY A 87 10.551 0.053 -17.935 1.00 0.00 O +ATOM 728 N HIS A 88 11.979 -1.693 -18.203 1.00 0.00 N +ATOM 729 CA HIS A 88 10.969 -2.521 -18.853 1.00 0.00 C +ATOM 730 C HIS A 88 10.107 -3.246 -17.825 1.00 0.00 C +ATOM 731 CB HIS A 88 11.629 -3.533 -19.792 1.00 0.00 C +ATOM 732 O HIS A 88 10.623 -3.781 -16.841 1.00 0.00 O +ATOM 733 CG HIS A 88 12.426 -2.902 -20.889 1.00 0.00 C +ATOM 734 CD2 HIS A 88 13.745 -2.608 -20.974 1.00 0.00 C +ATOM 735 ND1 HIS A 88 11.862 -2.492 -22.078 1.00 0.00 N +ATOM 736 CE1 HIS A 88 12.803 -1.973 -22.849 1.00 0.00 C +ATOM 737 NE2 HIS A 88 13.954 -2.032 -22.202 1.00 0.00 N +ATOM 738 N ASP A 89 8.762 -3.058 -18.051 1.00 0.00 N +ATOM 739 CA ASP A 89 7.804 -3.799 -17.236 1.00 0.00 C +ATOM 740 C ASP A 89 8.001 -5.305 -17.390 1.00 0.00 C +ATOM 741 CB ASP A 89 6.371 -3.415 -17.610 1.00 0.00 C +ATOM 742 O ASP A 89 8.056 -5.818 -18.509 1.00 0.00 O +ATOM 743 CG ASP A 89 5.330 -4.062 -16.713 1.00 0.00 C +ATOM 744 OD1 ASP A 89 5.698 -4.883 -15.846 1.00 0.00 O +ATOM 745 OD2 ASP A 89 4.131 -3.750 -16.877 1.00 0.00 O +ATOM 746 N LEU A 90 8.334 -6.022 -16.320 1.00 0.00 N +ATOM 747 CA LEU A 90 8.513 -7.469 -16.350 1.00 0.00 C +ATOM 748 C LEU A 90 7.166 -8.183 -16.380 1.00 0.00 C +ATOM 749 CB LEU A 90 9.323 -7.935 -15.137 1.00 0.00 C +ATOM 750 O LEU A 90 7.109 -9.414 -16.328 1.00 0.00 O +ATOM 751 CG LEU A 90 10.771 -7.447 -15.061 1.00 0.00 C +ATOM 752 CD1 LEU A 90 11.421 -7.917 -13.764 1.00 0.00 C +ATOM 753 CD2 LEU A 90 11.563 -7.933 -16.269 1.00 0.00 C +ATOM 754 N CYS A 91 6.145 -7.647 -16.960 1.00 0.00 N +ATOM 755 CA CYS A 91 4.838 -8.217 -17.265 1.00 0.00 C +ATOM 756 C CYS A 91 4.203 -8.824 -16.019 1.00 0.00 C +ATOM 757 CB CYS A 91 4.958 -9.281 -18.356 1.00 0.00 C +ATOM 758 O CYS A 91 3.278 -9.632 -16.119 1.00 0.00 O +ATOM 759 SG CYS A 91 5.305 -8.606 -19.995 1.00 0.00 S +ATOM 760 N SER A 92 4.554 -8.421 -14.785 1.00 0.00 N +ATOM 761 CA SER A 92 3.890 -8.942 -13.594 1.00 0.00 C +ATOM 762 C SER A 92 3.226 -7.824 -12.798 1.00 0.00 C +ATOM 763 CB SER A 92 4.887 -9.688 -12.707 1.00 0.00 C +ATOM 764 O SER A 92 3.909 -6.970 -12.228 1.00 0.00 O +ATOM 765 OG SER A 92 5.438 -10.799 -13.394 1.00 0.00 O +ATOM 766 N GLN A 93 2.048 -7.488 -13.242 1.00 0.00 N +ATOM 767 CA GLN A 93 1.265 -6.515 -12.488 1.00 0.00 C +ATOM 768 C GLN A 93 0.295 -7.209 -11.535 1.00 0.00 C +ATOM 769 CB GLN A 93 0.499 -5.591 -13.435 1.00 0.00 C +ATOM 770 O GLN A 93 -0.357 -8.187 -11.908 1.00 0.00 O +ATOM 771 CG GLN A 93 1.388 -4.851 -14.425 1.00 0.00 C +ATOM 772 CD GLN A 93 0.606 -3.933 -15.345 1.00 0.00 C +ATOM 773 NE2 GLN A 93 0.928 -3.968 -16.633 1.00 0.00 N +ATOM 774 OE1 GLN A 93 -0.282 -3.198 -14.901 1.00 0.00 O +ATOM 775 N ILE A 94 0.338 -6.724 -10.308 1.00 0.00 N +ATOM 776 CA ILE A 94 -0.576 -7.263 -9.307 1.00 0.00 C +ATOM 777 C ILE A 94 -1.499 -6.156 -8.801 1.00 0.00 C +ATOM 778 CB ILE A 94 0.191 -7.903 -8.128 1.00 0.00 C +ATOM 779 O ILE A 94 -1.043 -5.052 -8.494 1.00 0.00 O +ATOM 780 CG1 ILE A 94 1.112 -9.019 -8.633 1.00 0.00 C +ATOM 781 CG2 ILE A 94 -0.785 -8.433 -7.074 1.00 0.00 C +ATOM 782 CD1 ILE A 94 2.092 -9.533 -7.587 1.00 0.00 C +ATOM 783 N PHE A 95 -2.760 -6.352 -8.842 1.00 0.00 N +ATOM 784 CA PHE A 95 -3.734 -5.442 -8.251 1.00 0.00 C +ATOM 785 C PHE A 95 -4.216 -5.966 -6.904 1.00 0.00 C +ATOM 786 CB PHE A 95 -4.925 -5.241 -9.194 1.00 0.00 C +ATOM 787 O PHE A 95 -4.517 -7.153 -6.765 1.00 0.00 O +ATOM 788 CG PHE A 95 -4.559 -4.610 -10.509 1.00 0.00 C +ATOM 789 CD1 PHE A 95 -4.613 -3.231 -10.674 1.00 0.00 C +ATOM 790 CD2 PHE A 95 -4.159 -5.395 -11.583 1.00 0.00 C +ATOM 791 CE1 PHE A 95 -4.274 -2.644 -11.891 1.00 0.00 C +ATOM 792 CE2 PHE A 95 -3.819 -4.816 -12.802 1.00 0.00 C +ATOM 793 CZ PHE A 95 -3.878 -3.440 -12.954 1.00 0.00 C +ATOM 794 N PHE A 96 -4.198 -5.106 -5.926 1.00 0.00 N +ATOM 795 CA PHE A 96 -4.673 -5.552 -4.622 1.00 0.00 C +ATOM 796 C PHE A 96 -5.353 -4.411 -3.874 1.00 0.00 C +ATOM 797 CB PHE A 96 -3.515 -6.110 -3.788 1.00 0.00 C +ATOM 798 O PHE A 96 -5.223 -3.246 -4.257 1.00 0.00 O +ATOM 799 CG PHE A 96 -2.419 -5.112 -3.528 1.00 0.00 C +ATOM 800 CD1 PHE A 96 -1.422 -4.893 -4.470 1.00 0.00 C +ATOM 801 CD2 PHE A 96 -2.386 -4.393 -2.340 1.00 0.00 C +ATOM 802 CE1 PHE A 96 -0.406 -3.970 -4.231 1.00 0.00 C +ATOM 803 CE2 PHE A 96 -1.375 -3.470 -2.094 1.00 0.00 C +ATOM 804 CZ PHE A 96 -0.385 -3.260 -3.041 1.00 0.00 C +ATOM 805 N THR A 97 -6.180 -4.778 -2.859 1.00 0.00 N +ATOM 806 CA THR A 97 -6.980 -3.835 -2.084 1.00 0.00 C +ATOM 807 C THR A 97 -6.479 -3.756 -0.645 1.00 0.00 C +ATOM 808 CB THR A 97 -8.469 -4.227 -2.092 1.00 0.00 C +ATOM 809 O THR A 97 -6.224 -4.784 -0.014 1.00 0.00 O +ATOM 810 CG2 THR A 97 -9.298 -3.253 -1.261 1.00 0.00 C +ATOM 811 OG1 THR A 97 -8.952 -4.218 -3.441 1.00 0.00 O +ATOM 812 N VAL A 98 -6.234 -2.578 -0.176 1.00 0.00 N +ATOM 813 CA VAL A 98 -5.879 -2.335 1.218 1.00 0.00 C +ATOM 814 C VAL A 98 -7.047 -1.667 1.940 1.00 0.00 C +ATOM 815 CB VAL A 98 -4.609 -1.463 1.337 1.00 0.00 C +ATOM 816 O VAL A 98 -7.631 -0.706 1.432 1.00 0.00 O +ATOM 817 CG1 VAL A 98 -4.398 -1.011 2.781 1.00 0.00 C +ATOM 818 CG2 VAL A 98 -3.388 -2.228 0.830 1.00 0.00 C +ATOM 819 N ILE A 99 -7.430 -2.236 3.089 1.00 0.00 N +ATOM 820 CA ILE A 99 -8.491 -1.681 3.923 1.00 0.00 C +ATOM 821 C ILE A 99 -7.881 -0.965 5.126 1.00 0.00 C +ATOM 822 CB ILE A 99 -9.472 -2.778 4.394 1.00 0.00 C +ATOM 823 O ILE A 99 -7.148 -1.571 5.911 1.00 0.00 O +ATOM 824 CG1 ILE A 99 -10.082 -3.502 3.188 1.00 0.00 C +ATOM 825 CG2 ILE A 99 -10.565 -2.179 5.283 1.00 0.00 C +ATOM 826 CD1 ILE A 99 -10.874 -4.750 3.551 1.00 0.00 C +ATOM 827 N CYS A 100 -8.094 0.384 5.239 1.00 0.00 N +ATOM 828 CA CYS A 100 -7.717 1.159 6.416 1.00 0.00 C +ATOM 829 C CYS A 100 -8.858 1.208 7.425 1.00 0.00 C +ATOM 830 CB CYS A 100 -7.315 2.579 6.018 1.00 0.00 C +ATOM 831 O CYS A 100 -9.811 1.971 7.255 1.00 0.00 O +ATOM 832 SG CYS A 100 -5.913 2.646 4.881 1.00 0.00 S +ATOM 833 N ARG A 101 -8.734 0.291 8.419 1.00 0.00 N +ATOM 834 CA ARG A 101 -9.778 0.171 9.433 1.00 0.00 C +ATOM 835 C ARG A 101 -9.235 0.503 10.819 1.00 0.00 C +ATOM 836 CB ARG A 101 -10.374 -1.239 9.427 1.00 0.00 C +ATOM 837 O ARG A 101 -8.234 -0.070 11.253 1.00 0.00 O +ATOM 838 CG ARG A 101 -11.479 -1.444 10.451 1.00 0.00 C +ATOM 839 CD ARG A 101 -12.048 -2.855 10.390 1.00 0.00 C +ATOM 840 NE ARG A 101 -13.253 -2.985 11.205 1.00 0.00 N +ATOM 841 NH1 ARG A 101 -13.729 -5.141 10.524 1.00 0.00 N +ATOM 842 NH2 ARG A 101 -15.098 -4.079 12.026 1.00 0.00 N +ATOM 843 CZ ARG A 101 -14.024 -4.068 11.250 1.00 0.00 C +ATOM 844 N GLU A 102 -9.902 1.446 11.474 1.00 0.00 N +ATOM 845 CA GLU A 102 -9.563 1.790 12.851 1.00 0.00 C +ATOM 846 C GLU A 102 -10.253 0.854 13.840 1.00 0.00 C +ATOM 847 CB GLU A 102 -9.940 3.242 13.152 1.00 0.00 C +ATOM 848 O GLU A 102 -11.402 0.458 13.630 1.00 0.00 O +ATOM 849 CG GLU A 102 -9.221 4.258 12.277 1.00 0.00 C +ATOM 850 CD GLU A 102 -9.635 5.694 12.559 1.00 0.00 C +ATOM 851 OE1 GLU A 102 -9.298 6.592 11.755 1.00 0.00 O +ATOM 852 OE2 GLU A 102 -10.302 5.922 13.593 1.00 0.00 O +ATOM 853 N TYR A 103 -9.489 0.409 14.757 1.00 0.00 N +ATOM 854 CA TYR A 103 -10.036 -0.359 15.869 1.00 0.00 C +ATOM 855 C TYR A 103 -10.129 0.495 17.128 1.00 0.00 C +ATOM 856 CB TYR A 103 -9.178 -1.598 16.141 1.00 0.00 C +ATOM 857 O TYR A 103 -9.113 0.975 17.636 1.00 0.00 O +ATOM 858 CG TYR A 103 -9.123 -2.564 14.982 1.00 0.00 C +ATOM 859 CD1 TYR A 103 -10.142 -3.491 14.774 1.00 0.00 C +ATOM 860 CD2 TYR A 103 -8.052 -2.553 14.095 1.00 0.00 C +ATOM 861 CE1 TYR A 103 -10.094 -4.384 13.709 1.00 0.00 C +ATOM 862 CE2 TYR A 103 -7.994 -3.442 13.027 1.00 0.00 C +ATOM 863 OH TYR A 103 -8.966 -5.234 11.786 1.00 0.00 O +ATOM 864 CZ TYR A 103 -9.018 -4.352 12.842 1.00 0.00 C +ATOM 865 N CYS A 104 -11.495 0.756 17.579 1.00 0.00 N +ATOM 866 CA CYS A 104 -11.696 1.681 18.689 1.00 0.00 C +ATOM 867 C CYS A 104 -12.302 0.967 19.891 1.00 0.00 C +ATOM 868 CB CYS A 104 -12.597 2.840 18.264 1.00 0.00 C +ATOM 869 O CYS A 104 -13.088 0.031 19.731 1.00 0.00 O +ATOM 870 SG CYS A 104 -12.045 3.679 16.763 1.00 0.00 S +ATOM 871 N CYS A 105 -11.819 1.196 21.115 1.00 0.00 N +ATOM 872 CA CYS A 105 -12.422 0.718 22.354 1.00 0.00 C +ATOM 873 C CYS A 105 -13.019 1.872 23.151 1.00 0.00 C +ATOM 874 CB CYS A 105 -11.386 -0.018 23.204 1.00 0.00 C +ATOM 875 O CYS A 105 -12.550 3.008 23.054 1.00 0.00 O +ATOM 876 SG CYS A 105 -10.010 1.023 23.738 1.00 0.00 S +TER 877 CYS A 105 +ENDMDL +END diff --git a/af_backprop/examples/sc_hall/README.md b/af_backprop/examples/sc_hall/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/af_backprop/examples/sc_hall/README.md @@ -0,0 +1 @@ + diff --git a/af_backprop/examples/sc_hall/semigreedy_refinement_4_models.ipynb b/af_backprop/examples/sc_hall/semigreedy_refinement_4_models.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..e2f8e2a44fddd8381cd4ff36294d065e96d8b8a1 --- /dev/null +++ b/af_backprop/examples/sc_hall/semigreedy_refinement_4_models.ipynb @@ -0,0 +1,942 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "semigreedy_refinement_4_models.ipynb", + "provenance": [], + "collapsed_sections": [], + "authorship_tag": "ABX9TyN4wuQkRswgF3n+yu1fsFUx", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5IOr3jQEvoe6", + "outputId": "63354ed1-48d3-4b35-e91c-20a0aade9e02" + }, + "source": [ + "%%bash\n", + "if [ ! -d af_backprop ]; then\n", + " git clone https://github.com/sokrypton/af_backprop.git\n", + " pip -q install dm-haiku py3Dmol biopython ml_collections\n", + "fi\n", + "if [ ! -d params ]; then\n", + " mkdir params\n", + " curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params\n", + "fi" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Cloning into 'af_backprop'...\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3Ym3Vie7v1Yb" + }, + "source": [ + "import os\n", + "import sys\n", + "sys.path.append('af_backprop')\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import py3Dmol\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "from jax.experimental.optimizers import adam\n", + "\n", + "from alphafold.common import protein\n", + "from alphafold.data import pipeline\n", + "from alphafold.model import data, config, model, modules\n", + "from alphafold.common import residue_constants\n", + "\n", + "from alphafold.model import all_atom\n", + "from alphafold.model import folding\n", + "\n", + "# custom functions\n", + "from alphafold.data import prep_inputs\n", + "from utils import *" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "shh_V1eswjrH" + }, + "source": [ + "# setup which model params to use\n", + "model_name = \"model_3_ptm\"\n", + "model_config = config.model_config(model_name)\n", + "\n", + "# enable checkpointing\n", + "model_config.model.global_config.use_remat = True\n", + "\n", + "# number of recycles\n", + "model_config.model.num_recycle = 3\n", + "model_config.data.common.num_recycle = 3\n", + "\n", + "# backprop through recycles\n", + "model_config.model.backprop_recycle = False\n", + "model_config.model.embeddings_and_evoformer.backprop_dgram = False\n", + "\n", + "# custom relative features (needed for insertion/deletion)\n", + "INDELS = False\n", + "model_config.model.embeddings_and_evoformer.custom_relative_features = INDELS\n", + "\n", + "# number of sequences\n", + "N = 1\n", + "model_config.data.eval.max_msa_clusters = N\n", + "model_config.data.common.max_extra_msa = 1\n", + "model_config.data.eval.masked_msa_replace_fraction = 0\n", + "\n", + "# dropout\n", + "model_config = set_dropout(model_config, 0.0)\n", + "\n", + "# setup model\n", + "model_params = [data.get_model_haiku_params(model_name=model_name, data_dir=\".\")]\n", + "model_runner = model.RunModel(model_config, model_params[0], is_training=True)\n", + "\n", + "# load the other models to sample during design.\n", + "for model_name in [\"model_1_ptm\",\"model_2_ptm\",\"model_5_ptm\",\"model_4_ptm\"]:\n", + " params = data.get_model_haiku_params(model_name, '.')\n", + " model_params.append({k: params[k] for k in model_runner.params.keys()})" + ], + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "XOB0t6dslBAD" + }, + "source": [ + "#################\n", + "# USER INPUT\n", + "#################\n", + "# native structure you want to pull active site from\n", + "pos_idx_ref = [13,37,98] # note: zero indexed\n", + "PDB_REF = \"af_backprop/examples/sc_hall/1QJG.pdb\"\n", + "\n", + "# starting structure (for random starting sequence, set PDB=None and LEN to desired length)\n", + "pos_idx = [74+5,32+5,7+5]\n", + "MODE = \"af_backprop/examples/sc_hall/1QJS_starting\"\n", + "PDB = f\"{MODE}.pdb\"\n", + "LEN = 105" + ], + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "yYiNDSxNwmVw" + }, + "source": [ + "# prep reference (native) features\n", + "OBJ_REF = protein.from_pdb_string(pdb_to_string(PDB_REF), chain_id=\"A\")\n", + "SEQ_REF = jax.nn.one_hot(OBJ_REF.aatype,20)\n", + "START_SEQ_REF = \"\".join([order_restype[a] for a in OBJ_REF.aatype])\n", + "\n", + "batch_ref = {'aatype': OBJ_REF.aatype,\n", + " 'all_atom_positions': OBJ_REF.atom_positions,\n", + " 'all_atom_mask': OBJ_REF.atom_mask}\n", + "batch_ref.update(all_atom.atom37_to_frames(**batch_ref))\n", + "batch_ref.update(prep_inputs.make_atom14_positions(batch_ref))\n", + "batch_ref[\"idx\"] = pos_idx_ref\n", + "\n", + "# prep starting (design) features\n", + "if PDB is not None:\n", + " OBJ = protein.from_pdb_string(pdb_to_string(PDB), chain_id=\"A\")\n", + " SEQ = jax.nn.one_hot(OBJ.aatype,20)\n", + " START_SEQ = \"\".join([order_restype[a] for a in OBJ.aatype])\n", + "\n", + " batch = {'aatype': OBJ.aatype,\n", + " 'all_atom_positions': OBJ.atom_positions,\n", + " 'all_atom_mask': OBJ.atom_mask}\n", + " batch.update(all_atom.atom37_to_frames(**batch))\n", + " batch.update(prep_inputs.make_atom14_positions(batch))\n", + "else:\n", + " SEQ = jnp.zeros(LEN).at[jnp.asarray(pos_idx)].set([OBJ_REF.aatype[i] for i in pos_idx_ref])\n", + " START_SEQ = \"\".join([order_restype[a] for a in SEQ])\n", + " SEQ = jax.nn.one_hot(SEQ,20)\n", + "\n", + "# prep input features\n", + "feature_dict = {\n", + " **pipeline.make_sequence_features(sequence=START_SEQ,description=\"none\",num_res=len(START_SEQ)),\n", + " **pipeline.make_msa_features(msas=[N*[START_SEQ]], deletion_matrices=[N*[[0]*len(START_SEQ)]]),\n", + "}\n", + "inputs = model_runner.process_features(feature_dict, random_seed=0)\n", + "\n", + "if N > 1:\n", + " inputs[\"msa_row_mask\"] = jnp.ones_like(inputs[\"msa_row_mask\"])\n", + " inputs[\"msa_mask\"] = jnp.ones_like(inputs[\"msa_mask\"])" + ], + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ADmZt232wr8O", + "outputId": "67056d1f-d0b4-4825-8ad7-13a799b4d905" + }, + "source": [ + "print([START_SEQ[i] for i in pos_idx])\n", + "print([START_SEQ_REF[i] for i in pos_idx_ref])" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "['Y', 'N', 'D']\n", + "['Y', 'N', 'D']\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GfdnXo9ywwVg" + }, + "source": [ + "def get_grad_fn(model_runner, inputs, pos_idx_ref, inc_backbone=False):\n", + " \n", + " def mod(params, key, model_params, opt):\n", + " pos_idx = opt[\"pos_idx\"]\n", + " pos_idx_ref = batch_ref[\"idx\"]\n", + " ############################\n", + " # set amino acid sequence\n", + " ############################\n", + " seq_logits = jax.random.permutation(key, params[\"msa\"])\n", + " seq_soft = jax.nn.softmax(seq_logits)\n", + " seq = jax.lax.stop_gradient(jax.nn.one_hot(seq_soft.argmax(-1),20) - seq_soft) + seq_soft\n", + " seq = seq.at[:,pos_idx,:].set(SEQ_REF[pos_idx_ref,:])\n", + "\n", + " oh_mask = opt[\"oh_mask\"][:,None]\n", + " pseudo_seq = oh_mask * seq + (1-oh_mask) * seq_logits\n", + "\n", + " inputs_mod = inputs.copy()\n", + " update_seq(pseudo_seq, inputs_mod, msa_input=(\"msa\" in params))\n", + "\n", + " if \"msa_mask\" in opt:\n", + " inputs_mod[\"msa_mask\"] = inputs_mod[\"msa_mask\"] * opt[\"msa_mask\"][None,:,None]\n", + " inputs_mod[\"msa_row_mask\"] = inputs_mod[\"msa_row_mask\"] * opt[\"msa_mask\"][None,:]\n", + " \n", + " ####################\n", + " # set sidechains identity\n", + " ####################\n", + " B,L = inputs_mod[\"aatype\"].shape[:2]\n", + " ALA = jax.nn.one_hot(residue_constants.restype_order[\"A\"],21)\n", + "\n", + " aatype = jnp.zeros((B,L,21)).at[...,:20].set(seq[0])\n", + " ala_mask = opt[\"ala_mask\"][:,None]\n", + " aatype_ala = jnp.zeros((B,L,21)).at[:].set(ALA)\n", + " aatype_ala = aatype_ala.at[:,pos_idx,:20].set(SEQ_REF[pos_idx_ref,:])\n", + " aatype_pseudo = ala_mask * aatype + (1-ala_mask) * aatype_ala\n", + " update_aatype(aatype_pseudo, inputs_mod)\n", + "\n", + " ############################################################\n", + " if model_runner.config.model.embeddings_and_evoformer.custom_relative_features:\n", + " # set positions\n", + " active_pos = jax.nn.sigmoid(params[\"active_pos\"])\n", + " active_pos = active_pos.at[jnp.asarray(pos_idx)].set(1.0)\n", + "\n", + " # hard constraint\n", + " active_pos = jax.lax.stop_gradient((active_pos > 0.5).astype(jnp.float32) - active_pos) + active_pos\n", + " \n", + " residue_idx = jax.lax.scan(lambda x,y:(x+y,x), 0, active_pos)[1]\n", + " offset = residue_idx[:, None] - residue_idx[None, :]\n", + " rel_pos = jax.nn.softmax(-jnp.square(offset[...,None] - jnp.arange(-32,33,dtype=jnp.float32)))\n", + "\n", + " inputs_mod[\"rel_pos\"] = jnp.tile(rel_pos[None],[B,1,1,1])\n", + " inputs_mod[\"seq_mask\"] = jnp.zeros_like(inputs_mod[\"seq_mask\"]).at[...,:].set(active_pos)\n", + " inputs_mod[\"msa_mask\"] = jnp.zeros_like(inputs_mod[\"msa_mask\"]).at[...,:].set(active_pos)\n", + "\n", + " inputs_mod[\"atom14_atom_exists\"] *= active_pos[None,:,None]\n", + " inputs_mod[\"atom37_atom_exists\"] *= active_pos[None,:,None]\n", + " inputs_mod[\"residx_atom14_to_atom37\"] *= active_pos[None,:,None,None]\n", + " inputs_mod[\"residx_atom37_to_atom14\"] *= active_pos[None,:,None,None]\n", + "\n", + " ############################################################\n", + " \n", + " # get output\n", + " outputs = model_runner.apply(model_params, key, inputs_mod)\n", + "\n", + " ###################\n", + " # structure loss\n", + " ###################\n", + " fape_loss = get_fape_loss_idx(batch_ref, outputs, pos_idx, model_config, backbone=inc_backbone, sidechain=True)\n", + " rmsd_loss = get_sidechain_rmsd_idx(batch_ref, outputs, pos_idx, model_config)\n", + " dgram_loss = get_dgram_loss_idx(batch_ref, outputs, pos_idx, model_config)\n", + "\n", + " losses = {\"fape\":fape_loss,\n", + " \"rmsd\":rmsd_loss,\n", + " \"dgram\":dgram_loss}\n", + "\n", + " if \"sc_weight_fape\" in opt: fape_loss *= opt[\"sc_weight_fape\"]\n", + " if \"sc_weight_rmsd\" in opt: rmsd_loss *= opt[\"sc_weight_rmsd\"]\n", + " if \"sc_weight_dgram\" in opt: dgram_loss *= opt[\"sc_weight_dgram\"]\n", + "\n", + " loss = (rmsd_loss + fape_loss + dgram_loss) * opt[\"sc_weight\"]\n", + " \n", + " ################### \n", + " # background loss\n", + " ###################\n", + " if \"conf_weight\" in opt:\n", + " pae = jax.nn.softmax(outputs[\"predicted_aligned_error\"][\"logits\"])\n", + " plddt = jax.nn.softmax(outputs['predicted_lddt']['logits'])\n", + " pae_loss = (pae * jnp.arange(pae.shape[-1])).sum(-1)\n", + " plddt_loss = (plddt * jnp.arange(plddt.shape[-1])[::-1]).sum(-1)\n", + "\n", + " if model_runner.config.model.embeddings_and_evoformer.custom_relative_features:\n", + " active_pos_mask = active_pos[:,None] * active_pos[None,:]\n", + " pae_loss = (pae_loss * active_pos_mask).sum() / (1e-8 + active_pos_mask.sum())\n", + " plddt_loss = (plddt_loss * active_pos).sum() / (1e-8 + active_pos.sum())\n", + " else:\n", + " pae_loss = pae_loss.mean()\n", + " plddt_loss = plddt_loss.mean()\n", + "\n", + " loss = loss + (pae_loss + plddt_loss) * opt[\"conf_weight\"]\n", + " losses[\"pae\"] = pae_loss\n", + " losses[\"plddt\"] = plddt_loss\n", + "\n", + " if \"rg_weight\" in opt:\n", + " ca_coords = outputs[\"structure_module\"][\"final_atom_positions\"][:,1,:]\n", + " rg_loss = jnp.sqrt(jnp.square(ca_coords - ca_coords.mean(0)).sum(-1).mean() + 1e-8)\n", + " loss = loss + rg_loss * opt[\"rg_weight\"]\n", + " losses[\"rg\"] = rg_loss\n", + " \n", + " if \"msa\" in params and \"ent_weight\" in opt:\n", + " seq_prf = seq.mean(0)\n", + " ent_loss = -(seq_prf * jnp.log(seq_prf + 1e-8)).sum(-1).mean()\n", + " loss = loss + ent_loss * opt[\"ent_weight\"]\n", + " losses[\"ent\"] = ent_loss\n", + " else:\n", + " ent_loss = 0\n", + "\n", + " outs = {\"final_atom_positions\":outputs[\"structure_module\"][\"final_atom_positions\"],\n", + " \"final_atom_mask\":outputs[\"structure_module\"][\"final_atom_mask\"]}\n", + "\n", + " if model_runner.config.model.embeddings_and_evoformer.custom_relative_features:\n", + " outs[\"residue_idx\"] = residue_idx\n", + "\n", + " seq_ = seq[0] if \"msa\" in params else seq\n", + "\n", + " return loss, ({\"losses\":losses, \"outputs\":outs, \"seq\":seq_})\n", + " loss_fn = mod\n", + " grad_fn = jax.value_and_grad(mod, has_aux=True, argnums=0)\n", + " return loss_fn, grad_fn" + ], + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "vavxyvYJwyPC" + }, + "source": [ + "# gradient function (note for greedy search we won't be using grad_fn, only loss_fn)\n", + "loss_fn, grad_fn = get_grad_fn(model_runner, inputs, pos_idx_ref=pos_idx_ref)\n", + "loss_fn = jax.jit(loss_fn)\n", + "\n", + "# stack model params (we exclude the last model: model_4_ptm for validation)\n", + "model_params_multi = jax.tree_multimap(lambda *values: jnp.stack(values, axis=0), *model_params[:-1])\n", + "loss_fn_multi = jax.jit(jax.vmap(loss_fn,(None,None,0,None)))" + ], + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "lsLk1lLCQNcw" + }, + "source": [ + "key = jax.random.PRNGKey(0)\n", + "L,A = len(START_SEQ),20\n", + "\n", + "pos_idx_ = jnp.asarray(pos_idx)\n", + "pos_idx_ref_ = jnp.asarray(pos_idx_ref)\n", + "\n", + "msa = SEQ[None]\n", + "params = {\"msa\":msa, \"active_pos\":jnp.ones(L)}" + ], + "execution_count": 9, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "J7ZL2zc4t9I0" + }, + "source": [ + "def mut(params, indel=False):\n", + " L,A = params[\"msa\"].shape[-2:]\n", + " while True:\n", + " i = np.random.randint(L)\n", + " a = np.random.randint(A)\n", + " if i not in pos_idx and params[\"msa\"][0,i,a] == 0 and (params[\"active_pos\"][i] == 1 or indel):\n", + " break\n", + "\n", + " params_ = params.copy()\n", + " params_[\"msa\"] = params[\"msa\"].at[:,i,:].set(jnp.eye(A)[a])\n", + "\n", + " if indel:\n", + " state = -1 if params[\"active_pos\"][i] == 1 else 1\n", + " params_[\"active_pos\"] = params[\"active_pos\"].at[i].set(state)\n", + " return params_" + ], + "execution_count": 10, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wWlPaQGB4Nq8" + }, + "source": [ + "multi-model refinement" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "dqM44WHW3DMw", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "562d6eae-9592-44a4-e1db-b65df3931fa9" + }, + "source": [ + "oh_mask = jnp.ones((L,))\n", + "ala_mask = jnp.ones((L,))\n", + "msa_mask = jnp.ones((N,))\n", + "opt={\"oh_mask\":oh_mask,\n", + " \"msa_mask\":msa_mask,\n", + " \"ala_mask\":ala_mask,\n", + " \"sc_weight\":1.0,\n", + " \"sc_weight_rmsd\":1.0,\n", + " \"sc_weight_fape\":1.0,\n", + " \"sc_weight_dgram\":0.0,\n", + " \"conf_weight\":0.01,\n", + " \"pos_idx\":pos_idx_}\n", + "loss, outs = loss_fn_multi(params, key, model_params_multi, opt)\n", + "print(np.mean(loss),\n", + " np.mean(outs[\"losses\"][\"rmsd\"]),\n", + " np.mean(outs[\"losses\"][\"fape\"]))\n", + "\n", + "print(outs[\"losses\"][\"rmsd\"])" + ], + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "5.0831347 4.0395207 0.5736508\n", + "[0.32047477 6.103425 3.9624836 5.7717004 ]\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VAPNQZ4232bN", + "outputId": "5e2fefc9-ade5-406c-f860-f33bd296cf4a" + }, + "source": [ + "LOSS = np.mean(loss)\n", + "OVERALL_RMSD = np.mean(outs[\"losses\"][\"rmsd\"])\n", + "OVERALL_FAPE = np.mean(outs[\"losses\"][\"fape\"])\n", + "OVERALL_LOSS = LOSS\n", + "key = jax.random.PRNGKey(0)\n", + "for n in range(10):\n", + " params_ = params.copy()\n", + " buff_p,buff_l,buff_o = [],[],[]\n", + " for m in range(20):\n", + " key,subkey = jax.random.split(key)\n", + " do_indel = False #np.random.uniform() < 0.25\n", + " p = mut(params, indel=do_indel)\n", + " l,o = loss_fn_multi(p, subkey, model_params_multi, opt)\n", + " print(\"-----------\", m, np.mean(o[\"losses\"][\"rmsd\"]), list(o[\"losses\"][\"rmsd\"]))\n", + " buff_p.append(p); buff_l.append(l); buff_o.append(o)\n", + " best = np.argmin(np.asarray(buff_l).mean(-1))\n", + " params, LOSS, outs = buff_p[best], buff_l[best], buff_o[best]\n", + " LOSS = np.mean(LOSS)\n", + " RMSD = np.mean(outs[\"losses\"][\"rmsd\"])\n", + " FAPE = np.mean(outs[\"losses\"][\"fape\"])\n", + "\n", + " outs = jax.tree_map(lambda x: x[0], outs)\n", + " if RMSD < OVERALL_RMSD:\n", + " OVERALL_RMSD = RMSD\n", + " save_pdb(outs,f\"{MODE}_best_rmsd.pdb\")\n", + " if FAPE < OVERALL_FAPE:\n", + " OVERALL_FAPE = FAPE\n", + " save_pdb(outs,f\"{MODE}_best_fape.pdb\")\n", + " if LOSS < OVERALL_LOSS:\n", + " OVERALL_LOSS = LOSS\n", + " save_pdb(outs,f\"{MODE}_best_loss.pdb\")\n", + " print(n, LOSS, RMSD, FAPE, (params[\"active_pos\"] > 0).sum(), len(buff_l))" + ], + "execution_count": 13, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "----------- 0 10.582149 [12.104791, 11.438794, 6.8327885, 11.95222]\n", + "----------- 1 5.7060966 [0.34404072, 7.8016953, 6.8321495, 7.8465023]\n", + "----------- 2 4.9283195 [0.38857314, 7.33674, 3.9765234, 8.011441]\n", + "----------- 3 4.596478 [1.2857033, 8.185179, 3.6270323, 5.2879977]\n", + "----------- 4 4.33639 [0.34355134, 7.6549, 3.8751266, 5.471982]\n", + "----------- 5 6.8581476 [7.102624, 8.735716, 4.742427, 6.851823]\n", + "----------- 6 3.9228725 [0.37756792, 6.210618, 3.7853143, 5.317991]\n", + "----------- 7 4.8257957 [0.3333986, 8.438012, 5.8835196, 4.648253]\n", + "----------- 8 4.437786 [0.32170418, 8.433742, 4.022462, 4.9732375]\n", + "----------- 9 4.3738766 [0.3662424, 8.010605, 3.854046, 5.264613]\n", + "----------- 10 5.4476604 [0.3221591, 7.0530643, 5.5048957, 8.910523]\n", + "----------- 11 4.311565 [0.33136475, 6.864394, 4.1079164, 5.9425855]\n", + "----------- 12 8.771259 [7.6971173, 9.20276, 8.589025, 9.596133]\n", + "----------- 13 4.458911 [0.32208133, 8.985687, 4.7378616, 3.7900143]\n", + "----------- 14 4.6013637 [0.35995653, 6.008712, 4.1167917, 7.919994]\n", + "----------- 15 4.414083 [0.323973, 8.743649, 3.9161198, 4.672591]\n", + "----------- 16 5.3490214 [1.0300151, 11.366785, 4.0967994, 4.902487]\n", + "----------- 17 4.038536 [0.33273467, 6.64519, 4.211386, 4.9648337]\n", + "----------- 18 4.9283195 [0.38857314, 7.33674, 3.9765234, 8.011441]\n", + "----------- 19 3.6465042 [0.344874, 5.7528996, 3.7142928, 4.7739506]\n", + "0 4.665577 3.6465042 0.57121277 105 20\n", + "----------- 0 4.8786836 [0.389769, 4.1099725, 6.0193267, 8.9956665]\n", + "----------- 1 3.5066729 [0.34728014, 6.085302, 3.764575, 3.8295348]\n", + "----------- 2 3.6940656 [0.3797003, 6.742977, 3.6415355, 4.01205]\n", + "----------- 3 3.7695255 [0.36818364, 5.561354, 4.4751167, 4.6734476]\n", + "----------- 4 3.6419036 [0.3596261, 5.3361835, 3.9025295, 4.9692755]\n", + "----------- 5 3.8179073 [0.3272733, 5.84101, 4.1292744, 4.9740715]\n", + "----------- 6 4.6439056 [0.345108, 6.285658, 8.125, 3.8198566]\n", + "----------- 7 3.8682199 [1.9818479, 6.9291024, 2.6291928, 3.9327354]\n", + "----------- 8 4.152326 [0.35060146, 5.0279503, 5.719198, 5.5115542]\n", + "----------- 9 8.809828 [12.948995, 5.485416, 7.6189466, 9.185956]\n", + "----------- 10 4.5837007 [2.7283385, 6.2634826, 4.454889, 4.8880925]\n", + "----------- 11 4.010665 [0.42540643, 8.533676, 4.0781097, 3.0054672]\n", + "----------- 12 3.7557948 [0.3431681, 4.354364, 5.114262, 5.211385]\n", + "----------- 13 3.7458181 [0.368222, 5.6166067, 4.855059, 4.1433854]\n", + "----------- 14 3.7996607 [0.3551456, 5.5467305, 4.9242435, 4.372523]\n", + "----------- 15 3.6419036 [0.3596261, 5.3361835, 3.9025295, 4.9692755]\n", + "----------- 16 3.7063732 [0.3709019, 5.440703, 4.383684, 4.6302032]\n", + "----------- 17 4.4927187 [2.5542479, 5.6653757, 3.2255049, 6.5257473]\n", + "----------- 18 2.923112 [0.4098379, 4.310831, 6.3609443, 0.6108337]\n", + "----------- 19 4.8385973 [1.1260694, 7.660815, 4.391685, 6.1758204]\n", + "1 3.8471584 2.923112 0.47667003 105 20\n", + "----------- 0 2.4071047 [0.40594643, 4.268774, 4.386216, 0.567483]\n", + "----------- 1 3.5637014 [3.4049003, 5.7517843, 4.466902, 0.631219]\n", + "----------- 2 6.385169 [0.42317274, 7.9488277, 10.441073, 6.7276015]\n", + "----------- 3 7.003173 [1.0159206, 8.392572, 10.478087, 8.12611]\n", + "----------- 4 4.4137754 [0.44963905, 4.5486503, 5.6105924, 7.0462213]\n", + "----------- 5 4.247386 [0.46088308, 5.267075, 4.02893, 7.2326555]\n", + "----------- 6 6.511106 [0.62076414, 10.929417, 7.5565777, 6.937666]\n", + "----------- 7 6.4648147 [0.94779783, 8.850026, 9.21556, 6.8458743]\n", + "----------- 8 2.9776573 [0.4029673, 5.3560576, 5.5093575, 0.64224577]\n", + "----------- 9 3.6655078 [0.37931573, 3.7486086, 5.3787656, 5.155341]\n", + "----------- 10 3.8851724 [0.4140052, 4.2403994, 7.3400993, 3.5461855]\n", + "----------- 11 3.404581 [1.050032, 7.233591, 4.872441, 0.46226132]\n", + "----------- 12 2.5674534 [0.4273828, 3.6238961, 5.5727496, 0.64578414]\n", + "----------- 13 4.3858285 [0.46913218, 6.663731, 5.7499824, 4.6604686]\n", + "----------- 14 4.559425 [0.4063991, 4.701339, 8.109741, 5.0202203]\n", + "----------- 15 3.7580328 [0.3951192, 5.44158, 8.7475815, 0.44785148]\n", + "----------- 16 2.678008 [0.4227554, 4.3767376, 5.2928696, 0.6196703]\n", + "----------- 17 3.5162039 [0.45585072, 5.0549345, 7.816131, 0.7378993]\n", + "----------- 18 2.1885293 [0.4061838, 3.8336825, 3.8553498, 0.65890104]\n", + "----------- 19 2.6114652 [0.46717995, 3.6555567, 5.7019725, 0.6211518]\n", + "2 3.067666 2.1885293 0.46251404 105 20\n", + "----------- 0 2.3537211 [0.4172569, 4.672165, 3.645054, 0.68040824]\n", + "----------- 1 6.565819 [0.46917838, 9.24697, 13.639761, 2.9073648]\n", + "----------- 2 2.5486968 [0.42084527, 5.1198244, 4.0136366, 0.64048123]\n", + "----------- 3 2.265966 [0.44895777, 4.072321, 3.9046526, 0.63793194]\n", + "----------- 4 6.415695 [5.689722, 7.4149528, 6.209097, 6.349009]\n", + "----------- 5 6.701787 [6.1952324, 4.0250883, 9.226227, 7.360599]\n", + "----------- 6 4.2596745 [0.4309035, 6.793351, 5.283859, 4.530585]\n", + "----------- 7 4.2896414 [0.4131705, 6.1410913, 5.336551, 5.267752]\n", + "----------- 8 4.533273 [0.41316566, 3.5603185, 8.388473, 5.7711363]\n", + "----------- 9 2.173087 [0.48648486, 3.431612, 4.1570277, 0.61722285]\n", + "----------- 10 2.343117 [0.41405722, 4.579, 3.7360349, 0.6433763]\n", + "----------- 11 2.2203813 [0.408508, 4.2504196, 3.6253512, 0.5972458]\n", + "----------- 12 2.8376088 [0.40539894, 3.930316, 3.6116252, 3.4030957]\n", + "----------- 13 2.9735963 [0.45622826, 6.1875687, 4.7524834, 0.4981052]\n", + "----------- 14 3.9415567 [0.48253584, 5.8193073, 6.4667287, 2.9976552]\n", + "----------- 15 2.789737 [0.39956462, 7.160851, 2.9988146, 0.5997168]\n", + "----------- 16 5.3915973 [0.41455936, 7.5174804, 8.774411, 4.859938]\n", + "----------- 17 3.1227882 [0.40277773, 4.0853357, 4.670654, 3.3323848]\n", + "----------- 18 3.9933128 [0.5662083, 6.5397005, 5.448346, 3.418997]\n", + "----------- 19 8.413696 [5.8865013, 10.086903, 11.466493, 6.2148895]\n", + "3 3.0549202 2.173087 0.46643674 105 20\n", + "----------- 0 3.0471392 [0.47183362, 5.662226, 5.3606963, 0.6938004]\n", + "----------- 1 4.0131655 [0.5177209, 4.983051, 6.9292545, 3.6226368]\n", + "----------- 2 2.17895 [0.50002897, 3.2442615, 4.345606, 0.6259041]\n", + "----------- 3 2.374694 [0.4852461, 5.912502, 2.445409, 0.65562004]\n", + "----------- 4 2.3062282 [0.44955295, 3.5542312, 4.5708838, 0.65024453]\n", + "----------- 5 6.3752103 [0.5182901, 10.698274, 7.579367, 6.70491]\n", + "----------- 6 4.6237783 [0.5973701, 5.4147644, 3.7537475, 8.729232]\n", + "----------- 7 3.6459582 [0.6986359, 6.070299, 3.8660624, 3.9488356]\n", + "----------- 8 4.2973356 [0.48809233, 2.9792452, 10.25743, 3.464575]\n", + "----------- 9 3.4267287 [0.47858143, 8.201494, 4.4400992, 0.5867403]\n", + "----------- 10 6.5054493 [0.4609542, 8.021693, 7.273257, 10.265895]\n", + "----------- 11 6.4646187 [0.45118427, 11.614124, 8.662091, 5.131075]\n", + "----------- 12 3.44925 [0.5102441, 6.75787, 5.6903915, 0.83849436]\n", + "----------- 13 1.5510874 [0.46765023, 3.1209612, 1.937722, 0.67801625]\n", + "----------- 14 4.1438828 [0.5995776, 8.073297, 4.360333, 3.5423234]\n", + "----------- 15 3.0242546 [0.4761192, 6.872156, 4.1430902, 0.6056532]\n", + "----------- 16 3.354655 [0.5910091, 7.9251766, 4.2511578, 0.65127695]\n", + "----------- 17 4.737821 [0.52790046, 12.361472, 4.349435, 1.7124759]\n", + "----------- 18 3.3077118 [0.48165753, 2.8231888, 5.756186, 4.169815]\n", + "----------- 19 3.858794 [0.46834785, 6.2006445, 5.3235, 3.442683]\n", + "4 2.3836753 1.5510874 0.40621892 105 20\n", + "----------- 0 2.408711 [0.59259933, 2.994746, 5.3717113, 0.6757873]\n", + "----------- 1 3.583911 [0.59484565, 3.0236213, 6.1144195, 4.602757]\n", + "----------- 2 3.7031112 [1.9154848, 4.1394815, 3.6726053, 5.084873]\n", + "----------- 3 4.259982 [0.6372889, 6.2120976, 6.1206055, 4.0699363]\n", + "----------- 4 6.560093 [0.4791305, 12.672071, 6.238699, 6.850472]\n", + "----------- 5 4.8324957 [0.47789723, 7.418234, 6.5990267, 4.8348246]\n", + "----------- 6 4.253159 [0.6549075, 7.017164, 7.2769494, 2.0636148]\n", + "----------- 7 3.104456 [0.45488828, 3.344442, 4.431807, 4.186686]\n", + "----------- 8 5.6592755 [3.4889212, 6.648672, 9.099052, 3.4004571]\n", + "----------- 9 6.041602 [2.0915234, 7.53962, 5.7716804, 8.763584]\n", + "----------- 10 2.513018 [0.986735, 3.3085477, 4.969964, 0.7868239]\n", + "----------- 11 2.6217945 [0.45438075, 3.1789834, 6.181294, 0.6725199]\n", + "----------- 12 3.941772 [0.6494507, 3.9241767, 5.559955, 5.6335053]\n", + "----------- 13 6.465238 [7.7704196, 9.203311, 5.048996, 3.8382263]\n", + "----------- 14 9.766651 [9.576719, 12.081644, 6.587141, 10.821101]\n", + "----------- 15 4.38824 [0.86634266, 5.365814, 6.151191, 5.1696105]\n", + "----------- 16 4.7873554 [1.5247656, 7.3221726, 5.155653, 5.146831]\n", + "----------- 17 3.2191634 [0.45180723, 6.635691, 5.1962996, 0.5928558]\n", + "----------- 18 1.0649618 [0.4786889, 2.5209634, 0.5535031, 0.70669174]\n", + "----------- 19 2.9940672 [0.46470752, 3.1425028, 4.3368435, 4.0322146]\n", + "5 1.7999133 1.0649618 0.35217547 105 20\n", + "----------- 0 2.9137554 [0.5797037, 3.3810425, 2.3903174, 5.3039575]\n", + "----------- 1 1.3706349 [0.60660183, 3.5293875, 0.5786312, 0.76791924]\n", + "----------- 2 3.2712543 [0.6478559, 6.8537245, 5.073406, 0.51003104]\n", + "----------- 3 2.3383627 [0.5436503, 2.9409494, 3.0435867, 2.8252642]\n", + "----------- 4 2.665322 [0.47297964, 4.5350113, 5.0160937, 0.63720375]\n", + "----------- 5 1.1749109 [0.49457368, 3.0391412, 0.5314649, 0.6344639]\n", + "----------- 6 4.515865 [3.1955972, 7.380846, 4.317868, 3.1691482]\n", + "----------- 7 3.0060768 [0.48280758, 6.405368, 0.59157217, 4.544559]\n", + "----------- 8 2.8049726 [0.47285232, 2.9449687, 0.5356863, 7.266383]\n", + "----------- 9 1.2798908 [0.741102, 2.8399186, 0.77493376, 0.7636087]\n", + "----------- 10 1.2820382 [0.69970554, 3.0904496, 0.63325167, 0.70474607]\n", + "----------- 11 2.962519 [0.4278536, 4.5380526, 6.266224, 0.61794597]\n", + "----------- 12 3.182825 [0.4819531, 8.55946, 0.8070526, 2.8828354]\n", + "----------- 13 3.605903 [0.4913268, 6.5359335, 4.999935, 2.396417]\n", + "----------- 14 2.740058 [0.54361457, 2.5836465, 5.5235553, 2.3094163]\n", + "----------- 15 6.772599 [6.7603326, 8.049665, 5.5924, 6.6879997]\n", + "----------- 16 1.4681029 [1.3279539, 3.2857857, 0.6039699, 0.6547022]\n", + "----------- 17 5.723872 [5.6197824, 6.501586, 7.962329, 2.8117895]\n", + "----------- 18 8.821382 [6.989175, 9.054461, 11.158386, 8.083503]\n", + "----------- 19 4.487301 [0.47633642, 7.56953, 7.7259836, 2.1773546]\n", + "6 1.9241389 1.1749109 0.35637453 105 20\n", + "----------- 0 5.563302 [0.5416731, 8.03251, 5.338444, 8.340583]\n", + "----------- 1 5.9103003 [0.49172208, 7.6001053, 7.972028, 7.577347]\n", + "----------- 2 2.5772943 [0.5054023, 6.534628, 2.565927, 0.7032203]\n", + "----------- 3 3.149089 [2.9514372, 2.9636126, 4.853099, 1.8282076]\n", + "----------- 4 5.3910804 [0.5104481, 7.3227873, 8.840729, 4.8903584]\n", + "----------- 5 2.8454852 [0.519192, 6.0013995, 2.2649152, 2.5964339]\n", + "----------- 6 5.9070616 [0.47018862, 10.039518, 5.3783193, 7.7402215]\n", + "----------- 7 3.656711 [0.4781471, 2.7352362, 6.1736174, 5.239844]\n", + "----------- 8 8.381494 [8.087517, 7.3926554, 10.748937, 7.2968645]\n", + "----------- 9 5.4170775 [3.1999342, 7.563527, 2.8614478, 8.043402]\n", + "----------- 10 2.9442503 [0.77243835, 6.1304655, 3.8895233, 0.98457426]\n", + "----------- 11 5.116531 [0.5584655, 6.541425, 9.796945, 3.5692887]\n", + "----------- 12 3.4091916 [0.5013325, 4.0858235, 3.5191205, 5.53049]\n", + "----------- 13 3.9140825 [0.5068759, 7.5968504, 2.8186543, 4.73395]\n", + "----------- 14 4.0137873 [0.5880374, 6.1081376, 4.78733, 4.5716434]\n", + "----------- 15 3.6466699 [0.48415077, 7.088579, 3.2573853, 3.7565641]\n", + "----------- 16 2.868372 [0.4403102, 3.0626845, 4.2559824, 3.7145107]\n", + "----------- 17 6.698191 [2.7048624, 8.386596, 7.608087, 8.09322]\n", + "----------- 18 2.0177314 [0.49044654, 6.0936966, 0.6455321, 0.84125084]\n", + "----------- 19 4.445796 [0.49084687, 9.170093, 2.4752064, 5.647039]\n", + "7 2.854764 2.0177314 0.37624437 105 20\n", + "----------- 0 1.0931559 [0.49172053, 2.6824176, 0.5764521, 0.62203354]\n", + "----------- 1 4.773432 [1.517017, 10.158834, 5.3810177, 2.0368586]\n", + "----------- 2 1.613029 [0.9292759, 3.1221318, 1.5021861, 0.89852214]\n", + "----------- 3 2.6226964 [1.8922995, 2.640792, 3.1363952, 2.8212986]\n", + "----------- 4 3.0730336 [0.49155927, 7.6547313, 2.3609092, 1.7849343]\n", + "----------- 5 6.935523 [7.507216, 7.1102643, 6.5306544, 6.5939574]\n", + "----------- 6 4.0209823 [0.46951586, 7.1055694, 3.2235124, 5.2853317]\n", + "----------- 7 5.5829983 [7.1345463, 3.6441324, 5.4739, 6.0794134]\n", + "----------- 8 3.455373 [0.48668343, 3.0806482, 2.7548976, 7.499263]\n", + "----------- 9 4.1937394 [2.2798746, 7.402249, 6.465774, 0.62705946]\n", + "----------- 10 6.605816 [9.141533, 5.757394, 6.2058573, 5.31848]\n", + "----------- 11 3.533581 [0.4979934, 4.1526065, 5.30537, 4.1783543]\n", + "----------- 12 3.5613365 [0.49206477, 6.739398, 2.5934587, 4.4204245]\n", + "----------- 13 3.6620011 [3.5900655, 3.119967, 7.2964883, 0.6414829]\n", + "----------- 14 2.321041 [0.5420785, 2.5902116, 0.5915193, 5.5603547]\n", + "----------- 15 6.608162 [0.69685817, 11.084266, 11.961578, 2.6899447]\n", + "----------- 16 6.5043054 [3.85887, 7.975657, 8.475497, 5.707197]\n", + "----------- 17 5.5829983 [7.1345463, 3.6441324, 5.4739, 6.0794134]\n", + "----------- 18 4.681854 [4.05151, 2.8608103, 5.5413437, 6.273751]\n", + "----------- 19 3.0863926 [0.4749892, 7.594475, 2.212759, 2.0633478]\n", + "8 1.8283346 1.0931559 0.35270125 105 20\n", + "----------- 0 7.171954 [6.307287, 8.116448, 8.80696, 5.4571204]\n", + "----------- 1 2.5624762 [0.51132023, 3.0274832, 2.454945, 4.256156]\n", + "----------- 2 3.8772728 [0.48861453, 7.6562514, 5.951251, 1.4129744]\n", + "----------- 3 3.6076546 [0.47928056, 7.586958, 0.70100075, 5.663379]\n", + "----------- 4 6.3786244 [5.0353436, 8.022774, 5.9298387, 6.5265427]\n", + "----------- 5 2.8925982 [1.2213205, 2.6918755, 1.9224159, 5.7347803]\n", + "----------- 6 3.9868257 [1.578602, 6.206632, 0.95245385, 7.2096148]\n", + "----------- 7 2.8703656 [0.52009696, 2.3494925, 7.980505, 0.6313675]\n", + "----------- 8 2.6793237 [0.49006915, 8.2979965, 0.5241353, 1.4050932]\n", + "----------- 9 1.0168215 [0.46446496, 2.638379, 0.547157, 0.4172848]\n", + "----------- 10 3.8984096 [0.48657328, 7.467347, 6.977629, 0.66208917]\n", + "----------- 11 1.1310002 [0.49976665, 2.7513857, 0.5728694, 0.69997895]\n", + "----------- 12 1.6584501 [0.4734316, 2.7602808, 2.7501695, 0.6499184]\n", + "----------- 13 7.321125 [0.61161333, 10.463595, 6.7591476, 11.450143]\n", + "----------- 14 6.9232044 [4.7297215, 9.2294235, 7.205658, 6.5280156]\n", + "----------- 15 3.6137753 [0.48802245, 2.9645853, 5.6860676, 5.3164253]\n", + "----------- 16 4.172943 [0.49637416, 8.188139, 2.3736038, 5.633655]\n", + "----------- 17 2.35454 [0.4830108, 7.8507514, 0.53617835, 0.54821944]\n", + "----------- 18 1.6584501 [0.4734316, 2.7602808, 2.7501695, 0.6499184]\n", + "----------- 19 2.228596 [0.53493524, 3.9120338, 0.61403877, 3.8533762]\n", + "9 1.7148367 1.0168215 0.3432635 105 20\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nMcwzOU7M2tg", + "outputId": "db49d356-7c0c-4541-e6b8-ff66cd05e7b9" + }, + "source": [ + "for n in range(300):\n", + " params_ = params.copy()\n", + " buff_p,buff_l,buff_o = [],[],[]\n", + " for _ in range(20):\n", + " key,subkey = jax.random.split(key)\n", + " do_indel = (INDELS and np.random.uniform() < 0.25)\n", + " p = mut(params, indel=do_indel)\n", + " l,o = loss_fn_multi(p, subkey, model_params_multi, opt)\n", + " buff_p.append(p); buff_l.append(l); buff_o.append(o)\n", + " if np.mean(l) < LOSS: break\n", + " best = np.argmin(np.asarray(buff_l).mean(-1))\n", + " params, LOSS, outs = buff_p[best], buff_l[best], buff_o[best]\n", + " LOSS = np.mean(LOSS)\n", + " RMSD = np.mean(outs[\"losses\"][\"rmsd\"])\n", + " FAPE = np.mean(outs[\"losses\"][\"fape\"])\n", + "\n", + " outs = jax.tree_map(lambda x: x[0], outs)\n", + " if RMSD < OVERALL_RMSD:\n", + " OVERALL_RMSD = RMSD\n", + " save_pdb(outs,f\"{MODE}_best_rmsd.pdb\")\n", + " if FAPE < OVERALL_FAPE:\n", + " OVERALL_FAPE = FAPE\n", + " save_pdb(outs,f\"{MODE}_best_fape.pdb\")\n", + " if LOSS < OVERALL_LOSS:\n", + " OVERALL_LOSS = LOSS\n", + " save_pdb(outs,f\"{MODE}_best_loss.pdb\")\n", + " l4,o4 = loss_fn(params, subkey, model_params[-1], opt)\n", + " print(n, LOSS, RMSD, FAPE, (params[\"active_pos\"] > 0).sum(), len(buff_l), o4[\"losses\"][\"rmsd\"])" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "0 1.6684058 0.97566986 0.33917558 105 13 6.570461\n", + "1 1.7213771 1.0266101 0.34318787 105 20 0.5521997\n", + "2 1.7309557 1.0141158 0.35105625 105 20 2.0901546\n", + "3 1.6896502 0.9780132 0.34791833 105 2 3.075305\n", + "4 1.2836819 0.63118124 0.2795934 105 9 3.0850897\n", + "5 1.0221133 0.47783357 0.26446223 105 2 0.58060026\n", + "6 1.0106819 0.471611 0.2663944 105 17 0.579923\n", + "7 0.9707656 0.44111815 0.26466346 105 8 0.5299034\n", + "8 0.95611095 0.46036988 0.26778838 105 2 0.4459814\n", + "9 0.94456077 0.44426697 0.2588649 105 7 0.51522875\n", + "10 0.92737645 0.41471896 0.26362437 105 11 0.4391624\n", + "11 0.9707828 0.43690652 0.2637179 105 20 3.0559409\n", + "12 1.0268421 0.48600835 0.27266476 105 20 1.6863861\n", + "13 1.000179 0.45579082 0.26858282 105 9 1.6001016\n", + "14 1.0144405 0.46453598 0.2691497 105 20 3.1118512\n", + "15 0.9921475 0.45157805 0.26792496 105 6 3.04773\n", + "16 1.0790019 0.51705253 0.27363735 105 20 1.599278\n", + "17 0.9422333 0.44149858 0.26067227 105 5 0.49108532\n", + "18 0.95633763 0.4589308 0.26509196 105 20 0.46267003\n", + "19 0.91339755 0.4367942 0.26139438 105 2 0.4516135\n", + "20 0.9490597 0.44927266 0.25715116 105 20 1.9399705\n", + "21 0.92148614 0.44103473 0.25952393 105 10 0.5380588\n", + "22 0.9095851 0.44128704 0.26291054 105 5 0.52902484\n", + "23 0.9001045 0.43394685 0.26216513 105 11 0.507325\n", + "24 0.935823 0.44728646 0.26444185 105 20 0.5153226\n", + "25 0.92956334 0.4358675 0.25304818 105 1 0.50120604\n", + "26 0.93669957 0.44484594 0.25761187 105 20 0.5072669\n", + "27 0.92630804 0.4272276 0.25512138 105 17 0.50545853\n", + "28 0.90342486 0.40582314 0.25182638 105 7 0.42194045\n", + "29 0.9023535 0.3999255 0.2542632 105 18 0.5808631\n", + "30 0.8987059 0.39197487 0.25395167 105 4 0.698982\n", + "31 0.89763 0.398355 0.2523234 105 8 0.7569108\n", + "32 0.8851971 0.39835936 0.2517715 105 10 0.40553972\n", + "33 0.87716496 0.39584976 0.2536103 105 18 0.43690002\n", + "34 0.8672372 0.39135563 0.2522359 105 19 0.46205065\n", + "35 0.86573654 0.40179157 0.24927694 105 5 0.43356445\n", + "36 0.86034834 0.39331198 0.24897593 105 5 0.41553757\n", + "37 0.85704434 0.3940274 0.24887043 105 15 0.42463255\n", + "38 0.85812235 0.39109135 0.24861696 105 20 1.5544825\n", + "39 0.8563244 0.38625145 0.2548745 105 2 0.46328327\n", + "40 0.85393006 0.38619673 0.25590825 105 1 0.65370554\n", + "41 0.8434426 0.37833232 0.25496525 105 3 0.6442375\n", + "42 0.84568226 0.3767385 0.25646347 105 20 0.45805952\n", + "43 0.8283565 0.36935914 0.25602132 105 16 0.43450886\n", + "44 0.8259764 0.3721248 0.2545429 105 9 0.44600332\n", + "45 0.8283949 0.37862396 0.25495532 105 20 0.4076944\n", + "46 0.8460327 0.39139068 0.2576192 105 20 0.41672665\n", + "47 0.85488296 0.38974053 0.25892863 105 20 0.4298103\n", + "48 0.88770354 0.41136163 0.26714876 105 20 0.5820464\n", + "49 0.850071 0.38693905 0.2618707 105 19 0.45211282\n", + "50 0.87110484 0.40780997 0.25945115 105 20 0.61626506\n", + "51 0.85944444 0.4052241 0.26254568 105 2 0.58453935\n", + "52 0.8510729 0.39752704 0.25937673 105 17 0.42348662\n", + "53 0.85884845 0.39929175 0.26002827 105 20 0.42013463\n", + "54 0.8492986 0.39085233 0.2594124 105 20 0.4298053\n", + "55 0.8502816 0.39487204 0.26038072 105 20 0.45310912\n", + "56 0.8514839 0.396228 0.25912333 105 20 0.43081018\n", + "57 0.84856915 0.3966666 0.25921145 105 1 0.42670247\n", + "58 0.8441243 0.39192587 0.2584624 105 1 0.43249077\n", + "59 0.8388046 0.39230713 0.25792277 105 9 0.4469592\n", + "60 0.8518801 0.4112948 0.26125896 105 20 0.42879182\n", + "61 0.8495097 0.40249667 0.2577355 105 5 0.41180924\n", + "62 0.85121775 0.40240282 0.25863898 105 20 0.4375655\n", + "63 0.84820503 0.39923126 0.26354426 105 4 0.4226636\n", + "64 0.8493221 0.39893606 0.26149994 105 20 0.55794966\n", + "65 0.8680916 0.41174316 0.26457256 105 20 0.52679855\n", + "66 0.8690653 0.4129597 0.26212114 105 20 0.40440193\n", + "67 0.8550012 0.3974012 0.25566924 105 2 0.61957014\n", + "68 0.8269218 0.3811624 0.25568238 105 2 0.5771368\n", + "69 0.82573533 0.38210166 0.2538772 105 3 0.3971571\n", + "70 0.8257129 0.3830118 0.2558058 105 1 0.37792858\n", + "71 0.82546926 0.38395628 0.2544651 105 12 0.6366667\n", + "72 0.843058 0.39773583 0.25683063 105 20 0.47362173\n", + "73 0.82866305 0.3778933 0.25027376 105 14 0.80487704\n", + "74 0.82513636 0.3715942 0.24893484 105 11 0.62339985\n", + "75 0.7955326 0.36016932 0.24934904 105 14 0.4288529\n", + "76 0.7816151 0.35555938 0.2501549 105 4 0.37923443\n", + "77 0.79697263 0.37449193 0.2569723 105 20 0.38451034\n", + "78 0.79643744 0.37195808 0.2569681 105 8 0.3864982\n", + "79 0.79316473 0.37074983 0.25291863 105 5 0.3956389\n", + "80 0.79527825 0.37851173 0.255054 105 20 0.41184312\n", + "81 0.80367917 0.38413125 0.2587978 105 20 0.38481775\n", + "82 0.80195075 0.38171855 0.25619864 105 19 0.38471878\n", + "83 0.79585415 0.37876022 0.25607145 105 3 0.38093916\n", + "84 0.79297435 0.36952016 0.25707933 105 13 0.38803926\n", + "85 0.7860149 0.36533368 0.25712195 105 5 0.39711148\n", + "86 0.78051245 0.36108324 0.25746024 105 15 0.4006365\n", + "87 0.7777182 0.3615132 0.2557983 105 2 0.4549559\n", + "88 0.7739917 0.3581917 0.25417912 105 10 0.37290683\n", + "89 0.774044 0.3581766 0.2542063 105 20 0.39311743\n", + "90 0.7884384 0.37267008 0.2572489 105 20 0.3724221\n", + "91 0.7724658 0.36318746 0.253304 105 15 0.46365902\n", + "92 0.7829318 0.36188036 0.25197 105 20 0.73234975\n", + "93 0.7881139 0.3734092 0.25267625 105 20 0.4020775\n", + "94 0.7744282 0.36122805 0.25204524 105 14 0.83379424\n", + "95 0.7931042 0.36954057 0.2545379 105 20 0.37395406\n", + "96 0.8209839 0.38647577 0.25685614 105 20 0.37840688\n", + "97 0.84438217 0.38822842 0.2526749 105 20 7.4663424\n", + "98 0.84945333 0.39228576 0.25577798 105 20 2.3648822\n", + "99 0.8253926 0.3976271 0.2493584 105 3 1.4438024\n", + "100 0.81372565 0.38538033 0.2464183 105 20 1.6098862\n", + "101 0.8090967 0.3809868 0.25952134 105 10 0.5458334\n", + "102 0.81226456 0.38264075 0.26079932 105 20 1.8315157\n", + "103 0.80271786 0.38448375 0.2629453 105 5 1.6406913\n", + "104 0.7911093 0.37916207 0.26077878 105 11 1.5000534\n", + "105 0.7949822 0.3803891 0.25980854 105 20 2.5178695\n", + "106 0.79229176 0.3745945 0.26065707 105 14 1.8893733\n", + "107 0.78998435 0.3735939 0.26069322 105 7 0.53828716\n", + "108 0.79045296 0.3755261 0.25611162 105 20 0.49430197\n", + "109 0.79229045 0.37573683 0.256774 105 20 0.5186832\n", + "110 0.7838174 0.3678024 0.256809 105 19 0.45888138\n", + "111 0.78609425 0.3710606 0.25731885 105 20 0.44479498\n", + "112 0.7958678 0.3732134 0.25617442 105 20 0.45889676\n", + "113 0.78859544 0.37344506 0.25625542 105 14 0.42856166\n", + "114 0.7868535 0.37011522 0.25506067 105 8 0.5194302\n", + "115 0.79009044 0.3758374 0.25532395 105 20 0.48774284\n", + "116 0.7797909 0.36544997 0.25534868 105 1 0.4631959\n", + "117 0.784542 0.36659348 0.25564831 105 20 0.44635403\n", + "118 0.7873805 0.3621328 0.25646555 105 20 0.56861454\n", + "119 0.7807896 0.36947665 0.25755095 105 11 0.5583043\n", + "120 0.78442085 0.37247407 0.2567057 105 20 0.5963035\n", + "121 0.7767699 0.37258375 0.25616622 105 16 0.62414336\n", + "122 0.7766647 0.36842233 0.25467855 105 4 0.5682627\n", + "123 0.7774848 0.36421263 0.25515088 105 20 0.62787694\n", + "124 0.77390444 0.36387503 0.25549126 105 12 0.6466514\n", + "125 0.7843654 0.3667697 0.25571302 105 20 0.66524696\n", + "126 0.77834857 0.3649481 0.25825095 105 12 0.5677394\n", + "127 0.76237154 0.3616419 0.2528265 105 7 0.6747014\n", + "128 0.762183 0.361956 0.25155586 105 19 0.67311424\n", + "129 0.77545625 0.3703637 0.25180376 105 20 0.5765867\n", + "130 0.77570856 0.36720365 0.25119123 105 20 0.5863473\n", + "131 0.7715299 0.3634779 0.24999247 105 12 0.5599069\n", + "132 0.7726184 0.3661586 0.25005862 105 20 0.57318836\n", + "133 0.77815336 0.3716631 0.25104138 105 20 0.37932914\n", + "134 0.7815921 0.36497754 0.25097612 105 20 0.4990302\n", + "135 0.77946424 0.3639908 0.25069007 105 1 0.41056108\n", + "136 0.7800069 0.36707234 0.25195867 105 20 0.39035815\n", + "137 0.7680088 0.35747153 0.25032467 105 4 0.5359447\n", + "138 0.7680603 0.3517444 0.24826044 105 20 0.53611857\n", + "139 0.7804315 0.3633203 0.24894942 105 20 0.48859638\n", + "140 0.77528805 0.3598613 0.24950016 105 5 0.5153854\n", + "141 0.7719278 0.3596449 0.25640878 105 6 0.63598144\n", + "142 0.7580904 0.35209888 0.2516064 105 9 0.6285438\n", + "143 0.7646548 0.35849902 0.25306997 105 20 0.6293292\n" + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/af_backprop/examples/sc_hall/trdesign_sub.py b/af_backprop/examples/sc_hall/trdesign_sub.py new file mode 100644 index 0000000000000000000000000000000000000000..43f0ef2170ee69da0d30928521b755c83ad2a79a --- /dev/null +++ b/af_backprop/examples/sc_hall/trdesign_sub.py @@ -0,0 +1,193 @@ +import tensorflow as tf +import tensorflow.keras.backend as K +import tensorflow.compat.v1 as tf1 +import tensorflow.compat.v1.keras.backend as K1 +tf1.disable_eager_execution() + +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Input, Conv2D, Activation, Dense, Lambda, Layer, Concatenate + +def get_TrR_weights(filename): + weights = [np.squeeze(w) for w in np.load(filename, allow_pickle=True)] + # remove weights for beta-beta pairing + del weights[-4:-2] + return weights + +def get_TrR(blocks=12, trainable=False, weights=None, name="TrR"): + ex = {"trainable":trainable} + # custom layer(s) + class PSSM(Layer): + # modified from MRF to only output tiled 1D features + def __init__(self, diag=0.4, use_entropy=False): + super(PSSM, self).__init__() + self.diag = diag + self.use_entropy = use_entropy + def call(self, inputs): + x,y = inputs + _,_,L,A = [tf.shape(y)[k] for k in range(4)] + with tf.name_scope('1d_features'): + # sequence + x_i = x[0,0,:,:20] + # pssm + f_i = y[0,0] + # entropy + if self.use_entropy: + h_i = K.sum(-f_i * K.log(f_i + 1e-8), axis=-1, keepdims=True) + else: + h_i = tf.zeros((L,1)) + # tile and combined 1D features + feat_1D = tf.concat([x_i,f_i,h_i], axis=-1) + feat_1D_tile_A = tf.tile(feat_1D[:,None,:], [1,L,1]) + feat_1D_tile_B = tf.tile(feat_1D[None,:,:], [L,1,1]) + + with tf.name_scope('2d_features'): + ic = self.diag * tf.eye(L*A) + ic = tf.reshape(ic,(L,A,L,A)) + ic = tf.transpose(ic,(0,2,1,3)) + ic = tf.reshape(ic,(L,L,A*A)) + i0 = tf.zeros([L,L,1]) + feat_2D = tf.concat([ic,i0], axis=-1) + + feat = tf.concat([feat_1D_tile_A, feat_1D_tile_B, feat_2D],axis=-1) + return tf.reshape(feat, [1,L,L,442+2*42]) + + class instance_norm(Layer): + def __init__(self, axes=(1,2),trainable=True): + super(instance_norm, self).__init__() + self.axes = axes + self.trainable = trainable + def build(self, input_shape): + self.beta = self.add_weight(name='beta',shape=(input_shape[-1],), + initializer='zeros',trainable=self.trainable) + self.gamma = self.add_weight(name='gamma',shape=(input_shape[-1],), + initializer='ones',trainable=self.trainable) + def call(self, inputs): + mean, variance = tf.nn.moments(inputs, self.axes, keepdims=True) + return tf.nn.batch_normalization(inputs, mean, variance, self.beta, self.gamma, 1e-6) + + ## INPUT ## + inputs = Input((None,None,21),batch_size=1) + A = PSSM()([inputs,inputs]) + A = Dense(64, **ex)(A) + A = instance_norm(**ex)(A) + A = Activation("elu")(A) + + ## RESNET ## + def resnet(X, dilation=1, filters=64, win=3): + Y = Conv2D(filters, win, dilation_rate=dilation, padding='SAME', **ex)(X) + Y = instance_norm(**ex)(Y) + Y = Activation("elu")(Y) + Y = Conv2D(filters, win, dilation_rate=dilation, padding='SAME', **ex)(Y) + Y = instance_norm(**ex)(Y) + return Activation("elu")(X+Y) + + for _ in range(blocks): + for dilation in [1,2,4,8,16]: + A = resnet(A, dilation) + A = resnet(A, dilation=1) + + ## OUTPUT ## + A_input = Input((None,None,64)) + p_theta = Dense(25, activation="softmax", **ex)(A_input) + p_phi = Dense(13, activation="softmax", **ex)(A_input) + A_sym = Lambda(lambda x: (x + tf.transpose(x,[0,2,1,3]))/2)(A_input) + p_dist = Dense(37, activation="softmax", **ex)(A_sym) + p_omega = Dense(25, activation="softmax", **ex)(A_sym) + A_model = Model(A_input,Concatenate()([p_theta,p_phi,p_dist,p_omega])) + + ## MODEL ## + model = Model(inputs, A_model(A),name=name) + if weights is not None: model.set_weights(weights) + return model + +def get_TrR_model(L=None, exclude_theta=False, use_idx=False, use_bkg=False, models_path="models"): + def gather_idx(x): + idx = x[1][0] + return tf.gather(tf.gather(x[0],idx,axis=-2),idx,axis=-3) + + def get_cce_loss(x, eps=1e-8, only_dist=False): + if only_dist: + true_x = split_feat(x[0])["dist"] + pred_x = split_feat(x[1])["dist"] + loss = -tf.reduce_mean(tf.reduce_sum(true_x*tf.math.log(pred_x + eps),-1),[-1,-2]) + return loss * 4 + + elif exclude_theta: + true_x = split_feat(x[0]) + pred_x = split_feat(x[1]) + true_x = tf.concat([true_x[k] for k in ["phi","dist","omega"]],-1) + pred_x = tf.concat([pred_x[k] for k in ["phi","dist","omega"]],-1) + loss = -tf.reduce_mean(tf.reduce_sum(true_x*tf.math.log(pred_x + eps),-1),[-1,-2]) + return loss * 4/3 + + else: + return -tf.reduce_mean(tf.reduce_sum(x[0]*tf.math.log(x[1] + eps),-1),[-1,-2]) + + def get_bkg_loss(x, eps=1e-8): + return -tf.reduce_mean(tf.reduce_sum(x[1]*(tf.math.log(x[1]+eps)-tf.math.log(x[0]+eps)),-1),[-1,-2]) + + def prep_seq(x_logits): + x_soft = tf.nn.softmax(x_logits,-1) + x_hard = tf.one_hot(tf.argmax(x_logits,-1),20) + x = tf.stop_gradient(x_hard - x_soft) + x_soft + x = tf.pad(x,[[0,0],[0,0],[0,1]]) + return x[None] + + I_seq_logits = Input((L,20),name="seq_logits") + seq = Lambda(prep_seq,name="seq")(I_seq_logits) + I_true = Input((L,L,100),name="true") + if use_bkg: + I_bkg = Input((L,L,100),name="bkg") + if use_idx: + I_idx = Input((None,),dtype=tf.int32,name="idx") + I_idx_true = Input((None,),dtype=tf.int32,name="idx_true") + + pred = [] + for nam in ["xaa","xab","xac","xad","xae"]: + print(nam) + TrR = get_TrR(weights=get_TrR_weights(f"{models_path}/model_{nam}.npy"),name=nam) + pred.append(TrR(seq)) + pred = sum(pred)/len(pred) + + if use_idx: + pred_sub = Lambda(gather_idx, name="pred_sub")([pred,I_idx]) + true_sub = Lambda(gather_idx, name="true_sub")([I_true,I_idx_true]) + else: + pred_sub = pred + true_sub = I_true + + cce_loss = Lambda(get_cce_loss,name="cce_loss")([true_sub, pred_sub]) + if use_bkg: + bkg_loss = Lambda(get_bkg_loss,name="bkg_loss")([I_bkg, pred]) + loss = Lambda(lambda x: x[0]+0.1*x[1])([cce_loss,bkg_loss]) + else: + loss = cce_loss + grad = Lambda(lambda x: tf.gradients(x[0],x[1]), name="grad")([loss,I_seq_logits]) + + # setup model + inputs = [I_seq_logits, I_true] + outputs = [cce_loss] + if use_bkg: + inputs += [I_bkg] + outputs += [bkg_loss] + if use_idx: inputs += [I_idx, I_idx_true] + model = Model(intputs, outputs + [grad, pred], name="TrR_model") + + TrR_model(seq, true, **kwargs): + i = [seq[None],true[None]] + if use_bkg: + i += [kwargs["bkg"][None]] + if use_idx: + pos_idx = kwargs["pos_idx"] + if "pos_idx_ref" not in kwargs or kwargs["pos_idx_ref"] is None: + pos_idx_ref = pos_idx + else: + pos_idx_ref = kwargs["pos_idx_ref"] + i += [pos_idx[None],pos_idx_ref[None]] + + *o = model.predict(i) + r = {"cce_loss":o[0][0],"grad":o[-1][0],"pred":o[-2][0]} + if use_bkg: r["bkg_loss"] = o[1][0] + return r + + return TrR_model diff --git a/af_backprop/setup.py b/af_backprop/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e0fcaeba78f7c4e78ebe57b95138c91a0b7f59 --- /dev/null +++ b/af_backprop/setup.py @@ -0,0 +1,21 @@ +from setuptools import setup, find_packages +setup( + name='af_backprop', + version='0.0.0', + packages=find_packages(), + install_requires=[ + 'absl-py', + 'biopython', + 'chex', + 'dm-haiku', + 'dm-tree', + 'docker', + 'immutabledict', + 'jax', + 'ml-collections', + 'numpy', + 'pandas', + 'scipy', + 'tensorflow', + ], +) diff --git a/af_backprop/utils.py b/af_backprop/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f3b298b782eafcebf6174263f893ac0b0f17eb3 --- /dev/null +++ b/af_backprop/utils.py @@ -0,0 +1,492 @@ +import jax +import jax.numpy as jnp +import tensorflow as tf +tf.config.set_visible_devices([], 'GPU') + +import numpy as np +from alphafold.common import protein +from alphafold.common import residue_constants +from alphafold.model import model +from alphafold.model import folding +from alphafold.model import all_atom +from alphafold.model.tf import shape_placeholders + +####################### +# reshape inputs +####################### +def make_fixed_size(feat, model_runner, length, batch_axis=True): + '''pad input features''' + cfg = model_runner.config + if batch_axis: + shape_schema = {k:[None]+v for k,v in dict(cfg.data.eval.feat).items()} + else: + shape_schema = {k:v for k,v in dict(cfg.data.eval.feat).items()} + + pad_size_map = { + shape_placeholders.NUM_RES: length, + shape_placeholders.NUM_MSA_SEQ: cfg.data.eval.max_msa_clusters, + shape_placeholders.NUM_EXTRA_SEQ: cfg.data.common.max_extra_msa, + shape_placeholders.NUM_TEMPLATES: cfg.data.eval.max_templates + } + for k, v in feat.items(): + # Don't transfer this to the accelerator. + if k == 'extra_cluster_assignment': + continue + shape = list(v.shape) + schema = shape_schema[k] + assert len(shape) == len(schema), ( + f'Rank mismatch between shape and shape schema for {k}: ' + f'{shape} vs {schema}') + pad_size = [pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)] + padding = [(0, p - tf.shape(v)[i]) for i, p in enumerate(pad_size)] + if padding: + feat[k] = tf.pad(v, padding, name=f'pad_to_fixed_{k}') + feat[k].set_shape(pad_size) + return {k:np.asarray(v) for k,v in feat.items()} + +######################### +# rmsd +######################### +def jnp_rmsdist(true, pred): + return _np_rmsdist(true, pred) + +def jnp_rmsd(true, pred, add_dist=False): + rmsd = _np_rmsd(true, pred) + if add_dist: rmsd = (rmsd + _np_rmsdist(true, pred))/2 + return rmsd + +def jnp_kabsch_w(a, b, weights): + return _np_kabsch(a * weights[:,None], b) + +def jnp_rmsd_w(true, pred, weights): + p = true - (true * weights[:,None]).sum(0,keepdims=True)/weights.sum() + q = pred - (pred * weights[:,None]).sum(0,keepdims=True)/weights.sum() + p = p @ _np_kabsch(p * weights[:,None], q) + return jnp.sqrt((weights*jnp.square(p-q).sum(-1)).sum()/weights.sum() + 1e-8) + +def get_rmsd_loss_w(batch, outputs, copies=1): + weights = batch["all_atom_mask"][:,1] + true = batch["all_atom_positions"][:,1,:] + pred = outputs["structure_module"]["final_atom_positions"][:,1,:] + if copies == 1: + return jnp_rmsd_w(true, pred, weights) + else: + # TODO add support for weights + I = copies - 1 + L = true.shape[0] // copies + p = true - true[:L].mean(0) + q = pred - pred[:L].mean(0) + p = p @ _np_kabsch(p[:L], q[:L]) + rm = jnp.square(p[:L]-q[:L]).sum(-1).mean() + p,q = p[L:].reshape(I,1,L,-1),q[L:].reshape(1,I,L,-1) + rm += jnp.square(p-q).sum(-1).mean(-1).min(-1).sum() + return jnp.sqrt(rm / copies) + +#################### +# confidence metrics +#################### +def get_plddt(outputs): + logits = outputs["predicted_lddt"]["logits"] + num_bins = logits.shape[-1] + bin_width = 1.0 / num_bins + bin_centers = jnp.arange(start=0.5 * bin_width, stop=1.0, step=bin_width) + probs = jax.nn.softmax(logits, axis=-1) + return jnp.sum(probs * bin_centers[None, :], axis=-1) + +def get_pae(outputs): + prob = jax.nn.softmax(outputs["predicted_aligned_error"]["logits"],-1) + breaks = outputs["predicted_aligned_error"]["breaks"] + step = breaks[1]-breaks[0] + bin_centers = breaks + step/2 + bin_centers = jnp.append(bin_centers,bin_centers[-1]+step) + return (prob*bin_centers).sum(-1) + +#################### +# loss functions +#################### +def get_rmsd_loss(batch, outputs): + true = batch["all_atom_positions"][:,1,:] + pred = outputs["structure_module"]["final_atom_positions"][:,1,:] + return _np_rmsd(true,pred) + +def _distogram_log_loss(logits, bin_edges, batch, num_bins, copies=1): + """Log loss of a distogram.""" + pos,mask = batch['pseudo_beta'],batch['pseudo_beta_mask'] + sq_breaks = jnp.square(bin_edges) + dist2 = jnp.square(pos[:,None] - pos[None,:]).sum(-1,keepdims=True) + true_bins = jnp.sum(dist2 > sq_breaks, axis=-1) + true = jax.nn.one_hot(true_bins, num_bins) + + if copies == 1: + errors = -(true * jax.nn.log_softmax(logits)).sum(-1) + sq_mask = mask[:,None] * mask[None,:] + avg_error = (errors * sq_mask).sum()/(1e-6 + sq_mask.sum()) + return avg_error + else: + # TODO add support for masks + L = pos.shape[0] // copies + I = copies - 1 + true_, pred_ = true[:L,:L], logits[:L,:L] + errors = -(true_ * jax.nn.log_softmax(pred_)).sum(-1) + avg_error = errors.mean() + + true_, pred_ = true[:L,L:], logits[:L,L:] + true_, pred_ = true_.reshape(L,I,1,L,-1), pred_.reshape(L,1,I,L,-1) + errors = -(true_ * jax.nn.log_softmax(pred_)).sum(-1) + avg_error += errors.mean((0,-1)).min(-1).sum() + + return avg_error / copies + +def get_dgram_loss(batch, outputs, model_config, logits=None, copies=1): + # get cb features (ca in case of glycine) + pb, pb_mask = model.modules.pseudo_beta_fn(batch["aatype"], + batch["all_atom_positions"], + batch["all_atom_mask"]) + if logits is None: logits = outputs["distogram"]["logits"] + dgram_loss = _distogram_log_loss(logits, + outputs["distogram"]["bin_edges"], + batch={"pseudo_beta":pb,"pseudo_beta_mask":pb_mask}, + num_bins=model_config.model.heads.distogram.num_bins, + copies=copies) + return dgram_loss + +def get_fape_loss(batch, outputs, model_config, use_clamped_fape=False): + sub_batch = jax.tree_map(lambda x: x, batch) + sub_batch["use_clamped_fape"] = use_clamped_fape + loss = {"loss":0.0} + folding.backbone_loss(loss, sub_batch, outputs["structure_module"], model_config.model.heads.structure_module) + return loss["loss"] + +#################### +# loss functions (restricted to idx and/or sidechains) +#################### +def get_dgram_loss_idx(batch, outputs, idx, model_config): + idx_ref = batch["idx"] + pb, pb_mask = model.modules.pseudo_beta_fn(batch["aatype"][idx_ref], + batch["all_atom_positions"][idx_ref], + batch["all_atom_mask"][idx_ref]) + + dgram_loss = model.modules._distogram_log_loss(outputs["distogram"]["logits"][:,idx][idx,:], + outputs["distogram"]["bin_edges"], + batch={"pseudo_beta":pb,"pseudo_beta_mask":pb_mask}, + num_bins=model_config.model.heads.distogram.num_bins) + return dgram_loss["loss"] + +def get_fape_loss_idx(batch, outputs, idx, model_config, backbone=False, sidechain=True, use_clamped_fape=False): + idx_ref = batch["idx"] + + sub_batch = batch.copy() + sub_batch.pop("idx") + sub_batch = jax.tree_map(lambda x: x[idx_ref,...],sub_batch) + sub_batch["use_clamped_fape"] = use_clamped_fape + + value = jax.tree_map(lambda x: x, outputs["structure_module"]) + loss = {"loss":0.0} + + if sidechain: + value.update(folding.compute_renamed_ground_truth(sub_batch, value['final_atom14_positions'][idx,...])) + value['sidechains']['frames'] = jax.tree_map(lambda x: x[:,idx,:], value["sidechains"]["frames"]) + value['sidechains']['atom_pos'] = jax.tree_map(lambda x: x[:,idx,:], value["sidechains"]["atom_pos"]) + loss.update(folding.sidechain_loss(sub_batch, value, model_config.model.heads.structure_module)) + + if backbone: + value["traj"] = value["traj"][...,idx,:] + folding.backbone_loss(loss, sub_batch, value, model_config.model.heads.structure_module) + + return loss["loss"] + +def get_sc_rmsd(true_pos, pred_pos, aa_ident, atoms_to_exclude=None): + + if atoms_to_exclude is None: atoms_to_exclude = ["N","C","O"] + + # collect atom indices + idx,idx_alt = [],[] + for n,a in enumerate(aa_ident): + aa = idx_to_resname[a] + atoms = set(residue_constants.residue_atoms[aa]) + atoms14 = residue_constants.restype_name_to_atom14_names[aa] + swaps = residue_constants.residue_atom_renaming_swaps.get(aa,{}) + swaps.update({v:k for k,v in swaps.items()}) + for atom in atoms.difference(atoms_to_exclude): + idx.append(n * 14 + atoms14.index(atom)) + if atom in swaps: + idx_alt.append(n * 14 + atoms14.index(swaps[atom])) + else: + idx_alt.append(idx[-1]) + idx, idx_alt = np.asarray(idx), np.asarray(idx_alt) + + # select atoms + T, P = true_pos.reshape(-1,3)[idx], pred_pos.reshape(-1,3)[idx] + + # select non-ambigious atoms + non_amb = idx == idx_alt + t, p = T[non_amb], P[non_amb] + + # align non-ambigious atoms + aln = _np_kabsch(t-t.mean(0), p-p.mean(0)) + T,P = (T-t.mean(0)) @ aln, P-p.mean(0) + P_alt = pred_pos.reshape(-1,3)[idx_alt]-p.mean(0) + + # compute rmsd + msd = jnp.minimum(jnp.square(T-P).sum(-1),jnp.square(T-P_alt).sum(-1)).mean() + return jnp.sqrt(msd + 1e-8) + +def get_sidechain_rmsd_idx(batch, outputs, idx, model_config, include_ca=True): + idx_ref = batch["idx"] + true_aa_idx = batch["aatype"][idx_ref] + true_pos = all_atom.atom37_to_atom14(batch["all_atom_positions"],batch)[idx_ref,:,:] + pred_pos = outputs["structure_module"]["final_atom14_positions"][idx,:,:] + bb_atoms_to_exclude = ["N","C","O"] if include_ca else ["N","CA","C","O"] + + return get_sc_rmsd(true_pos, pred_pos, true_aa_idx, bb_atoms_to_exclude) + +################################################################################# +################################################################################# +################################################################################# + +def _np_len_pw(x, use_jax=True): + '''compute pairwise distance''' + _np = jnp if use_jax else np + + x_norm = _np.square(x).sum(-1) + xx = _np.einsum("...ia,...ja->...ij",x,x) + sq_dist = x_norm[...,:,None] + x_norm[...,None,:] - 2 * xx + + # due to precision errors the values can sometimes be negative + if use_jax: sq_dist = jax.nn.relu(sq_dist) + else: sq_dist[sq_dist < 0] = 0 + + # return euclidean pairwise distance matrix + return _np.sqrt(sq_dist + 1e-8) + +def _np_rmsdist(true, pred, use_jax=True): + '''compute RMSD of distance matrices''' + _np = jnp if use_jax else np + t = _np_len_pw(true, use_jax=use_jax) + p = _np_len_pw(pred, use_jax=use_jax) + return _np.sqrt(_np.square(t-p).mean() + 1e-8) + +def _np_kabsch(a, b, return_v=False, use_jax=True): + '''get alignment matrix for two sets of coodinates''' + _np = jnp if use_jax else np + ab = a.swapaxes(-1,-2) @ b + u, s, vh = _np.linalg.svd(ab, full_matrices=False) + flip = _np.linalg.det(u @ vh) < 0 + u_ = _np.where(flip, -u[...,-1].T, u[...,-1].T).T + if use_jax: u = u.at[...,-1].set(u_) + else: u[...,-1] = u_ + return u if return_v else (u @ vh) + +def _np_rmsd(true, pred, use_jax=True): + '''compute RMSD of coordinates after alignment''' + _np = jnp if use_jax else np + p = true - true.mean(-2,keepdims=True) + q = pred - pred.mean(-2,keepdims=True) + p = p @ _np_kabsch(p, q, use_jax=use_jax) + return _np.sqrt(_np.square(p-q).sum(-1).mean(-1) + 1e-8) + +def _np_norm(x, axis=-1, keepdims=True, eps=1e-8, use_jax=True): + '''compute norm of vector''' + _np = jnp if use_jax else np + return _np.sqrt(_np.square(x).sum(axis,keepdims=keepdims) + 1e-8) + +def _np_len(a, b, use_jax=True): + '''given coordinates a-b, return length or distance''' + return _np_norm(a-b, use_jax=use_jax) + +def _np_ang(a, b, c, use_acos=False, use_jax=True): + '''given coordinates a-b-c, return angle''' + _np = jnp if use_jax else np + norm = lambda x: _np_norm(x, use_jax=use_jax) + ba, bc = b-a, b-c + cos_ang = (ba * bc).sum(-1,keepdims=True) / (norm(ba) * norm(bc)) + # note the derivative at acos(-1 or 1) is inf, to avoid nans we use cos(ang) + if use_acos: return _np.arccos(cos_ang) + else: return cos_ang + +def _np_dih(a, b, c, d, use_atan2=False, standardize=False, use_jax=True): + '''given coordinates a-b-c-d, return dihedral''' + _np = jnp if use_jax else np + normalize = lambda x: x/_np_norm(x, use_jax=use_jax) + ab, bc, cd = normalize(a-b), normalize(b-c), normalize(c-d) + n1,n2 = _np.cross(ab, bc), _np.cross(bc, cd) + sin_ang = (_np.cross(n1, bc) * n2).sum(-1,keepdims=True) + cos_ang = (n1 * n2).sum(-1,keepdims=True) + if use_atan2: + return _np.arctan2(sin_ang, cos_ang) + else: + angs = _np.concatenate([sin_ang, cos_ang],-1) + if standardize: return normalize(angs) + else: return angs + +def _np_extend(a,b,c, L,A,D, use_jax=True): + ''' + given coordinates a-b-c, + c-d (L)ength, b-c-d (A)ngle, and a-b-c-d (D)ihedral + return 4th coordinate d + ''' + _np = jnp if use_jax else np + normalize = lambda x: x/_np_norm(x, use_jax=use_jax) + bc = normalize(b-c) + n = normalize(_np.cross(b-a, bc)) + return c + sum([L * _np.cos(A) * bc, + L * _np.sin(A) * _np.cos(D) * _np.cross(n, bc), + L * _np.sin(A) * _np.sin(D) * -n]) + +def _np_get_cb(N,CA,C, use_jax=True): + '''compute CB placement from N, CA, C''' + return _np_extend(C, N, CA, 1.522, 1.927, -2.143, use_jax=use_jax) + +def _np_get_6D(all_atom_positions, all_atom_mask=None, use_jax=True): + '''get 6D features (see TrRosetta paper)''' + + # get CB coordinate + atom_idx = {k:residue_constants.atom_order[k] for k in ["N","CA","C"]} + out = {k:all_atom_positions[...,i,:] for k,i in atom_idx.items()} + out["CB"] = _np_get_cb(**out, use_jax=use_jax) + + if all_atom_mask is not None: + idx = np.fromiter(atom_idx.values(),int) + out["CB_mask"] = all_atom_mask[...,idx].prod(-1) + + # get pairwise features + N,A,B = (out[k] for k in ["N","CA","CB"]) + j = {"use_jax":use_jax} + out.update({"dist": _np_len_pw(B,**j), + "phi": _np_ang(A[...,:,None,:],B[...,:,None,:],B[...,None,:,:],**j), + "omega": _np_dih(A[...,:,None,:],B[...,:,None,:],B[...,None,:,:],A[...,None,:,:],**j), + "theta": _np_dih(N[...,:,None,:],A[...,:,None,:],B[...,:,None,:],B[...,None,:,:],**j), + }) + return out + +#################### +# 6D loss (see TrRosetta paper) +#################### +def _np_get_6D_loss(true, pred, mask=None, use_theta=True, use_dist=False, use_jax=True): + _np = jnp if use_jax else np + + f = {"T":_np_get_6D(true, mask, use_jax=use_jax), + "P":_np_get_6D(pred, use_jax=use_jax)} + + for k in f: f[k]["dist"] /= 10.0 + + keys = ["omega","phi"] + if use_theta: keys.append("theta") + if use_dist: keys.append("dist") + sq_diff = sum([_np.square(f["T"][k]-f["P"][k]).sum(-1) for k in keys]) + + mask = _np.ones(true.shape[0]) if mask is None else f["T"]["CB_mask"] + mask = mask[:,None] * mask[None,:] + loss = (sq_diff * mask).sum((-1,-2)) / mask.sum((-1,-2)) + + return _np.sqrt(loss + 1e-8).mean() + +def get_6D_loss(batch, outputs, **kwargs): + true = batch["all_atom_positions"] + pred = outputs["structure_module"]["final_atom_positions"] + mask = batch["all_atom_mask"] + return _np_get_6D_loss(true, pred, mask, **kwargs) + +################################################################################# +################################################################################# +################################################################################# + +#################### +# update sequence +#################### +def soft_seq(seq_logits, temp=1.0, hard=True): + seq_soft = jax.nn.softmax(seq_logits / temp) + if hard: + seq_hard = jax.nn.one_hot(seq_soft.argmax(-1),20) + return jax.lax.stop_gradient(seq_hard - seq_soft) + seq_soft + else: + return seq_soft + +def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, msa_input=None): + '''update the sequence features''' + + if seq_1hot is None: seq_1hot = seq + if seq_pssm is None: seq_pssm = seq + msa_feat = jnp.zeros_like(inputs["msa_feat"]).at[...,0:20].set(seq_1hot).at[...,25:45].set(seq_pssm) + if seq.ndim == 3: + target_feat = jnp.zeros_like(inputs["target_feat"]).at[...,1:21].set(seq[0]) + else: + target_feat = jnp.zeros_like(inputs["target_feat"]).at[...,1:21].set(seq) + + inputs.update({"target_feat":target_feat,"msa_feat":msa_feat}) + +def update_aatype(aatype, inputs): + if jnp.issubdtype(aatype.dtype, jnp.integer): + inputs.update({"aatype":aatype, + "atom14_atom_exists":residue_constants.restype_atom14_mask[aatype], + "atom37_atom_exists":residue_constants.restype_atom37_mask[aatype], + "residx_atom14_to_atom37":residue_constants.restype_atom14_to_atom37[aatype], + "residx_atom37_to_atom14":residue_constants.restype_atom37_to_atom14[aatype]}) + else: + restype_atom14_to_atom37 = jax.nn.one_hot(residue_constants.restype_atom14_to_atom37,37) + restype_atom37_to_atom14 = jax.nn.one_hot(residue_constants.restype_atom37_to_atom14,14) + inputs.update({"aatype":aatype, + "atom14_atom_exists":jnp.einsum("...a,am->...m", aatype, residue_constants.restype_atom14_mask), + "atom37_atom_exists":jnp.einsum("...a,am->...m", aatype, residue_constants.restype_atom37_mask), + "residx_atom14_to_atom37":jnp.einsum("...a,abc->...bc", aatype, restype_atom14_to_atom37), + "residx_atom37_to_atom14":jnp.einsum("...a,abc->...bc", aatype, restype_atom37_to_atom14)}) + +#################### +# utils +#################### + +def pdb_to_string(pdb_file): + lines = [] + for line in open(pdb_file,"r"): + if line[:6] == "HETATM" and line[17:20] == "MSE": + line = "ATOM "+line[6:17]+"MET"+line[20:] + if line[:4] == "ATOM": + lines.append(line) + return "".join(lines) + +def save_pdb(outs, filename="tmp.pdb"): + seq = outs["seq"].argmax(-1) + while seq.ndim > 1: seq = seq[0] + b_factors = np.zeros_like(outs["outputs"]['final_atom_mask']) + p = protein.Protein( + aatype=seq, + atom_positions=outs["outputs"]["final_atom_positions"], + atom_mask=outs["outputs"]['final_atom_mask'], + residue_index=jnp.arange(len(seq))+1, + b_factors=b_factors) + pdb_lines = protein.to_pdb(p) + with open(filename, 'w') as f: + f.write(pdb_lines) + +order_restype = {v: k for k, v in residue_constants.restype_order.items()} +idx_to_resname = dict((v,k) for k,v in residue_constants.resname_to_idx.items()) +template_aa_map = np.eye(20)[[residue_constants.HHBLITS_AA_TO_ID[order_restype[i]] for i in range(20)]].T + +########################### +# MISC +########################### +jalview_color_list = {"Clustal": ["#80a0f0","#f01505","#00ff00","#c048c0","#f08080","#00ff00","#c048c0","#f09048","#15a4a4","#80a0f0","#80a0f0","#f01505","#80a0f0","#80a0f0","#ffff00","#00ff00","#00ff00","#80a0f0","#15a4a4","#80a0f0"], + "Zappo": ["#ffafaf","#6464ff","#00ff00","#ff0000","#ffff00","#00ff00","#ff0000","#ff00ff","#6464ff","#ffafaf","#ffafaf","#6464ff","#ffafaf","#ffc800","#ff00ff","#00ff00","#00ff00","#ffc800","#ffc800","#ffafaf"], + "Taylor": ["#ccff00","#0000ff","#cc00ff","#ff0000","#ffff00","#ff00cc","#ff0066","#ff9900","#0066ff","#66ff00","#33ff00","#6600ff","#00ff00","#00ff66","#ffcc00","#ff3300","#ff6600","#00ccff","#00ffcc","#99ff00"], + "Hydrophobicity": ["#ad0052","#0000ff","#0c00f3","#0c00f3","#c2003d","#0c00f3","#0c00f3","#6a0095","#1500ea","#ff0000","#ea0015","#0000ff","#b0004f","#cb0034","#4600b9","#5e00a1","#61009e","#5b00a4","#4f00b0","#f60009","#0c00f3","#680097","#0c00f3"], + "Helix Propensity": ["#e718e7","#6f906f","#1be41b","#778877","#23dc23","#926d92","#ff00ff","#00ff00","#758a75","#8a758a","#ae51ae","#a05fa0","#ef10ef","#986798","#00ff00","#36c936","#47b847","#8a758a","#21de21","#857a85","#49b649","#758a75","#c936c9"], + "Strand Propensity": ["#5858a7","#6b6b94","#64649b","#2121de","#9d9d62","#8c8c73","#0000ff","#4949b6","#60609f","#ecec13","#b2b24d","#4747b8","#82827d","#c2c23d","#2323dc","#4949b6","#9d9d62","#c0c03f","#d3d32c","#ffff00","#4343bc","#797986","#4747b8"], + "Turn Propensity": ["#2cd3d3","#708f8f","#ff0000","#e81717","#a85757","#3fc0c0","#778888","#ff0000","#708f8f","#00ffff","#1ce3e3","#7e8181","#1ee1e1","#1ee1e1","#f60909","#e11e1e","#738c8c","#738c8c","#9d6262","#07f8f8","#f30c0c","#7c8383","#5ba4a4"], + "Buried Index": ["#00a35c","#00fc03","#00eb14","#00eb14","#0000ff","#00f10e","#00f10e","#009d62","#00d52a","#0054ab","#007b84","#00ff00","#009768","#008778","#00e01f","#00d52a","#00db24","#00a857","#00e619","#005fa0","#00eb14","#00b649","#00f10e"]} + + +########################### +# to be deprecated functions +########################### +def set_dropout(model_config, dropout=0.0): + model_config.model.embeddings_and_evoformer.evoformer.msa_row_attention_with_pair_bias.dropout_rate = dropout + model_config.model.embeddings_and_evoformer.evoformer.triangle_attention_ending_node.dropout_rate = dropout + model_config.model.embeddings_and_evoformer.evoformer.triangle_attention_starting_node.dropout_rate = dropout + model_config.model.embeddings_and_evoformer.evoformer.triangle_multiplication_incoming.dropout_rate = dropout + model_config.model.embeddings_and_evoformer.evoformer.triangle_multiplication_outgoing.dropout_rate = dropout + model_config.model.embeddings_and_evoformer.template.template_pair_stack.triangle_attention_ending_node.dropout_rate = dropout + model_config.model.embeddings_and_evoformer.template.template_pair_stack.triangle_attention_starting_node.dropout_rate = dropout + model_config.model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_incoming.dropout_rate = dropout + model_config.model.embeddings_and_evoformer.template.template_pair_stack.triangle_multiplication_outgoing.dropout_rate = dropout + model_config.model.heads.structure_module.dropout = dropout + return model_config diff --git a/app.py b/app.py index 712ed467fbe50313f976e7635b5fcabd547ef3eb..45aa81c0b213a2f364c769004a4fa652070391a8 100644 --- a/app.py +++ b/app.py @@ -16,6 +16,7 @@ import copy import torch.nn as nn import torch.nn.functional as F import random +import os import os.path from protein_mpnn_utils import ( loss_nll, @@ -32,19 +33,103 @@ from protein_mpnn_utils import ( from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN import plotly.express as px import urllib +import jax.numpy as jnp +import tensorflow as tf -if "/home/user/app/alphafold" not in sys.path: - sys.path.append("/home/user/app/alphafold") +if "/home/user/app/af_backprop" not in sys.path: + sys.path.append("/home/user/app/af_backprop") +from utils import * + +# import libraries +import colabfold as cf from alphafold.common import protein from alphafold.data import pipeline -from alphafold.data import templates -from alphafold.model import data -from alphafold.model import config -from alphafold.model import model +from alphafold.model import data, config, model +from alphafold.common import residue_constants + + import plotly.graph_objects as go import ray +import re + +import numpy as np +import jax + +tf.config.set_visible_devices([], "GPU") + + +def chain_break(idx_res, Ls, length=200): + # Minkyung's code + # add big enough number to residue index to indicate chain breaks + L_prev = 0 + for L_i in Ls[:-1]: + idx_res[L_prev + L_i :] += length + L_prev += L_i + return idx_res + + +def setup_model(seq, model_name="model_1_ptm"): + + # setup model + cfg = config.model_config("model_1_ptm") + cfg.model.num_recycle = 0 + cfg.data.common.num_recycle = 0 + cfg.data.eval.max_msa_clusters = 1 + cfg.data.common.max_extra_msa = 1 + cfg.data.eval.masked_msa_replace_fraction = 0 + cfg.model.global_config.subbatch_size = None + model_params = data.get_model_haiku_params(model_name=model_name, data_dir=".") + model_runner = model.RunModel(cfg, model_params, is_training=False) + Ls = [len(s) for s in seq.split("/")] + + seq = re.sub("[^A-Z]", "", seq.upper()) + length = len(seq) + feature_dict = { + **pipeline.make_sequence_features( + sequence=seq, description="none", num_res=length + ), + **pipeline.make_msa_features(msas=[[seq]], deletion_matrices=[[[0] * length]]), + } + feature_dict["residue_index"] = chain_break(feature_dict["residue_index"], Ls) + inputs = model_runner.process_features(feature_dict, random_seed=0) + + def runner(seq, opt): + # update sequence + inputs = opt["inputs"] + inputs.update(opt["prev"]) + update_seq(seq, inputs) + update_aatype(inputs["target_feat"][..., 1:], inputs) + + # mask prediction + mask = seq.sum(-1) + inputs["seq_mask"] = inputs["seq_mask"].at[:].set(mask) + inputs["msa_mask"] = inputs["msa_mask"].at[:].set(mask) + inputs["residue_index"] = jnp.where(mask == 1, inputs["residue_index"], 0) + + # get prediction + key = jax.random.PRNGKey(0) + outputs = model_runner.apply(opt["params"], key, inputs) + + prev = { + "init_msa_first_row": outputs["representations"]["msa_first_row"][None], + "init_pair": outputs["representations"]["pair"][None], + "init_pos": outputs["structure_module"]["final_atom_positions"][None], + } + + aux = { + "final_atom_positions": outputs["structure_module"]["final_atom_positions"], + "final_atom_mask": outputs["structure_module"]["final_atom_mask"], + "plddt": get_plddt(outputs), + "pae": get_pae(outputs), + "inputs": inputs, + "prev": prev, + } + return aux + + return jax.jit(runner), {"inputs": inputs, "params": model_params} + def make_tied_positions_for_homomers(pdb_dict_list): my_dict = {} @@ -63,51 +148,47 @@ def make_tied_positions_for_homomers(pdb_dict_list): return my_dict -def mk_mock_template(query_sequence): - """create blank template""" - ln = len(query_sequence) - output_templates_sequence = "-" * ln - templates_all_atom_positions = np.zeros( - (ln, templates.residue_constants.atom_type_num, 3) - ) - templates_all_atom_masks = np.zeros((ln, templates.residue_constants.atom_type_num)) - templates_aatype = templates.residue_constants.sequence_to_onehot( - output_templates_sequence, templates.residue_constants.HHBLITS_AA_TO_ID - ) - template_features = { - "template_all_atom_positions": templates_all_atom_positions[None], - "template_all_atom_masks": templates_all_atom_masks[None], - "template_aatype": np.array(templates_aatype)[None], - "template_domain_names": [f"none".encode()], - } - return template_features - - -def align_structures(pdb1, pdb2): +def renumber(struc): + """Renumber residues consecutively and remove all hetero residues""" + resid = 0 + residue_to_remove = [] + chain_to_remove = [] + for model in struc: + for chain in model: + for i, residue in enumerate(chain.get_residues()): + res_id = list(residue.id) + res_id[1] = resid + resid += 1 + residue.id = tuple(res_id) + if residue.id[0] != " ": + residue_to_remove.append((chain.id, residue.id)) + if len(chain) == 0: + chain_to_remove.append(chain.id) + for residue in residue_to_remove: + struc[0][residue[0]].detach_child(residue[1]) + + for chain in chain_to_remove: + model.detach_child(chain) + return struc + + +def align_structures(pdb1, pdb2, lenRes): + """Take two structure and superimpose pdb1 on pdb2""" import Bio.PDB - # Select what residues numbers you wish to align - # and put them in a list - # TODO Get residues from PDB file - atoms_to_be_aligned = range(start_id, end_id + 1) + # We use all residues + atoms_to_be_aligned = range(0, lenRes) - # Start the parser pdb_parser = Bio.PDB.PDBParser(QUIET=True) - # Get the structures - ref_structure = pdb_parser.get_structure("reference", pdb1) - sample_structure = pdb_parser.get_structure("samle", pdb2) - + ref_structure = pdb_parser.get_structure("samle", pdb2) + sample_structure = renumber(pdb_parser.get_structure("reference", pdb1)) # Use the first model in the pdb-files for alignment - # Change the number 0 if you want to align to another structure ref_model = ref_structure[0] sample_model = sample_structure[0] - - # Make a list of the atoms (in the structures) you wish to align. - # In this case we use CA atoms whose index is in the specified range + # Make a list of the atoms (in the structures) to align. ref_atoms = [] sample_atoms = [] - # Iterate of all chains in the model in order to find all residues for ref_chain in ref_model: # Iterate of all residues in each model in order to find proper atoms @@ -116,8 +197,6 @@ def align_structures(pdb1, pdb2): if ref_res.get_id()[1] in atoms_to_be_aligned: # Append CA atom to list ref_atoms.append(ref_res["CA"]) - - # Do the same for the sample structure for sample_chain in sample_model: for sample_res in sample_chain: if sample_res.get_id()[1] in atoms_to_be_aligned: @@ -131,67 +210,59 @@ def align_structures(pdb1, pdb2): io = Bio.PDB.PDBIO() io.set_structure(sample_structure) io.save(f"{pdb1}_aligned.pdb") - return super_imposer.rms - - -def predict_structure(prefix, feature_dict, model_runners, random_seed=0): - """Predicts structure using AlphaFold for the given sequence.""" - - # Run the models. - # currently we only run model1 - plddts = {} - for model_name, model_runner in model_runners.items(): - processed_feature_dict = model_runner.process_features( - feature_dict, random_seed=random_seed - ) - prediction_result = model_runner.predict(processed_feature_dict) - b_factors = ( - prediction_result["plddt"][:, None] - * prediction_result["structure_module"]["final_atom_mask"] - ) - unrelaxed_protein = protein.from_prediction( - processed_feature_dict, prediction_result, b_factors - ) - unrelaxed_pdb_path = f"/home/user/app/{prefix}_unrelaxed_{model_name}.pdb" - plddts[model_name] = prediction_result["plddt"] + return super_imposer.rms, f"{pdb1}_aligned.pdb" - print(f"{model_name} {plddts[model_name].mean()}") - with open(unrelaxed_pdb_path, "w") as f: - f.write(protein.to_pdb(unrelaxed_protein)) - return plddts +def save_pdb(outs, filename, LEN): + """save pdb coordinates""" + p = { + "residue_index": outs["inputs"]["residue_index"][0][:LEN], + "aatype": outs["inputs"]["aatype"].argmax(-1)[0][:LEN], + "atom_positions": outs["final_atom_positions"][:LEN], + "atom_mask": outs["final_atom_mask"][:LEN], + } + b_factors = 100.0 * outs["plddt"][:LEN, None] * p["atom_mask"] + p = protein.Protein(**p, b_factors=b_factors) + pdb_lines = protein.to_pdb(p) + with open(filename, "w") as f: + f.write(pdb_lines) + print(os.listdir(), os.getcwd()) @ray.remote(num_gpus=1, max_calls=1) -def run_alphafold(startsequence): - model_runners = {} - models = ["model_1"] # ,"model_2","model_3","model_4","model_5"] - for model_name in models: - model_config = config.model_config(model_name) - model_config.data.eval.num_ensemble = 1 - model_params = data.get_model_haiku_params( - model_name=model_name, data_dir="/home/user/app/" - ) - model_runner = model.RunModel(model_config, model_params) - model_runners[model_name] = model_runner - query_sequence = startsequence.replace("\n", "") - - feature_dict = { - **pipeline.make_sequence_features( - sequence=query_sequence, description="none", num_res=len(query_sequence) - ), - **pipeline.make_msa_features( - msas=[[query_sequence]], deletion_matrices=[[[0] * len(query_sequence)]] - ), - **mk_mock_template(query_sequence), +def run_alphafold(sequence): + recycles = 3 + RUNNER, OPT = setup_model(sequence) + + SEQ = re.sub("[^A-Z]", "", sequence.upper()) + MAX_LEN = len(SEQ) + LEN = len(SEQ) + + x = np.array([residue_constants.restype_order.get(aa, -1) for aa in SEQ]) + x = np.pad(x, [0, MAX_LEN - LEN], constant_values=-1) + x = jax.nn.one_hot(x, 20) + + OPT["prev"] = { + "init_msa_first_row": np.zeros([1, MAX_LEN, 256]), + "init_pair": np.zeros([1, MAX_LEN, MAX_LEN, 128]), + "init_pos": np.zeros([1, MAX_LEN, 37, 3]), } - print(feature_dict["residue_index"]) - plddts = predict_structure("test", feature_dict, model_runners) - print("AF2 done") - return plddts["model_1"] + + positions = [] + plddts = [] + for r in range(recycles + 1): + outs = RUNNER(x, OPT) + outs = jax.tree_map(lambda x: np.asarray(x), outs) + positions.append(outs["prev"]["init_pos"][0, :LEN]) + plddts.append(outs["plddt"][:LEN]) + OPT["prev"] = outs["prev"] + if recycles > 0: + print(r, plddts[-1].mean()) + save_pdb(outs, "out.pdb", LEN) + num_res = int(outs["inputs"]["aatype"][0].sum()) + return outs["plddt"], outs["pae"], num_res -print("Cuda available", torch.cuda.is_available()) device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") model_name = "v_48_020" # ProteinMPNN model name: v_48_002, v_48_010, v_48_020, v_48_030, v_32_002, v_32_010; v_32_020, v_32_030; v_48_010=version with 48 edges 0.10A noise backbone_noise = 0.00 # Standard deviation of Gaussian noise to add to backbone atoms @@ -225,11 +296,6 @@ model.load_state_dict(checkpoint["model_state_dict"]) model.eval() -import re - -import numpy as np - - def get_pdb(pdb_code="", filepath=""): if pdb_code is None or pdb_code == "": return filepath.name @@ -556,31 +622,38 @@ def update(inp, file, designed_chain, fixed_chain, homomer, num_seqs, sampling_t fig_tadjusted, gr.File.update(value="all_log_probs_concat.csv", visible=True), gr.File.update(value="all_probs_concat.csv", visible=True), + pdb_path, ) -def update_AF(startsequence): +def update_AF(startsequence, pdb): # # run alphafold using ray - plddts = ray.get(run_alphafold.remote(startsequence)) - print(plddts) + plddts, pae, num_res = ray.get(run_alphafold.remote(startsequence)) x = np.arange(10) - plotAF = go.Figure( + plotAF_plddt = go.Figure( data=go.Scatter( x=np.arange(len(plddts)), y=plddts, hovertemplate="pLDDT: %{y:.2f}
Residue index: %{x}", ) ) - plotAF.update_layout( + plotAF_plddt.update_layout( title="pLDDT", xaxis_title="Residue index", yaxis_title="pLDDT", height=500, template="simple_white", ) - return molecule(f"test_unrelaxed_model_1.pdb"), plotAF + + plotAF_pae = px.imshow( + pae, + labels=dict(x="Scored residue", y="Aligned residue", color=""), + ) + plotAF_pae.update_layout(title="Predicted Aligned Error", template="simple_white") + + return molecule(pdb, "af_backprop/out.pdb", num_res), plotAF_plddt, plotAF_pae def read_mol(molpath): @@ -592,8 +665,12 @@ def read_mol(molpath): return mol -def molecule(pdb): +def molecule(pdb, afpdb, num_res): + + rms, aligned_pdb = align_structures(pdb, afpdb, num_res) + mol = read_mol(pdb) + pred_mol = read_mol(aligned_pdb) x = ( """ @@ -677,9 +754,13 @@ select{ viewer = $3Dmol.createViewer( element, config ); viewer.ui.initiateUI(); let data = `""" + + pred_mol + + """` + let pdb = `""" + mol + """` viewer.addModel( data, "pdb" ); + viewer.addModel( pdb, "pdb" ); //AlphaFold code from https://gist.github.com/piroyon/30d1c1099ad488a7952c3b21a5bebc96 let colorAlpha = function (atom) { if (atom.b < 50) { @@ -721,7 +802,9 @@ select{ } }); $("#download").click(function () { - download("gradioFold_model1.pdb", data); + download(\"""" + + aligned_pdb + + """\", data); }) }); function download(filename, text) { @@ -833,16 +916,17 @@ with proteinMPNN: plot_tadjusted = gr.Plot() all_probs = gr.File(visible=False) with gr.TabItem("Structure validation w/ AF2"): - gr.Markdown("Coming soon") - # with gr.Row(): - # chosen_seq = gr.Textbox( - # label="Copy and paste a sequence for validation" - # ) - # btnAF = gr.Button("Run AF2 on sequence") - # with gr.Row(): - # mol = gr.HTML() - # plotAF = gr.Plot(label="pLDDT") - + # gr.Markdown("Coming soon") + with gr.Row(): + chosen_seq = gr.Textbox( + label="Copy and paste a sequence for validation" + ) + btnAF = gr.Button("Run AF2 on sequence") + mol = gr.HTML() + with gr.Row(): + plotAF_plddt = gr.Plot(label="pLDDT") + plotAF_pae = gr.Plot(label="PAE") + file = gr.Variable() btn.click( fn=update, inputs=[ @@ -854,13 +938,13 @@ with proteinMPNN: num_seqs, sampling_temp, ], - outputs=[out, plot, plot_tadjusted, all_log_probs, all_probs], + outputs=[out, plot, plot_tadjusted, all_log_probs, all_probs, file], + ) + btnAF.click( + fn=update_AF, + inputs=[chosen_seq, file], + outputs=[mol, plotAF_plddt, plotAF_pae], ) - # btnAF.click( - # fn=update_AF, - # inputs=[chosen_seq], - # outputs=[mol, plotAF], - # ) examples.click(fn=set_examples, inputs=examples, outputs=examples.components) gr.Markdown( """Citation: **Robust deep learning based protein sequence design using ProteinMPNN**
@@ -869,6 +953,6 @@ bioRxiv 2022.06.03.494563; doi: [10.1101/2022.06.03.494563](https://doi.org/10.1 ) -ray.init(runtime_env={"working_dir": "./alphafold"}) +ray.init(runtime_env={"working_dir": "./af_backprop"}) proteinMPNN.launch(share=True) diff --git a/colabfold.py b/colabfold.py new file mode 100644 index 0000000000000000000000000000000000000000..78253f160ce1b98aa0f7255486615c7246c9980e --- /dev/null +++ b/colabfold.py @@ -0,0 +1,691 @@ +# fmt: off + +############################################ +# imports +############################################ +import jax +import requests +import hashlib +import tarfile +import time +import pickle +import os +import re + +import random +import tqdm.notebook + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +import matplotlib.patheffects +from matplotlib import collections as mcoll + +try: + import py3Dmol +except: + pass + +from string import ascii_uppercase,ascii_lowercase + +pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00", + "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200", + "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f", + "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c", + "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"] + +pymol_cmap = matplotlib.colors.ListedColormap(pymol_color_list) +alphabet_list = list(ascii_uppercase+ascii_lowercase) + +aatypes = set('ACDEFGHIKLMNPQRSTVWY') + + +########################################### +# control gpu/cpu memory usage +########################################### +def rm(x): + '''remove data from device''' + jax.tree_util.tree_map(lambda y: y.device_buffer.delete(), x) + +def to(x,device="cpu"): + '''move data to device''' + d = jax.devices(device)[0] + return jax.tree_util.tree_map(lambda y:jax.device_put(y,d), x) + +def clear_mem(device="gpu"): + '''remove all data from device''' + backend = jax.lib.xla_bridge.get_backend(device) + for buf in backend.live_buffers(): buf.delete() + +########################################## +# call mmseqs2 +########################################## + +TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]' + +def run_mmseqs2(x, prefix, use_env=True, use_filter=True, + use_templates=False, filter=None, host_url="https://a3m.mmseqs.com"): + + def submit(seqs, mode, N=101): + n,query = N,"" + for seq in seqs: + query += f">{n}\n{seq}\n" + n += 1 + + res = requests.post(f'{host_url}/ticket/msa', data={'q':query,'mode': mode}) + try: out = res.json() + except ValueError: out = {"status":"UNKNOWN"} + return out + + def status(ID): + res = requests.get(f'{host_url}/ticket/{ID}') + try: out = res.json() + except ValueError: out = {"status":"UNKNOWN"} + return out + + def download(ID, path): + res = requests.get(f'{host_url}/result/download/{ID}') + with open(path,"wb") as out: out.write(res.content) + + # process input x + seqs = [x] if isinstance(x, str) else x + + # compatibility to old option + if filter is not None: + use_filter = filter + + # setup mode + if use_filter: + mode = "env" if use_env else "all" + else: + mode = "env-nofilter" if use_env else "nofilter" + + # define path + path = f"{prefix}_{mode}" + if not os.path.isdir(path): os.mkdir(path) + + # call mmseqs2 api + tar_gz_file = f'{path}/out.tar.gz' + N,REDO = 101,True + + # deduplicate and keep track of order + seqs_unique = sorted(list(set(seqs))) + Ms = [N+seqs_unique.index(seq) for seq in seqs] + + # lets do it! + if not os.path.isfile(tar_gz_file): + TIME_ESTIMATE = 150 * len(seqs_unique) + with tqdm.notebook.tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar: + while REDO: + pbar.set_description("SUBMIT") + + # Resubmit job until it goes through + out = submit(seqs_unique, mode, N) + while out["status"] in ["UNKNOWN","RATELIMIT"]: + # resubmit + time.sleep(5 + random.randint(0,5)) + out = submit(seqs_unique, mode, N) + + if out["status"] == "ERROR": + raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.') + + if out["status"] == "MAINTENANCE": + raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.') + + # wait for job to finish + ID,TIME = out["id"],0 + pbar.set_description(out["status"]) + while out["status"] in ["UNKNOWN","RUNNING","PENDING"]: + t = 5 + random.randint(0,5) + time.sleep(t) + out = status(ID) + pbar.set_description(out["status"]) + if out["status"] == "RUNNING": + TIME += t + pbar.update(n=t) + #if TIME > 900 and out["status"] != "COMPLETE": + # # something failed on the server side, need to resubmit + # N += 1 + # break + + if out["status"] == "COMPLETE": + if TIME < TIME_ESTIMATE: + pbar.update(n=(TIME_ESTIMATE-TIME)) + REDO = False + + # Download results + download(ID, tar_gz_file) + + # prep list of a3m files + a3m_files = [f"{path}/uniref.a3m"] + if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m") + + # extract a3m files + if not os.path.isfile(a3m_files[0]): + with tarfile.open(tar_gz_file) as tar_gz: + tar_gz.extractall(path) + + # templates + if use_templates: + templates = {} + print("seq\tpdb\tcid\tevalue") + for line in open(f"{path}/pdb70.m8","r"): + p = line.rstrip().split() + M,pdb,qid,e_value = p[0],p[1],p[2],p[10] + M = int(M) + if M not in templates: templates[M] = [] + templates[M].append(pdb) + if len(templates[M]) <= 20: + print(f"{int(M)-N}\t{pdb}\t{qid}\t{e_value}") + + template_paths = {} + for k,TMPL in templates.items(): + TMPL_PATH = f"{prefix}_{mode}/templates_{k}" + if not os.path.isdir(TMPL_PATH): + os.mkdir(TMPL_PATH) + TMPL_LINE = ",".join(TMPL[:20]) + os.system(f"curl -s https://a3m-templates.mmseqs.com/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/") + os.system(f"cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex") + os.system(f"touch {TMPL_PATH}/pdb70_cs219.ffdata") + template_paths[k] = TMPL_PATH + + # gather a3m lines + a3m_lines = {} + for a3m_file in a3m_files: + update_M,M = True,None + for line in open(a3m_file,"r"): + if len(line) > 0: + if "\x00" in line: + line = line.replace("\x00","") + update_M = True + if line.startswith(">") and update_M: + M = int(line[1:].rstrip()) + update_M = False + if M not in a3m_lines: a3m_lines[M] = [] + a3m_lines[M].append(line) + + # return results + a3m_lines = ["".join(a3m_lines[n]) for n in Ms] + + if use_templates: + template_paths_ = [] + for n in Ms: + if n not in template_paths: + template_paths_.append(None) + print(f"{n-N}\tno_templates_found") + else: + template_paths_.append(template_paths[n]) + template_paths = template_paths_ + + if isinstance(x, str): + return (a3m_lines[0], template_paths[0]) if use_templates else a3m_lines[0] + else: + return (a3m_lines, template_paths) if use_templates else a3m_lines + + +######################################################################### +# utils +######################################################################### +def get_hash(x): + return hashlib.sha1(x.encode()).hexdigest() + +def homooligomerize(msas, deletion_matrices, homooligomer=1): + if homooligomer == 1: + return msas, deletion_matrices + else: + new_msas = [] + new_mtxs = [] + for o in range(homooligomer): + for msa,mtx in zip(msas, deletion_matrices): + num_res = len(msa[0]) + L = num_res * o + R = num_res * (homooligomer-(o+1)) + new_msas.append(["-"*L+s+"-"*R for s in msa]) + new_mtxs.append([[0]*L+m+[0]*R for m in mtx]) + return new_msas, new_mtxs + +# keeping typo for cross-compatibility +def homooliomerize(msas, deletion_matrices, homooligomer=1): + return homooligomerize(msas, deletion_matrices, homooligomer=homooligomer) + +def homooligomerize_heterooligomer(msas, deletion_matrices, lengths, homooligomers): + ''' + ----- inputs ----- + msas: list of msas + deletion_matrices: list of deletion matrices + lengths: list of lengths for each component in complex + homooligomers: list of number of homooligomeric copies for each component + ----- outputs ----- + (msas, deletion_matrices) + ''' + if max(homooligomers) == 1: + return msas, deletion_matrices + + elif len(homooligomers) == 1: + return homooligomerize(msas, deletion_matrices, homooligomers[0]) + + else: + frag_ij = [[0,lengths[0]]] + for length in lengths[1:]: + j = frag_ij[-1][-1] + frag_ij.append([j,j+length]) + + # for every msa + mod_msas, mod_mtxs = [],[] + for msa, mtx in zip(msas, deletion_matrices): + mod_msa, mod_mtx = [],[] + # for every sequence + for n,(s,m) in enumerate(zip(msa,mtx)): + # split sequence + _s,_m,_ok = [],[],[] + for i,j in frag_ij: + _s.append(s[i:j]); _m.append(m[i:j]) + _ok.append(max([o != "-" for o in _s[-1]])) + + if n == 0: + # if first query sequence + mod_msa.append("".join([x*h for x,h in zip(_s,homooligomers)])) + mod_mtx.append(sum([x*h for x,h in zip(_m,homooligomers)],[])) + + elif sum(_ok) == 1: + # elif one fragment: copy each fragment to every homooligomeric copy + a = _ok.index(True) + for h_a in range(homooligomers[a]): + _blank_seq = [["-"*l]*h for l,h in zip(lengths,homooligomers)] + _blank_mtx = [[[0]*l]*h for l,h in zip(lengths,homooligomers)] + _blank_seq[a][h_a] = _s[a] + _blank_mtx[a][h_a] = _m[a] + mod_msa.append("".join(["".join(x) for x in _blank_seq])) + mod_mtx.append(sum([sum(x,[]) for x in _blank_mtx],[])) + else: + # else: copy fragment pair to every homooligomeric copy pair + for a in range(len(lengths)-1): + if _ok[a]: + for b in range(a+1,len(lengths)): + if _ok[b]: + for h_a in range(homooligomers[a]): + for h_b in range(homooligomers[b]): + _blank_seq = [["-"*l]*h for l,h in zip(lengths,homooligomers)] + _blank_mtx = [[[0]*l]*h for l,h in zip(lengths,homooligomers)] + for c,h_c in zip([a,b],[h_a,h_b]): + _blank_seq[c][h_c] = _s[c] + _blank_mtx[c][h_c] = _m[c] + mod_msa.append("".join(["".join(x) for x in _blank_seq])) + mod_mtx.append(sum([sum(x,[]) for x in _blank_mtx],[])) + mod_msas.append(mod_msa) + mod_mtxs.append(mod_mtx) + return mod_msas, mod_mtxs + +def chain_break(idx_res, Ls, length=200): + # Minkyung's code + # add big enough number to residue index to indicate chain breaks + L_prev = 0 + for L_i in Ls[:-1]: + idx_res[L_prev+L_i:] += length + L_prev += L_i + return idx_res + +################################################## +# plotting +################################################## + +def plot_plddt_legend(dpi=100): + thresh = ['plDDT:','Very low (<50)','Low (60)','OK (70)','Confident (80)','Very high (>90)'] + plt.figure(figsize=(1,0.1),dpi=dpi) + ######################################## + for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]: + plt.bar(0, 0, color=c) + plt.legend(thresh, frameon=False, + loc='center', ncol=6, + handletextpad=1, + columnspacing=1, + markerscale=0.5,) + plt.axis(False) + return plt + +def plot_ticks(Ls): + Ln = sum(Ls) + L_prev = 0 + for L_i in Ls[:-1]: + L = L_prev + L_i + L_prev += L_i + plt.plot([0,Ln],[L,L],color="black") + plt.plot([L,L],[0,Ln],color="black") + ticks = np.cumsum([0]+Ls) + ticks = (ticks[1:] + ticks[:-1])/2 + plt.yticks(ticks,alphabet_list[:len(ticks)]) + +def plot_confidence(plddt, pae=None, Ls=None, dpi=100): + use_ptm = False if pae is None else True + if use_ptm: + plt.figure(figsize=(10,3), dpi=dpi) + plt.subplot(1,2,1); + else: + plt.figure(figsize=(5,3), dpi=dpi) + plt.title('Predicted lDDT') + plt.plot(plddt) + if Ls is not None: + L_prev = 0 + for L_i in Ls[:-1]: + L = L_prev + L_i + L_prev += L_i + plt.plot([L,L],[0,100],color="black") + plt.ylim(0,100) + plt.ylabel('plDDT') + plt.xlabel('position') + if use_ptm: + plt.subplot(1,2,2);plt.title('Predicted Aligned Error') + Ln = pae.shape[0] + plt.imshow(pae,cmap="bwr",vmin=0,vmax=30,extent=(0, Ln, Ln, 0)) + if Ls is not None and len(Ls) > 1: plot_ticks(Ls) + plt.colorbar() + plt.xlabel('Scored residue') + plt.ylabel('Aligned residue') + return plt + +def plot_msas(msas, ori_seq=None, sort_by_seqid=True, deduplicate=True, dpi=100, return_plt=True): + ''' + plot the msas + ''' + if ori_seq is None: ori_seq = msas[0][0] + seqs = ori_seq.replace("/","").split(":") + seqs_dash = ori_seq.replace(":","").split("/") + + Ln = np.cumsum(np.append(0,[len(seq) for seq in seqs])) + Ln_dash = np.cumsum(np.append(0,[len(seq) for seq in seqs_dash])) + Nn,lines = [],[] + for msa in msas: + msa_ = set(msa) if deduplicate else msa + if len(msa_) > 0: + Nn.append(len(msa_)) + msa_ = np.asarray([list(seq) for seq in msa_]) + gap_ = msa_ != "-" + qid_ = msa_ == np.array(list("".join(seqs))) + gapid = np.stack([gap_[:,Ln[i]:Ln[i+1]].max(-1) for i in range(len(seqs))],-1) + seqid = np.stack([qid_[:,Ln[i]:Ln[i+1]].mean(-1) for i in range(len(seqs))],-1).sum(-1) / (gapid.sum(-1) + 1e-8) + non_gaps = gap_.astype(np.float) + non_gaps[non_gaps == 0] = np.nan + if sort_by_seqid: + lines.append(non_gaps[seqid.argsort()]*seqid[seqid.argsort(),None]) + else: + lines.append(non_gaps[::-1] * seqid[::-1,None]) + + Nn = np.cumsum(np.append(0,Nn)) + lines = np.concatenate(lines,0) + + if return_plt: + plt.figure(figsize=(8,5),dpi=dpi) + plt.title("Sequence coverage") + plt.imshow(lines, + interpolation='nearest', aspect='auto', + cmap="rainbow_r", vmin=0, vmax=1, origin='lower', + extent=(0, lines.shape[1], 0, lines.shape[0])) + for i in Ln[1:-1]: + plt.plot([i,i],[0,lines.shape[0]],color="black") + for i in Ln_dash[1:-1]: + plt.plot([i,i],[0,lines.shape[0]],"--",color="black") + for j in Nn[1:-1]: + plt.plot([0,lines.shape[1]],[j,j],color="black") + + plt.plot((np.isnan(lines) == False).sum(0), color='black') + plt.xlim(0,lines.shape[1]) + plt.ylim(0,lines.shape[0]) + plt.colorbar(label="Sequence identity to query") + plt.xlabel("Positions") + plt.ylabel("Sequences") + if return_plt: return plt + +def read_pdb_renum(pdb_filename, Ls=None): + if Ls is not None: + L_init = 0 + new_chain = {} + for L,c in zip(Ls, alphabet_list): + new_chain.update({i:c for i in range(L_init,L_init+L)}) + L_init += L + + n,pdb_out = 1,[] + resnum_,chain_ = 1,"A" + for line in open(pdb_filename,"r"): + if line[:4] == "ATOM": + chain = line[21:22] + resnum = int(line[22:22+5]) + if resnum != resnum_ or chain != chain_: + resnum_,chain_ = resnum,chain + n += 1 + if Ls is None: pdb_out.append("%s%4i%s" % (line[:22],n,line[26:])) + else: pdb_out.append("%s%s%4i%s" % (line[:21],new_chain[n-1],n,line[26:])) + return "".join(pdb_out) + +def show_pdb(pred_output_path, show_sidechains=False, show_mainchains=False, + color="lDDT", chains=None, Ls=None, vmin=50, vmax=90, + color_HP=False, size=(800,480)): + + if chains is None: + chains = 1 if Ls is None else len(Ls) + + view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1]) + view.addModel(read_pdb_renum(pred_output_path, Ls),'pdb') + if color == "lDDT": + view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':vmin,'max':vmax}}}) + elif color == "rainbow": + view.setStyle({'cartoon': {'color':'spectrum'}}) + elif color == "chain": + for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list): + view.setStyle({'chain':chain},{'cartoon': {'color':color}}) + if show_sidechains: + BB = ['C','O','N'] + HP = ["ALA","GLY","VAL","ILE","LEU","PHE","MET","PRO","TRP","CYS","TYR"] + if color_HP: + view.addStyle({'and':[{'resn':HP},{'atom':BB,'invert':True}]}, + {'stick':{'colorscheme':"yellowCarbon",'radius':0.3}}) + view.addStyle({'and':[{'resn':HP,'invert':True},{'atom':BB,'invert':True}]}, + {'stick':{'colorscheme':"whiteCarbon",'radius':0.3}}) + view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]}, + {'sphere':{'colorscheme':"yellowCarbon",'radius':0.3}}) + view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]}, + {'stick':{'colorscheme':"yellowCarbon",'radius':0.3}}) + else: + view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]}, + {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) + view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]}, + {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) + view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]}, + {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) + if show_mainchains: + BB = ['C','O','N','CA'] + view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}}) + view.zoomTo() + return view + +def plot_plddts(plddts, Ls=None, dpi=100, fig=True): + if fig: plt.figure(figsize=(8,5),dpi=100) + plt.title("Predicted lDDT per position") + for n,plddt in enumerate(plddts): + plt.plot(plddt,label=f"rank_{n+1}") + if Ls is not None: + L_prev = 0 + for L_i in Ls[:-1]: + L = L_prev + L_i + L_prev += L_i + plt.plot([L,L],[0,100],color="black") + plt.legend() + plt.ylim(0,100) + plt.ylabel("Predicted lDDT") + plt.xlabel("Positions") + return plt + +def plot_paes(paes, Ls=None, dpi=100, fig=True): + num_models = len(paes) + if fig: plt.figure(figsize=(3*num_models,2), dpi=dpi) + for n,pae in enumerate(paes): + plt.subplot(1,num_models,n+1) + plt.title(f"rank_{n+1}") + Ln = pae.shape[0] + plt.imshow(pae,cmap="bwr",vmin=0,vmax=30,extent=(0, Ln, Ln, 0)) + if Ls is not None and len(Ls) > 1: plot_ticks(Ls) + plt.colorbar() + return plt + +def plot_adjs(adjs, Ls=None, dpi=100, fig=True): + num_models = len(adjs) + if fig: plt.figure(figsize=(3*num_models,2), dpi=dpi) + for n,adj in enumerate(adjs): + plt.subplot(1,num_models,n+1) + plt.title(f"rank_{n+1}") + Ln = adj.shape[0] + plt.imshow(adj,cmap="binary",vmin=0,vmax=1,extent=(0, Ln, Ln, 0)) + if Ls is not None and len(Ls) > 1: plot_ticks(Ls) + plt.colorbar() + return plt + +def plot_dists(dists, Ls=None, dpi=100, fig=True): + num_models = len(dists) + if fig: plt.figure(figsize=(3*num_models,2), dpi=dpi) + for n,dist in enumerate(dists): + plt.subplot(1,num_models,n+1) + plt.title(f"rank_{n+1}") + Ln = dist.shape[0] + plt.imshow(dist,extent=(0, Ln, Ln, 0)) + if Ls is not None and len(Ls) > 1: plot_ticks(Ls) + plt.colorbar() + return plt + +########################################################################## +########################################################################## + +def kabsch(a, b, weights=None, return_v=False): + a = np.asarray(a) + b = np.asarray(b) + if weights is None: weights = np.ones(len(b)) + else: weights = np.asarray(weights) + B = np.einsum('ji,jk->ik', weights[:, None] * a, b) + u, s, vh = np.linalg.svd(B) + if np.linalg.det(u @ vh) < 0: u[:, -1] = -u[:, -1] + if return_v: return u + else: return u @ vh + +def plot_pseudo_3D(xyz, c=None, ax=None, chainbreak=5, + cmap="gist_rainbow", line_w=2.0, + cmin=None, cmax=None, zmin=None, zmax=None): + + def rescale(a,amin=None,amax=None): + a = np.copy(a) + if amin is None: amin = a.min() + if amax is None: amax = a.max() + a[a < amin] = amin + a[a > amax] = amax + return (a - amin)/(amax - amin) + + # make segments + xyz = np.asarray(xyz) + seg = np.concatenate([xyz[:-1,None,:],xyz[1:,None,:]],axis=-2) + seg_xy = seg[...,:2] + seg_z = seg[...,2].mean(-1) + ord = seg_z.argsort() + + # set colors + if c is None: c = np.arange(len(seg))[::-1] + else: c = (c[1:] + c[:-1])/2 + c = rescale(c,cmin,cmax) + + if isinstance(cmap, str): + if cmap == "gist_rainbow": c *= 0.75 + colors = matplotlib.cm.get_cmap(cmap)(c) + else: + colors = cmap(c) + + if chainbreak is not None: + dist = np.linalg.norm(xyz[:-1] - xyz[1:], axis=-1) + colors[...,3] = (dist < chainbreak).astype(np.float) + + # add shade/tint based on z-dimension + z = rescale(seg_z,zmin,zmax)[:,None] + tint, shade = z/3, (z+2)/3 + colors[:,:3] = colors[:,:3] + (1 - colors[:,:3]) * tint + colors[:,:3] = colors[:,:3] * shade + + set_lim = False + if ax is None: + fig, ax = plt.subplots() + fig.set_figwidth(5) + fig.set_figheight(5) + set_lim = True + else: + fig = ax.get_figure() + if ax.get_xlim() == (0,1): + set_lim = True + + if set_lim: + xy_min = xyz[:,:2].min() - line_w + xy_max = xyz[:,:2].max() + line_w + ax.set_xlim(xy_min,xy_max) + ax.set_ylim(xy_min,xy_max) + + ax.set_aspect('equal') + + # determine linewidths + width = fig.bbox_inches.width * ax.get_position().width + linewidths = line_w * 72 * width / np.diff(ax.get_xlim()) + + lines = mcoll.LineCollection(seg_xy[ord], colors=colors[ord], linewidths=linewidths, + path_effects=[matplotlib.patheffects.Stroke(capstyle="round")]) + + return ax.add_collection(lines) + +def add_text(text, ax): + return plt.text(0.5, 1.01, text, horizontalalignment='center', + verticalalignment='bottom', transform=ax.transAxes) + +def plot_protein(protein=None, pos=None, plddt=None, Ls=None, dpi=100, best_view=True, line_w=2.0): + + if protein is not None: + pos = np.asarray(protein.atom_positions[:,1,:]) + plddt = np.asarray(protein.b_factors[:,0]) + + # get best view + if best_view: + if plddt is not None: + weights = plddt/100 + pos = pos - (pos * weights[:,None]).sum(0,keepdims=True) / weights.sum() + pos = pos @ kabsch(pos, pos, weights, return_v=True) + else: + pos = pos - pos.mean(0,keepdims=True) + pos = pos @ kabsch(pos, pos, return_v=True) + + if plddt is not None: + fig, (ax1, ax2) = plt.subplots(1,2) + fig.set_figwidth(6); fig.set_figheight(3) + ax = [ax1, ax2] + else: + fig, ax1 = plt.subplots(1,1) + fig.set_figwidth(3); fig.set_figheight(3) + ax = [ax1] + + fig.set_dpi(dpi) + fig.subplots_adjust(top = 0.9, bottom = 0.1, right = 1, left = 0, hspace = 0, wspace = 0) + + xy_min = pos[...,:2].min() - line_w + xy_max = pos[...,:2].max() + line_w + for a in ax: + a.set_xlim(xy_min, xy_max) + a.set_ylim(xy_min, xy_max) + a.axis(False) + + if Ls is None or len(Ls) == 1: + # color N->C + c = np.arange(len(pos))[::-1] + plot_pseudo_3D(pos, line_w=line_w, ax=ax1) + add_text("colored by N→C", ax1) + else: + # color by chain + c = np.concatenate([[n]*L for n,L in enumerate(Ls)]) + if len(Ls) > 40: plot_pseudo_3D(pos, c=c, line_w=line_w, ax=ax1) + else: plot_pseudo_3D(pos, c=c, cmap=pymol_cmap, cmin=0, cmax=39, line_w=line_w, ax=ax1) + add_text("colored by chain", ax1) + + if plddt is not None: + # color by pLDDT + plot_pseudo_3D(pos, c=plddt, cmin=50, cmax=90, line_w=line_w, ax=ax2) + add_text("colored by pLDDT", ax2) + + return fig diff --git a/requirements.txt b/requirements.txt index 60954ce7a5aaff0420dcf289cfa6b4f44b731dad..5d4a21ae541ac37ef5ac6613776bf3373427f7d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -transformers==4.19.2 absl-py==0.13.0 biopython==1.79 chex==0.0.7 @@ -6,16 +5,17 @@ dm-haiku==0.0.5 dm-tree==0.1.6 docker==5.0.0 immutabledict==2.0.0 -jax[cuda]<0.3.0 -jaxlib==0.1.76 +jax[cuda]==0.3.8 +jaxlib==0.3.7 ml-collections==0.1.0 numpy==1.19.5 pandas==1.3.4 scipy==1.7.0 -tensorflow-gpu==2.5.0 +tensorflow torch plotly GPUtil ray +tqdm protobuf<4 -f https://storage.googleapis.com/jax-releases/jax_releases.html