|
|
|
|
|
|
|
|
|
|
|
from argparse import Namespace |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .features import GVPGraphEmbedding |
|
from .gvp_modules import GVPConvLayer, LayerNorm |
|
from .gvp_utils import unflatten_graph |
|
|
|
|
|
|
|
class GVPEncoder(nn.Module): |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
self.embed_graph = GVPGraphEmbedding(args) |
|
|
|
node_hidden_dim = (args.node_hidden_dim_scalar, |
|
args.node_hidden_dim_vector) |
|
edge_hidden_dim = (args.edge_hidden_dim_scalar, |
|
args.edge_hidden_dim_vector) |
|
|
|
conv_activations = (F.relu, torch.sigmoid) |
|
self.encoder_layers = nn.ModuleList( |
|
GVPConvLayer( |
|
node_hidden_dim, |
|
edge_hidden_dim, |
|
drop_rate=args.dropout, |
|
vector_gate=True, |
|
attention_heads=0, |
|
n_message=3, |
|
conv_activations=conv_activations, |
|
n_edge_gvps=0, |
|
eps=1e-4, |
|
layernorm=True, |
|
) |
|
for i in range(args.num_encoder_layers) |
|
) |
|
|
|
def forward(self, coords, coord_mask, padding_mask, confidence): |
|
node_embeddings, edge_embeddings, edge_index = self.embed_graph( |
|
coords, coord_mask, padding_mask, confidence) |
|
|
|
for i, layer in enumerate(self.encoder_layers): |
|
node_embeddings, edge_embeddings = layer(node_embeddings, |
|
edge_index, edge_embeddings) |
|
|
|
node_embeddings = unflatten_graph(node_embeddings, coords.shape[0]) |
|
return node_embeddings |
|
|