|
from typing import List, Dict, Any |
|
from openfold.utils import rigid_utils as ru |
|
from data import residue_constants |
|
import numpy as np |
|
import collections |
|
import string |
|
import pickle |
|
import os |
|
import torch |
|
from torch_scatter import scatter_add, scatter |
|
from Bio.PDB.Chain import Chain |
|
from data import protein |
|
import dataclasses |
|
from Bio import PDB |
|
|
|
Rigid = ru.Rigid |
|
Protein = protein.Protein |
|
|
|
|
|
ALPHANUMERIC = string.ascii_letters + string.digits + ' ' |
|
CHAIN_TO_INT = { |
|
chain_char: i for i, chain_char in enumerate(ALPHANUMERIC) |
|
} |
|
INT_TO_CHAIN = { |
|
i: chain_char for i, chain_char in enumerate(ALPHANUMERIC) |
|
} |
|
|
|
NM_TO_ANG_SCALE = 10.0 |
|
ANG_TO_NM_SCALE = 1 / NM_TO_ANG_SCALE |
|
|
|
CHAIN_FEATS = [ |
|
'atom_positions', 'aatype', 'atom_mask', 'residue_index', 'b_factors' |
|
] |
|
|
|
to_numpy = lambda x: x.detach().cpu().numpy() |
|
aatype_to_seq = lambda aatype: ''.join([ |
|
residue_constants.restypes_with_x[x] for x in aatype]) |
|
|
|
|
|
class CPU_Unpickler(pickle.Unpickler): |
|
"""Pytorch pickle loading workaround. |
|
|
|
https://github.com/pytorch/pytorch/issues/16797 |
|
""" |
|
def find_class(self, module, name): |
|
if module == 'torch.storage' and name == '_load_from_bytes': |
|
return lambda b: torch.load(io.BytesIO(b), map_location='cpu') |
|
else: return super().find_class(module, name) |
|
|
|
|
|
def create_rigid(rots, trans): |
|
rots = ru.Rotation(rot_mats=rots) |
|
return Rigid(rots=rots, trans=trans) |
|
|
|
|
|
def batch_align_structures(pos_1, pos_2, mask=None): |
|
if pos_1.shape != pos_2.shape: |
|
raise ValueError('pos_1 and pos_2 must have the same shape.') |
|
if pos_1.ndim != 3: |
|
raise ValueError(f'Expected inputs to have shape [B, N, 3]') |
|
num_batch = pos_1.shape[0] |
|
device = pos_1.device |
|
batch_indices = ( |
|
torch.ones(*pos_1.shape[:2], device=device, dtype=torch.int64) |
|
* torch.arange(num_batch, device=device)[:, None] |
|
) |
|
flat_pos_1 = pos_1.reshape(-1, 3) |
|
flat_pos_2 = pos_2.reshape(-1, 3) |
|
flat_batch_indices = batch_indices.reshape(-1) |
|
if mask is None: |
|
|
|
|
|
|
|
|
|
|
|
mask = torch.ones(*pos_1.shape[:2], device=device).reshape(-1).bool() |
|
|
|
flat_mask = mask.reshape(-1).bool() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
aligned_pos_1, aligned_pos_2, align_rots = align_structures( |
|
flat_pos_1[flat_mask], flat_batch_indices[flat_mask], flat_pos_2[flat_mask]) |
|
aligned_pos_1 = aligned_pos_1.reshape(num_batch, -1, 3) |
|
aligned_pos_2 = aligned_pos_2.reshape(num_batch, -1, 3) |
|
return aligned_pos_1, aligned_pos_2, align_rots |
|
|
|
|
|
|
|
def adjust_oxygen_pos( |
|
atom_37: torch.Tensor, pos_is_known = None |
|
) -> torch.Tensor: |
|
""" |
|
Imputes the position of the oxygen atom on the backbone by using adjacent frame information. |
|
Specifically, we say that the oxygen atom is in the plane created by the Calpha and C from the |
|
current frame and the nitrogen of the next frame. The oxygen is then placed c_o_bond_length Angstrom |
|
away from the C in the current frame in the direction away from the Ca-C-N triangle. |
|
|
|
For cases where the next frame is not available, for example we are at the C-terminus or the |
|
next frame is not available in the data then we place the oxygen in the same plane as the |
|
N-Ca-C of the current frame and pointing in the same direction as the average of the |
|
Ca->C and Ca->N vectors. |
|
|
|
Args: |
|
atom_37 (torch.Tensor): (N, 37, 3) tensor of positions of the backbone atoms in atom_37 ordering |
|
which is ['N', 'CA', 'C', 'CB', 'O', ...] |
|
pos_is_known (torch.Tensor): (N,) mask for known residues. |
|
""" |
|
|
|
N = atom_37.shape[0] |
|
assert atom_37.shape == (N, 37, 3) |
|
|
|
|
|
|
|
|
|
|
|
calpha_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[:-1, 1, :]) / ( |
|
torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-7 |
|
) |
|
|
|
|
|
|
|
|
|
nitrogen_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[1:, 0, :]) / ( |
|
torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-7 |
|
) |
|
|
|
carbonyl_to_oxygen: torch.Tensor = calpha_to_carbonyl + nitrogen_to_carbonyl |
|
carbonyl_to_oxygen = carbonyl_to_oxygen / ( |
|
torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-7 |
|
) |
|
|
|
atom_37[:-1, 4, :] = atom_37[:-1, 2, :] + carbonyl_to_oxygen * 1.23 |
|
|
|
|
|
|
|
|
|
calpha_to_carbonyl_term: torch.Tensor = (atom_37[:, 2, :] - atom_37[:, 1, :]) / ( |
|
torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 |
|
) |
|
|
|
calpha_to_nitrogen_term: torch.Tensor = (atom_37[:, 0, :] - atom_37[:, 1, :]) / ( |
|
torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 |
|
) |
|
carbonyl_to_oxygen_term: torch.Tensor = ( |
|
calpha_to_carbonyl_term + calpha_to_nitrogen_term |
|
) |
|
carbonyl_to_oxygen_term = carbonyl_to_oxygen_term / ( |
|
torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-7 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
if pos_is_known is None: |
|
pos_is_known = torch.ones((atom_37.shape[0],), dtype=torch.int64, device=atom_37.device) |
|
|
|
next_res_gone: torch.Tensor = ~pos_is_known.bool() |
|
next_res_gone = torch.cat( |
|
[next_res_gone, torch.ones((1,), device=pos_is_known.device).bool()], dim=0 |
|
) |
|
next_res_gone = next_res_gone[1:] |
|
|
|
atom_37[next_res_gone, 4, :] = ( |
|
atom_37[next_res_gone, 2, :] |
|
+ carbonyl_to_oxygen_term[next_res_gone, :] * 1.23 |
|
) |
|
|
|
return atom_37 |
|
|
|
|
|
def write_pkl( |
|
save_path: str, pkl_data: Any, create_dir: bool = False, use_torch=False): |
|
"""Serialize data into a pickle file.""" |
|
if create_dir: |
|
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
if use_torch: |
|
torch.save(pkl_data, save_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) |
|
else: |
|
with open(save_path, 'wb') as handle: |
|
pickle.dump(pkl_data, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
|
|
def read_pkl(read_path: str, verbose=True, use_torch=False, map_location=None): |
|
"""Read data from a pickle file.""" |
|
try: |
|
if use_torch: |
|
return torch.load(read_path, map_location=map_location) |
|
else: |
|
with open(read_path, 'rb') as handle: |
|
return pickle.load(handle) |
|
except Exception as e: |
|
try: |
|
with open(read_path, 'rb') as handle: |
|
return CPU_Unpickler(handle).load() |
|
except Exception as e2: |
|
if verbose: |
|
print(f'Failed to read {read_path}. First error: {e}\n Second error: {e2}') |
|
raise(e) |
|
|
|
|
|
def chain_str_to_int(chain_str: str): |
|
chain_int = 0 |
|
if len(chain_str) == 1: |
|
return CHAIN_TO_INT[chain_str] |
|
for i, chain_char in enumerate(chain_str): |
|
chain_int += CHAIN_TO_INT[chain_char] + (i * len(ALPHANUMERIC)) |
|
return chain_int |
|
|
|
|
|
def parse_chain_feats(chain_feats, scale_factor=1.): |
|
ca_idx = residue_constants.atom_order['CA'] |
|
chain_feats['bb_mask'] = chain_feats['atom_mask'][:, ca_idx] |
|
bb_pos = chain_feats['atom_positions'][:, ca_idx] |
|
bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-5) |
|
centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :] |
|
scaled_pos = centered_pos / scale_factor |
|
chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None] |
|
chain_feats['bb_positions'] = chain_feats['atom_positions'][:, ca_idx] |
|
return chain_feats |
|
|
|
|
|
def concat_np_features( |
|
np_dicts: List[Dict[str, np.ndarray]], add_batch_dim: bool): |
|
"""Performs a nested concatenation of feature dicts. |
|
|
|
Args: |
|
np_dicts: list of dicts with the same structure. |
|
Each dict must have the same keys and numpy arrays as the values. |
|
add_batch_dim: whether to add a batch dimension to each feature. |
|
|
|
Returns: |
|
A single dict with all the features concatenated. |
|
""" |
|
combined_dict = collections.defaultdict(list) |
|
for chain_dict in np_dicts: |
|
for feat_name, feat_val in chain_dict.items(): |
|
if add_batch_dim: |
|
feat_val = feat_val[None] |
|
combined_dict[feat_name].append(feat_val) |
|
|
|
for feat_name, feat_vals in combined_dict.items(): |
|
combined_dict[feat_name] = np.concatenate(feat_vals, axis=0) |
|
return combined_dict |
|
|
|
|
|
def center_zero(pos: torch.Tensor, batch_indexes: torch.LongTensor) -> torch.Tensor: |
|
""" |
|
Move the molecule center to zero for sparse position tensors. |
|
|
|
Args: |
|
pos: [N, 3] batch positions of atoms in the molecule in sparse batch format. |
|
batch_indexes: [N] batch index for each atom in sparse batch format. |
|
|
|
Returns: |
|
pos: [N, 3] zero-centered batch positions of atoms in the molecule in sparse batch format. |
|
""" |
|
assert len(pos.shape) == 2 and pos.shape[-1] == 3, "pos must have shape [N, 3]" |
|
|
|
means = scatter(pos, batch_indexes, dim=0, reduce="mean") |
|
return pos - means[batch_indexes] |
|
|
|
|
|
@torch.no_grad() |
|
def align_structures( |
|
batch_positions: torch.Tensor, |
|
batch_indices: torch.Tensor, |
|
reference_positions: torch.Tensor, |
|
broadcast_reference: bool = False, |
|
): |
|
""" |
|
Align structures in a ChemGraph batch to a reference, e.g. for RMSD computation. This uses the |
|
sparse formulation of pytorch geometric. If the ChemGraph is composed of a single system, then |
|
the reference can be given as a single structure and broadcasted. Returns the structure |
|
coordinates shifted to the geometric center and the batch structures rotated to match the |
|
reference structures. Uses the Kabsch algorithm (see e.g. [kabsch_align1]_). No permutation of |
|
atoms is carried out. |
|
|
|
Args: |
|
batch_positions (Tensor): Batch of structures (e.g. from ChemGraph) which should be aligned |
|
to a reference. |
|
batch_indices (Tensor): Index tensor mapping each node / atom in batch to the respective |
|
system (e.g. batch attribute of ChemGraph batch). |
|
reference_positions (Tensor): Reference structure. Can either be a batch of structures or a |
|
single structure. In the second case, broadcasting is possible if the input batch is |
|
composed exclusively of this structure. |
|
broadcast_reference (bool, optional): If reference batch contains only a single structure, |
|
broadcast this structure to match the ChemGraph batch. Defaults to False. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: Tensors containing the centered positions of batch |
|
structures rotated into the reference and the centered reference batch. |
|
|
|
References |
|
---------- |
|
.. [kabsch_align1] Lawrence, Bernal, Witzgall: |
|
A purely algebraic justification of the Kabsch-Umeyama algorithm. |
|
Journal of research of the National Institute of Standards and Technology, 124, 1. 2019. |
|
""" |
|
|
|
|
|
|
|
|
|
if batch_positions.shape[0] != reference_positions.shape[0]: |
|
if broadcast_reference: |
|
|
|
|
|
|
|
num_molecules = int(torch.max(batch_indices) + 1) |
|
reference_positions = reference_positions.repeat(num_molecules, 1) |
|
else: |
|
raise ValueError("Mismatch in batch dimensions.") |
|
|
|
|
|
batch_positions = center_zero(batch_positions, batch_indices) |
|
reference_positions = center_zero(reference_positions, batch_indices) |
|
|
|
|
|
cov = scatter_add( |
|
batch_positions[:, None, :] * reference_positions[:, :, None], batch_indices, dim=0 |
|
) |
|
|
|
|
|
u, _, v_t = torch.linalg.svd(cov) |
|
|
|
u_t = u.transpose(1, 2) |
|
v = v_t.transpose(1, 2) |
|
|
|
|
|
|
|
sign_correction = torch.sign(torch.linalg.det(torch.bmm(v, u_t))) |
|
|
|
u_t[:, 2, :] = u_t[:, 2, :] * sign_correction[:, None] |
|
|
|
|
|
rotation_matrices = torch.bmm(v, u_t) |
|
|
|
|
|
|
|
rotation_matrices = rotation_matrices.type(batch_positions.dtype) |
|
|
|
|
|
batch_positions_rotated = torch.bmm( |
|
batch_positions[:, None, :], |
|
rotation_matrices[batch_indices], |
|
).squeeze(1) |
|
|
|
return batch_positions_rotated, reference_positions, rotation_matrices |
|
|
|
|
|
def parse_pdb_feats( |
|
pdb_name: str, |
|
pdb_path: str, |
|
scale_factor=1., |
|
|
|
chain_id='A', |
|
): |
|
""" |
|
Args: |
|
pdb_name: name of PDB to parse. |
|
pdb_path: path to PDB file to read. |
|
scale_factor: factor to scale atom positions. |
|
mean_center: whether to mean center atom positions. |
|
Returns: |
|
Dict with CHAIN_FEATS features extracted from PDB with specified |
|
preprocessing. |
|
""" |
|
parser = PDB.PDBParser(QUIET=True) |
|
structure = parser.get_structure(pdb_name, pdb_path) |
|
struct_chains = { |
|
chain.id: chain |
|
for chain in structure.get_chains()} |
|
|
|
def _process_chain_id(x): |
|
chain_prot = process_chain(struct_chains[x], x) |
|
chain_dict = dataclasses.asdict(chain_prot) |
|
|
|
|
|
feat_dict = {x: chain_dict[x] for x in CHAIN_FEATS} |
|
return parse_chain_feats( |
|
feat_dict, scale_factor=scale_factor) |
|
|
|
if isinstance(chain_id, str): |
|
return _process_chain_id(chain_id) |
|
elif isinstance(chain_id, list): |
|
return { |
|
x: _process_chain_id(x) for x in chain_id |
|
} |
|
elif chain_id is None: |
|
return { |
|
x: _process_chain_id(x) for x in struct_chains |
|
} |
|
else: |
|
raise ValueError(f'Unrecognized chain list {chain_id}') |
|
|
|
def rigid_transform_3D(A, B, verbose=False): |
|
|
|
|
|
assert A.shape == B.shape |
|
A = A.T |
|
B = B.T |
|
|
|
num_rows, num_cols = A.shape |
|
if num_rows != 3: |
|
raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") |
|
|
|
num_rows, num_cols = B.shape |
|
if num_rows != 3: |
|
raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") |
|
|
|
|
|
centroid_A = np.mean(A, axis=1) |
|
centroid_B = np.mean(B, axis=1) |
|
|
|
|
|
centroid_A = centroid_A.reshape(-1, 1) |
|
centroid_B = centroid_B.reshape(-1, 1) |
|
|
|
|
|
Am = A - centroid_A |
|
Bm = B - centroid_B |
|
|
|
H = Am @ np.transpose(Bm) |
|
|
|
|
|
|
|
|
|
|
|
|
|
U, S, Vt = np.linalg.svd(H) |
|
R = Vt.T @ U.T |
|
|
|
|
|
reflection_detected = False |
|
if np.linalg.det(R) < 0: |
|
if verbose: |
|
print("det(R) < R, reflection detected!, correcting for it ...") |
|
Vt[2,:] *= -1 |
|
R = Vt.T @ U.T |
|
reflection_detected = True |
|
|
|
t = -R @ centroid_A + centroid_B |
|
optimal_A = R @ A + t |
|
|
|
return optimal_A.T, R, t, reflection_detected |
|
|
|
def process_chain(chain: Chain, chain_id: str) -> Protein: |
|
"""Convert a PDB chain object into a AlphaFold Protein instance. |
|
|
|
Forked from alphafold.common.protein.from_pdb_string |
|
|
|
WARNING: All non-standard residue types will be converted into UNK. All |
|
non-standard atoms will be ignored. |
|
|
|
Took out lines 94-97 which don't allow insertions in the PDB. |
|
Sabdab uses insertions for the chothia numbering so we need to allow them. |
|
|
|
Took out lines 110-112 since that would mess up CDR numbering. |
|
|
|
Args: |
|
chain: Instance of Biopython's chain class. |
|
|
|
Returns: |
|
Protein object with protein features. |
|
""" |
|
atom_positions = [] |
|
aatype = [] |
|
atom_mask = [] |
|
residue_index = [] |
|
b_factors = [] |
|
chain_ids = [] |
|
for res in chain: |
|
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') |
|
restype_idx = residue_constants.restype_order.get( |
|
res_shortname, residue_constants.restype_num) |
|
pos = np.zeros((residue_constants.atom_type_num, 3)) |
|
mask = np.zeros((residue_constants.atom_type_num,)) |
|
res_b_factors = np.zeros((residue_constants.atom_type_num,)) |
|
for atom in res: |
|
if atom.name not in residue_constants.atom_types: |
|
continue |
|
pos[residue_constants.atom_order[atom.name]] = atom.coord |
|
mask[residue_constants.atom_order[atom.name]] = 1. |
|
res_b_factors[residue_constants.atom_order[atom.name] |
|
] = atom.bfactor |
|
aatype.append(restype_idx) |
|
atom_positions.append(pos) |
|
atom_mask.append(mask) |
|
residue_index.append(res.id[1]) |
|
b_factors.append(res_b_factors) |
|
chain_ids.append(chain_id) |
|
|
|
return Protein( |
|
atom_positions=np.array(atom_positions), |
|
atom_mask=np.array(atom_mask), |
|
aatype=np.array(aatype), |
|
residue_index=np.array(residue_index), |
|
chain_index=np.array(chain_ids), |
|
b_factors=np.array(b_factors)) |
|
|