PLA-Net / gcn_lib /sparse /torch_vertex.py
juliocesar-io's picture
Added initial app
799e642
raw
history blame
12.5 kB
import torch
from torch import nn
import torch.nn.functional as F
import torch_geometric as tg
from .torch_nn import MLP, act_layer, norm_layer, BondEncoder, MM_BondEncoder
from .torch_edge import DilatedKnnGraph
from .torch_message import GenMessagePassing, MsgNorm
from torch_geometric.utils import remove_self_loops, add_self_loops
class GENConv(GenMessagePassing):
"""
GENeralized Graph Convolution (GENConv): https://arxiv.org/pdf/2006.07739.pdf
SoftMax & PowerMean Aggregation
"""
def __init__(self, in_dim, emb_dim, args,
aggr='softmax',
t=1.0, learn_t=False,
p=1.0, learn_p=False,
y=0.0, learn_y=False,
msg_norm=False, learn_msg_scale=True,
encode_edge=False, bond_encoder=False,
edge_feat_dim=None,
norm='batch', mlp_layers=2,
eps=1e-7):
super(GENConv, self).__init__(aggr=aggr,
t=t, learn_t=learn_t,
p=p, learn_p=learn_p,
y=y, learn_y=learn_y)
channels_list = [in_dim]
for i in range(mlp_layers-1):
channels_list.append(in_dim*2)
channels_list.append(emb_dim)
self.mlp = MLP(channels=channels_list,
norm=norm,
last_lin=True)
self.msg_encoder = torch.nn.ReLU()
self.eps = eps
self.msg_norm = msg_norm
self.encode_edge = encode_edge
self.bond_encoder = bond_encoder
self.advs = args.advs
if msg_norm:
self.msg_norm = MsgNorm(learn_msg_scale=learn_msg_scale)
else:
self.msg_norm = None
if self.encode_edge:
if self.bond_encoder:
if self.advs:
self.edge_encoder = MM_BondEncoder(emb_dim=in_dim)
else:
self.edge_encoder = BondEncoder(emb_dim=in_dim)
else:
self.edge_encoder = torch.nn.Linear(edge_feat_dim, in_dim)
def forward(self, x, edge_index, edge_attr=None):
x = x
if self.encode_edge and edge_attr is not None:
edge_emb = self.edge_encoder(edge_attr)
else:
edge_emb = edge_attr
m = self.propagate(edge_index, x=x, edge_attr=edge_emb)
if self.msg_norm is not None:
m = self.msg_norm(x, m)
h = x + m
out = self.mlp(h)
return out
def message(self, x_j, edge_attr=None):
if edge_attr is not None:
msg = x_j + edge_attr
else:
msg = x_j
return self.msg_encoder(msg) + self.eps
def update(self, aggr_out):
return aggr_out
class MRConv(nn.Module):
"""
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751)
"""
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='max'):
super(MRConv, self).__init__()
self.nn = MLP([in_channels*2, out_channels], act, norm, bias)
self.aggr = aggr
def forward(self, x, edge_index):
""""""
x_j = tg.utils.scatter_(self.aggr, torch.index_select(x, 0, edge_index[0]) - torch.index_select(x, 0, edge_index[1]), edge_index[1], dim_size=x.shape[0])
return self.nn(torch.cat([x, x_j], dim=1))
class EdgConv(tg.nn.EdgeConv):
"""
Edge convolution layer (with activation, batch normalization)
"""
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='max'):
super(EdgConv, self).__init__(MLP([in_channels*2, out_channels], act, norm, bias), aggr)
def forward(self, x, edge_index):
return super(EdgConv, self).forward(x, edge_index)
class GATConv(nn.Module):
"""
Graph Attention Convolution layer (with activation, batch normalization)
"""
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, heads=8):
super(GATConv, self).__init__()
self.gconv = tg.nn.GATConv(in_channels, out_channels, heads, bias=bias)
m =[]
if act:
m.append(act_layer(act))
if norm:
m.append(norm_layer(norm, out_channels))
self.unlinear = nn.Sequential(*m)
def forward(self, x, edge_index):
out = self.unlinear(self.gconv(x, edge_index))
return out
class SAGEConv(tg.nn.SAGEConv):
r"""The GraphSAGE operator from the `"Inductive Representation Learning on
Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper
.. math::
\mathbf{\hat{x}}_i &= \mathbf{\Theta} \cdot
\mathrm{mean}_{j \in \mathcal{N(i) \cup \{ i \}}}(\mathbf{x}_j)
\mathbf{x}^{\prime}_i &= \frac{\mathbf{\hat{x}}_i}
{\| \mathbf{\hat{x}}_i \|_2}.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
normalize (bool, optional): If set to :obj:`False`, output features
will not be :math:`\ell_2`-normalized. (default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self,
in_channels,
out_channels,
nn,
norm=True,
bias=True,
relative=False,
**kwargs):
self.relative = relative
if norm is not None:
super(SAGEConv, self).__init__(in_channels, out_channels, True, bias, **kwargs)
else:
super(SAGEConv, self).__init__(in_channels, out_channels, False, bias, **kwargs)
self.nn = nn
def forward(self, x, edge_index, size=None):
""""""
if size is None:
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = x.unsqueeze(-1) if x.dim() == 1 else x
return self.propagate(edge_index, size=size, x=x)
def message(self, x_i, x_j):
if self.relative:
x = torch.matmul(x_j - x_i, self.weight)
else:
x = torch.matmul(x_j, self.weight)
return x
def update(self, aggr_out, x):
out = self.nn(torch.cat((x, aggr_out), dim=1))
if self.bias is not None:
out = out + self.bias
if self.normalize:
out = F.normalize(out, p=2, dim=-1)
return out
class RSAGEConv(SAGEConv):
"""
Residual SAGE convolution layer (with activation, batch normalization)
"""
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, relative=False):
nn = MLP([out_channels + in_channels, out_channels], act, norm, bias)
super(RSAGEConv, self).__init__(in_channels, out_channels, nn, norm, bias, relative)
class SemiGCNConv(nn.Module):
"""
SemiGCN convolution layer (with activation, batch normalization)
"""
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
super(SemiGCNConv, self).__init__()
self.gconv = tg.nn.GCNConv(in_channels, out_channels, bias=bias)
m = []
if act:
m.append(act_layer(act))
if norm:
m.append(norm_layer(norm, out_channels))
self.unlinear = nn.Sequential(*m)
def forward(self, x, edge_index):
out = self.unlinear(self.gconv(x, edge_index))
return out
class GinConv(tg.nn.GINConv):
"""
GINConv layer (with activation, batch normalization)
"""
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='add'):
super(GinConv, self).__init__(MLP([in_channels, out_channels], act, norm, bias))
def forward(self, x, edge_index):
return super(GinConv, self).forward(x, edge_index)
class GraphConv(nn.Module):
"""
Static graph convolution layer
"""
def __init__(self, in_channels, out_channels, conv='edge',
act='relu', norm=None, bias=True, heads=8):
super(GraphConv, self).__init__()
if conv.lower() == 'edge':
self.gconv = EdgConv(in_channels, out_channels, act, norm, bias)
elif conv.lower() == 'mr':
self.gconv = MRConv(in_channels, out_channels, act, norm, bias)
elif conv.lower() == 'gat':
self.gconv = GATConv(in_channels, out_channels//heads, act, norm, bias, heads)
elif conv.lower() == 'gcn':
self.gconv = SemiGCNConv(in_channels, out_channels, act, norm, bias)
elif conv.lower() == 'gin':
self.gconv = GinConv(in_channels, out_channels, act, norm, bias)
elif conv.lower() == 'sage':
self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, False)
elif conv.lower() == 'rsage':
self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, True)
else:
raise NotImplementedError('conv {} is not implemented'.format(conv))
def forward(self, x, edge_index):
return self.gconv(x, edge_index)
class DynConv(GraphConv):
"""
Dynamic graph convolution layer
"""
def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu',
norm=None, bias=True, heads=8, **kwargs):
super(DynConv, self).__init__(in_channels, out_channels, conv, act, norm, bias, heads)
self.k = kernel_size
self.d = dilation
self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, **kwargs)
def forward(self, x, batch=None):
edge_index = self.dilated_knn_graph(x, batch)
return super(DynConv, self).forward(x, edge_index)
class PlainDynBlock(nn.Module):
"""
Plain Dynamic graph convolution block
"""
def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
bias=True, res_scale=1, **kwargs):
super(PlainDynBlock, self).__init__()
self.body = DynConv(channels, channels, kernel_size, dilation, conv,
act, norm, bias, **kwargs)
self.res_scale = res_scale
def forward(self, x, batch=None):
return self.body(x, batch), batch
class ResDynBlock(nn.Module):
"""
Residual Dynamic graph convolution block
"""
def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
bias=True, res_scale=1, **kwargs):
super(ResDynBlock, self).__init__()
self.body = DynConv(channels, channels, kernel_size, dilation, conv,
act, norm, bias, **kwargs)
self.res_scale = res_scale
def forward(self, x, batch=None):
return self.body(x, batch) + x*self.res_scale, batch
class DenseDynBlock(nn.Module):
"""
Dense Dynamic graph convolution block
"""
def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, bias=True, **kwargs):
super(DenseDynBlock, self).__init__()
self.body = DynConv(in_channels, out_channels, kernel_size, dilation, conv,
act, norm, bias, **kwargs)
def forward(self, x, batch=None):
dense = self.body(x, batch)
return torch.cat((x, dense), 1), batch
class ResGraphBlock(nn.Module):
"""
Residual Static graph convolution block
"""
def __init__(self, channels, conv='edge', act='relu', norm=None, bias=True, heads=8, res_scale=1):
super(ResGraphBlock, self).__init__()
self.body = GraphConv(channels, channels, conv, act, norm, bias, heads)
self.res_scale = res_scale
def forward(self, x, edge_index):
return self.body(x, edge_index) + x*self.res_scale, edge_index
class DenseGraphBlock(nn.Module):
"""
Dense Static graph convolution block
"""
def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True, heads=8):
super(DenseGraphBlock, self).__init__()
self.body = GraphConv(in_channels, out_channels, conv, act, norm, bias, heads)
def forward(self, x, edge_index):
dense = self.body(x, edge_index)
return torch.cat((x, dense), 1), edge_index