| | import math |
| | import functools |
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| | from torch_scatter import scatter_mean, scatter_std, scatter_min, scatter_max, scatter_softmax |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from src.model.gvp import GVP, _norm_no_nan, tuple_sum, Dropout, LayerNorm, \ |
| | tuple_cat, tuple_index, _rbf, _normalize |
| |
|
| |
|
| | def tuple_mul(tup, val): |
| | if isinstance(val, torch.Tensor): |
| | return (tup[0] * val, tup[1] * val.unsqueeze(-1)) |
| | return (tup[0] * val, tup[1] * val) |
| |
|
| |
|
| | class GVPBlock(nn.Module): |
| | def __init__(self, in_dims, out_dims, n_layers=1, |
| | activations=(F.relu, torch.sigmoid), vector_gate=False, |
| | dropout=0.0, skip=False, layernorm=False): |
| | super(GVPBlock, self).__init__() |
| | self.si, self.vi = in_dims |
| | self.so, self.vo = out_dims |
| | assert not skip or (self.si == self.so and self.vi == self.vo) |
| | self.skip = skip |
| |
|
| | GVP_ = functools.partial(GVP, activations=activations, vector_gate=vector_gate) |
| |
|
| | module_list = [] |
| | if n_layers == 1: |
| | module_list.append(GVP_(in_dims, out_dims, activations=(None, None))) |
| | else: |
| | module_list.append(GVP_(in_dims, out_dims)) |
| | for i in range(n_layers - 2): |
| | module_list.append(GVP_(out_dims, out_dims)) |
| | module_list.append(GVP_(out_dims, out_dims, activations=(None, None))) |
| |
|
| | self.layers = nn.Sequential(*module_list) |
| |
|
| | self.norm = LayerNorm(out_dims, learnable_vector_weight=True) if layernorm else None |
| | self.dropout = Dropout(dropout) if dropout > 0 else None |
| |
|
| | def forward(self, x): |
| | """ |
| | :param x: tuple (s, V) of `torch.Tensor` |
| | :return: tuple (s, V) of `torch.Tensor` |
| | """ |
| |
|
| | dx = self.layers(x) |
| |
|
| | if self.dropout is not None: |
| | dx = self.dropout(dx) |
| |
|
| | if self.skip: |
| | x = tuple_sum(x, dx) |
| | else: |
| | x = dx |
| |
|
| | if self.norm is not None: |
| | x = self.norm(x) |
| |
|
| | return x |
| |
|
| |
|
| | class GeometricPNA(nn.Module): |
| | def __init__(self, d_in, d_out): |
| | """ Map features to global features """ |
| | super().__init__() |
| | si, vi = d_in |
| | so, vo = d_out |
| | self.gvp = GVPBlock((4 * si + 3 * vi, vi), d_out) |
| |
|
| | def forward(self, x, batch_mask, batch_size=None): |
| | """ x: tuple (s, V) """ |
| | s, v = x |
| |
|
| | sm = scatter_mean(s, batch_mask, dim=0, dim_size=batch_size) |
| | smi = scatter_min(s, batch_mask, dim=0, dim_size=batch_size)[0] |
| | sma = scatter_max(s, batch_mask, dim=0, dim_size=batch_size)[0] |
| | sstd = scatter_std(s, batch_mask, dim=0, dim_size=batch_size) |
| |
|
| | vnorm = _norm_no_nan(v) |
| | vm = scatter_mean(v, batch_mask, dim=0, dim_size=batch_size) |
| | vmi = scatter_min(vnorm, batch_mask, dim=0, dim_size=batch_size)[0] |
| | vma = scatter_max(vnorm, batch_mask, dim=0, dim_size=batch_size)[0] |
| | vstd = scatter_std(vnorm, batch_mask, dim=0, dim_size=batch_size) |
| |
|
| | z = torch.hstack((sm, smi, sma, sstd, vmi, vma, vstd)) |
| | out = self.gvp((z, vm)) |
| | return out |
| |
|
| |
|
| | class TupleLinear(nn.Module): |
| | def __init__(self, in_dims, out_dims, bias=True): |
| | super().__init__() |
| | self.si, self.vi = in_dims |
| | self.so, self.vo = out_dims |
| | assert self.si and self.so |
| | self.ws = nn.Linear(self.si, self.so, bias=bias) |
| | self.wv = nn.Linear(self.vi, self.vo, bias=bias) if self.vi and self.vo else None |
| |
|
| | def forward(self, x): |
| | if self.vi: |
| | s, v = x |
| |
|
| | s = self.ws(s) |
| |
|
| | if self.vo: |
| | v = v.transpose(-1, -2) |
| | v = self.wv(v) |
| | v = v.transpose(-1, -2) |
| |
|
| | else: |
| | s = self.ws(x) |
| |
|
| | if self.vo: |
| | v = torch.zeros(s.size(0), self.vo, 3, device=s.device) |
| |
|
| | return (s, v) if self.vo else s |
| |
|
| |
|
| | class GVPTransformerLayer(nn.Module): |
| | """ |
| | Full graph transformer layer with Geometric Vector Perceptrons. |
| | Inspired by |
| | - GVP: Jing, Bowen, et al. "Learning from protein structure with geometric vector perceptrons." arXiv preprint arXiv:2009.01411 (2020). |
| | - Transformer architecture: Vignac, Clement, et al. "Digress: Discrete denoising diffusion for graph generation." arXiv preprint arXiv:2209.14734 (2022). |
| | - Invariant point attention: Jumper, John, et al. "Highly accurate protein structure prediction with AlphaFold." Nature 596.7873 (2021): 583-589. |
| | |
| | :param node_dims: node embedding dimensions (n_scalar, n_vector) |
| | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) |
| | :param global_dims: global feature dimension (n_scalar, n_vector) |
| | :param dk: key dimension, (n_scalar, n_vector) |
| | :param dv: node value dimension, (n_scalar, n_vector) |
| | :param de: edge value dimension, (n_scalar, n_vector) |
| | :param db: dimension of edge contribution to attention, int |
| | :param attn_heads: number of attention heads, int |
| | :param n_feedforward: number of GVPs to use in feedforward function |
| | :param drop_rate: drop probability in all dropout layers |
| | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs |
| | :param vector_gate: whether to use vector gating. |
| | (vector_act will be used as sigma^+ in vector gating if `True`) |
| | :param attention: can be used to turn off the attention mechanism |
| | """ |
| |
|
| | def __init__(self, node_dims, edge_dims, global_dims, dk, dv, de, db, |
| | attn_heads, n_feedforward=1, drop_rate=0.0, |
| | activations=(F.relu, torch.sigmoid), vector_gate=False, |
| | attention=True): |
| |
|
| | super(GVPTransformerLayer, self).__init__() |
| |
|
| | self.attention = attention |
| |
|
| | dq = dk |
| | self.dq = dq |
| | self.dk = dk |
| | self.dv = dv |
| | self.de = de |
| | self.db = db |
| |
|
| | self.h = attn_heads |
| |
|
| | self.q = TupleLinear(node_dims, tuple_mul(dq, self.h), bias=False) if self.attention else None |
| | self.k = TupleLinear(node_dims, tuple_mul(dk, self.h), bias=False) if self.attention else None |
| | self.vx = TupleLinear(node_dims, tuple_mul(dv, self.h), bias=False) |
| |
|
| | self.ve = TupleLinear(edge_dims, tuple_mul(de, self.h), bias=False) |
| | self.b = TupleLinear(edge_dims, (db * self.h, 0), bias=False) if self.attention else None |
| |
|
| | m_dim = tuple_sum(tuple_mul(dv, self.h), tuple_mul(de, self.h)) |
| | self.msg = GVPBlock(m_dim, m_dim, n_feedforward, |
| | activations=activations, vector_gate=vector_gate) |
| |
|
| | m_dim = tuple_sum(m_dim, global_dims) |
| | self.x_out = GVPBlock(m_dim, node_dims, n_feedforward, |
| | activations=activations, vector_gate=vector_gate) |
| | self.x_norm = LayerNorm(node_dims, learnable_vector_weight=True) |
| | self.x_dropout = Dropout(drop_rate) |
| |
|
| | e_dim = tuple_sum(tuple_mul(node_dims, 2), edge_dims, global_dims) |
| | if self.attention: |
| | e_dim = (e_dim[0] + 3 * attn_heads, e_dim[1]) |
| | self.e_out = GVPBlock(e_dim, edge_dims, n_feedforward, |
| | activations=activations, vector_gate=vector_gate) |
| | self.e_norm = LayerNorm(edge_dims, learnable_vector_weight=True) |
| | self.e_dropout = Dropout(drop_rate) |
| |
|
| | self.pna_x = GeometricPNA(node_dims, node_dims) |
| | self.pna_e = GeometricPNA(edge_dims, edge_dims) |
| | self.y = GVP(global_dims, global_dims, activations=(None, None), vector_gate=vector_gate) |
| | _dim = tuple_sum(node_dims, edge_dims, global_dims) |
| | self.y_out = GVPBlock(_dim, global_dims, n_feedforward, |
| | activations=activations, vector_gate=vector_gate) |
| | self.y_norm = LayerNorm(global_dims, learnable_vector_weight=True) |
| | self.y_dropout = Dropout(drop_rate) |
| |
|
| | def forward(self, x, edge_index, batch_mask, edge_attr, global_attr=None, |
| | node_mask=None): |
| | """ |
| | :param x: tuple (s, V) of `torch.Tensor` |
| | :param edge_index: array of shape [2, n_edges] |
| | :param batch_mask: array indicating different graphs |
| | :param edge_attr: tuple (s, V) of `torch.Tensor` |
| | :param global_attr: tuple (s, V) of `torch.Tensor` |
| | :param node_mask: array of type `bool` to index into the first |
| | dim of node embeddings (s, V). If not `None`, only |
| | these nodes will be updated. |
| | """ |
| |
|
| | row, col = edge_index |
| | n = len(x[0]) |
| | batch_size = len(torch.unique(batch_mask)) |
| |
|
| | |
| | if self.attention: |
| | Q = self.q(x) |
| | K = self.k(x) |
| | b = self.b(edge_attr) |
| |
|
| | qs, qv = Q |
| | ks, kv = K |
| | attn_s = (qs[row] * ks[col]).reshape(len(row), self.h, self.dq[0]).sum(dim=-1) |
| | |
| | |
| | |
| | attn_v = (qv[row] * kv[col]).reshape(len(row), self.h, self.dq[1], 3).sum(dim=(-2, -1)) |
| | attn_e = b.reshape(b.size(0), self.h, self.db).sum(dim=-1) |
| |
|
| | attn = attn_s / math.sqrt(3 * self.dk[0]) + \ |
| | attn_v / math.sqrt(9 * self.dk[1]) + \ |
| | attn_e / math.sqrt(3 * self.db) |
| | attn = scatter_softmax(attn, row, dim=0) |
| | attn = attn.unsqueeze(-1) |
| |
|
| | |
| | Vx = self.vx(x) |
| | Ve = self.ve(edge_attr) |
| |
|
| | mx = (Vx[0].reshape(Vx[0].size(0), self.h, self.dv[0]), |
| | Vx[1].reshape(Vx[1].size(0), self.h, self.dv[1], 3)) |
| | me = (Ve[0].reshape(Ve[0].size(0), self.h, self.de[0]), |
| | Ve[1].reshape(Ve[1].size(0), self.h, self.de[1], 3)) |
| |
|
| | mx = tuple_index(mx, col) |
| | if self.attention: |
| | mx = tuple_mul(mx, attn) |
| | me = tuple_mul(me, attn) |
| |
|
| | _m = tuple_cat(mx, me) |
| | _m = (_m[0].flatten(1), _m[1].flatten(1, 2)) |
| | m = self.msg(_m) |
| | m = (scatter_mean(m[0], row, dim=0, dim_size=n), |
| | scatter_mean(m[1], row, dim=0, dim_size=n)) |
| | if global_attr is not None: |
| | m = tuple_cat(m, tuple_index(global_attr, batch_mask)) |
| | X_out = self.x_norm(tuple_sum(x, self.x_dropout(self.x_out(m)))) |
| |
|
| | _e = tuple_cat(tuple_index(x, row), tuple_index(x, col), edge_attr) |
| | if self.attention: |
| | _e = (torch.cat([_e[0], attn_s, attn_v, attn_e], dim=-1), _e[1]) |
| | if global_attr is not None: |
| | _e = tuple_cat(_e, tuple_index(global_attr, batch_mask[row])) |
| | E_out = self.e_norm(tuple_sum(edge_attr, self.e_dropout(self.e_out(_e)))) |
| |
|
| | _y = tuple_cat(self.pna_x(x, batch_mask, batch_size), |
| | self.pna_e(edge_attr, batch_mask[row], batch_size)) |
| | if global_attr is not None: |
| | _y = tuple_cat(_y, self.y(global_attr)) |
| | y_out = self.y_norm(tuple_sum(global_attr, self.y_dropout(self.y_out(_y)))) |
| | else: |
| | y_out = self.y_norm(self.y_dropout(self.y_out(_y))) |
| |
|
| | if node_mask is not None: |
| | X_out[0][~node_mask], X_out[1][~node_mask] = tuple_index(x, ~node_mask) |
| |
|
| | return X_out, E_out, y_out |
| |
|
| |
|
| | class GVPTransformerModel(torch.nn.Module): |
| | """ |
| | GVP-Transformer model |
| | |
| | :param node_in_dim: node dimension in input graph, scalars or tuple (scalars, vectors) |
| | :param node_h_dim: node dimensions to use in GVP-GNN layers, tuple (s, V) |
| | :param node_out_nf: node dimensions in output graph, tuple (s, V) |
| | :param edge_in_nf: edge dimension in input graph (scalars) |
| | :param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers, |
| | tuple (s, V) |
| | :param edge_out_nf: edge dimensions in output graph, tuple (s, V) |
| | :param num_layers: number of GVP-GNN layers |
| | :param drop_rate: rate to use in all dropout layers |
| | :param reflection_equiv: bool, use reflection-sensitive feature based on the |
| | cross product if False |
| | :param d_max: |
| | :param num_rbf: |
| | :param vector_gate: use vector gates in all GVPs |
| | :param attention: can be used to turn off the attention mechanism |
| | """ |
| | def __init__(self, node_in_dim, node_h_dim, node_out_nf, edge_in_nf, |
| | edge_h_dim, edge_out_nf, num_layers, dk, dv, de, db, dy, |
| | attn_heads, n_feedforward, drop_rate, reflection_equiv=True, |
| | d_max=20.0, num_rbf=16, vector_gate=False, attention=True): |
| |
|
| | super(GVPTransformerModel, self).__init__() |
| |
|
| | self.reflection_equiv = reflection_equiv |
| | self.d_max = d_max |
| | self.num_rbf = num_rbf |
| |
|
| | |
| | if not isinstance(node_in_dim, tuple): |
| | node_in_dim = (node_in_dim, 0) |
| |
|
| | edge_in_dim = (edge_in_nf + 2 * node_in_dim[0] + self.num_rbf, 1) |
| | if not self.reflection_equiv: |
| | edge_in_dim = (edge_in_dim[0], edge_in_dim[1] + 1) |
| |
|
| | self.W_v = GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate) |
| | self.W_e = GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | self.dy = dy |
| | self.layers = nn.ModuleList( |
| | GVPTransformerLayer(node_h_dim, edge_h_dim, dy, dk, dv, de, db, |
| | attn_heads, n_feedforward=n_feedforward, |
| | drop_rate=drop_rate, vector_gate=vector_gate, |
| | activations=(F.relu, None), attention=attention) |
| | for _ in range(num_layers)) |
| |
|
| | self.W_v_out = GVP(node_h_dim, (node_out_nf, 1), activations=(None, None), vector_gate=vector_gate) |
| | self.W_e_out = GVP(edge_h_dim, (edge_out_nf, 0), activations=(None, None), vector_gate=vector_gate) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def edge_features(self, h, x, edge_index, batch_mask=None, edge_attr=None): |
| | """ |
| | :param h: |
| | :param x: |
| | :param edge_index: |
| | :param batch_mask: |
| | :param edge_attr: |
| | :return: scalar and vector-valued edge features |
| | """ |
| | row, col = edge_index |
| | coord_diff = x[row] - x[col] |
| | dist = coord_diff.norm(dim=-1) |
| | rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, |
| | device=x.device) |
| |
|
| | edge_s = torch.cat([h[row], h[col], rbf], dim=1) |
| | edge_v = _normalize(coord_diff).unsqueeze(-2) |
| |
|
| | if edge_attr is not None: |
| | edge_s = torch.cat([edge_s, edge_attr], dim=1) |
| |
|
| | if not self.reflection_equiv: |
| | mean = scatter_mean(x, batch_mask, dim=0, |
| | dim_size=batch_mask.max() + 1) |
| | row, col = edge_index |
| | cross = torch.cross(x[row] - mean[batch_mask[row]], |
| | x[col] - mean[batch_mask[col]], dim=1) |
| | cross = _normalize(cross).unsqueeze(-2) |
| |
|
| | edge_v = torch.cat([edge_v, cross], dim=-2) |
| |
|
| | return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v) |
| |
|
| | def forward(self, h, x, edge_index, v=None, batch_mask=None, edge_attr=None): |
| |
|
| | bs = len(batch_mask.unique()) |
| |
|
| | |
| | h_v = h if v is None else (h, v) |
| | h_e = self.edge_features(h, x, edge_index, batch_mask, edge_attr) |
| |
|
| | h_v = self.W_v(h_v) |
| | h_e = self.W_e(h_e) |
| | h_y = (torch.zeros(bs, self.dy[0], device=h.device), |
| | torch.zeros(bs, self.dy[1], 3, device=h.device)) |
| |
|
| | for layer in self.layers: |
| | h_v, h_e, h_y = layer(h_v, edge_index, batch_mask, h_e, h_y) |
| |
|
| | |
| | |
| | h, vel = self.W_v_out(h_v) |
| | |
| |
|
| | edge_attr = self.W_e_out(h_e) |
| |
|
| | |
| | return h, vel.squeeze(-2), edge_attr |
| |
|
| |
|
| | if __name__ == "__main__": |
| | from src.model.gvp import randn |
| | from scipy.spatial.transform import Rotation |
| |
|
| | def test_equivariance(model, nodes, edges, glob_feat): |
| | random = torch.as_tensor(Rotation.random().as_matrix(), |
| | dtype=torch.float32, device=device) |
| |
|
| | with torch.no_grad(): |
| | X_out, E_out, y_out = model(nodes, edges, glob_feat) |
| | n_v_rot, e_v_rot, y_v_rot = nodes[1] @ random, edges[1] @ random, glob_feat[1] @ random |
| | X_out_v_rot = X_out[1] @ random |
| | E_out_v_rot = E_out[1] @ random |
| | y_out_v_rot = y_out[1] @ random |
| | X_out_prime, E_out_prime, y_out_prime = model((nodes[0], n_v_rot), (edges[0], e_v_rot), (glob_feat[0], y_v_rot)) |
| |
|
| | assert torch.allclose(X_out[0], X_out_prime[0], atol=1e-5, rtol=1e-4) |
| | assert torch.allclose(X_out_v_rot, X_out_prime[1], atol=1e-5, rtol=1e-4) |
| | assert torch.allclose(E_out[0], E_out_prime[0], atol=1e-5, rtol=1e-4) |
| | assert torch.allclose(E_out_v_rot, E_out_prime[1], atol=1e-5, rtol=1e-4) |
| | assert torch.allclose(y_out[0], y_out_prime[0], atol=1e-5, rtol=1e-4) |
| | assert torch.allclose(y_out_v_rot, y_out_prime[1], atol=1e-5, rtol=1e-4) |
| | print("SUCCESS") |
| |
|
| |
|
| | n_nodes = 300 |
| | n_edges = 10000 |
| | batch_size = 6 |
| |
|
| | node_dim = (16, 8) |
| | edge_dim = (8, 4) |
| | global_dim = (4, 2) |
| | dk = (6, 3) |
| | dv = (7, 4) |
| | de = (5, 2) |
| | db = 10 |
| | attn_heads = 9 |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| |
|
| | nodes = randn(n_nodes, node_dim, device=device) |
| | edges = randn(n_edges, edge_dim, device=device) |
| | glob_feat = randn(batch_size, global_dim, device=device) |
| | edge_index = torch.randint(0, n_nodes, (2, n_edges), device=device) |
| | batch_idx = torch.randint(0, batch_size, (n_nodes,), device=device) |
| |
|
| | model = GVPTransformerLayer(node_dim, edge_dim, global_dim, dk, dv, de, db, |
| | attn_heads, n_feedforward = 2, |
| | drop_rate = 0.1).to(device).eval() |
| |
|
| | model_fn = lambda h_V, h_E, h_y: model(h_V, edge_index, batch_idx, h_E, h_y) |
| | test_equivariance(model_fn, nodes, edges, glob_feat) |
| |
|