zdou0830's picture
desco
749745d
raw
history blame
1.35 kB
import torch
import torch.nn as nn
class EvoNorm2d(nn.Module):
__constants__ = ["num_features", "eps", "nonlinearity"]
def __init__(self, num_features, eps=1e-5, nonlinearity=True, group=32):
super(EvoNorm2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.nonlinearity = nonlinearity
self.group = group
self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
if self.nonlinearity:
self.v = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.nonlinearity:
nn.init.ones_(self.v)
def group_std(self, x, groups=32):
N, C, H, W = x.shape
x = torch.reshape(x, (N, groups, C // groups, H, W))
std = torch.std(x, (3, 4), keepdim=True)
return torch.reshape(std + self.eps, (N, C, 1, 1))
def forward(self, x):
if self.nonlinearity:
num = x * torch.sigmoid(self.v * x)
return num / self.group_std(x, self.group) * self.weight + self.bias
else:
return x * self.weight + self.bias