|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
def flatten_graph(node_embeddings, edge_embeddings, edge_index): |
|
""" |
|
Flattens the graph into a batch size one (with disconnected subgraphs for |
|
each example) to be compatible with pytorch-geometric package. |
|
Args: |
|
node_embeddings: node embeddings in tuple form (scalar, vector) |
|
- scalar: shape batch size x nodes x node_embed_dim |
|
- vector: shape batch size x nodes x node_embed_dim x 3 |
|
edge_embeddings: edge embeddings of in tuple form (scalar, vector) |
|
- scalar: shape batch size x edges x edge_embed_dim |
|
- vector: shape batch size x edges x edge_embed_dim x 3 |
|
edge_index: shape batch_size x 2 (source node and target node) x edges |
|
Returns: |
|
node_embeddings: node embeddings in tuple form (scalar, vector) |
|
- scalar: shape batch total_nodes x node_embed_dim |
|
- vector: shape batch total_nodes x node_embed_dim x 3 |
|
edge_embeddings: edge embeddings of in tuple form (scalar, vector) |
|
- scalar: shape batch total_edges x edge_embed_dim |
|
- vector: shape batch total_edges x edge_embed_dim x 3 |
|
edge_index: shape 2 x total_edges |
|
""" |
|
x_s, x_v = node_embeddings |
|
e_s, e_v = edge_embeddings |
|
batch_size, N = x_s.shape[0], x_s.shape[1] |
|
node_embeddings = (torch.flatten(x_s, 0, 1), torch.flatten(x_v, 0, 1)) |
|
edge_embeddings = (torch.flatten(e_s, 0, 1), torch.flatten(e_v, 0, 1)) |
|
|
|
edge_mask = torch.any(edge_index != -1, dim=1) |
|
|
|
edge_index = edge_index + (torch.arange(batch_size, device=edge_index.device) * |
|
N).unsqueeze(-1).unsqueeze(-1) |
|
edge_index = edge_index.permute(1, 0, 2).flatten(1, 2) |
|
edge_mask = edge_mask.flatten() |
|
edge_index = edge_index[:, edge_mask] |
|
edge_embeddings = ( |
|
edge_embeddings[0][edge_mask, :], |
|
edge_embeddings[1][edge_mask, :] |
|
) |
|
return node_embeddings, edge_embeddings, edge_index |
|
|
|
|
|
def unflatten_graph(node_embeddings, batch_size): |
|
""" |
|
Unflattens node embeddings. |
|
Args: |
|
node_embeddings: node embeddings in tuple form (scalar, vector) |
|
- scalar: shape batch total_nodes x node_embed_dim |
|
- vector: shape batch total_nodes x node_embed_dim x 3 |
|
batch_size: int |
|
Returns: |
|
node_embeddings: node embeddings in tuple form (scalar, vector) |
|
- scalar: shape batch size x nodes x node_embed_dim |
|
- vector: shape batch size x nodes x node_embed_dim x 3 |
|
""" |
|
x_s, x_v = node_embeddings |
|
x_s = x_s.reshape(batch_size, -1, x_s.shape[1]) |
|
x_v = x_v.reshape(batch_size, -1, x_v.shape[1], x_v.shape[2]) |
|
return (x_s, x_v) |
|
|
|
|
|
|