EdwardoSunny's picture
finished
85ab89d
raw
history blame
1.88 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from argparse import Namespace
import torch
import torch.nn as nn
import torch.nn.functional as F
from .features import GVPGraphEmbedding
from .gvp_modules import GVPConvLayer, LayerNorm
from .gvp_utils import unflatten_graph
class GVPEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.embed_graph = GVPGraphEmbedding(args)
node_hidden_dim = (args.node_hidden_dim_scalar,
args.node_hidden_dim_vector)
edge_hidden_dim = (args.edge_hidden_dim_scalar,
args.edge_hidden_dim_vector)
conv_activations = (F.relu, torch.sigmoid)
self.encoder_layers = nn.ModuleList(
GVPConvLayer(
node_hidden_dim,
edge_hidden_dim,
drop_rate=args.dropout,
vector_gate=True,
attention_heads=0,
n_message=3,
conv_activations=conv_activations,
n_edge_gvps=0,
eps=1e-4,
layernorm=True,
)
for i in range(args.num_encoder_layers)
)
def forward(self, coords, coord_mask, padding_mask, confidence):
node_embeddings, edge_embeddings, edge_index = self.embed_graph(
coords, coord_mask, padding_mask, confidence)
for i, layer in enumerate(self.encoder_layers):
node_embeddings, edge_embeddings = layer(node_embeddings,
edge_index, edge_embeddings)
node_embeddings = unflatten_graph(node_embeddings, coords.shape[0])
return node_embeddings