File size: 4,410 Bytes
83d8d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch.nn as nn
import torch.nn.functional as F


class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, down_sample=False, up_sample=False, norm=True):
        super(ResBlock, self).__init__()

        main_module_list = []
        if norm:
            main_module_list += [
                nn.InstanceNorm2d(in_channel),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
            ]
        else:
            main_module_list += [
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
            ]
        if down_sample:
            main_module_list.append(nn.AvgPool2d(kernel_size=2))
        elif up_sample:
            main_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear"))
        if norm:
            main_module_list += [
                nn.InstanceNorm2d(out_channel),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
            ]
        else:
            main_module_list += [
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
            ]
        self.main_path = nn.Sequential(*main_module_list)

        side_module_list = [nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0)]
        if down_sample:
            side_module_list.append(nn.AvgPool2d(kernel_size=2))
        elif up_sample:
            side_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear"))
        self.side_path = nn.Sequential(*side_module_list)

    def forward(self, x):
        x1 = self.main_path(x)
        x2 = self.side_path(x)
        return x1 + x2


class AdaIn(nn.Module):
    def __init__(self, in_channel, vector_size):
        super(AdaIn, self).__init__()
        self.eps = 1e-5
        self.std_style_fc = nn.Linear(vector_size, in_channel)
        self.mean_style_fc = nn.Linear(vector_size, in_channel)

    def forward(self, x, style_vector):
        std_style = self.std_style_fc(style_vector)
        mean_style = self.mean_style_fc(style_vector)

        std_style = std_style.unsqueeze(-1).unsqueeze(-1)
        mean_style = mean_style.unsqueeze(-1).unsqueeze(-1)

        x = F.instance_norm(x)
        x = std_style * x + mean_style
        return x


class AdaInResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, up_sample=False):
        super(AdaInResBlock, self).__init__()
        self.vector_size = 257 + 512
        self.up_sample = up_sample

        self.adain1 = AdaIn(in_channel, self.vector_size)
        self.adain2 = AdaIn(out_channel, self.vector_size)

        main_module_list = []
        main_module_list += [
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
        ]
        if up_sample:
            main_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear"))
        self.main_path1 = nn.Sequential(*main_module_list)

        self.main_path2 = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
        )

        side_module_list = [nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0)]
        if up_sample:
            side_module_list.append(nn.Upsample(scale_factor=2, mode="bilinear"))
        self.side_path = nn.Sequential(*side_module_list)

    def forward(self, x, id_vector):
        x1 = self.adain1(x, id_vector)
        x1 = self.main_path1(x1)
        x2 = self.side_path(x)

        x1 = self.adain2(x1, id_vector)
        x1 = self.main_path2(x1)

        return x1 + x2


class UpSamplingBlock(nn.Module):
    def __init__(
        self,
    ):
        super(UpSamplingBlock, self).__init__()
        self.net = nn.Sequential(ResBlock(256, 256, up_sample=True), ResBlock(256, 256, up_sample=True))
        self.i_r_net = nn.Sequential(nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 3, 3, 1, 1))
        self.m_r_net = nn.Sequential(nn.Conv2d(256, 1, 3, 1, 1), nn.Sigmoid())

    def forward(self, x):
        x = self.net(x)
        i_r = self.i_r_net(x)
        m_r = self.m_r_net(x)
        return i_r, m_r