|
|
|
import torch |
|
import torchvision |
|
import torch.nn as nn |
|
import torchvision.models as models |
|
from PIL import Image |
|
from torchvision import transforms |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import matplotlib.pyplot as plt |
|
import torchvision.transforms as transforms |
|
import copy |
|
import torchvision.models as models |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
class ContentLoss(nn.Module): |
|
|
|
def __init__(self, target,): |
|
super(ContentLoss, self).__init__() |
|
''' |
|
we 'detach' the target content from the tree used |
|
to dynamically compute the gradient: this is a stated value, |
|
not a variable. Otherwise the forward method of the criterion |
|
will throw an error. |
|
''' |
|
self.target = target.detach() |
|
|
|
def forward(self, input): |
|
self.loss = F.mse_loss(input, self.target) |
|
return input |
|
|
|
|
|
def gram_matrix(input): |
|
a, b, c, d = input.size() |
|
|
|
|
|
|
|
features = input.view(a * b, c * d) |
|
|
|
G = torch.mm(features, features.t()) |
|
|
|
|
|
|
|
return G.div(a * b * c * d) |
|
|
|
class StyleLoss(nn.Module): |
|
|
|
def __init__(self, target_feature): |
|
super(StyleLoss, self).__init__() |
|
self.target = gram_matrix(target_feature).detach() |
|
|
|
def forward(self, input): |
|
G = gram_matrix(input) |
|
self.loss = F.mse_loss(G, self.target) |
|
return input |
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((128,128)), |
|
transforms.ToTensor()]) |
|
|
|
def image_transform(image): |
|
|
|
if image is not None: |
|
if isinstance(image, str): |
|
|
|
image = Image.open(image).convert('RGB') |
|
else: |
|
|
|
image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
|
|
image = transform(image).unsqueeze(0) |
|
return image |
|
|
|
|
|
|
|
|
|
class Normalization(nn.Module): |
|
def __init__(self, mean, std): |
|
super(Normalization, self).__init__() |
|
|
|
|
|
|
|
self.mean = torch.tensor(mean).view(-1, 1, 1) |
|
self.std = torch.tensor(std).view(-1, 1, 1) |
|
|
|
def forward(self, img): |
|
|
|
return (img - self.mean) / self.std |
|
|
|
|
|
|
|
|
|
|
|
content_layers_default = ['conv_4'] |
|
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] |
|
|
|
def get_style_model_and_losses(cnn, normalization_mean, normalization_std, |
|
style_img, content_img, |
|
content_layers=content_layers_default, |
|
style_layers=style_layers_default): |
|
|
|
normalization = Normalization(normalization_mean, normalization_std) |
|
|
|
|
|
|
|
content_losses = [] |
|
style_losses = [] |
|
|
|
|
|
|
|
model = nn.Sequential(normalization) |
|
|
|
i = 0 |
|
for layer in cnn.children(): |
|
if isinstance(layer, nn.Conv2d): |
|
i += 1 |
|
name = 'conv_{}'.format(i) |
|
elif isinstance(layer, nn.ReLU): |
|
name = 'relu_{}'.format(i) |
|
|
|
|
|
|
|
layer = nn.ReLU(inplace=False) |
|
elif isinstance(layer, nn.MaxPool2d): |
|
name = 'pool_{}'.format(i) |
|
elif isinstance(layer, nn.BatchNorm2d): |
|
name = 'bn_{}'.format(i) |
|
else: |
|
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) |
|
|
|
model.add_module(name, layer) |
|
|
|
if name in content_layers: |
|
|
|
target = model(content_img).detach() |
|
content_loss = ContentLoss(target) |
|
model.add_module("content_loss_{}".format(i), content_loss) |
|
content_losses.append(content_loss) |
|
|
|
if name in style_layers: |
|
|
|
target_feature = model(style_img).detach() |
|
style_loss = StyleLoss(target_feature) |
|
model.add_module("style_loss_{}".format(i), style_loss) |
|
style_losses.append(style_loss) |
|
|
|
|
|
for i in range(len(model) - 1, -1, -1): |
|
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss): |
|
break |
|
|
|
model = model[:(i + 1)] |
|
|
|
return model, style_losses, content_losses |
|
|
|
def get_input_optimizer(input_img): |
|
|
|
optimizer = optim.LBFGS([input_img]) |
|
return optimizer |
|
|