File size: 3,264 Bytes
6788772 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
import torch.nn as nn
from torch_geometric.nn import GATv2Conv, GINConv
# MLP with leaky relu activation and skip connection
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, num_layer):
super().__init__()
self.layers = nn.ModuleList( [nn.Linear(in_dim, hidden_dim)] + [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layer-1)] + [nn.Linear(hidden_dim, out_dim)] )
self.activation = nn.LeakyReLU(negative_slope=0.05)
def forward(self, x):
for idx, layer in enumerate(self.layers):
if (idx != 0) and (idx != len(self.layers) - 1):
x0 = x
x = layer(x)
x = x0 + self.activation(x)
elif idx == 0:
x = self.activation(layer(x))
elif idx == len(self.layers) - 1:
x = layer(x)
return x
class MLPBiasFree(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, num_layer):
super().__init__()
self.layers = nn.ModuleList( [nn.Linear(in_dim, hidden_dim, bias=False)]
+ [nn.Linear(hidden_dim, hidden_dim, bias=False) for _ in range(num_layer-2)]
+ [nn.Linear(hidden_dim, out_dim, bias=False)] )
self.layernorms = nn.ModuleList( [nn.LayerNorm(hidden_dim, elementwise_affine=False) for _ in range(num_layer-1)] )
self.activation = nn.ReLU() # nn.Tanh()
def forward(self, x):
for idx, layer in enumerate(self.layers):
if (idx != 0) and (idx != len(self.layers) - 1):
x0 = x
x = layer(x)
x = x0 + self.activation(x)
x = self.layernorms[idx](x)
elif idx == 0:
x = layer(x)
x = self.activation(x)
x = self.layernorms[idx](x)
elif idx == len(self.layers) - 1:
x = layer(x)
return x
class GNN(nn.Module):
# if gnn_model=='gat', hidden_dim needs to be divisible by gat_attn_head(=8)
def __init__(self, gnn_model, num_layer, node_dim, hidden_dim, out_dim):
super().__init__()
self.x_linear = nn.Linear(node_dim, hidden_dim)
self.x_linear_out = nn.Linear(hidden_dim, out_dim)
if gnn_model == 'GAT':
gat_attn_head = 8
self.gnnconv_list = nn.ModuleList( [GATv2Conv(in_channels=hidden_dim, out_channels=hidden_dim//gat_attn_head, heads=gat_attn_head)
for _ in range(num_layer)] )
elif gnn_model == 'GIN':
mlp_num_layer = 2
self.gnnconv_list = nn.ModuleList( [GINConv(nn.Sequential(MLP(hidden_dim, out_dim, hidden_dim, mlp_num_layer)))
for _ in range(num_layer)] )
self.relu = nn.ReLU()
def forward(self, x, edge_index):
x = self.x_linear(x)
x_sum = x
for gnnconv in self.gnnconv_list:
x = self.relu(x)
x = gnnconv(x=x, edge_index=edge_index)
x_sum += x
x = x_sum / (len(self.gnnconv_list) + 1)
x = self.x_linear_out(x)
return x
|