File size: 4,309 Bytes
b456239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import math
import time

class ConvCIN(nn.Module):
    def __init__(self, n_styles, C_in, C_out, kernel_size, padding, stride, activation=None):
        super(ConvCIN, self).__init__()
        
        self.reflection = nn.ReflectionPad2d(padding)
        self.conv = nn.Conv2d(in_channels=C_in, out_channels=C_out, kernel_size=kernel_size, stride=stride)
        nn.init.normal_(self.conv.weight, mean=0, std=1e-2)

        self.instnorm = nn.InstanceNorm2d(C_out)#, affine=True)
        #nn.init.normal_(self.instnorm.weight, mean=1, std=1e-2)
        #nn.init.normal_(self.instnorm.bias, mean=0, std=1e-2)

        
        self.gamma = torch.nn.Parameter(data=torch.randn(n_styles, C_out)*1e-2 + 1, requires_grad=True)
        #self.gamma.data.uniform_(1.0, 1.0)

        self.beta = torch.nn.Parameter(data=torch.randn(n_styles, C_out)*1e-2, requires_grad=True)
        #self.beta.data.uniform_(0, 0)

        self.activation = activation

    def forward(self, x, style_1, style_2, alpha):

        x = self.reflection(x)
        x = self.conv(x)

        x = self.instnorm(x)

        
        if style_2 != None:
            gamma = alpha*self.gamma[style_1] + (1-alpha)*self.gamma[style_2]
            beta = alpha*self.beta[style_1] + (1-alpha)*self.beta[style_2]
        else:
            gamma = self.gamma[style_1]
            beta = self.beta[style_1]
        

        b,d,w,h = x.size()
        x = x.view(b,d,w*h)

        x = (x*gamma.unsqueeze(-1) + beta.unsqueeze(-1)).view(b,d,w,h)

        if self.activation == 'relu':
            x = F.relu(x)
        elif self.activation == 'sigmoid':
            x = torch.sigmoid(x)

        return x

class ResidualBlock(nn.Module):
    def __init__(self, n_styles, C_in, C_out):
        super(ResidualBlock,self).__init__()

        self.convcin1 = ConvCIN(n_styles, C_in, C_out, kernel_size=3, padding=1, stride=1, activation='relu')
        self.convcin2 = ConvCIN(n_styles, C_in, C_out, kernel_size=3, padding=1, stride=1)

    def forward(self, x, style_1, style_2, alpha):
        out = self.convcin1(x, style_1, style_2, alpha)
        out = self.convcin2(out, style_1, style_2, alpha)
        return x + out

class UpSampling(nn.Module):
    def __init__(self, n_styles, C_in, C_out):
        super(UpSampling,self).__init__()

        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.convcin = ConvCIN(n_styles, C_in, C_out, kernel_size=3, padding=1, stride=1, activation='relu')

    def forward(self, x, style_1, style_2, alpha):
        x = self.upsample(x)
        x = self.convcin(x, style_1, style_2, alpha)
        return x

class STModel(nn.Module):
    def __init__(self, n_styles):
        super(STModel,self).__init__()

        self.convcin1 = ConvCIN(n_styles, C_in=3, C_out=32, kernel_size=9, padding=4, stride=1, activation='relu')
        self.convcin2 = ConvCIN(n_styles, C_in=32, C_out=64, kernel_size=3, padding=1, stride=2, activation='relu')
        self.convcin3 = ConvCIN(n_styles, C_in=64, C_out=128, kernel_size=3, padding=1, stride=2, activation='relu')

        self.rb1 = ResidualBlock(n_styles, 128, 128)
        self.rb2 = ResidualBlock(n_styles, 128, 128)
        self.rb3 = ResidualBlock(n_styles, 128, 128)
        self.rb4 = ResidualBlock(n_styles, 128, 128)
        self.rb5 = ResidualBlock(n_styles, 128, 128)

        self.upsample1 = UpSampling(n_styles, 128, 64)
        self.upsample2 = UpSampling(n_styles, 64, 32)

        self.convcin4 = ConvCIN(n_styles, C_in=32, C_out=3, kernel_size=9, padding=4, stride=1, activation='sigmoid')

    def forward(self, x, style_1, style_2=None, alpha=0.5):
        x = self.convcin1(x, style_1, style_2, alpha)
        x = self.convcin2(x, style_1, style_2, alpha)
        x = self.convcin3(x, style_1, style_2, alpha)

        x = self.rb1(x, style_1, style_2, alpha)
        x = self.rb2(x, style_1, style_2, alpha)
        x = self.rb3(x, style_1, style_2, alpha)
        x = self.rb4(x, style_1, style_2, alpha)
        x = self.rb5(x, style_1, style_2, alpha)

        x = self.upsample1(x, style_1, style_2, alpha)
        x = self.upsample2(x, style_1, style_2, alpha)

        x = self.convcin4(x, style_1, style_2, alpha)

        return x