File size: 11,754 Bytes
3f0529e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import math
import biotite.structure
from biotite.structure.io import pdbx, pdb
from biotite.structure.residues import get_residues
from biotite.structure import filter_backbone
from biotite.structure import get_chains
from biotite.sequence import ProteinSequence
import numpy as np
from scipy.spatial import transform
from scipy.stats import special_ortho_group
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from typing import Sequence, Tuple, List
from esm.data import BatchConverter
def load_structure(fpath, chain=None):
"""
Args:
fpath: filepath to either pdb or cif file
chain: the chain id or list of chain ids to load
Returns:
biotite.structure.AtomArray
"""
if fpath.endswith('cif'):
with open(fpath) as fin:
pdbxf = pdbx.PDBxFile.read(fin)
structure = pdbx.get_structure(pdbxf, model=1)
elif fpath.endswith('pdb'):
with open(fpath) as fin:
pdbf = pdb.PDBFile.read(fin)
structure = pdb.get_structure(pdbf, model=1)
bbmask = filter_backbone(structure)
structure = structure[bbmask]
all_chains = get_chains(structure)
if len(all_chains) == 0:
raise ValueError('No chains found in the input file.')
if chain is None:
chain_ids = all_chains
elif isinstance(chain, list):
chain_ids = chain
else:
chain_ids = [chain]
for chain in chain_ids:
if chain not in all_chains:
raise ValueError(f'Chain {chain} not found in input file')
chain_filter = [a.chain_id in chain_ids for a in structure]
structure = structure[chain_filter]
return structure
def extract_coords_from_structure(structure: biotite.structure.AtomArray):
"""
Args:
structure: An instance of biotite AtomArray
Returns:
Tuple (coords, seq)
- coords is an L x 3 x 3 array for N, CA, C coordinates
- seq is the extracted sequence
"""
coords = get_atom_coords_residuewise(["N", "CA", "C"], structure)
residue_identities = get_residues(structure)[1]
seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
return coords, seq
def load_coords(fpath, chain):
"""
Args:
fpath: filepath to either pdb or cif file
chain: the chain id
Returns:
Tuple (coords, seq)
- coords is an L x 3 x 3 array for N, CA, C coordinates
- seq is the extracted sequence
"""
structure = load_structure(fpath, chain)
return extract_coords_from_structure(structure)
def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray):
"""
Example for atoms argument: ["N", "CA", "C"]
"""
def filterfn(s, axis=None):
filters = np.stack([s.atom_name == name for name in atoms], axis=1)
sum = filters.sum(0)
if not np.all(sum <= np.ones(filters.shape[1])):
raise RuntimeError("structure has multiple atoms with same name")
index = filters.argmax(0)
coords = s[index].coord
coords[sum == 0] = float("nan")
return coords
return biotite.structure.apply_residue_wise(struct, struct, filterfn)
def get_sequence_loss(model, alphabet, coords, seq):
device = next(model.parameters()).device
batch_converter = CoordBatchConverter(alphabet)
batch = [(coords, None, seq)]
coords, confidence, strs, tokens, padding_mask = batch_converter(
batch, device=device)
prev_output_tokens = tokens[:, :-1].to(device)
target = tokens[:, 1:]
target_padding_mask = (target == alphabet.padding_idx)
logits, _ = model.forward(coords, padding_mask, confidence, prev_output_tokens)
loss = F.cross_entropy(logits, target, reduction='none')
loss = loss[0].cpu().detach().numpy()
target_padding_mask = target_padding_mask[0].cpu().numpy()
return loss, target_padding_mask
def score_sequence(model, alphabet, coords, seq):
loss, target_padding_mask = get_sequence_loss(model, alphabet, coords, seq)
ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(~target_padding_mask)
# Also calculate average when excluding masked portions
coord_mask = np.all(np.isfinite(coords), axis=(-1, -2))
ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask)
return ll_fullseq, ll_withcoord
def get_encoder_output(model, alphabet, coords):
device = next(model.parameters()).device
batch_converter = CoordBatchConverter(alphabet)
batch = [(coords, None, None)]
coords, confidence, strs, tokens, padding_mask = batch_converter(
batch, device=device)
encoder_out = model.encoder.forward(coords, padding_mask, confidence,
return_all_hiddens=False)
# remove beginning and end (bos and eos tokens)
return encoder_out['encoder_out'][0][1:-1, 0]
def rotate(v, R):
"""
Rotates a vector by a rotation matrix.
Args:
v: 3D vector, tensor of shape (length x batch_size x channels x 3)
R: rotation matrix, tensor of shape (length x batch_size x 3 x 3)
Returns:
Rotated version of v by rotation matrix R.
"""
R = R.unsqueeze(-3)
v = v.unsqueeze(-1)
return torch.sum(v * R, dim=-2)
def get_rotation_frames(coords):
"""
Returns a local rotation frame defined by N, CA, C positions.
Args:
coords: coordinates, tensor of shape (batch_size x length x 3 x 3)
where the third dimension is in order of N, CA, C
Returns:
Local relative rotation frames in shape (batch_size x length x 3 x 3)
"""
v1 = coords[:, :, 2] - coords[:, :, 1]
v2 = coords[:, :, 0] - coords[:, :, 1]
e1 = normalize(v1, dim=-1)
u2 = v2 - e1 * torch.sum(e1 * v2, dim=-1, keepdim=True)
e2 = normalize(u2, dim=-1)
e3 = torch.cross(e1, e2, dim=-1)
R = torch.stack([e1, e2, e3], dim=-2)
return R
def nan_to_num(ts, val=0.0):
"""
Replaces nans in tensor with a fixed value.
"""
val = torch.tensor(val, dtype=ts.dtype, device=ts.device)
return torch.where(~torch.isfinite(ts), val, ts)
def rbf(values, v_min, v_max, n_bins=16):
"""
Returns RBF encodings in a new dimension at the end.
"""
rbf_centers = torch.linspace(v_min, v_max, n_bins, device=values.device)
rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1])
rbf_std = (v_max - v_min) / n_bins
v_expand = torch.unsqueeze(values, -1)
z = (values.unsqueeze(-1) - rbf_centers) / rbf_std
return torch.exp(-z ** 2)
def norm(tensor, dim, eps=1e-8, keepdim=False):
"""
Returns L2 norm along a dimension.
"""
return torch.sqrt(
torch.sum(torch.square(tensor), dim=dim, keepdim=keepdim) + eps)
def normalize(tensor, dim=-1):
"""
Normalizes a tensor along a dimension after removing nans.
"""
return nan_to_num(
torch.div(tensor, norm(tensor, dim=dim, keepdim=True))
)
class CoordBatchConverter(BatchConverter):
def __call__(self, raw_batch: Sequence[Tuple[Sequence, str]], device=None):
"""
Args:
raw_batch: List of tuples (coords, confidence, seq)
In each tuple,
coords: list of floats, shape L x 3 x 3
confidence: list of floats, shape L; or scalar float; or None
seq: string of length L
Returns:
coords: Tensor of shape batch_size x L x 3 x 3
confidence: Tensor of shape batch_size x L
strs: list of strings
tokens: LongTensor of shape batch_size x L
padding_mask: ByteTensor of shape batch_size x L
"""
self.alphabet.cls_idx = self.alphabet.get_idx("<cath>")
batch = []
for coords, confidence, seq in raw_batch:
if confidence is None:
confidence = 1.
if isinstance(confidence, float) or isinstance(confidence, int):
confidence = [float(confidence)] * len(coords)
if seq is None:
seq = 'X' * len(coords)
batch.append(((coords, confidence), seq))
coords_and_confidence, strs, tokens = super().__call__(batch)
# pad beginning and end of each protein due to legacy reasons
coords = [
F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.inf)
for cd, _ in coords_and_confidence
]
confidence = [
F.pad(torch.tensor(cf), (1, 1), value=-1.)
for _, cf in coords_and_confidence
]
coords = self.collate_dense_tensors(coords, pad_v=np.nan)
confidence = self.collate_dense_tensors(confidence, pad_v=-1.)
if device is not None:
coords = coords.to(device)
confidence = confidence.to(device)
tokens = tokens.to(device)
padding_mask = torch.isnan(coords[:,:,0,0])
coord_mask = torch.isfinite(coords.sum(-2).sum(-1))
confidence = confidence * coord_mask + (-1.) * padding_mask
return coords, confidence, strs, tokens, padding_mask
def from_lists(self, coords_list, confidence_list=None, seq_list=None, device=None):
"""
Args:
coords_list: list of length batch_size, each item is a list of
floats in shape L x 3 x 3 to describe a backbone
confidence_list: one of
- None, default to highest confidence
- list of length batch_size, each item is a scalar
- list of length batch_size, each item is a list of floats of
length L to describe the confidence scores for the backbone
with values between 0. and 1.
seq_list: either None or a list of strings
Returns:
coords: Tensor of shape batch_size x L x 3 x 3
confidence: Tensor of shape batch_size x L
strs: list of strings
tokens: LongTensor of shape batch_size x L
padding_mask: ByteTensor of shape batch_size x L
"""
batch_size = len(coords_list)
if confidence_list is None:
confidence_list = [None] * batch_size
if seq_list is None:
seq_list = [None] * batch_size
raw_batch = zip(coords_list, confidence_list, seq_list)
return self.__call__(raw_batch, device)
@staticmethod
def collate_dense_tensors(samples, pad_v):
"""
Takes a list of tensors with the following dimensions:
[(d_11, ..., d_1K),
(d_21, ..., d_2K),
...,
(d_N1, ..., d_NK)]
and stack + pads them into a single tensor of:
(N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
"""
if len(samples) == 0:
return torch.Tensor()
if len(set(x.dim() for x in samples)) != 1:
raise RuntimeError(
f"Samples has varying dimensions: {[x.dim() for x in samples]}"
)
(device,) = tuple(set(x.device for x in samples)) # assumes all on same device
max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
result = torch.empty(
len(samples), *max_shape, dtype=samples[0].dtype, device=device
)
result.fill_(pad_v)
for i in range(len(samples)):
result_i = result[i]
t = samples[i]
result_i[tuple(slice(0, k) for k in t.shape)] = t
return result
|