Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
def flatten_graph(node_embeddings, edge_embeddings, edge_index): | |
""" | |
Flattens the graph into a batch size one (with disconnected subgraphs for | |
each example) to be compatible with pytorch-geometric package. | |
Args: | |
node_embeddings: node embeddings in tuple form (scalar, vector) | |
- scalar: shape batch size x nodes x node_embed_dim | |
- vector: shape batch size x nodes x node_embed_dim x 3 | |
edge_embeddings: edge embeddings of in tuple form (scalar, vector) | |
- scalar: shape batch size x edges x edge_embed_dim | |
- vector: shape batch size x edges x edge_embed_dim x 3 | |
edge_index: shape batch_size x 2 (source node and target node) x edges | |
Returns: | |
node_embeddings: node embeddings in tuple form (scalar, vector) | |
- scalar: shape batch total_nodes x node_embed_dim | |
- vector: shape batch total_nodes x node_embed_dim x 3 | |
edge_embeddings: edge embeddings of in tuple form (scalar, vector) | |
- scalar: shape batch total_edges x edge_embed_dim | |
- vector: shape batch total_edges x edge_embed_dim x 3 | |
edge_index: shape 2 x total_edges | |
""" | |
x_s, x_v = node_embeddings | |
e_s, e_v = edge_embeddings | |
batch_size, N = x_s.shape[0], x_s.shape[1] | |
node_embeddings = (torch.flatten(x_s, 0, 1), torch.flatten(x_v, 0, 1)) | |
edge_embeddings = (torch.flatten(e_s, 0, 1), torch.flatten(e_v, 0, 1)) | |
edge_mask = torch.any(edge_index != -1, dim=1) | |
# Re-number the nodes by adding batch_idx * N to each batch | |
edge_index = edge_index + (torch.arange(batch_size, device=edge_index.device) * | |
N).unsqueeze(-1).unsqueeze(-1) | |
edge_index = edge_index.permute(1, 0, 2).flatten(1, 2) | |
edge_mask = edge_mask.flatten() | |
edge_index = edge_index[:, edge_mask] | |
edge_embeddings = ( | |
edge_embeddings[0][edge_mask, :], | |
edge_embeddings[1][edge_mask, :] | |
) | |
return node_embeddings, edge_embeddings, edge_index | |
def unflatten_graph(node_embeddings, batch_size): | |
""" | |
Unflattens node embeddings. | |
Args: | |
node_embeddings: node embeddings in tuple form (scalar, vector) | |
- scalar: shape batch total_nodes x node_embed_dim | |
- vector: shape batch total_nodes x node_embed_dim x 3 | |
batch_size: int | |
Returns: | |
node_embeddings: node embeddings in tuple form (scalar, vector) | |
- scalar: shape batch size x nodes x node_embed_dim | |
- vector: shape batch size x nodes x node_embed_dim x 3 | |
""" | |
x_s, x_v = node_embeddings | |
x_s = x_s.reshape(batch_size, -1, x_s.shape[1]) | |
x_v = x_v.reshape(batch_size, -1, x_v.shape[1], x_v.shape[2]) | |
return (x_s, x_v) | |