|
|
|
|
|
|
|
|
|
|
|
import json |
|
import math |
|
|
|
import biotite.structure |
|
from biotite.structure.io import pdbx, pdb |
|
from biotite.structure.residues import get_residues |
|
from biotite.structure import filter_backbone |
|
from biotite.structure import get_chains |
|
from biotite.sequence import ProteinSequence |
|
import numpy as np |
|
from scipy.spatial import transform |
|
from scipy.stats import special_ortho_group |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.utils.data as data |
|
from typing import Sequence, Tuple, List |
|
|
|
from esm.data import BatchConverter |
|
|
|
|
|
def load_structure(fpath, chain=None): |
|
""" |
|
Args: |
|
fpath: filepath to either pdb or cif file |
|
chain: the chain id or list of chain ids to load |
|
Returns: |
|
biotite.structure.AtomArray |
|
""" |
|
if fpath.endswith('cif'): |
|
with open(fpath) as fin: |
|
pdbxf = pdbx.PDBxFile.read(fin) |
|
structure = pdbx.get_structure(pdbxf, model=1) |
|
elif fpath.endswith('pdb'): |
|
with open(fpath) as fin: |
|
pdbf = pdb.PDBFile.read(fin) |
|
structure = pdb.get_structure(pdbf, model=1) |
|
bbmask = filter_backbone(structure) |
|
structure = structure[bbmask] |
|
all_chains = get_chains(structure) |
|
if len(all_chains) == 0: |
|
raise ValueError('No chains found in the input file.') |
|
if chain is None: |
|
chain_ids = all_chains |
|
elif isinstance(chain, list): |
|
chain_ids = chain |
|
else: |
|
chain_ids = [chain] |
|
for chain in chain_ids: |
|
if chain not in all_chains: |
|
raise ValueError(f'Chain {chain} not found in input file') |
|
chain_filter = [a.chain_id in chain_ids for a in structure] |
|
structure = structure[chain_filter] |
|
return structure |
|
|
|
|
|
def extract_coords_from_structure(structure: biotite.structure.AtomArray): |
|
""" |
|
Args: |
|
structure: An instance of biotite AtomArray |
|
Returns: |
|
Tuple (coords, seq) |
|
- coords is an L x 3 x 3 array for N, CA, C coordinates |
|
- seq is the extracted sequence |
|
""" |
|
coords = get_atom_coords_residuewise(["N", "CA", "C"], structure) |
|
residue_identities = get_residues(structure)[1] |
|
seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) |
|
return coords, seq |
|
|
|
|
|
def load_coords(fpath, chain): |
|
""" |
|
Args: |
|
fpath: filepath to either pdb or cif file |
|
chain: the chain id |
|
Returns: |
|
Tuple (coords, seq) |
|
- coords is an L x 3 x 3 array for N, CA, C coordinates |
|
- seq is the extracted sequence |
|
""" |
|
structure = load_structure(fpath, chain) |
|
return extract_coords_from_structure(structure) |
|
|
|
|
|
def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray): |
|
""" |
|
Example for atoms argument: ["N", "CA", "C"] |
|
""" |
|
def filterfn(s, axis=None): |
|
filters = np.stack([s.atom_name == name for name in atoms], axis=1) |
|
sum = filters.sum(0) |
|
if not np.all(sum <= np.ones(filters.shape[1])): |
|
raise RuntimeError("structure has multiple atoms with same name") |
|
index = filters.argmax(0) |
|
coords = s[index].coord |
|
coords[sum == 0] = float("nan") |
|
return coords |
|
|
|
return biotite.structure.apply_residue_wise(struct, struct, filterfn) |
|
|
|
|
|
def get_sequence_loss(model, alphabet, coords, seq): |
|
device = next(model.parameters()).device |
|
batch_converter = CoordBatchConverter(alphabet) |
|
batch = [(coords, None, seq)] |
|
coords, confidence, strs, tokens, padding_mask = batch_converter( |
|
batch, device=device) |
|
|
|
prev_output_tokens = tokens[:, :-1].to(device) |
|
target = tokens[:, 1:] |
|
target_padding_mask = (target == alphabet.padding_idx) |
|
logits, _ = model.forward(coords, padding_mask, confidence, prev_output_tokens) |
|
loss = F.cross_entropy(logits, target, reduction='none') |
|
loss = loss[0].cpu().detach().numpy() |
|
target_padding_mask = target_padding_mask[0].cpu().numpy() |
|
return loss, target_padding_mask |
|
|
|
|
|
def score_sequence(model, alphabet, coords, seq): |
|
loss, target_padding_mask = get_sequence_loss(model, alphabet, coords, seq) |
|
ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(~target_padding_mask) |
|
|
|
coord_mask = np.all(np.isfinite(coords), axis=(-1, -2)) |
|
ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask) |
|
return ll_fullseq, ll_withcoord |
|
|
|
|
|
def get_encoder_output(model, alphabet, coords): |
|
device = next(model.parameters()).device |
|
batch_converter = CoordBatchConverter(alphabet) |
|
batch = [(coords, None, None)] |
|
coords, confidence, strs, tokens, padding_mask = batch_converter( |
|
batch, device=device) |
|
encoder_out = model.encoder.forward(coords, padding_mask, confidence, |
|
return_all_hiddens=False) |
|
|
|
return encoder_out['encoder_out'][0][1:-1, 0] |
|
|
|
|
|
def rotate(v, R): |
|
""" |
|
Rotates a vector by a rotation matrix. |
|
|
|
Args: |
|
v: 3D vector, tensor of shape (length x batch_size x channels x 3) |
|
R: rotation matrix, tensor of shape (length x batch_size x 3 x 3) |
|
|
|
Returns: |
|
Rotated version of v by rotation matrix R. |
|
""" |
|
R = R.unsqueeze(-3) |
|
v = v.unsqueeze(-1) |
|
return torch.sum(v * R, dim=-2) |
|
|
|
|
|
def get_rotation_frames(coords): |
|
""" |
|
Returns a local rotation frame defined by N, CA, C positions. |
|
|
|
Args: |
|
coords: coordinates, tensor of shape (batch_size x length x 3 x 3) |
|
where the third dimension is in order of N, CA, C |
|
|
|
Returns: |
|
Local relative rotation frames in shape (batch_size x length x 3 x 3) |
|
""" |
|
v1 = coords[:, :, 2] - coords[:, :, 1] |
|
v2 = coords[:, :, 0] - coords[:, :, 1] |
|
e1 = normalize(v1, dim=-1) |
|
u2 = v2 - e1 * torch.sum(e1 * v2, dim=-1, keepdim=True) |
|
e2 = normalize(u2, dim=-1) |
|
e3 = torch.cross(e1, e2, dim=-1) |
|
R = torch.stack([e1, e2, e3], dim=-2) |
|
return R |
|
|
|
|
|
def nan_to_num(ts, val=0.0): |
|
""" |
|
Replaces nans in tensor with a fixed value. |
|
""" |
|
val = torch.tensor(val, dtype=ts.dtype, device=ts.device) |
|
return torch.where(~torch.isfinite(ts), val, ts) |
|
|
|
|
|
def rbf(values, v_min, v_max, n_bins=16): |
|
""" |
|
Returns RBF encodings in a new dimension at the end. |
|
""" |
|
rbf_centers = torch.linspace(v_min, v_max, n_bins, device=values.device) |
|
rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1]) |
|
rbf_std = (v_max - v_min) / n_bins |
|
v_expand = torch.unsqueeze(values, -1) |
|
z = (values.unsqueeze(-1) - rbf_centers) / rbf_std |
|
return torch.exp(-z ** 2) |
|
|
|
|
|
def norm(tensor, dim, eps=1e-8, keepdim=False): |
|
""" |
|
Returns L2 norm along a dimension. |
|
""" |
|
return torch.sqrt( |
|
torch.sum(torch.square(tensor), dim=dim, keepdim=keepdim) + eps) |
|
|
|
|
|
def normalize(tensor, dim=-1): |
|
""" |
|
Normalizes a tensor along a dimension after removing nans. |
|
""" |
|
return nan_to_num( |
|
torch.div(tensor, norm(tensor, dim=dim, keepdim=True)) |
|
) |
|
|
|
|
|
class CoordBatchConverter(BatchConverter): |
|
def __call__(self, raw_batch: Sequence[Tuple[Sequence, str]], device=None): |
|
""" |
|
Args: |
|
raw_batch: List of tuples (coords, confidence, seq) |
|
In each tuple, |
|
coords: list of floats, shape L x 3 x 3 |
|
confidence: list of floats, shape L; or scalar float; or None |
|
seq: string of length L |
|
Returns: |
|
coords: Tensor of shape batch_size x L x 3 x 3 |
|
confidence: Tensor of shape batch_size x L |
|
strs: list of strings |
|
tokens: LongTensor of shape batch_size x L |
|
padding_mask: ByteTensor of shape batch_size x L |
|
""" |
|
self.alphabet.cls_idx = self.alphabet.get_idx("<cath>") |
|
batch = [] |
|
for coords, confidence, seq in raw_batch: |
|
if confidence is None: |
|
confidence = 1. |
|
if isinstance(confidence, float) or isinstance(confidence, int): |
|
confidence = [float(confidence)] * len(coords) |
|
if seq is None: |
|
seq = 'X' * len(coords) |
|
batch.append(((coords, confidence), seq)) |
|
|
|
coords_and_confidence, strs, tokens = super().__call__(batch) |
|
|
|
|
|
coords = [ |
|
F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.inf) |
|
for cd, _ in coords_and_confidence |
|
] |
|
confidence = [ |
|
F.pad(torch.tensor(cf), (1, 1), value=-1.) |
|
for _, cf in coords_and_confidence |
|
] |
|
coords = self.collate_dense_tensors(coords, pad_v=np.nan) |
|
confidence = self.collate_dense_tensors(confidence, pad_v=-1.) |
|
if device is not None: |
|
coords = coords.to(device) |
|
confidence = confidence.to(device) |
|
tokens = tokens.to(device) |
|
padding_mask = torch.isnan(coords[:,:,0,0]) |
|
coord_mask = torch.isfinite(coords.sum(-2).sum(-1)) |
|
confidence = confidence * coord_mask + (-1.) * padding_mask |
|
return coords, confidence, strs, tokens, padding_mask |
|
|
|
def from_lists(self, coords_list, confidence_list=None, seq_list=None, device=None): |
|
""" |
|
Args: |
|
coords_list: list of length batch_size, each item is a list of |
|
floats in shape L x 3 x 3 to describe a backbone |
|
confidence_list: one of |
|
- None, default to highest confidence |
|
- list of length batch_size, each item is a scalar |
|
- list of length batch_size, each item is a list of floats of |
|
length L to describe the confidence scores for the backbone |
|
with values between 0. and 1. |
|
seq_list: either None or a list of strings |
|
Returns: |
|
coords: Tensor of shape batch_size x L x 3 x 3 |
|
confidence: Tensor of shape batch_size x L |
|
strs: list of strings |
|
tokens: LongTensor of shape batch_size x L |
|
padding_mask: ByteTensor of shape batch_size x L |
|
""" |
|
batch_size = len(coords_list) |
|
if confidence_list is None: |
|
confidence_list = [None] * batch_size |
|
if seq_list is None: |
|
seq_list = [None] * batch_size |
|
raw_batch = zip(coords_list, confidence_list, seq_list) |
|
return self.__call__(raw_batch, device) |
|
|
|
@staticmethod |
|
def collate_dense_tensors(samples, pad_v): |
|
""" |
|
Takes a list of tensors with the following dimensions: |
|
[(d_11, ..., d_1K), |
|
(d_21, ..., d_2K), |
|
..., |
|
(d_N1, ..., d_NK)] |
|
and stack + pads them into a single tensor of: |
|
(N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK}) |
|
""" |
|
if len(samples) == 0: |
|
return torch.Tensor() |
|
if len(set(x.dim() for x in samples)) != 1: |
|
raise RuntimeError( |
|
f"Samples has varying dimensions: {[x.dim() for x in samples]}" |
|
) |
|
(device,) = tuple(set(x.device for x in samples)) |
|
max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])] |
|
result = torch.empty( |
|
len(samples), *max_shape, dtype=samples[0].dtype, device=device |
|
) |
|
result.fill_(pad_v) |
|
for i in range(len(samples)): |
|
result_i = result[i] |
|
t = samples[i] |
|
result_i[tuple(slice(0, k) for k in t.shape)] = t |
|
return result |
|
|