Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# | |
# Portions of this file were adapted from the open source code for the following | |
# two papers: | |
# | |
# Ingraham, J., Garg, V., Barzilay, R., & Jaakkola, T. (2019). Generative | |
# models for graph-based protein design. Advances in Neural Information | |
# Processing Systems, 32. | |
# | |
# Jing, B., Eismann, S., Suriana, P., Townshend, R. J. L., & Dror, R. (2020). | |
# Learning from Protein Structure with Geometric Vector Perceptrons. In | |
# International Conference on Learning Representations. | |
# | |
# MIT License | |
# | |
# Copyright (c) 2020 Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael Townshend, Ron Dror | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# | |
# ================================================================ | |
# The below license applies to the portions of the code (parts of | |
# src/datasets.py and src/models.py) adapted from Ingraham, et al. | |
# ================================================================ | |
# | |
# MIT License | |
# | |
# Copyright (c) 2019 John Ingraham, Vikas Garg, Regina Barzilay, Tommi Jaakkola | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
import math | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
print("features1") | |
from .gvp_utils import flatten_graph | |
print("features2") | |
from .gvp_modules import GVP, LayerNorm | |
print("features3") | |
from .util import normalize, norm, nan_to_num, rbf | |
print("features4") | |
class GVPInputFeaturizer(nn.Module): | |
def get_node_features(coords, coord_mask, with_coord_mask=True): | |
# scalar features | |
node_scalar_features = GVPInputFeaturizer._dihedrals(coords) | |
if with_coord_mask: | |
node_scalar_features = torch.cat([ | |
node_scalar_features, | |
coord_mask.float().unsqueeze(-1) | |
], dim=-1) | |
# vector features | |
X_ca = coords[:, :, 1] | |
orientations = GVPInputFeaturizer._orientations(X_ca) | |
sidechains = GVPInputFeaturizer._sidechains(coords) | |
node_vector_features = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2) | |
return node_scalar_features, node_vector_features | |
def _orientations(X): | |
forward = normalize(X[:, 1:] - X[:, :-1]) | |
backward = normalize(X[:, :-1] - X[:, 1:]) | |
forward = F.pad(forward, [0, 0, 0, 1]) | |
backward = F.pad(backward, [0, 0, 1, 0]) | |
return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2) | |
def _sidechains(X): | |
n, origin, c = X[:, :, 0], X[:, :, 1], X[:, :, 2] | |
c, n = normalize(c - origin), normalize(n - origin) | |
bisector = normalize(c + n) | |
perp = normalize(torch.cross(c, n, dim=-1)) | |
vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) | |
return vec | |
def _dihedrals(X, eps=1e-7): | |
X = torch.flatten(X[:, :, :3], 1, 2) | |
bsz = X.shape[0] | |
dX = X[:, 1:] - X[:, :-1] | |
U = normalize(dX, dim=-1) | |
u_2 = U[:, :-2] | |
u_1 = U[:, 1:-1] | |
u_0 = U[:, 2:] | |
# Backbone normals | |
n_2 = normalize(torch.cross(u_2, u_1, dim=-1), dim=-1) | |
n_1 = normalize(torch.cross(u_1, u_0, dim=-1), dim=-1) | |
# Angle between normals | |
cosD = torch.sum(n_2 * n_1, -1) | |
cosD = torch.clamp(cosD, -1 + eps, 1 - eps) | |
D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD) | |
# This scheme will remove phi[0], psi[-1], omega[-1] | |
D = F.pad(D, [1, 2]) | |
D = torch.reshape(D, [bsz, -1, 3]) | |
# Lift angle representations to the circle | |
D_features = torch.cat([torch.cos(D), torch.sin(D)], -1) | |
return D_features | |
def _positional_embeddings(edge_index, | |
num_embeddings=None, | |
num_positional_embeddings=16, | |
period_range=[2, 1000]): | |
# From https://github.com/jingraham/neurips19-graph-protein-design | |
num_embeddings = num_embeddings or num_positional_embeddings | |
d = edge_index[0] - edge_index[1] | |
frequency = torch.exp( | |
torch.arange(0, num_embeddings, 2, dtype=torch.float32, | |
device=edge_index.device) | |
* -(np.log(10000.0) / num_embeddings) | |
) | |
angles = d.unsqueeze(-1) * frequency | |
E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) | |
return E | |
def _dist(X, coord_mask, padding_mask, top_k_neighbors, eps=1e-8): | |
""" Pairwise euclidean distances """ | |
bsz, maxlen = X.size(0), X.size(1) | |
coord_mask_2D = torch.unsqueeze(coord_mask,1) * torch.unsqueeze(coord_mask,2) | |
residue_mask = ~padding_mask | |
residue_mask_2D = torch.unsqueeze(residue_mask,1) * torch.unsqueeze(residue_mask,2) | |
dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2) | |
D = coord_mask_2D * norm(dX, dim=-1) | |
# sorting preference: first those with coords, then among the residues that | |
# exist but are masked use distance in sequence as tie breaker, and then the | |
# residues that came from padding are last | |
seqpos = torch.arange(maxlen, device=X.device) | |
Dseq = torch.abs(seqpos.unsqueeze(1) - seqpos.unsqueeze(0)).repeat(bsz, 1, 1) | |
D_adjust = nan_to_num(D) + (~coord_mask_2D) * (1e8 + Dseq*1e6) + ( | |
~residue_mask_2D) * (1e10) | |
if top_k_neighbors == -1: | |
D_neighbors = D_adjust | |
E_idx = seqpos.repeat( | |
*D_neighbors.shape[:-1], 1) | |
else: | |
# Identify k nearest neighbors (including self) | |
k = min(top_k_neighbors, X.size(1)) | |
D_neighbors, E_idx = torch.topk(D_adjust, k, dim=-1, largest=False) | |
coord_mask_neighbors = (D_neighbors < 5e7) | |
residue_mask_neighbors = (D_neighbors < 5e9) | |
return D_neighbors, E_idx, coord_mask_neighbors, residue_mask_neighbors | |
class Normalize(nn.Module): | |
def __init__(self, features, epsilon=1e-6): | |
super(Normalize, self).__init__() | |
self.gain = nn.Parameter(torch.ones(features)) | |
self.bias = nn.Parameter(torch.zeros(features)) | |
self.epsilon = epsilon | |
def forward(self, x, dim=-1): | |
mu = x.mean(dim, keepdim=True) | |
sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon) | |
gain = self.gain | |
bias = self.bias | |
# Reshape | |
if dim != -1: | |
shape = [1] * len(mu.size()) | |
shape[dim] = self.gain.size()[0] | |
gain = gain.view(shape) | |
bias = bias.view(shape) | |
return gain * (x - mu) / (sigma + self.epsilon) + bias | |
class DihedralFeatures(nn.Module): | |
def __init__(self, node_embed_dim): | |
""" Embed dihedral angle features. """ | |
super(DihedralFeatures, self).__init__() | |
# 3 dihedral angles; sin and cos of each angle | |
node_in = 6 | |
# Normalization and embedding | |
self.node_embedding = nn.Linear(node_in, node_embed_dim, bias=True) | |
self.norm_nodes = Normalize(node_embed_dim) | |
def forward(self, X): | |
""" Featurize coordinates as an attributed graph """ | |
V = self._dihedrals(X) | |
V = self.node_embedding(V) | |
V = self.norm_nodes(V) | |
return V | |
def _dihedrals(X, eps=1e-7, return_angles=False): | |
# First 3 coordinates are N, CA, C | |
X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3) | |
# Shifted slices of unit vectors | |
dX = X[:,1:,:] - X[:,:-1,:] | |
U = F.normalize(dX, dim=-1) | |
u_2 = U[:,:-2,:] | |
u_1 = U[:,1:-1,:] | |
u_0 = U[:,2:,:] | |
# Backbone normals | |
n_2 = F.normalize(torch.cross(u_2, u_1, dim=-1), dim=-1) | |
n_1 = F.normalize(torch.cross(u_1, u_0, dim=-1), dim=-1) | |
# Angle between normals | |
cosD = (n_2 * n_1).sum(-1) | |
cosD = torch.clamp(cosD, -1+eps, 1-eps) | |
D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD) | |
# This scheme will remove phi[0], psi[-1], omega[-1] | |
D = F.pad(D, (1,2), 'constant', 0) | |
D = D.view((D.size(0), int(D.size(1)/3), 3)) | |
phi, psi, omega = torch.unbind(D,-1) | |
if return_angles: | |
return phi, psi, omega | |
# Lift angle representations to the circle | |
D_features = torch.cat((torch.cos(D), torch.sin(D)), 2) | |
return D_features | |
class GVPGraphEmbedding(GVPInputFeaturizer): | |
def __init__(self, args): | |
super().__init__() | |
self.top_k_neighbors = args.top_k_neighbors | |
self.num_positional_embeddings = 16 | |
self.remove_edges_without_coords = True | |
node_input_dim = (7, 3) | |
edge_input_dim = (34, 1) | |
node_hidden_dim = (args.node_hidden_dim_scalar, | |
args.node_hidden_dim_vector) | |
edge_hidden_dim = (args.edge_hidden_dim_scalar, | |
args.edge_hidden_dim_vector) | |
self.embed_node = nn.Sequential( | |
GVP(node_input_dim, node_hidden_dim, activations=(None, None)), | |
LayerNorm(node_hidden_dim, eps=1e-4) | |
) | |
self.embed_edge = nn.Sequential( | |
GVP(edge_input_dim, edge_hidden_dim, activations=(None, None)), | |
LayerNorm(edge_hidden_dim, eps=1e-4) | |
) | |
self.embed_confidence = nn.Linear(16, args.node_hidden_dim_scalar) | |
def forward(self, coords, coord_mask, padding_mask, confidence): | |
with torch.no_grad(): | |
node_features = self.get_node_features(coords, coord_mask) | |
edge_features, edge_index = self.get_edge_features( | |
coords, coord_mask, padding_mask) | |
node_embeddings_scalar, node_embeddings_vector = self.embed_node(node_features) | |
edge_embeddings = self.embed_edge(edge_features) | |
rbf_rep = rbf(confidence, 0., 1.) | |
node_embeddings = ( | |
node_embeddings_scalar + self.embed_confidence(rbf_rep), | |
node_embeddings_vector | |
) | |
node_embeddings, edge_embeddings, edge_index = flatten_graph( | |
node_embeddings, edge_embeddings, edge_index) | |
return node_embeddings, edge_embeddings, edge_index | |
def get_edge_features(self, coords, coord_mask, padding_mask): | |
X_ca = coords[:, :, 1] | |
# Get distances to the top k neighbors | |
E_dist, E_idx, E_coord_mask, E_residue_mask = GVPInputFeaturizer._dist( | |
X_ca, coord_mask, padding_mask, self.top_k_neighbors) | |
# Flatten the graph to be batch size 1 for torch_geometric package | |
dest = E_idx | |
B, L, k = E_idx.shape[:3] | |
src = torch.arange(L, device=E_idx.device).view([1, L, 1]).expand(B, L, k) | |
# After flattening, [2, B, E] | |
edge_index = torch.stack([src, dest], dim=0).flatten(2, 3) | |
# After flattening, [B, E] | |
E_dist = E_dist.flatten(1, 2) | |
E_coord_mask = E_coord_mask.flatten(1, 2).unsqueeze(-1) | |
E_residue_mask = E_residue_mask.flatten(1, 2) | |
# Calculate relative positional embeddings and distance RBF | |
pos_embeddings = GVPInputFeaturizer._positional_embeddings( | |
edge_index, | |
num_positional_embeddings=self.num_positional_embeddings, | |
) | |
D_rbf = rbf(E_dist, 0., 20.) | |
# Calculate relative orientation | |
X_src = X_ca.unsqueeze(2).expand(-1, -1, k, -1).flatten(1, 2) | |
X_dest = torch.gather( | |
X_ca, | |
1, | |
edge_index[1, :, :].unsqueeze(-1).expand([B, L*k, 3]) | |
) | |
coord_mask_src = coord_mask.unsqueeze(2).expand(-1, -1, k).flatten(1, 2) | |
coord_mask_dest = torch.gather( | |
coord_mask, | |
1, | |
edge_index[1, :, :].expand([B, L*k]) | |
) | |
E_vectors = X_src - X_dest | |
# For the ones without coordinates, substitute in the average vector | |
E_vector_mean = torch.sum(E_vectors * E_coord_mask, dim=1, | |
keepdims=True) / torch.sum(E_coord_mask, dim=1, keepdims=True) | |
E_vectors = E_vectors * E_coord_mask + E_vector_mean * ~(E_coord_mask) | |
# Normalize and remove nans | |
edge_s = torch.cat([D_rbf, pos_embeddings], dim=-1) | |
edge_v = normalize(E_vectors).unsqueeze(-2) | |
edge_s, edge_v = map(nan_to_num, (edge_s, edge_v)) | |
# Also add indications of whether the coordinates are present | |
edge_s = torch.cat([ | |
edge_s, | |
(~coord_mask_src).float().unsqueeze(-1), | |
(~coord_mask_dest).float().unsqueeze(-1), | |
], dim=-1) | |
edge_index[:, ~E_residue_mask] = -1 | |
if self.remove_edges_without_coords: | |
edge_index[:, ~E_coord_mask.squeeze(-1)] = -1 | |
return (edge_s, edge_v), edge_index.transpose(0, 1) | |