File size: 4,828 Bytes
1f7d4dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e3b4df
1f7d4dd
 
 
 
 
 
 
 
 
 
 
 
9e3b4df
1f7d4dd
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
import torchvision
from torchvision.models import vgg19
import utils
from utils import batch_wct, batch_histogram_matching

class Encoder(nn.Module):
  def __init__(self, layers = [1, 6, 11, 20]):
    super(Encoder, self).__init__()
    vgg = torchvision.models.vgg19(pretrained=True).features

    self.encoder = nn.ModuleList()
    temp_seq = nn.Sequential()
    for i in range(max(layers)+1):
        temp_seq.add_module(str(i), vgg[i])
        if i in layers:
            self.encoder.append(temp_seq)
            temp_seq = nn.Sequential()

  def forward(self, x):
    features = []
    for layer in self.encoder:
        x = layer(x)
        features.append(x)
    return features

# need to copy the whole architecture bcuz we will need outputs from "layers" layers to compute the loss
class Decoder(nn.Module):
    def __init__(self, layers=[1, 6, 11, 20]):
        super(Decoder, self).__init__()
        vgg = torchvision.models.vgg19(pretrained=False).features

        self.decoder = nn.ModuleList()
        temp_seq  = nn.Sequential()
        count = 0
        for i in range(max(layers)-1, -1, -1):
            if isinstance(vgg[i], nn.Conv2d):
                # get number of in/out channels
                out_channels = vgg[i].in_channels
                in_channels = vgg[i].out_channels
                kernel_size = vgg[i].kernel_size

                # make a [reflection pad + convolution + relu] layer
                temp_seq.add_module(str(count), nn.ReflectionPad2d(padding=(1,1,1,1)))
                count += 1
                temp_seq.add_module(str(count), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size))
                count += 1
                temp_seq.add_module(str(count), nn.ReLU())
                count += 1

            # change down-sampling(MaxPooling) --> upsampling
            elif isinstance(vgg[i], nn.MaxPool2d):
                temp_seq.add_module(str(count), nn.Upsample(scale_factor=2))
                count += 1

            if i in layers:
                self.decoder.append(temp_seq)
                temp_seq  = nn.Sequential()

        # append last conv layers without ReLU activation
        self.decoder.append(temp_seq[:-1])

    def forward(self, x):
        y = x
        for layer in self.decoder:
            y = layer(y)
        return y

class AdaIN(nn.Module):
    def __init__(self):
        super(AdaIN, self).__init__()

    def forward(self, content, style, style_strength=1.0, eps=1e-5):
        """
        content: tensor of shape B * C * H * W
        style: tensor of shape B * C * H * W
        note that AdaIN does computation on a pair of content - style img"""
        b, c, h, w = content.size()

        content_std, content_mean = torch.std_mean(content.view(b, c, -1), dim=2, keepdim=True)
        style_std, style_mean = torch.std_mean(style.view(b, c, -1), dim=2, keepdim=True)

        normalized_content = (content.view(b, c, -1) - content_mean) / (content_std+eps)

        stylized_content = (normalized_content * style_std) + style_mean

        output = (1-style_strength) * content + style_strength * stylized_content.view(b, c, h, w)
        return output

class Style_Transfer_Network(nn.Module):
  def __init__(self, layers = [1, 6, 11, 20]):
    super(Style_Transfer_Network, self).__init__()
    self.encoder = Encoder(layers)
    self.decoder = Decoder(layers)
    self.adain = AdaIN()

  def forward(self, content, styles, style_strength = 1., interpolation_weights = None, preserve_color = None, train = False):
    if interpolation_weights is None:
       interpolation_weights = [1/len(styles)] * len(styles)
    # encode the content image
    content_feature = self.encoder(content)

    # encode style images
    style_features = []
    for style in styles:
        if preserve_color == 'whitening_and_coloring' or preserve_color == 'histogram_matching':
                style = batch_wct(style, content)
        style_features.append(self.encoder(style))

    transformed_features = []
    for style_feature, interpolation_weight in zip(style_features, interpolation_weights):
        AdaIN_feature = self.adain(content_feature[-1], style_feature[-1], style_strength) * interpolation_weight
        if preserve_color == 'histogram_matching':
            AdaIN_feature *= 0.9
        transformed_features.append(AdaIN_feature)
    transformed_feature = sum(transformed_features)

    stylized_image = self.decoder(transformed_feature)
    if preserve_color == "whitening_and_coloring":
        stylized_image = batch_wct(stylized_image, content)
    if preserve_color == "histogram_matching":
        stylized_image = batch_histogram_matching(stylized_image, content)
    if train:  
      return stylized_image, transformed_feature
    else:
      return stylized_image