import torch | |
import torch.nn as nn | |
import torchvision | |
from torchvision.models import vgg19 | |
import utils | |
from utils import batch_wct, batch_histogram_matching | |
class Encoder(nn.Module): | |
def __init__(self, layers = [1, 6, 11, 20]): | |
super(Encoder, self).__init__() | |
vgg = torchvision.models.vgg19(pretrained=True).features | |
self.encoder = nn.ModuleList() | |
temp_seq = nn.Sequential() | |
for i in range(max(layers)+1): | |
temp_seq.add_module(str(i), vgg[i]) | |
if i in layers: | |
self.encoder.append(temp_seq) | |
temp_seq = nn.Sequential() | |
def forward(self, x): | |
features = [] | |
for layer in self.encoder: | |
x = layer(x) | |
features.append(x) | |
return features | |
# need to copy the whole architecture bcuz we will need outputs from "layers" layers to compute the loss | |
class Decoder(nn.Module): | |
def __init__(self, layers=[1, 6, 11, 20]): | |
super(Decoder, self).__init__() | |
vgg = torchvision.models.vgg19(pretrained=False).features | |
self.decoder = nn.ModuleList() | |
temp_seq = nn.Sequential() | |
count = 0 | |
for i in range(max(layers)-1, -1, -1): | |
if isinstance(vgg[i], nn.Conv2d): | |
# get number of in/out channels | |
out_channels = vgg[i].in_channels | |
in_channels = vgg[i].out_channels | |
kernel_size = vgg[i].kernel_size | |
# make a [reflection pad + convolution + relu] layer | |
temp_seq.add_module(str(count), nn.ReflectionPad2d(padding=(1,1,1,1))) | |
count += 1 | |
temp_seq.add_module(str(count), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size)) | |
count += 1 | |
temp_seq.add_module(str(count), nn.ReLU()) | |
count += 1 | |
# change down-sampling(MaxPooling) --> upsampling | |
elif isinstance(vgg[i], nn.MaxPool2d): | |
temp_seq.add_module(str(count), nn.Upsample(scale_factor=2)) | |
count += 1 | |
if i in layers: | |
self.decoder.append(temp_seq) | |
temp_seq = nn.Sequential() | |
# append last conv layers without ReLU activation | |
self.decoder.append(temp_seq[:-1]) | |
def forward(self, x): | |
y = x | |
for layer in self.decoder: | |
y = layer(y) | |
return y | |
class AdaIN(nn.Module): | |
def __init__(self): | |
super(AdaIN, self).__init__() | |
def forward(self, content, style, style_strength=1.0, eps=1e-5): | |
""" | |
content: tensor of shape B * C * H * W | |
style: tensor of shape B * C * H * W | |
note that AdaIN does computation on a pair of content - style img""" | |
b, c, h, w = content.size() | |
content_std, content_mean = torch.std_mean(content.view(b, c, -1), dim=2, keepdim=True) | |
style_std, style_mean = torch.std_mean(style.view(b, c, -1), dim=2, keepdim=True) | |
normalized_content = (content.view(b, c, -1) - content_mean) / (content_std+eps) | |
stylized_content = (normalized_content * style_std) + style_mean | |
output = (1-style_strength) * content + style_strength * stylized_content.view(b, c, h, w) | |
return output | |
class Style_Transfer_Network(nn.Module): | |
def __init__(self, layers = [1, 6, 11, 20]): | |
super(Style_Transfer_Network, self).__init__() | |
self.encoder = Encoder(layers) | |
self.decoder = Decoder(layers) | |
self.adain = AdaIN() | |
def forward(self, content, styles, style_strength = 1., interpolation_weights = None, preserve_color = None, train = False): | |
if interpolation_weights is None: | |
interpolation_weights = [1/len(styles)] * len(styles) | |
# encode the content image | |
content_feature = self.encoder(content) | |
# encode style images | |
style_features = [] | |
for style in styles: | |
if preserve_color == 'whitening_and_coloring' or preserve_color == 'histogram_matching': | |
style = batch_wct(style, content) | |
style_features.append(self.encoder(style)) | |
transformed_features = [] | |
for style_feature, interpolation_weight in zip(style_features, interpolation_weights): | |
AdaIN_feature = self.adain(content_feature[-1], style_feature[-1], style_strength) * interpolation_weight | |
if preserve_color == 'histogram_matching': | |
AdaIN_feature *= 0.9 | |
transformed_features.append(AdaIN_feature) | |
transformed_feature = sum(transformed_features) | |
stylized_image = self.decoder(transformed_feature) | |
if preserve_color == "whitening_and_coloring": | |
stylized_image = batch_wct(stylized_image, content) | |
if preserve_color == "histogram_matching": | |
stylized_image = batch_histogram_matching(stylized_image, content) | |
if train: | |
return stylized_image, transformed_feature | |
else: | |
return stylized_image |