Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from argparse import Namespace | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .features import GVPGraphEmbedding | |
from .gvp_modules import GVPConvLayer, LayerNorm | |
from .gvp_utils import unflatten_graph | |
class GVPEncoder(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
self.embed_graph = GVPGraphEmbedding(args) | |
node_hidden_dim = (args.node_hidden_dim_scalar, | |
args.node_hidden_dim_vector) | |
edge_hidden_dim = (args.edge_hidden_dim_scalar, | |
args.edge_hidden_dim_vector) | |
conv_activations = (F.relu, torch.sigmoid) | |
self.encoder_layers = nn.ModuleList( | |
GVPConvLayer( | |
node_hidden_dim, | |
edge_hidden_dim, | |
drop_rate=args.dropout, | |
vector_gate=True, | |
attention_heads=0, | |
n_message=3, | |
conv_activations=conv_activations, | |
n_edge_gvps=0, | |
eps=1e-4, | |
layernorm=True, | |
) | |
for i in range(args.num_encoder_layers) | |
) | |
def forward(self, coords, coord_mask, padding_mask, confidence): | |
node_embeddings, edge_embeddings, edge_index = self.embed_graph( | |
coords, coord_mask, padding_mask, confidence) | |
for i, layer in enumerate(self.encoder_layers): | |
node_embeddings, edge_embeddings = layer(node_embeddings, | |
edge_index, edge_embeddings) | |
node_embeddings = unflatten_graph(node_embeddings, coords.shape[0]) | |
return node_embeddings | |