Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from dgl.nn.pytorch.conv import GINConv | |
from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling | |
class ApplyNodeFunc(nn.Module): | |
"""Update the node feature hv with MLP, BN and ReLU.""" | |
def __init__(self, mlp): | |
super(ApplyNodeFunc, self).__init__() | |
self.mlp = mlp | |
self.bn = nn.BatchNorm1d(self.mlp.output_dim) | |
def forward(self, h): | |
h = self.mlp(h) | |
h = self.bn(h) | |
h = F.relu(h) | |
return h | |
class MLP(nn.Module): | |
"""MLP with linear output""" | |
def __init__(self, num_layers, input_dim, hidden_dim, output_dim): | |
"""MLP layers construction | |
Paramters | |
--------- | |
num_layers: int | |
The number of linear layers | |
input_dim: int | |
The dimensionality of input features | |
hidden_dim: int | |
The dimensionality of hidden units at ALL layers | |
output_dim: int | |
The number of classes for prediction | |
""" | |
super(MLP, self).__init__() | |
self.linear_or_not = True # default is linear model | |
self.num_layers = num_layers | |
self.output_dim = output_dim | |
if num_layers < 1: | |
raise ValueError("number of layers should be positive!") | |
elif num_layers == 1: | |
# Linear model | |
self.linear = nn.Linear(input_dim, output_dim) | |
else: | |
# Multi-layer model | |
self.linear_or_not = False | |
self.linears = torch.nn.ModuleList() | |
self.batch_norms = torch.nn.ModuleList() | |
self.linears.append(nn.Linear(input_dim, hidden_dim)) | |
for layer in range(num_layers - 2): | |
self.linears.append(nn.Linear(hidden_dim, hidden_dim)) | |
self.linears.append(nn.Linear(hidden_dim, output_dim)) | |
for layer in range(num_layers - 1): | |
self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) | |
def forward(self, x): | |
if self.linear_or_not: | |
# If linear model | |
return self.linear(x) | |
else: | |
# If MLP | |
h = x | |
for i in range(self.num_layers - 1): | |
h = F.relu(self.batch_norms[i](self.linears[i](h))) | |
return self.linears[-1](h) | |
class GIN(nn.Module): | |
"""GIN model""" | |
def __init__(self, input_dim, hidden_dim,num_layers, num_mlp_layers=2, | |
dropout=0.1, learn_eps=False, neighbor_pooling_type='sum',JK='sum'): | |
"""model parameters setting | |
Paramters | |
--------- | |
num_layers: int | |
The number of linear layers in the neural network | |
num_mlp_layers: int | |
The number of linear layers in mlps | |
input_dim: int | |
The dimensionality of input features | |
hidden_dim: int | |
The dimensionality of hidden units at ALL layers | |
dropout: float | |
dropout ratio on the final linear layer | |
learn_eps: boolean | |
If True, learn epsilon to distinguish center nodes from neighbors | |
If False, aggregate neighbors and center nodes altogether. | |
neighbor_pooling_type: str | |
how to aggregate neighbors (sum, mean, or max) | |
""" | |
super(GIN, self).__init__() | |
self.num_layers = num_layers | |
self.learn_eps = learn_eps | |
# List of MLPs | |
self.ginlayers = torch.nn.ModuleList() | |
self.batch_norms = torch.nn.ModuleList() | |
for layer in range(self.num_layers - 1): | |
if layer == 0: | |
mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim) | |
else: | |
mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim) | |
self.ginlayers.append( | |
GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps)) | |
self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) | |
# Linear function for graph poolings of output of each layer | |
# which maps the output of different layers into a prediction score | |
self.drop = nn.Dropout(dropout) | |
self.JK = JK | |
def forward(self, g, Perturb=None): | |
# list of hidden representation at each layer (including input) | |
h = g.ndata.pop('h').float() | |
hidden_rep = [] | |
for i in range(self.num_layers - 1): | |
if i == 0 and Perturb is not None: | |
h = h + Perturb | |
h = self.ginlayers[i](g, h) | |
h = self.batch_norms[i](h) | |
h = F.relu(h) | |
h = self.drop(h) | |
hidden_rep.append(h) | |
if self.JK=='sum': | |
hidden_rep = [h.unsqueeze(0) for h in hidden_rep] | |
return torch.sum(torch.cat(hidden_rep, dim=0), dim=0) | |
elif self.JK=='max': | |
hidden_rep = [h.unsqueeze(0) for h in hidden_rep] | |
return torch.max(torch.cat(hidden_rep, dim = 0), dim = 0)[0] | |
elif self.JK=='concat': | |
return torch.cat(hidden_rep, dim = 1) | |
elif self.JK=='last': | |
return hidden_rep[-1] |