File size: 775 Bytes
db26c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch

from torch import nn


class Norm1D(nn.Module):

    def __init__(self, dim, ntype='batch', affine=False):
        super(Norm1D, self).__init__()
        clazz_dict = {'batch': nn.BatchNorm1d, 'instance': nn.InstanceNorm1d}
        self.nn_norm = clazz_dict[ntype](dim, eps=1e-10, affine=affine)

    def forward(self, x):
        return self.nn_norm(x.permute(0, 2, 1)).permute(0, 2, 1)


class Norm2D(nn.Module):

    def __init__(self, dim, ntype='batch', affine=False):
        super(Norm2D, self).__init__()
        clazz_dict = {'batch': nn.BatchNorm2d, 'instance': nn.InstanceNorm2d}
        self.nn_norm = clazz_dict[ntype](dim, eps=1e-10, affine=affine)

    def forward(self, x):
        return self.nn_norm(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)