# Contents of this file are from the open source code for # # Jing, B., Eismann, S., Suriana, P., Townshend, R. J. L., & Dror, R. (2020). # Learning from Protein Structure with Geometric Vector Perceptrons. In # International Conference on Learning Representations. # # MIT License # # Copyright (c) 2020 Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael Townshend, Ron Dror # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import typing as T import torch from torch import nn import torch.nn.functional as F from torch_geometric.nn import MessagePassing def tuple_size(tp): return tuple([0 if a is None else a.size() for a in tp]) def tuple_sum(tp1, tp2): s1, v1 = tp1 s2, v2 = tp2 if v2 is None and v2 is None: return (s1 + s2, None) return (s1 + s2, v1 + v2) def tuple_cat(*args, dim=-1): ''' Concatenates any number of tuples (s, V) elementwise. :param dim: dimension along which to concatenate when viewed as the `dim` index for the scalar-channel tensors. This means that `dim=-1` will be applied as `dim=-2` for the vector-channel tensors. ''' dim %= len(args[0][0].shape) s_args, v_args = list(zip(*args)) return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim) def tuple_index(x, idx): ''' Indexes into a tuple (s, V) along the first dimension. :param idx: any object which can be used to index into a `torch.Tensor` ''' return x[0][idx], x[1][idx] def randn(n, dims, device="cpu"): ''' Returns random tuples (s, V) drawn elementwise from a normal distribution. :param n: number of data points :param dims: tuple of dimensions (n_scalar, n_vector) :return: (s, V) with s.shape = (n, n_scalar) and V.shape = (n, n_vector, 3) ''' return torch.randn(n, dims[0], device=device), \ torch.randn(n, dims[1], 3, device=device) def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True): ''' L2 norm of tensor clamped above a minimum value `eps`. :param sqrt: if `False`, returns the square of the L2 norm ''' # clamp is slow # out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps) out = torch.sum(torch.square(x), axis, keepdims) + eps return torch.sqrt(out) if sqrt else out def _split(x, nv): ''' Splits a merged representation of (s, V) back into a tuple. Should be used only with `_merge(s, V)` and only if the tuple representation cannot be used. :param x: the `torch.Tensor` returned from `_merge` :param nv: the number of vector channels in the input to `_merge` ''' v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3)) s = x[..., :-3*nv] return s, v def _merge(s, v): ''' Merges a tuple (s, V) into a single `torch.Tensor`, where the vector channels are flattened and appended to the scalar channels. Should be used only if the tuple representation cannot be used. Use `_split(x, nv)` to reverse. ''' v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],)) return torch.cat([s, v], -1) class GVP(nn.Module): ''' Geometric Vector Perceptron. See manuscript and README.md for more details. :param in_dims: tuple (n_scalar, n_vector) :param out_dims: tuple (n_scalar, n_vector) :param h_dim: intermediate number of vector channels, optional :param activations: tuple of functions (scalar_act, vector_act) :param tuple_io: whether to keep accepting tuple inputs and outputs when vi or vo = 0 ''' def __init__(self, in_dims, out_dims, h_dim=None, vector_gate=False, activations=(F.relu, torch.sigmoid), tuple_io=True, eps=1e-8): super(GVP, self).__init__() self.si, self.vi = in_dims self.so, self.vo = out_dims self.tuple_io = tuple_io if self.vi: self.h_dim = h_dim or max(self.vi, self.vo) self.wh = nn.Linear(self.vi, self.h_dim, bias=False) self.ws = nn.Linear(self.h_dim + self.si, self.so) if self.vo: self.wv = nn.Linear(self.h_dim, self.vo, bias=False) if vector_gate: self.wg = nn.Linear(self.so, self.vo) else: self.ws = nn.Linear(self.si, self.so) self.vector_gate = vector_gate self.scalar_act, self.vector_act = activations self.eps = eps def forward(self, x): ''' :param x: tuple (s, V) of `torch.Tensor`, or (if vectors_in is 0), a single `torch.Tensor` :return: tuple (s, V) of `torch.Tensor`, or (if vectors_out is 0), a single `torch.Tensor` ''' if self.vi: s, v = x v = torch.transpose(v, -1, -2) vh = self.wh(v) vn = _norm_no_nan(vh, axis=-2, eps=self.eps) s = self.ws(torch.cat([s, vn], -1)) if self.scalar_act: s = self.scalar_act(s) if self.vo: v = self.wv(vh) v = torch.transpose(v, -1, -2) if self.vector_gate: g = self.wg(s).unsqueeze(-1) else: g = _norm_no_nan(v, axis=-1, keepdims=True, eps=self.eps) if self.vector_act: g = self.vector_act(g) v = v * g else: if self.tuple_io: assert x[1] is None x = x[0] s = self.ws(x) if self.scalar_act: s = self.scalar_act(s) if self.vo: v = torch.zeros(list(s.shape)[:-1] + [self.vo, 3], device=s.device) if self.vo: return (s, v) elif self.tuple_io: return (s, None) else: return s class _VDropout(nn.Module): ''' Vector channel dropout where the elements of each vector channel are dropped together. ''' def __init__(self, drop_rate): super(_VDropout, self).__init__() self.drop_rate = drop_rate def forward(self, x): ''' :param x: `torch.Tensor` corresponding to vector channels ''' if x is None: return None device = x.device if not self.training: return x mask = torch.bernoulli( (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device) ).unsqueeze(-1) x = mask * x / (1 - self.drop_rate) return x class Dropout(nn.Module): ''' Combined dropout for tuples (s, V). Takes tuples (s, V) as input and as output. ''' def __init__(self, drop_rate): super(Dropout, self).__init__() self.sdropout = nn.Dropout(drop_rate) self.vdropout = _VDropout(drop_rate) def forward(self, x): ''' :param x: tuple (s, V) of `torch.Tensor`, or single `torch.Tensor` (will be assumed to be scalar channels) ''' if type(x) is torch.Tensor: return self.sdropout(x) s, v = x return self.sdropout(s), self.vdropout(v) class LayerNorm(nn.Module): ''' Combined LayerNorm for tuples (s, V). Takes tuples (s, V) as input and as output. ''' def __init__(self, dims, tuple_io=True, eps=1e-8): super(LayerNorm, self).__init__() self.tuple_io = tuple_io self.s, self.v = dims self.scalar_norm = nn.LayerNorm(self.s) self.eps = eps def forward(self, x): ''' :param x: tuple (s, V) of `torch.Tensor`, or single `torch.Tensor` (will be assumed to be scalar channels) ''' if not self.v: if self.tuple_io: return self.scalar_norm(x[0]), None return self.scalar_norm(x) s, v = x vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False, eps=self.eps) nonzero_mask = (vn > 2 * self.eps) vn = torch.sum(vn * nonzero_mask, dim=-2, keepdim=True ) / (self.eps + torch.sum(nonzero_mask, dim=-2, keepdim=True)) vn = torch.sqrt(vn + self.eps) v = nonzero_mask * (v / vn) return self.scalar_norm(s), v class GVPConv(MessagePassing): ''' Graph convolution / message passing with Geometric Vector Perceptrons. Takes in a graph with node and edge embeddings, and returns new node embeddings. This does NOT do residual updates and pointwise feedforward layers ---see `GVPConvLayer`. :param in_dims: input node embedding dimensions (n_scalar, n_vector) :param out_dims: output node embedding dimensions (n_scalar, n_vector) :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) :param n_layers: number of GVPs in the message function :param module_list: preconstructed message function, overrides n_layers :param aggr: should be "add" if some incoming edges are masked, as in a masked autoregressive decoder architecture ''' def __init__(self, in_dims, out_dims, edge_dims, n_layers=3, vector_gate=False, module_list=None, aggr="mean", eps=1e-8, activations=(F.relu, torch.sigmoid)): super(GVPConv, self).__init__(aggr=aggr) self.eps = eps self.si, self.vi = in_dims self.so, self.vo = out_dims self.se, self.ve = edge_dims module_list = module_list or [] if not module_list: if n_layers == 1: module_list.append( GVP((2*self.si + self.se, 2*self.vi + self.ve), (self.so, self.vo), activations=(None, None))) else: module_list.append( GVP((2*self.si + self.se, 2*self.vi + self.ve), out_dims, vector_gate=vector_gate, activations=activations) ) for i in range(n_layers - 2): module_list.append(GVP(out_dims, out_dims, vector_gate=vector_gate)) module_list.append(GVP(out_dims, out_dims, activations=(None, None))) self.message_func = nn.Sequential(*module_list) def forward(self, x, edge_index, edge_attr): ''' :param x: tuple (s, V) of `torch.Tensor` :param edge_index: array of shape [2, n_edges] :param edge_attr: tuple (s, V) of `torch.Tensor` ''' x_s, x_v = x message = self.propagate(edge_index, s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]), edge_attr=edge_attr) return _split(message, self.vo) def message(self, s_i, v_i, s_j, v_j, edge_attr): v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3) v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3) message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i)) message = self.message_func(message) return _merge(*message) class GVPConvLayer(nn.Module): ''' Full graph convolution / message passing layer with Geometric Vector Perceptrons. Residually updates node embeddings with aggregated incoming messages, applies a pointwise feedforward network to node embeddings, and returns updated node embeddings. To only compute the aggregated messages, see `GVPConv`. :param node_dims: node embedding dimensions (n_scalar, n_vector) :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) :param n_message: number of GVPs to use in message function :param n_feedforward: number of GVPs to use in feedforward function :param drop_rate: drop probability in all dropout layers :param autoregressive: if `True`, this `GVPConvLayer` will be used with a different set of input node embeddings for messages where src >= dst ''' def __init__(self, node_dims, edge_dims, vector_gate=False, n_message=3, n_feedforward=2, drop_rate=.1, autoregressive=False, attention_heads=0, conv_activations=(F.relu, torch.sigmoid), n_edge_gvps=0, layernorm=True, eps=1e-8): super(GVPConvLayer, self).__init__() if attention_heads == 0: self.conv = GVPConv( node_dims, node_dims, edge_dims, n_layers=n_message, vector_gate=vector_gate, aggr="add" if autoregressive else "mean", activations=conv_activations, eps=eps, ) else: raise NotImplementedError if layernorm: self.norm = nn.ModuleList([LayerNorm(node_dims, eps=eps) for _ in range(2)]) else: self.norm = nn.ModuleList([nn.Identity() for _ in range(2)]) self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)]) ff_func = [] if n_feedforward == 1: ff_func.append(GVP(node_dims, node_dims, activations=(None, None))) else: hid_dims = 4*node_dims[0], 2*node_dims[1] ff_func.append(GVP(node_dims, hid_dims, vector_gate=vector_gate)) for i in range(n_feedforward-2): ff_func.append(GVP(hid_dims, hid_dims, vector_gate=vector_gate)) ff_func.append(GVP(hid_dims, node_dims, activations=(None, None))) self.ff_func = nn.Sequential(*ff_func) self.edge_message_func = None if n_edge_gvps > 0: si, vi = node_dims se, ve = edge_dims module_list = [ GVP((2*si + se, 2*vi + ve), edge_dims, vector_gate=vector_gate) ] for i in range(n_edge_gvps - 2): module_list.append(GVP(edge_dims, edge_dims, vector_gate=vector_gate)) if n_edge_gvps > 1: module_list.append(GVP(edge_dims, edge_dims, activations=(None, None))) self.edge_message_func = nn.Sequential(*module_list) if layernorm: self.edge_norm = LayerNorm(edge_dims, eps=eps) else: self.edge_norm = nn.Identity() self.edge_dropout = Dropout(drop_rate) def forward(self, x, edge_index, edge_attr, autoregressive_x=None, node_mask=None): ''' :param x: tuple (s, V) of `torch.Tensor` :param edge_index: array of shape [2, n_edges] :param edge_attr: tuple (s, V) of `torch.Tensor` :param autoregressive_x: tuple (s, V) of `torch.Tensor`. If not `None`, will be used as srcqq node embeddings for forming messages where src >= dst. The corrent node embeddings `x` will still be the base of the update and the pointwise feedforward. :param node_mask: array of type `bool` to index into the first dim of node embeddings (s, V). If not `None`, only these nodes will be updated. ''' if self.edge_message_func: src, dst = edge_index if autoregressive_x is None: x_src = x[0][src], x[1][src] else: mask = (src < dst).unsqueeze(-1) x_src = ( torch.where(mask, x[0][src], autoregressive_x[0][src]), torch.where(mask.unsqueeze(-1), x[1][src], autoregressive_x[1][src]) ) x_dst = x[0][dst], x[1][dst] x_edge = ( torch.cat([x_src[0], edge_attr[0], x_dst[0]], dim=-1), torch.cat([x_src[1], edge_attr[1], x_dst[1]], dim=-2) ) edge_attr_dh = self.edge_message_func(x_edge) edge_attr = self.edge_norm(tuple_sum(edge_attr, self.edge_dropout(edge_attr_dh))) if autoregressive_x is not None: # Guarding this import here to remove the dependency on torch_scatter, since this isn't used # in ESM-IF1 from torch_scatter import scatter_add src, dst = edge_index mask = src < dst edge_index_forward = edge_index[:, mask] edge_index_backward = edge_index[:, ~mask] edge_attr_forward = tuple_index(edge_attr, mask) edge_attr_backward = tuple_index(edge_attr, ~mask) dh = tuple_sum( self.conv(x, edge_index_forward, edge_attr_forward), self.conv(autoregressive_x, edge_index_backward, edge_attr_backward) ) count = scatter_add(torch.ones_like(dst), dst, dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1) dh = dh[0] / count, dh[1] / count.unsqueeze(-1) else: dh = self.conv(x, edge_index, edge_attr) if node_mask is not None: x_ = x x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask) x = self.norm[0](tuple_sum(x, self.dropout[0](dh))) dh = self.ff_func(x) x = self.norm[1](tuple_sum(x, self.dropout[1](dh))) if node_mask is not None: x_[0][node_mask], x_[1][node_mask] = x[0], x[1] x = x_ return x, edge_attr