| import torch |
| import torch.nn as nn |
| from functools import partial |
|
|
|
|
| def create_activation(name): |
| if name == "relu": |
| return nn.ReLU() |
| elif name == "gelu": |
| return nn.GELU() |
| elif name == "prelu": |
| return nn.PReLU() |
| elif name is None: |
| return nn.Identity() |
| elif name == "elu": |
| return nn.ELU() |
| else: |
| raise NotImplementedError(f"{name} is not implemented.") |
|
|
|
|
| def create_norm(name): |
| if name == "layernorm": |
| return nn.LayerNorm |
| elif name == "batchnorm": |
| return nn.BatchNorm1d |
| elif name == "graphnorm": |
| return partial(NormLayer, norm_type="groupnorm") |
| else: |
| return nn.Identity |
|
|
|
|
| class NormLayer(nn.Module): |
| def __init__(self, hidden_dim, norm_type): |
| super().__init__() |
| if norm_type == "batchnorm": |
| self.norm = nn.BatchNorm1d(hidden_dim) |
| elif norm_type == "layernorm": |
| self.norm = nn.LayerNorm(hidden_dim) |
| elif norm_type == "graphnorm": |
| self.norm = norm_type |
| self.weight = nn.Parameter(torch.ones(hidden_dim)) |
| self.bias = nn.Parameter(torch.zeros(hidden_dim)) |
|
|
| self.mean_scale = nn.Parameter(torch.ones(hidden_dim)) |
| else: |
| raise NotImplementedError |
|
|
| def forward(self, graph, x): |
| tensor = x |
| if self.norm is not None and type(self.norm) != str: |
| return self.norm(tensor) |
| elif self.norm is None: |
| return tensor |
|
|
| batch_list = graph.batch_num_nodes |
| batch_size = len(batch_list) |
| batch_list = torch.Tensor(batch_list).long().to(tensor.device) |
| batch_index = ( |
| torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list) |
| ) |
| batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as( |
| tensor |
| ) |
| mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) |
| mean = mean.scatter_add_(0, batch_index, tensor) |
| mean = (mean.T / batch_list).T |
| mean = mean.repeat_interleave(batch_list, dim=0) |
|
|
| sub = tensor - mean * self.mean_scale |
|
|
| std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) |
| std = std.scatter_add_(0, batch_index, sub.pow(2)) |
| std = ((std.T / batch_list).T + 1e-6).sqrt() |
| std = std.repeat_interleave(batch_list, dim=0) |
| return self.weight * sub / std + self.bias |
|
|