File size: 2,163 Bytes
981b0ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
import torch.nn.utils.spectral_norm as SpectralNorm


def GroupNorm(in_channels):
    ec = 32
    assert in_channels % ec == 0
    return torch.nn.GroupNorm(num_groups=in_channels//32, num_channels=in_channels, eps=1e-6, affine=True)

def swish(x):
    return x*torch.sigmoid(x)

class ResTextBlockV2(nn.Module):
    def __init__(self, in_channels, out_channels=None):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels if out_channels is None else out_channels
        self.norm1 = GroupNorm(in_channels)
        self.conv1 = SpectralNorm(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1))
        self.norm2 = GroupNorm(out_channels)
        self.conv2 = SpectralNorm(nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1))
        if self.in_channels != self.out_channels:
            self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
    def forward(self, x_in):
        x = x_in
        x = self.norm1(x)
        x = swish(x)
        x = self.conv1(x)
        x = self.norm2(x)
        x = swish(x)
        x = self.conv2(x)
        if self.in_channels != self.out_channels:
            x_in = self.conv_out(x_in)
        return x + x_in
    



def calc_mean_std_4D(feat, eps=1e-6):
    size = feat.size()
    assert len(size) == 4, 'The input feature should be 4D tensor.'
    b, c = size[:2]
    feat_var = feat.view(b, c, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(b, c, 1, 1)
    feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
    return feat_mean, feat_std


def adaptive_instance_normalization(prior_feat, lq_feat):
    size = prior_feat.size()
    lq_mean, lq_std = calc_mean_std_4D(lq_feat)
    prior_mean, prior_std = calc_mean_std_4D(prior_feat)

    normalized_feat = (prior_feat - prior_mean.expand(size)) / prior_std.expand(size)
    return normalized_feat * lq_std.expand(size) + lq_mean.expand(size)


def network_param(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    return num_params / 1e6