EdwardoSunny's picture
finished
85ab89d
raw
history blame
2.97 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def flatten_graph(node_embeddings, edge_embeddings, edge_index):
"""
Flattens the graph into a batch size one (with disconnected subgraphs for
each example) to be compatible with pytorch-geometric package.
Args:
node_embeddings: node embeddings in tuple form (scalar, vector)
- scalar: shape batch size x nodes x node_embed_dim
- vector: shape batch size x nodes x node_embed_dim x 3
edge_embeddings: edge embeddings of in tuple form (scalar, vector)
- scalar: shape batch size x edges x edge_embed_dim
- vector: shape batch size x edges x edge_embed_dim x 3
edge_index: shape batch_size x 2 (source node and target node) x edges
Returns:
node_embeddings: node embeddings in tuple form (scalar, vector)
- scalar: shape batch total_nodes x node_embed_dim
- vector: shape batch total_nodes x node_embed_dim x 3
edge_embeddings: edge embeddings of in tuple form (scalar, vector)
- scalar: shape batch total_edges x edge_embed_dim
- vector: shape batch total_edges x edge_embed_dim x 3
edge_index: shape 2 x total_edges
"""
x_s, x_v = node_embeddings
e_s, e_v = edge_embeddings
batch_size, N = x_s.shape[0], x_s.shape[1]
node_embeddings = (torch.flatten(x_s, 0, 1), torch.flatten(x_v, 0, 1))
edge_embeddings = (torch.flatten(e_s, 0, 1), torch.flatten(e_v, 0, 1))
edge_mask = torch.any(edge_index != -1, dim=1)
# Re-number the nodes by adding batch_idx * N to each batch
edge_index = edge_index + (torch.arange(batch_size, device=edge_index.device) *
N).unsqueeze(-1).unsqueeze(-1)
edge_index = edge_index.permute(1, 0, 2).flatten(1, 2)
edge_mask = edge_mask.flatten()
edge_index = edge_index[:, edge_mask]
edge_embeddings = (
edge_embeddings[0][edge_mask, :],
edge_embeddings[1][edge_mask, :]
)
return node_embeddings, edge_embeddings, edge_index
def unflatten_graph(node_embeddings, batch_size):
"""
Unflattens node embeddings.
Args:
node_embeddings: node embeddings in tuple form (scalar, vector)
- scalar: shape batch total_nodes x node_embed_dim
- vector: shape batch total_nodes x node_embed_dim x 3
batch_size: int
Returns:
node_embeddings: node embeddings in tuple form (scalar, vector)
- scalar: shape batch size x nodes x node_embed_dim
- vector: shape batch size x nodes x node_embed_dim x 3
"""
x_s, x_v = node_embeddings
x_s = x_s.reshape(batch_size, -1, x_s.shape[1])
x_v = x_v.reshape(batch_size, -1, x_v.shape[1], x_v.shape[2])
return (x_s, x_v)