jiaxianustc's picture
history blame
5.07 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch.conv import GINConv
from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling
class ApplyNodeFunc(nn.Module):
"""Update the node feature hv with MLP, BN and ReLU."""
def __init__(self, mlp):
super(ApplyNodeFunc, self).__init__()
self.mlp = mlp
self.bn = nn.BatchNorm1d(self.mlp.output_dim)
def forward(self, h):
h = self.mlp(h)
h = self.bn(h)
h = F.relu(h)
return h
class MLP(nn.Module):
"""MLP with linear output"""
def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
"""MLP layers construction
num_layers: int
The number of linear layers
input_dim: int
The dimensionality of input features
hidden_dim: int
The dimensionality of hidden units at ALL layers
output_dim: int
The number of classes for prediction
super(MLP, self).__init__()
self.linear_or_not = True # default is linear model
self.num_layers = num_layers
self.output_dim = output_dim
if num_layers < 1:
raise ValueError("number of layers should be positive!")
elif num_layers == 1:
# Linear model
self.linear = nn.Linear(input_dim, output_dim)
# Multi-layer model
self.linear_or_not = False
self.linears = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
self.linears.append(nn.Linear(input_dim, hidden_dim))
for layer in range(num_layers - 2):
self.linears.append(nn.Linear(hidden_dim, hidden_dim))
self.linears.append(nn.Linear(hidden_dim, output_dim))
for layer in range(num_layers - 1):
def forward(self, x):
if self.linear_or_not:
# If linear model
return self.linear(x)
# If MLP
h = x
for i in range(self.num_layers - 1):
h = F.relu(self.batch_norms[i](self.linears[i](h)))
return self.linears[-1](h)
class GIN(nn.Module):
"""GIN model"""
def __init__(self, input_dim, hidden_dim,num_layers, num_mlp_layers=2,
dropout=0.1, learn_eps=False, neighbor_pooling_type='sum',JK='sum'):
"""model parameters setting
num_layers: int
The number of linear layers in the neural network
num_mlp_layers: int
The number of linear layers in mlps
input_dim: int
The dimensionality of input features
hidden_dim: int
The dimensionality of hidden units at ALL layers
dropout: float
dropout ratio on the final linear layer
learn_eps: boolean
If True, learn epsilon to distinguish center nodes from neighbors
If False, aggregate neighbors and center nodes altogether.
neighbor_pooling_type: str
how to aggregate neighbors (sum, mean, or max)
super(GIN, self).__init__()
self.num_layers = num_layers
self.learn_eps = learn_eps
# List of MLPs
self.ginlayers = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(self.num_layers - 1):
if layer == 0:
mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)
mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim)
GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps))
# Linear function for graph poolings of output of each layer
# which maps the output of different layers into a prediction score
self.drop = nn.Dropout(dropout)
self.JK = JK
def forward(self, g, Perturb=None):
# list of hidden representation at each layer (including input)
h = g.ndata.pop('h').float()
hidden_rep = []
for i in range(self.num_layers - 1):
if i == 0 and Perturb is not None:
h = h + Perturb
h = self.ginlayers[i](g, h)
h = self.batch_norms[i](h)
h = F.relu(h)
h = self.drop(h)
if self.JK=='sum':
hidden_rep = [h.unsqueeze(0) for h in hidden_rep]
return torch.sum(torch.cat(hidden_rep, dim=0), dim=0)
elif self.JK=='max':
hidden_rep = [h.unsqueeze(0) for h in hidden_rep]
return torch.max(torch.cat(hidden_rep, dim = 0), dim = 0)[0]
elif self.JK=='concat':
return torch.cat(hidden_rep, dim = 1)
elif self.JK=='last':
return hidden_rep[-1]