flexpert / Flexpert-Design /src /modules /graphtrans_module.py
Honzus24's picture
initial commit
7968cb0
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class PositionalEncodings(nn.Module):
def __init__(self, num_embeddings, period_range = None):
if period_range is None:
period_range = [2,1000]
super(PositionalEncodings, self).__init__()
self.num_embeddings = num_embeddings
self.period_range = period_range
def forward(self, E_idx):
N_nodes = E_idx.size(1)
ii = torch.arange(N_nodes, dtype=torch.float32, device = E_idx.device).view((1, -1, 1))
d = (E_idx.float() - ii).unsqueeze(-1)
# Original Transformer frequencies
frequency = torch.exp(torch.arange(0, self.num_embeddings, 2, dtype=torch.float32, device = E_idx.device) * -(np.log(10000.0) / self.num_embeddings))
angles = d * frequency.view((1,1,1,-1))
return torch.cat((torch.cos(angles), torch.sin(angles)), -1)
class ProteinFeatures(nn.Module):
def __init__(self, edge_features, node_features, num_positional_embeddings=16, num_rbf=16, top_k=30, features_type='full', augment_eps=0., dropout=0.1):
super(ProteinFeatures, self).__init__()
"""Extract Protein Features"""
self.edge_features = edge_features
self.node_features = node_features
self.top_k = top_k
self.augment_eps = augment_eps
self.num_rbf = num_rbf
self.num_positional_embeddings = num_positional_embeddings
## Feature types ##
self.features_type = features_type
self.feature_dimensions = {
'coarse': (3, num_positional_embeddings + num_rbf + 7),
'full': (6, num_positional_embeddings + num_rbf + 7),
'dist': (6, num_positional_embeddings + num_rbf),
'hbonds': (3, 2 * num_positional_embeddings)}
## Positional encoding ##
self.embeddings = PositionalEncodings(num_positional_embeddings)
self.dropout = nn.Dropout(dropout)
## Normalization and embedding ##
node_in, edge_in = self.feature_dimensions[features_type]
self.node_embedding = nn.Linear(node_in, node_features, bias=True)
self.edge_embedding = nn.Linear(edge_in, edge_features, bias=True)
self.norm_nodes = Normalize(node_features)
self.norm_edges = Normalize(edge_features)
def _dist(self, X, mask, eps=1E-6):
""" Pairwise Euclidean Distance """
mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps)
D_max, _ = torch.max(D, -1, keepdim=True)
D_adjust = D + (1. - mask_2D) * (D_max+1)
D_neighbors, E_idx = torch.topk(D_adjust, min(self.top_k, D_adjust.shape[-1]), dim=-1, largest=False)
mask_neighbors = gather_edges(mask_2D.unsqueeze(-1), E_idx)
return D_neighbors, E_idx, mask_neighbors
def _rbf(self, D):
""" Distance Radial Basis Function """
D_min, D_max, D_count = 0., 20., self.num_rbf
D_mu = torch.linspace(D_min, D_max, D_count, device=D.device)
D_mu = D_mu.view([1,1,1,-1])
D_sigma = (D_max - D_min) / D_count
D_expand = torch.unsqueeze(D, -1)
return torch.exp(-((D_expand - D_mu) / D_sigma)**2) # return RBF
def _quaternions(self, R):
""" Convert a batch of 3D rotations [R] to quaternions [Q] """
diag = torch.diagonal(R, dim1=-2, dim2=-1)
Rxx, Ryy, Rzz = diag.unbind(-1)
magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([
Rxx - Ryy - Rzz,
- Rxx + Ryy - Rzz,
- Rxx - Ryy + Rzz
], -1)))
_R = lambda i,j: R[:,:,:,i,j]
signs = torch.sign(torch.stack([
_R(2,1) - _R(1,2),
_R(0,2) - _R(2,0),
_R(1,0) - _R(0,1)
], -1))
xyz = signs * magnitudes
w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2.
Q = torch.cat((xyz, w), -1)
Q = F.normalize(Q, dim=-1)
return Q
def _contacts(self, D_neighbors, mask_neighbors, cutoff=8):
""" Contacts """
D_neighbors = D_neighbors.unsqueeze(-1)
return mask_neighbors * (D_neighbors < cutoff).type(torch.float32) # return neighbor_C
def _hbonds(self, X, E_idx, mask_neighbors, eps=1E-3):
""" Hydrogen bonds and contact map """
X_atoms = dict(zip(['N', 'CA', 'C', 'O'], torch.unbind(X, 2)))
# Virtual hydrogens
X_atoms['C_prev'] = F.pad(X_atoms['C'][:,1:,:], (0,0,0,1), 'constant', 0)
X_atoms['H'] = X_atoms['N'] + F.normalize(
F.normalize(X_atoms['N'] - X_atoms['C_prev'], -1)
+ F.normalize(X_atoms['N'] - X_atoms['CA'], -1)
, -1)
def _distance(X_a, X_b):
return torch.norm(X_a[:,None,:,:] - X_b[:,:,None,:], dim=-1)
def _inv_distance(X_a, X_b):
return 1. / (_distance(X_a, X_b) + eps)
U = (0.084 * 332) * (
_inv_distance(X_atoms['O'], X_atoms['N'])
+ _inv_distance(X_atoms['C'], X_atoms['H'])
- _inv_distance(X_atoms['O'], X_atoms['H'])
- _inv_distance(X_atoms['C'], X_atoms['N'])
)
HB = (U < -0.5).type(torch.float32)
neighbor_HB = mask_neighbors * gather_edges(HB.unsqueeze(-1), E_idx)
return neighbor_HB
def _orientations_coarse(self, X, E_idx, eps=1e-6):
# Pair features
# Shifted slices of unit vectors
dX = X[:,1:,:] - X[:,:-1,:]
U = F.normalize(dX, dim=-1)
u_2 = U[:,:-2,:]
u_1 = U[:,1:-1,:]
u_0 = U[:,2:,:]
# Backbone normals
n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1)
# Bond angle calculation
cosA = -(u_1 * u_0).sum(-1)
cosA = torch.clamp(cosA, -1+eps, 1-eps)
A = torch.acos(cosA)
# Angle between normals
cosD = (n_2 * n_1).sum(-1)
cosD = torch.clamp(cosD, -1+eps, 1-eps)
D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
# Backbone features
AD_features = torch.stack((torch.cos(A), torch.sin(A) * torch.cos(D), torch.sin(A) * torch.sin(D)), 2)
AD_features = F.pad(AD_features, (0,0,1,2), 'constant', 0)
# Build relative orientations
o_1 = F.normalize(u_2 - u_1, dim=-1)
O = torch.stack((o_1, n_2, torch.cross(o_1, n_2)), 2)
O = O.view(list(O.shape[:2]) + [9])
O = F.pad(O, (0,0,1,2), 'constant', 0)
O_neighbors = gather_nodes(O, E_idx)
X_neighbors = gather_nodes(X, E_idx)
# Re-view as rotation matrices
O = O.view(list(O.shape[:2]) + [3,3])
O_neighbors = O_neighbors.view(list(O_neighbors.shape[:3]) + [3,3])
# Rotate into local reference frames
dX = X_neighbors - X.unsqueeze(-2)
dU = torch.matmul(O.unsqueeze(2), dX.unsqueeze(-1)).squeeze(-1)
dU = F.normalize(dU, dim=-1)
R = torch.matmul(O.unsqueeze(2).transpose(-1,-2), O_neighbors)
Q = self._quaternions(R)
# Orientation features
O_features = torch.cat((dU,Q), dim=-1)
return AD_features, O_features
def _dihedrals(self, X, eps=1e-7):
# First 3 coordinates are N, CA, C
X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3)
# Shifted slices of unit vectors
dX = X[:,1:,:] - X[:,:-1,:]
U = F.normalize(dX, dim=-1)
u_2 = U[:,:-2,:]
u_1 = U[:,1:-1,:]
u_0 = U[:,2:,:]
# Backbone normals
n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1)
# Angle between normals
cosD = (n_2 * n_1).sum(-1)
cosD = torch.clamp(cosD, -1+eps, 1-eps)
D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
# This scheme will remove phi[0], psi[-1], omega[-1]
D = F.pad(D, (1,2), 'constant', 0)
D = D.view((D.size(0), int(D.size(1)/3), 3))
return torch.cat((torch.cos(D), torch.sin(D)), 2) # return D_features
def forward(self, X, L, mask):
""" Featurize coordinates as an attributed graph """
# Data augmentation
if self.training and self.augment_eps > 0:
X = X + self.augment_eps * torch.randn_like(X)
# Build k-Nearest Neighbors graph
X_ca = X[:,:,1,:] # [32, 483, 3]
D_neighbors, E_idx, mask_neighbors = self._dist(X_ca, mask) # [32, 483, 30], [32, 483, 30], [32, 483, 30, 1]
# Pairwise features
AD_features, O_features = self._orientations_coarse(X_ca, E_idx) # [32, 483, 3], [32, 483, 30, 7]
RBF = self._rbf(D_neighbors) # [32, 483, 30, 16]
# Pairwise embeddings
E_positional = self.embeddings(E_idx) # [32, 483, 30, 16]
if self.features_type == 'coarse':
# Coarse backbone features
V = AD_features
E = torch.cat((E_positional, RBF, O_features), -1)
elif self.features_type == 'hbonds':
# Hydrogen bonds and contacts
neighbor_HB = self._hbonds(X, E_idx, mask_neighbors)
neighbor_C = self._contacts(D_neighbors, E_idx, mask_neighbors)
# Dropout
neighbor_C = self.dropout(neighbor_C)
neighbor_HB = self.dropout(neighbor_HB)
# Pack
V = mask.unsqueeze(-1) * torch.ones_like(AD_features)
neighbor_C = neighbor_C.expand(-1,-1,-1, int(self.num_positional_embeddings / 2))
neighbor_HB = neighbor_HB.expand(-1,-1,-1, int(self.num_positional_embeddings / 2))
E = torch.cat((E_positional, neighbor_C, neighbor_HB), -1)
elif self.features_type == 'full':
# Full backbone angles
V = self._dihedrals(X) # [32, 483, 6]
E = torch.cat((E_positional, RBF, O_features), -1) # [32, 483, 30, 39]
elif self.features_type == 'dist':
# Full backbone angles
V = self._dihedrals(X)
E = torch.cat((E_positional, RBF), -1)
# Embed the nodes
V = self.node_embedding(V) # [32, 483, 6] --> [32, 483, 128]
V = self.norm_nodes(V) # [32, 483, 128] --> [32, 483, 128]
E = self.edge_embedding(E) # [32, 483, 30, 39] --> [32, 483, 30, 128]
E = self.norm_edges(E) # [32, 483, 30, 128] --> [32, 483, 30, 128]
return V, E, E_idx
def gather_edges(edges, neighbor_idx):
# Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C]
neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1))
return torch.gather(edges, 2, neighbors) # return edge_features
def gather_nodes(nodes, neighbor_idx):
# Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C]
# Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C]
neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1))
neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2)) # [32, 14460, 1]
# Gather and re-pack
neighbor_features = torch.gather(nodes, 1, neighbors_flat) # [32, 14460, 1]
neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1]) # [32, 482, 30, 1]
return neighbor_features
def gather_nodes_t(nodes, neighbor_idx):
# Features [B,N,C] at Neighbor index [B,K] => Neighbor features[B,K,C]
idx_flat = neighbor_idx.unsqueeze(-1).expand(-1, -1, nodes.size(2))
return torch.gather(nodes, 1, idx_flat) # return node features
def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx):
h_nodes = gather_nodes(h_nodes, E_idx)
return torch.cat([h_neighbors, h_nodes], -1)
class TransformerLayer(nn.Module):
def __init__(self, num_hidden, num_in, num_heads=4, dropout=0.1):
super(TransformerLayer, self).__init__()
self.num_heads = num_heads
self.num_hidden = num_hidden
self.num_in = num_in
self.dropout = nn.Dropout(dropout)
self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)])
self.attention = NeighborAttention(num_hidden, num_in, num_heads)
self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
def forward(self, h_V, h_E, mask_V=None, mask_attend=None): # h_V: [32, 482, 128], h_E: [32, 482, 30, 256], mask_V: [32, 482], mask_attend: [32, 482, 30]
""" Parallel computation of full transformer layer """
# Self-attention
dh = self.attention(h_V, h_E, mask_attend)
h_V = self.norm[0](h_V + self.dropout(dh))
# Position-wise feedforward
dh = self.dense(h_V)
h_V = self.norm[1](h_V + self.dropout(dh))
if mask_V is not None:
mask_V = mask_V.unsqueeze(-1)
h_V = mask_V * h_V
return h_V
def step(self, t, h_V, h_E, mask_V=None, mask_attend=None):
""" Sequential computation of step t of a transformer layer """
# Self-attention
h_V_t = h_V[:,t,:]
dh_t = self.attention.step(t, h_V, h_E, mask_attend)
h_V_t = self.norm[0](h_V_t + self.dropout(dh_t))
# Position-wise feedforward
dh_t = self.dense(h_V_t)
h_V_t = self.norm[1](h_V_t + self.dropout(dh_t))
if mask_V is not None:
mask_V_t = mask_V[:,t].unsqueeze(-1)
h_V_t = mask_V_t * h_V_t
return h_V_t
class MPNNLayer(nn.Module):
def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
super(MPNNLayer, self).__init__()
self.num_hidden = num_hidden
self.num_in = num_in
self.scale = scale
self.dropout = nn.Dropout(dropout)
self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)])
self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
def forward(self, h_V, h_E, mask_V=None, mask_attend=None):
""" Parallel computation of full transformer layer """
# Concatenate h_V_i to h_E_ij
h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_E.size(-2),-1)
h_EV = torch.cat([h_V_expand, h_E], -1)
h_message = self.W3(F.relu(self.W2(F.relu(self.W1(h_EV)))))
if mask_attend is not None:
h_message = mask_attend.unsqueeze(-1) * h_message
dh = torch.sum(h_message, -2) / self.scale
h_V = self.norm[0](h_V + self.dropout(dh))
# Position-wise feedforward
dh = self.dense(h_V)
h_V = self.norm[1](h_V + self.dropout(dh))
if mask_V is not None:
mask_V = mask_V.unsqueeze(-1)
h_V = mask_V * h_V
return h_V
class Normalize(nn.Module):
def __init__(self, features, epsilon=1e-6):
super(Normalize, self).__init__()
self.gain = nn.Parameter(torch.ones(features))
self.bias = nn.Parameter(torch.zeros(features))
self.epsilon = epsilon
def forward(self, x, dim=-1):
mu = x.mean(dim, keepdim=True)
sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon)
gain = self.gain
bias = self.bias
# Reshape
if dim != -1:
shape = [1] * len(mu.size())
shape[dim] = self.gain.size()[0]
gain = gain.view(shape)
bias = bias.view(shape)
return gain * (x - mu) / (sigma + self.epsilon) + bias
class PositionWiseFeedForward(nn.Module):
def __init__(self, num_hidden, num_ff):
super(PositionWiseFeedForward, self).__init__()
self.W_in = nn.Linear(num_hidden, num_ff, bias=True)
self.W_out = nn.Linear(num_ff, num_hidden, bias=True)
def forward(self, h_V):
h = F.relu(self.W_in(h_V))
h = self.W_out(h)
return h
class NeighborAttention(nn.Module):
def __init__(self, num_hidden, num_in, num_heads=4):
super(NeighborAttention, self).__init__()
self.num_heads = num_heads
self.num_hidden = num_hidden
# Self-attention layers: {queries, keys, values, output}
self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False)
self.W_K = nn.Linear(num_in, num_hidden, bias=False)
self.W_V = nn.Linear(num_in, num_hidden, bias=False)
self.W_O = nn.Linear(num_hidden, num_hidden, bias=False)
return
def _masked_softmax(self, attend_logits, mask_attend, dim=-1):
""" Numerically stable masked softmax """
negative_inf = np.finfo(np.float32).min
attend_logits = torch.where(mask_attend > 0, attend_logits, torch.tensor(negative_inf, device=attend_logits.device))
attend = F.softmax(attend_logits, dim)
attend = mask_attend * attend
return attend
def forward(self, h_V, h_E, mask_attend=None):
""" Self-attention, graph-structured O(Nk)
Args:
h_V: Node features [N_batch, N_nodes, N_hidden]
h_E: Neighbor features [N_batch, N_nodes, K, 3*N_hidden]
mask_attend: Mask for attention [N_batch, N_nodes, K]
Returns:
h_V: Node update
"""
# Queries, Keys, Values
n_batch, n_nodes, n_neighbors = h_E.shape[:3]
n_heads = self.num_heads
d = int(self.num_hidden / n_heads)
Q = self.W_Q(h_V).view([n_batch, n_nodes, 1, n_heads, 1, d])
K = self.W_K(h_E).view([n_batch, n_nodes, n_neighbors, n_heads, d, 1])
V = self.W_V(h_E).view([n_batch, n_nodes, n_neighbors, n_heads, d])
# Attention with scaled inner product
# n_neighbors这个维度提供attention权重,该权重可以视为邻居点和中心点做点积而得到
attend_logits = torch.matmul(Q, K).view([n_batch, n_nodes, n_neighbors, n_heads]).transpose(-2,-1)
attend_logits = attend_logits / np.sqrt(d) # [N_batch, N_nodes, n_heads, K]
if mask_attend is not None:
# Masked softmax
mask = mask_attend.unsqueeze(2).expand(-1,-1,n_heads,-1) # [N_batch, N_nodes, n_heads, K]
attend = self._masked_softmax(attend_logits, mask)
else:
attend = F.softmax(attend_logits, -1)
# Attentive reduction
h_V_update = torch.matmul(attend.unsqueeze(-2), V.transpose(2,3)) # [32, 482, 4, 1, 30], [32, 482, 4, 30, 32] --> [32, 482, 4, 1, 32] 相当于信息汇聚操作
h_V_update = h_V_update.view([n_batch, n_nodes, self.num_hidden])
h_V_update = self.W_O(h_V_update)
return h_V_update
def step(self, t, h_V, h_E, E_idx, mask_attend=None):
""" Self-attention for a specific time step t
Args:
h_V: Node features [N_batch, N_nodes, N_hidden]
h_E: Neighbor features [N_batch, N_nodes, K, N_in]
E_idx: Neighbor indices [N_batch, N_nodes, K]
mask_attend: Mask for attention [N_batch, N_nodes, K]
Returns:
h_V_t: Node update
"""
# Dimensions
n_batch, n_nodes, n_neighbors = h_E.shape[:3]
n_heads = self.num_heads
d = self.num_hidden / n_heads
# Per time-step tensors
h_V_t = h_V[:,t,:]
h_E_t = h_E[:,t,:,:]
E_idx_t = E_idx[:,t,:]
# Single time-step
h_V_neighbors_t = gather_nodes_t(h_V, E_idx_t)
E_t = torch.cat([h_E_t, h_V_neighbors_t], -1)
# Queries, Keys, Values
Q = self.W_Q(h_V_t).view([n_batch, 1, n_heads, 1, d])
K = self.W_K(E_t).view([n_batch, n_neighbors, n_heads, d, 1])
V = self.W_V(E_t).view([n_batch, n_neighbors, n_heads, d])
# Attention with scaled inner product
attend_logits = torch.matmul(Q, K).view([n_batch, n_neighbors, n_heads]).transpose(-2,-1)
attend_logits = attend_logits / np.sqrt(d)
if mask_attend is not None:
# Masked softmax
# [N_batch, K] -=> [N_batch, N_heads, K]
mask_t = mask_attend[:,t,:].unsqueeze(1).expand(-1,n_heads,-1)
attend = self._masked_softmax(attend_logits, mask_t)
else:
attend = F.softmax(attend_logits / np.sqrt(d), -1)
# Attentive reduction
h_V_t_update = torch.matmul(attend.unsqueeze(-2), V.transpose(1,2))
return h_V_t_update
class Struct2Seq(nn.Module):
def __init__(self, num_letters, node_features, edge_features,
hidden_dim, num_encoder_layers=3, num_decoder_layers=3,
vocab=33, k_neighbors=30, protein_features='full', augment_eps=0.,
dropout=0.1, forward_attention_decoder=True, use_mpnn=False):
""" Graph labeling network """
super(Struct2Seq, self).__init__()
# Hyperparameters
self.node_features = node_features
self.edge_features = edge_features
self.hidden_dim = hidden_dim
# Embedding layers
self.W_v = nn.Linear(node_features, hidden_dim, bias=True)
self.W_e = nn.Linear(edge_features, hidden_dim, bias=True)
self.W_s = nn.Embedding(vocab, hidden_dim)
layer = MPNNLayer if use_mpnn else TransformerLayer
# Encoder layers
self.encoder_layers = nn.ModuleList([
layer(hidden_dim, hidden_dim*2, dropout=dropout)
for _ in range(num_encoder_layers)
])
# Decoder layers
self.forward_attention_decoder = forward_attention_decoder
self.decoder_layers = nn.ModuleList([
layer(hidden_dim, hidden_dim*3, dropout=dropout)
for _ in range(num_decoder_layers)
])
self.W_out = nn.Linear(hidden_dim, num_letters, bias=True)
# Initialization
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def _autoregressive_mask(self, E_idx):
N_nodes = E_idx.size(1)
ii = torch.arange(N_nodes, device=E_idx.device)
ii = ii.view((1, -1, 1))
mask = E_idx < ii
mask = mask.type(torch.float32)
return mask
def forward_sequential(self, X, S, L, mask=None):
""" Compute the transformer layer sequentially, for purposes of debugging """
if self.args.augment_eps>0:
X = X + self.args.augment_eps * torch.randn_like(X)
# Prepare node and edge embeddings
V, E, E_idx = self.features(X, L, mask)
h_V = self.W_v(V)
h_E = self.W_e(E)
# Encoder is unmasked self-attention
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
mask_attend = mask.unsqueeze(-1) * mask_attend
for layer in self.encoder_layers:
h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
h_V = layer(h_V, h_EV, mask_V=mask, mask_attend=mask_attend)
# Decoder alternates masked self-attention
mask_attend = self._autoregressive_mask(E_idx).unsqueeze(-1)
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1. - mask_attend)
N_batch, N_nodes = X.size(0), X.size(1)
log_probs = torch.zeros((N_batch, N_nodes, 20))
h_S = torch.zeros_like(h_V)
h_V_stack = [h_V] + [torch.zeros_like(h_V) for _ in range(len(self.decoder_layers))]
for t in range(N_nodes):
# Hidden layers
E_idx_t = E_idx[:,t:t+1,:]
h_E_t = h_E[:,t:t+1,:,:]
h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
# Stale relational features for future states
h_ESV_encoder_t = mask_fw[:,t:t+1,:,:] * cat_neighbors_nodes(h_V, h_ES_t, E_idx_t)
for l, layer in enumerate(self.decoder_layers):
# Updated relational features for future states
h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t)
h_V_t = h_V_stack[l][:,t:t+1,:]
h_ESV_t = mask_bw[:,t:t+1,:,:] * h_ESV_decoder_t + h_ESV_encoder_t
h_V_stack[l+1][:,t,:] = layer(
h_V_t, h_ESV_t, mask_V=mask[:,t:t+1]
).squeeze(1)
# Sampling step
h_V_t = h_V_stack[-1][:,t,:]
logits = self.W_out(h_V_t)
log_probs[:,t,:] = F.log_softmax(logits, dim=-1)
# Update
h_S[:,t,:] = self.W_s(S[:,t])
return log_probs