|
import torch |
|
import torch.nn as nn |
|
from Network import vgg19, decoder |
|
from utils import adaptive_instance_normalization |
|
|
|
class AdaINNet(nn.Module): |
|
""" |
|
AdaIN Style Transfer Network |
|
|
|
Args: |
|
vgg_weight: pretrained vgg19 weight |
|
""" |
|
def __init__(self, vgg_weight): |
|
super().__init__() |
|
self.encoder = vgg19(vgg_weight) |
|
|
|
|
|
self.encoder = nn.Sequential(*list(self.encoder.children())[:22]) |
|
|
|
|
|
for parameter in self.encoder.parameters(): |
|
parameter.requires_grad = False |
|
|
|
self.decoder = decoder() |
|
|
|
self.mseloss = nn.MSELoss() |
|
|
|
""" |
|
Computes style loss of two images |
|
|
|
Args: |
|
x (torch.FloatTensor): content image tensor |
|
y (torch.FloatTensor): style image tensor |
|
|
|
Return: |
|
Mean Squared Error between x.mean, y.mean and MSE between x.std, y.std |
|
""" |
|
def _style_loss(self, x, y): |
|
return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \ |
|
self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3])) |
|
|
|
def forward(self, content, style, alpha=1.0): |
|
|
|
content_enc = self.encoder(content) |
|
style_enc = self.encoder(style) |
|
|
|
|
|
transfer_enc = adaptive_instance_normalization(content_enc, style_enc) |
|
|
|
|
|
out = self.decoder(transfer_enc) |
|
|
|
|
|
style_relu11 = self.encoder[:3](style) |
|
out_relu11 = self.encoder[:3](out) |
|
|
|
|
|
style_relu21 = self.encoder[3:8](style_relu11) |
|
out_relu21 = self.encoder[3:8](out_relu11) |
|
|
|
|
|
style_relu31 = self.encoder[8:13](style_relu21) |
|
out_relu31 = self.encoder[8:13](out_relu21) |
|
|
|
|
|
out_enc = self.encoder[13:](out_relu31) |
|
|
|
|
|
content_loss = self.mseloss(out_enc, transfer_enc) |
|
style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \ |
|
self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc) |
|
|
|
return content_loss, style_loss |