2dogey's picture
Upload folder using huggingface_hub
8918ac7 verified
import torch
import torch.nn as nn
from .layer import GVP, GVPConvLayer, LayerNorm
from torch_scatter import scatter_mean
class AttentionPooling(nn.Module):
def __init__(self, input_dim, attention_dim):
super(AttentionPooling, self).__init__()
self.attention_dim = attention_dim
self.query_layer = nn.Linear(input_dim, attention_dim, bias=True)
self.key_layer = nn.Linear(input_dim, attention_dim, bias=True)
self.value_layer = nn.Linear(input_dim, 1, bias=True) # value layer outputs one score
self.softmax = nn.Softmax(dim=1)
def forward(self, nodes_features1, nodes_features2):
# Assuming nodes_features1 and nodes_features2 are both of shape [node_num, 128]
nodes_features = nodes_features1 + nodes_features2 # This can also be concatenation or another operation
query = self.query_layer(nodes_features)
key = self.key_layer(nodes_features)
value = self.value_layer(nodes_features)
attention_scores = torch.matmul(query, key.transpose(-2, -1)) # [node_num, node_num]
attention_scores = self.softmax(attention_scores)
pooled_features = torch.matmul(attention_scores, value) # [node_num, 1]
return pooled_features
class AutoGraphEncoder(nn.Module):
def __init__(self, node_in_dim, node_h_dim,
edge_in_dim, edge_h_dim, attention_dim=64,
num_layers=4, drop_rate=0.1) -> None:
super().__init__()
self.W_v = nn.Sequential(
LayerNorm(node_in_dim),
GVP(node_in_dim, node_h_dim, activations=(None, None))
)
self.W_e = nn.Sequential(
LayerNorm(edge_in_dim),
GVP(edge_in_dim, edge_h_dim, activations=(None, None))
)
self.layers = nn.ModuleList(
GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate)
for _ in range(num_layers))
ns, _ = node_h_dim
self.W_out = nn.Sequential(
LayerNorm(node_h_dim),
GVP(node_h_dim, (ns, 0)))
self.dense = nn.Sequential(
nn.Linear(ns, 2*ns),
nn.ReLU(inplace=True),
nn.Dropout(p=drop_rate),
nn.Linear(2*ns, node_in_dim[0]) # label num
)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, h_V, edge_index, h_E, node_s_labels):
h_V = self.W_v(h_V)
h_E = self.W_e(h_E)
for layer in self.layers:
h_V = layer(h_V, edge_index, h_E)
out = self.W_out(h_V)
logits = self.dense(out)
loss = self.loss_fn(logits, node_s_labels)
return loss, logits
def get_embedding(self, h_V, edge_index, h_E):
h_V = self.W_v(h_V)
h_E = self.W_e(h_E)
for layer in self.layers:
h_V = layer(h_V, edge_index, h_E)
out = self.W_out(h_V)
return out
class SubgraphClassficationModel(nn.Module):
'''
:param node_in_dim: node dimensions in input graph, should be
(6, 3) if using original features
:param node_h_dim: node dimensions to use in GVP-GNN layers
:param edge_in_dim: edge dimensions in input graph, should be
(32, 1) if using original features
:param edge_h_dim: edge dimensions to embed to before use
in GVP-GNN layers
:param num_layers: number of GVP-GNN layers
:param drop_rate: rate to use in all dropout layers
'''
def __init__(self, node_in_dim, node_h_dim,
edge_in_dim, edge_h_dim, attention_dim=64,
num_layers=4, drop_rate=0.1):
super(SubgraphClassficationModel, self).__init__()
self.W_v = nn.Sequential(
LayerNorm(node_in_dim),
GVP(node_in_dim, node_h_dim, activations=(None, None))
)
self.W_e = nn.Sequential(
LayerNorm(edge_in_dim),
GVP(edge_in_dim, edge_h_dim, activations=(None, None))
)
self.layers = nn.ModuleList(
GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate)
for _ in range(num_layers))
ns, _ = node_h_dim
self.W_out = nn.Sequential(
LayerNorm(node_h_dim),
GVP(node_h_dim, (ns, 0)))
self.attention_classifier = AttentionPooling(ns, attention_dim)
self.dense = nn.Sequential(
nn.Linear(2*ns, 2*ns),
nn.ReLU(inplace=True),
nn.Dropout(p=drop_rate),
nn.Linear(2*ns, 1)
)
self.loss_fn = nn.BCEWithLogitsLoss()
def forward(self, h_V_parent, edge_index_parent, h_E_parent, batch_parent,
h_V_subgraph, edge_index_subgraph, h_E_subgraph, batch_subgraph,
labels):
'''
:param h_V: tuple (s, V) of node embeddings
:param edge_index: `torch.Tensor` of shape [2, num_edges]
:param h_E: tuple (s, V) of edge embeddings
'''
h_V_parent = self.W_v(h_V_parent)
h_E_parent = self.W_e(h_E_parent)
for layer in self.layers:
h_V_parent = layer(h_V_parent, edge_index_parent, h_E_parent)
out_parent = self.W_out(h_V_parent)
out_parent = scatter_mean(out_parent, batch_parent, dim=0)
h_V_subgraph = self.W_v(h_V_subgraph)
h_E_subgraph = self.W_e(h_E_subgraph)
for layer in self.layers:
h_V_subgraph = layer(h_V_subgraph, edge_index_subgraph, h_E_subgraph)
out_subgraph = self.W_out(h_V_subgraph)
out_subgraph = scatter_mean(out_subgraph, batch_subgraph, dim=0)
labels = labels.float()
out = torch.cat([out_parent, out_subgraph], dim=1)
logits = self.dense(out)
# logits = self.attention_classifier(out_parent, out_subgraph)
loss = self.loss_fn(logits.squeeze(-1), labels)
return loss, logits
def get_embedding(self, h_V, edge_index, h_E, batch):
h_V = self.W_v(h_V)
h_E = self.W_e(h_E)
for layer in self.layers:
h_V = layer(h_V, edge_index, h_E)
out = self.W_out(h_V)
out = scatter_mean(out, batch, dim=0)
return out