File size: 1,346 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn


class EvoNorm2d(nn.Module):
    __constants__ = ["num_features", "eps", "nonlinearity"]

    def __init__(self, num_features, eps=1e-5, nonlinearity=True, group=32):
        super(EvoNorm2d, self).__init__()

        self.num_features = num_features
        self.eps = eps
        self.nonlinearity = nonlinearity
        self.group = group

        self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        if self.nonlinearity:
            self.v = nn.Parameter(torch.Tensor(1, num_features, 1, 1))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.ones_(self.weight)
        nn.init.zeros_(self.bias)
        if self.nonlinearity:
            nn.init.ones_(self.v)

    def group_std(self, x, groups=32):
        N, C, H, W = x.shape
        x = torch.reshape(x, (N, groups, C // groups, H, W))
        std = torch.std(x, (3, 4), keepdim=True)
        return torch.reshape(std + self.eps, (N, C, 1, 1))

    def forward(self, x):
        if self.nonlinearity:
            num = x * torch.sigmoid(self.v * x)
            return num / self.group_std(x, self.group) * self.weight + self.bias
        else:
            return x * self.weight + self.bias