neural-style-transfer / network.py
1-13-am's picture
Update network.py
9e3b4df
raw
history blame
4.83 kB
import torch
import torch.nn as nn
import torchvision
from torchvision.models import vgg19
import utils
from utils import batch_wct, batch_histogram_matching
class Encoder(nn.Module):
def __init__(self, layers = [1, 6, 11, 20]):
super(Encoder, self).__init__()
vgg = torchvision.models.vgg19(pretrained=True).features
self.encoder = nn.ModuleList()
temp_seq = nn.Sequential()
for i in range(max(layers)+1):
temp_seq.add_module(str(i), vgg[i])
if i in layers:
self.encoder.append(temp_seq)
temp_seq = nn.Sequential()
def forward(self, x):
features = []
for layer in self.encoder:
x = layer(x)
features.append(x)
return features
# need to copy the whole architecture bcuz we will need outputs from "layers" layers to compute the loss
class Decoder(nn.Module):
def __init__(self, layers=[1, 6, 11, 20]):
super(Decoder, self).__init__()
vgg = torchvision.models.vgg19(pretrained=False).features
self.decoder = nn.ModuleList()
temp_seq = nn.Sequential()
count = 0
for i in range(max(layers)-1, -1, -1):
if isinstance(vgg[i], nn.Conv2d):
# get number of in/out channels
out_channels = vgg[i].in_channels
in_channels = vgg[i].out_channels
kernel_size = vgg[i].kernel_size
# make a [reflection pad + convolution + relu] layer
temp_seq.add_module(str(count), nn.ReflectionPad2d(padding=(1,1,1,1)))
count += 1
temp_seq.add_module(str(count), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size))
count += 1
temp_seq.add_module(str(count), nn.ReLU())
count += 1
# change down-sampling(MaxPooling) --> upsampling
elif isinstance(vgg[i], nn.MaxPool2d):
temp_seq.add_module(str(count), nn.Upsample(scale_factor=2))
count += 1
if i in layers:
self.decoder.append(temp_seq)
temp_seq = nn.Sequential()
# append last conv layers without ReLU activation
self.decoder.append(temp_seq[:-1])
def forward(self, x):
y = x
for layer in self.decoder:
y = layer(y)
return y
class AdaIN(nn.Module):
def __init__(self):
super(AdaIN, self).__init__()
def forward(self, content, style, style_strength=1.0, eps=1e-5):
"""
content: tensor of shape B * C * H * W
style: tensor of shape B * C * H * W
note that AdaIN does computation on a pair of content - style img"""
b, c, h, w = content.size()
content_std, content_mean = torch.std_mean(content.view(b, c, -1), dim=2, keepdim=True)
style_std, style_mean = torch.std_mean(style.view(b, c, -1), dim=2, keepdim=True)
normalized_content = (content.view(b, c, -1) - content_mean) / (content_std+eps)
stylized_content = (normalized_content * style_std) + style_mean
output = (1-style_strength) * content + style_strength * stylized_content.view(b, c, h, w)
return output
class Style_Transfer_Network(nn.Module):
def __init__(self, layers = [1, 6, 11, 20]):
super(Style_Transfer_Network, self).__init__()
self.encoder = Encoder(layers)
self.decoder = Decoder(layers)
self.adain = AdaIN()
def forward(self, content, styles, style_strength = 1., interpolation_weights = None, preserve_color = None, train = False):
if interpolation_weights is None:
interpolation_weights = [1/len(styles)] * len(styles)
# encode the content image
content_feature = self.encoder(content)
# encode style images
style_features = []
for style in styles:
if preserve_color == 'whitening_and_coloring' or preserve_color == 'histogram_matching':
style = batch_wct(style, content)
style_features.append(self.encoder(style))
transformed_features = []
for style_feature, interpolation_weight in zip(style_features, interpolation_weights):
AdaIN_feature = self.adain(content_feature[-1], style_feature[-1], style_strength) * interpolation_weight
if preserve_color == 'histogram_matching':
AdaIN_feature *= 0.9
transformed_features.append(AdaIN_feature)
transformed_feature = sum(transformed_features)
stylized_image = self.decoder(transformed_feature)
if preserve_color == "whitening_and_coloring":
stylized_image = batch_wct(stylized_image, content)
if preserve_color == "histogram_matching":
stylized_image = batch_histogram_matching(stylized_image, content)
if train:
return stylized_image, transformed_feature
else:
return stylized_image