# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import json import math import biotite.structure from biotite.structure.io import pdbx, pdb from biotite.structure.residues import get_residues from biotite.structure import filter_backbone from biotite.structure import get_chains from biotite.sequence import ProteinSequence import numpy as np from scipy.spatial import transform from scipy.stats import special_ortho_group import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data as data from typing import Sequence, Tuple, List from esm.data import BatchConverter def load_structure(fpath, chain=None): """ Args: fpath: filepath to either pdb or cif file chain: the chain id or list of chain ids to load Returns: biotite.structure.AtomArray """ if fpath.endswith('cif'): with open(fpath) as fin: pdbxf = pdbx.PDBxFile.read(fin) structure = pdbx.get_structure(pdbxf, model=1) elif fpath.endswith('pdb'): with open(fpath) as fin: pdbf = pdb.PDBFile.read(fin) structure = pdb.get_structure(pdbf, model=1) bbmask = filter_backbone(structure) structure = structure[bbmask] all_chains = get_chains(structure) if len(all_chains) == 0: raise ValueError('No chains found in the input file.') if chain is None: chain_ids = all_chains elif isinstance(chain, list): chain_ids = chain else: chain_ids = [chain] for chain in chain_ids: if chain not in all_chains: raise ValueError(f'Chain {chain} not found in input file') chain_filter = [a.chain_id in chain_ids for a in structure] structure = structure[chain_filter] return structure def extract_coords_from_structure(structure: biotite.structure.AtomArray): """ Args: structure: An instance of biotite AtomArray Returns: Tuple (coords, seq) - coords is an L x 3 x 3 array for N, CA, C coordinates - seq is the extracted sequence """ coords = get_atom_coords_residuewise(["N", "CA", "C"], structure) residue_identities = get_residues(structure)[1] seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) return coords, seq def load_coords(fpath, chain): """ Args: fpath: filepath to either pdb or cif file chain: the chain id Returns: Tuple (coords, seq) - coords is an L x 3 x 3 array for N, CA, C coordinates - seq is the extracted sequence """ structure = load_structure(fpath, chain) return extract_coords_from_structure(structure) def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray): """ Example for atoms argument: ["N", "CA", "C"] """ def filterfn(s, axis=None): filters = np.stack([s.atom_name == name for name in atoms], axis=1) sum = filters.sum(0) if not np.all(sum <= np.ones(filters.shape[1])): raise RuntimeError("structure has multiple atoms with same name") index = filters.argmax(0) coords = s[index].coord coords[sum == 0] = float("nan") return coords return biotite.structure.apply_residue_wise(struct, struct, filterfn) def get_sequence_loss(model, alphabet, coords, seq): device = next(model.parameters()).device batch_converter = CoordBatchConverter(alphabet) batch = [(coords, None, seq)] coords, confidence, strs, tokens, padding_mask = batch_converter( batch, device=device) prev_output_tokens = tokens[:, :-1].to(device) target = tokens[:, 1:] target_padding_mask = (target == alphabet.padding_idx) logits, _ = model.forward(coords, padding_mask, confidence, prev_output_tokens) loss = F.cross_entropy(logits, target, reduction='none') loss = loss[0].cpu().detach().numpy() target_padding_mask = target_padding_mask[0].cpu().numpy() return loss, target_padding_mask def score_sequence(model, alphabet, coords, seq): loss, target_padding_mask = get_sequence_loss(model, alphabet, coords, seq) ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(~target_padding_mask) # Also calculate average when excluding masked portions coord_mask = np.all(np.isfinite(coords), axis=(-1, -2)) ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask) return ll_fullseq, ll_withcoord def get_encoder_output(model, alphabet, coords): device = next(model.parameters()).device batch_converter = CoordBatchConverter(alphabet) batch = [(coords, None, None)] coords, confidence, strs, tokens, padding_mask = batch_converter( batch, device=device) encoder_out = model.encoder.forward(coords, padding_mask, confidence, return_all_hiddens=False) # remove beginning and end (bos and eos tokens) return encoder_out['encoder_out'][0][1:-1, 0] def rotate(v, R): """ Rotates a vector by a rotation matrix. Args: v: 3D vector, tensor of shape (length x batch_size x channels x 3) R: rotation matrix, tensor of shape (length x batch_size x 3 x 3) Returns: Rotated version of v by rotation matrix R. """ R = R.unsqueeze(-3) v = v.unsqueeze(-1) return torch.sum(v * R, dim=-2) def get_rotation_frames(coords): """ Returns a local rotation frame defined by N, CA, C positions. Args: coords: coordinates, tensor of shape (batch_size x length x 3 x 3) where the third dimension is in order of N, CA, C Returns: Local relative rotation frames in shape (batch_size x length x 3 x 3) """ v1 = coords[:, :, 2] - coords[:, :, 1] v2 = coords[:, :, 0] - coords[:, :, 1] e1 = normalize(v1, dim=-1) u2 = v2 - e1 * torch.sum(e1 * v2, dim=-1, keepdim=True) e2 = normalize(u2, dim=-1) e3 = torch.cross(e1, e2, dim=-1) R = torch.stack([e1, e2, e3], dim=-2) return R def nan_to_num(ts, val=0.0): """ Replaces nans in tensor with a fixed value. """ val = torch.tensor(val, dtype=ts.dtype, device=ts.device) return torch.where(~torch.isfinite(ts), val, ts) def rbf(values, v_min, v_max, n_bins=16): """ Returns RBF encodings in a new dimension at the end. """ rbf_centers = torch.linspace(v_min, v_max, n_bins, device=values.device) rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1]) rbf_std = (v_max - v_min) / n_bins v_expand = torch.unsqueeze(values, -1) z = (values.unsqueeze(-1) - rbf_centers) / rbf_std return torch.exp(-z ** 2) def norm(tensor, dim, eps=1e-8, keepdim=False): """ Returns L2 norm along a dimension. """ return torch.sqrt( torch.sum(torch.square(tensor), dim=dim, keepdim=keepdim) + eps) def normalize(tensor, dim=-1): """ Normalizes a tensor along a dimension after removing nans. """ return nan_to_num( torch.div(tensor, norm(tensor, dim=dim, keepdim=True)) ) class CoordBatchConverter(BatchConverter): def __call__(self, raw_batch: Sequence[Tuple[Sequence, str]], device=None): """ Args: raw_batch: List of tuples (coords, confidence, seq) In each tuple, coords: list of floats, shape L x 3 x 3 confidence: list of floats, shape L; or scalar float; or None seq: string of length L Returns: coords: Tensor of shape batch_size x L x 3 x 3 confidence: Tensor of shape batch_size x L strs: list of strings tokens: LongTensor of shape batch_size x L padding_mask: ByteTensor of shape batch_size x L """ self.alphabet.cls_idx = self.alphabet.get_idx("") batch = [] for coords, confidence, seq in raw_batch: if confidence is None: confidence = 1. if isinstance(confidence, float) or isinstance(confidence, int): confidence = [float(confidence)] * len(coords) if seq is None: seq = 'X' * len(coords) batch.append(((coords, confidence), seq)) coords_and_confidence, strs, tokens = super().__call__(batch) # pad beginning and end of each protein due to legacy reasons coords = [ F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.inf) for cd, _ in coords_and_confidence ] confidence = [ F.pad(torch.tensor(cf), (1, 1), value=-1.) for _, cf in coords_and_confidence ] coords = self.collate_dense_tensors(coords, pad_v=np.nan) confidence = self.collate_dense_tensors(confidence, pad_v=-1.) if device is not None: coords = coords.to(device) confidence = confidence.to(device) tokens = tokens.to(device) padding_mask = torch.isnan(coords[:,:,0,0]) coord_mask = torch.isfinite(coords.sum(-2).sum(-1)) confidence = confidence * coord_mask + (-1.) * padding_mask return coords, confidence, strs, tokens, padding_mask def from_lists(self, coords_list, confidence_list=None, seq_list=None, device=None): """ Args: coords_list: list of length batch_size, each item is a list of floats in shape L x 3 x 3 to describe a backbone confidence_list: one of - None, default to highest confidence - list of length batch_size, each item is a scalar - list of length batch_size, each item is a list of floats of length L to describe the confidence scores for the backbone with values between 0. and 1. seq_list: either None or a list of strings Returns: coords: Tensor of shape batch_size x L x 3 x 3 confidence: Tensor of shape batch_size x L strs: list of strings tokens: LongTensor of shape batch_size x L padding_mask: ByteTensor of shape batch_size x L """ batch_size = len(coords_list) if confidence_list is None: confidence_list = [None] * batch_size if seq_list is None: seq_list = [None] * batch_size raw_batch = zip(coords_list, confidence_list, seq_list) return self.__call__(raw_batch, device) @staticmethod def collate_dense_tensors(samples, pad_v): """ Takes a list of tensors with the following dimensions: [(d_11, ..., d_1K), (d_21, ..., d_2K), ..., (d_N1, ..., d_NK)] and stack + pads them into a single tensor of: (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK}) """ if len(samples) == 0: return torch.Tensor() if len(set(x.dim() for x in samples)) != 1: raise RuntimeError( f"Samples has varying dimensions: {[x.dim() for x in samples]}" ) (device,) = tuple(set(x.device for x in samples)) # assumes all on same device max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])] result = torch.empty( len(samples), *max_shape, dtype=samples[0].dtype, device=device ) result.fill_(pad_v) for i in range(len(samples)): result_i = result[i] t = samples[i] result_i[tuple(slice(0, k) for k in t.shape)] = t return result