libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
1.99 kB
from torch import cat, nn
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GINConv, global_add_pool
class GIN(nn.Module):
r"""
From `GraphDTA <https://doi.org/10.1093/bioinformatics/btaa921>`_ (Nguyen et al., 2020),
based on `Graph Isomorphism Network <https://arxiv.org/abs/1810.00826>`_ (Xu et al., 2019)
"""
def __init__(
self,
num_features: int,
out_channels: int,
dropout: float
):
super().__init__()
dim = 32
self.dropout = dropout
self.relu = nn.ReLU()
nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
self.conv1 = GINConv(nn1)
self.bn1 = nn.BatchNorm1d(dim)
nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv2 = GINConv(nn2)
self.bn2 = nn.BatchNorm1d(dim)
nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv3 = GINConv(nn3)
self.bn3 = nn.BatchNorm1d(dim)
nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv4 = GINConv(nn4)
self.bn4 = nn.BatchNorm1d(dim)
nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv5 = GINConv(nn5)
self.bn5 = nn.BatchNorm1d(dim)
self.fc1_xd = Linear(dim, out_channels)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
x = self.bn1(x)
x = F.relu(self.conv2(x, edge_index))
x = self.bn2(x)
x = F.relu(self.conv3(x, edge_index))
x = self.bn3(x)
x = F.relu(self.conv4(x, edge_index))
x = self.bn4(x)
x = F.relu(self.conv5(x, edge_index))
x = self.bn5(x)
x = global_add_pool(x, batch)
x = F.relu(self.fc1_xd(x))
x = F.dropout(x, p=self.dropout, training=self.training)
return x