FAPM_demo / esm /inverse_folding /gvp_modules.py
wenkai's picture
Upload 31 files
3f0529e verified
raw
history blame
18.6 kB
# Contents of this file are from the open source code for
#
# Jing, B., Eismann, S., Suriana, P., Townshend, R. J. L., & Dror, R. (2020).
# Learning from Protein Structure with Geometric Vector Perceptrons. In
# International Conference on Learning Representations.
#
# MIT License
#
# Copyright (c) 2020 Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael Townshend, Ron Dror
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import typing as T
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
def tuple_size(tp):
return tuple([0 if a is None else a.size() for a in tp])
def tuple_sum(tp1, tp2):
s1, v1 = tp1
s2, v2 = tp2
if v2 is None and v2 is None:
return (s1 + s2, None)
return (s1 + s2, v1 + v2)
def tuple_cat(*args, dim=-1):
'''
Concatenates any number of tuples (s, V) elementwise.
:param dim: dimension along which to concatenate when viewed
as the `dim` index for the scalar-channel tensors.
This means that `dim=-1` will be applied as
`dim=-2` for the vector-channel tensors.
'''
dim %= len(args[0][0].shape)
s_args, v_args = list(zip(*args))
return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)
def tuple_index(x, idx):
'''
Indexes into a tuple (s, V) along the first dimension.
:param idx: any object which can be used to index into a `torch.Tensor`
'''
return x[0][idx], x[1][idx]
def randn(n, dims, device="cpu"):
'''
Returns random tuples (s, V) drawn elementwise from a normal distribution.
:param n: number of data points
:param dims: tuple of dimensions (n_scalar, n_vector)
:return: (s, V) with s.shape = (n, n_scalar) and
V.shape = (n, n_vector, 3)
'''
return torch.randn(n, dims[0], device=device), \
torch.randn(n, dims[1], 3, device=device)
def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
'''
L2 norm of tensor clamped above a minimum value `eps`.
:param sqrt: if `False`, returns the square of the L2 norm
'''
# clamp is slow
# out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
out = torch.sum(torch.square(x), axis, keepdims) + eps
return torch.sqrt(out) if sqrt else out
def _split(x, nv):
'''
Splits a merged representation of (s, V) back into a tuple.
Should be used only with `_merge(s, V)` and only if the tuple
representation cannot be used.
:param x: the `torch.Tensor` returned from `_merge`
:param nv: the number of vector channels in the input to `_merge`
'''
v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3))
s = x[..., :-3*nv]
return s, v
def _merge(s, v):
'''
Merges a tuple (s, V) into a single `torch.Tensor`, where the
vector channels are flattened and appended to the scalar channels.
Should be used only if the tuple representation cannot be used.
Use `_split(x, nv)` to reverse.
'''
v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],))
return torch.cat([s, v], -1)
class GVP(nn.Module):
'''
Geometric Vector Perceptron. See manuscript and README.md
for more details.
:param in_dims: tuple (n_scalar, n_vector)
:param out_dims: tuple (n_scalar, n_vector)
:param h_dim: intermediate number of vector channels, optional
:param activations: tuple of functions (scalar_act, vector_act)
:param tuple_io: whether to keep accepting tuple inputs and outputs when vi
or vo = 0
'''
def __init__(self, in_dims, out_dims, h_dim=None, vector_gate=False,
activations=(F.relu, torch.sigmoid), tuple_io=True,
eps=1e-8):
super(GVP, self).__init__()
self.si, self.vi = in_dims
self.so, self.vo = out_dims
self.tuple_io = tuple_io
if self.vi:
self.h_dim = h_dim or max(self.vi, self.vo)
self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
self.ws = nn.Linear(self.h_dim + self.si, self.so)
if self.vo:
self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
if vector_gate:
self.wg = nn.Linear(self.so, self.vo)
else:
self.ws = nn.Linear(self.si, self.so)
self.vector_gate = vector_gate
self.scalar_act, self.vector_act = activations
self.eps = eps
def forward(self, x):
'''
:param x: tuple (s, V) of `torch.Tensor`,
or (if vectors_in is 0), a single `torch.Tensor`
:return: tuple (s, V) of `torch.Tensor`,
or (if vectors_out is 0), a single `torch.Tensor`
'''
if self.vi:
s, v = x
v = torch.transpose(v, -1, -2)
vh = self.wh(v)
vn = _norm_no_nan(vh, axis=-2, eps=self.eps)
s = self.ws(torch.cat([s, vn], -1))
if self.scalar_act:
s = self.scalar_act(s)
if self.vo:
v = self.wv(vh)
v = torch.transpose(v, -1, -2)
if self.vector_gate:
g = self.wg(s).unsqueeze(-1)
else:
g = _norm_no_nan(v, axis=-1, keepdims=True, eps=self.eps)
if self.vector_act:
g = self.vector_act(g)
v = v * g
else:
if self.tuple_io:
assert x[1] is None
x = x[0]
s = self.ws(x)
if self.scalar_act:
s = self.scalar_act(s)
if self.vo:
v = torch.zeros(list(s.shape)[:-1] + [self.vo, 3],
device=s.device)
if self.vo:
return (s, v)
elif self.tuple_io:
return (s, None)
else:
return s
class _VDropout(nn.Module):
'''
Vector channel dropout where the elements of each
vector channel are dropped together.
'''
def __init__(self, drop_rate):
super(_VDropout, self).__init__()
self.drop_rate = drop_rate
def forward(self, x):
'''
:param x: `torch.Tensor` corresponding to vector channels
'''
if x is None:
return None
device = x.device
if not self.training:
return x
mask = torch.bernoulli(
(1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
).unsqueeze(-1)
x = mask * x / (1 - self.drop_rate)
return x
class Dropout(nn.Module):
'''
Combined dropout for tuples (s, V).
Takes tuples (s, V) as input and as output.
'''
def __init__(self, drop_rate):
super(Dropout, self).__init__()
self.sdropout = nn.Dropout(drop_rate)
self.vdropout = _VDropout(drop_rate)
def forward(self, x):
'''
:param x: tuple (s, V) of `torch.Tensor`,
or single `torch.Tensor`
(will be assumed to be scalar channels)
'''
if type(x) is torch.Tensor:
return self.sdropout(x)
s, v = x
return self.sdropout(s), self.vdropout(v)
class LayerNorm(nn.Module):
'''
Combined LayerNorm for tuples (s, V).
Takes tuples (s, V) as input and as output.
'''
def __init__(self, dims, tuple_io=True, eps=1e-8):
super(LayerNorm, self).__init__()
self.tuple_io = tuple_io
self.s, self.v = dims
self.scalar_norm = nn.LayerNorm(self.s)
self.eps = eps
def forward(self, x):
'''
:param x: tuple (s, V) of `torch.Tensor`,
or single `torch.Tensor`
(will be assumed to be scalar channels)
'''
if not self.v:
if self.tuple_io:
return self.scalar_norm(x[0]), None
return self.scalar_norm(x)
s, v = x
vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False, eps=self.eps)
nonzero_mask = (vn > 2 * self.eps)
vn = torch.sum(vn * nonzero_mask, dim=-2, keepdim=True
) / (self.eps + torch.sum(nonzero_mask, dim=-2, keepdim=True))
vn = torch.sqrt(vn + self.eps)
v = nonzero_mask * (v / vn)
return self.scalar_norm(s), v
class GVPConv(MessagePassing):
'''
Graph convolution / message passing with Geometric Vector Perceptrons.
Takes in a graph with node and edge embeddings,
and returns new node embeddings.
This does NOT do residual updates and pointwise feedforward layers
---see `GVPConvLayer`.
:param in_dims: input node embedding dimensions (n_scalar, n_vector)
:param out_dims: output node embedding dimensions (n_scalar, n_vector)
:param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
:param n_layers: number of GVPs in the message function
:param module_list: preconstructed message function, overrides n_layers
:param aggr: should be "add" if some incoming edges are masked, as in
a masked autoregressive decoder architecture
'''
def __init__(self, in_dims, out_dims, edge_dims, n_layers=3,
vector_gate=False, module_list=None, aggr="mean", eps=1e-8,
activations=(F.relu, torch.sigmoid)):
super(GVPConv, self).__init__(aggr=aggr)
self.eps = eps
self.si, self.vi = in_dims
self.so, self.vo = out_dims
self.se, self.ve = edge_dims
module_list = module_list or []
if not module_list:
if n_layers == 1:
module_list.append(
GVP((2*self.si + self.se, 2*self.vi + self.ve),
(self.so, self.vo), activations=(None, None)))
else:
module_list.append(
GVP((2*self.si + self.se, 2*self.vi + self.ve), out_dims,
vector_gate=vector_gate, activations=activations)
)
for i in range(n_layers - 2):
module_list.append(GVP(out_dims, out_dims,
vector_gate=vector_gate))
module_list.append(GVP(out_dims, out_dims,
activations=(None, None)))
self.message_func = nn.Sequential(*module_list)
def forward(self, x, edge_index, edge_attr):
'''
:param x: tuple (s, V) of `torch.Tensor`
:param edge_index: array of shape [2, n_edges]
:param edge_attr: tuple (s, V) of `torch.Tensor`
'''
x_s, x_v = x
message = self.propagate(edge_index,
s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]),
edge_attr=edge_attr)
return _split(message, self.vo)
def message(self, s_i, v_i, s_j, v_j, edge_attr):
v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3)
v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3)
message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
message = self.message_func(message)
return _merge(*message)
class GVPConvLayer(nn.Module):
'''
Full graph convolution / message passing layer with
Geometric Vector Perceptrons. Residually updates node embeddings with
aggregated incoming messages, applies a pointwise feedforward
network to node embeddings, and returns updated node embeddings.
To only compute the aggregated messages, see `GVPConv`.
:param node_dims: node embedding dimensions (n_scalar, n_vector)
:param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
:param n_message: number of GVPs to use in message function
:param n_feedforward: number of GVPs to use in feedforward function
:param drop_rate: drop probability in all dropout layers
:param autoregressive: if `True`, this `GVPConvLayer` will be used
with a different set of input node embeddings for messages
where src >= dst
'''
def __init__(self, node_dims, edge_dims, vector_gate=False,
n_message=3, n_feedforward=2, drop_rate=.1,
autoregressive=False, attention_heads=0,
conv_activations=(F.relu, torch.sigmoid),
n_edge_gvps=0, layernorm=True, eps=1e-8):
super(GVPConvLayer, self).__init__()
if attention_heads == 0:
self.conv = GVPConv(
node_dims, node_dims, edge_dims, n_layers=n_message,
vector_gate=vector_gate,
aggr="add" if autoregressive else "mean",
activations=conv_activations,
eps=eps,
)
else:
raise NotImplementedError
if layernorm:
self.norm = nn.ModuleList([LayerNorm(node_dims, eps=eps) for _ in range(2)])
else:
self.norm = nn.ModuleList([nn.Identity() for _ in range(2)])
self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
ff_func = []
if n_feedforward == 1:
ff_func.append(GVP(node_dims, node_dims, activations=(None, None)))
else:
hid_dims = 4*node_dims[0], 2*node_dims[1]
ff_func.append(GVP(node_dims, hid_dims, vector_gate=vector_gate))
for i in range(n_feedforward-2):
ff_func.append(GVP(hid_dims, hid_dims, vector_gate=vector_gate))
ff_func.append(GVP(hid_dims, node_dims, activations=(None, None)))
self.ff_func = nn.Sequential(*ff_func)
self.edge_message_func = None
if n_edge_gvps > 0:
si, vi = node_dims
se, ve = edge_dims
module_list = [
GVP((2*si + se, 2*vi + ve), edge_dims, vector_gate=vector_gate)
]
for i in range(n_edge_gvps - 2):
module_list.append(GVP(edge_dims, edge_dims,
vector_gate=vector_gate))
if n_edge_gvps > 1:
module_list.append(GVP(edge_dims, edge_dims,
activations=(None, None)))
self.edge_message_func = nn.Sequential(*module_list)
if layernorm:
self.edge_norm = LayerNorm(edge_dims, eps=eps)
else:
self.edge_norm = nn.Identity()
self.edge_dropout = Dropout(drop_rate)
def forward(self, x, edge_index, edge_attr,
autoregressive_x=None, node_mask=None):
'''
:param x: tuple (s, V) of `torch.Tensor`
:param edge_index: array of shape [2, n_edges]
:param edge_attr: tuple (s, V) of `torch.Tensor`
:param autoregressive_x: tuple (s, V) of `torch.Tensor`.
If not `None`, will be used as srcqq node embeddings
for forming messages where src >= dst. The corrent node
embeddings `x` will still be the base of the update and the
pointwise feedforward.
:param node_mask: array of type `bool` to index into the first
dim of node embeddings (s, V). If not `None`, only
these nodes will be updated.
'''
if self.edge_message_func:
src, dst = edge_index
if autoregressive_x is None:
x_src = x[0][src], x[1][src]
else:
mask = (src < dst).unsqueeze(-1)
x_src = (
torch.where(mask, x[0][src], autoregressive_x[0][src]),
torch.where(mask.unsqueeze(-1), x[1][src],
autoregressive_x[1][src])
)
x_dst = x[0][dst], x[1][dst]
x_edge = (
torch.cat([x_src[0], edge_attr[0], x_dst[0]], dim=-1),
torch.cat([x_src[1], edge_attr[1], x_dst[1]], dim=-2)
)
edge_attr_dh = self.edge_message_func(x_edge)
edge_attr = self.edge_norm(tuple_sum(edge_attr,
self.edge_dropout(edge_attr_dh)))
if autoregressive_x is not None:
# Guarding this import here to remove the dependency on torch_scatter, since this isn't used
# in ESM-IF1
from torch_scatter import scatter_add
src, dst = edge_index
mask = src < dst
edge_index_forward = edge_index[:, mask]
edge_index_backward = edge_index[:, ~mask]
edge_attr_forward = tuple_index(edge_attr, mask)
edge_attr_backward = tuple_index(edge_attr, ~mask)
dh = tuple_sum(
self.conv(x, edge_index_forward, edge_attr_forward),
self.conv(autoregressive_x, edge_index_backward, edge_attr_backward)
)
count = scatter_add(torch.ones_like(dst), dst,
dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1)
dh = dh[0] / count, dh[1] / count.unsqueeze(-1)
else:
dh = self.conv(x, edge_index, edge_attr)
if node_mask is not None:
x_ = x
x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
x = self.norm[0](tuple_sum(x, self.dropout[0](dh)))
dh = self.ff_func(x)
x = self.norm[1](tuple_sum(x, self.dropout[1](dh)))
if node_mask is not None:
x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
x = x_
return x, edge_attr