Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
import torch | |
from function import adaptive_mean_normalization as adamean | |
from function import adaptive_std_normalization as adastd | |
from function import adaptive_instance_normalization as adain | |
from function import exact_feature_distribution_matching as efdm | |
from function import histogram_matching as hm | |
from function import calc_mean_std | |
# import ipdb | |
from skimage.exposure import match_histograms | |
import numpy as np | |
decoder = nn.Sequential( | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(512, 256, (3, 3)), | |
nn.ReLU(), | |
nn.Upsample(scale_factor=2, mode='nearest'), | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(256, 256, (3, 3)), | |
nn.ReLU(), | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(256, 256, (3, 3)), | |
nn.ReLU(), | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(256, 256, (3, 3)), | |
nn.ReLU(), | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(256, 128, (3, 3)), | |
nn.ReLU(), | |
nn.Upsample(scale_factor=2, mode='nearest'), | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(128, 128, (3, 3)), | |
nn.ReLU(), | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(128, 64, (3, 3)), | |
nn.ReLU(), | |
nn.Upsample(scale_factor=2, mode='nearest'), | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(64, 64, (3, 3)), | |
nn.ReLU(), | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(64, 3, (3, 3)), | |
) | |
vgg = nn.Sequential( | |
nn.Conv2d(3, 3, (1, 1)), | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(3, 64, (3, 3)), | |
nn.ReLU(), # relu1-1 | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(64, 64, (3, 3)), | |
nn.ReLU(), # relu1-2 | |
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(64, 128, (3, 3)), | |
nn.ReLU(), # relu2-1 | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(128, 128, (3, 3)), | |
nn.ReLU(), # relu2-2 | |
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(128, 256, (3, 3)), | |
nn.ReLU(), # relu3-1 | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(256, 256, (3, 3)), | |
nn.ReLU(), # relu3-2 | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(256, 256, (3, 3)), | |
nn.ReLU(), # relu3-3 | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(256, 256, (3, 3)), | |
nn.ReLU(), # relu3-4 | |
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(256, 512, (3, 3)), | |
nn.ReLU(), # relu4-1, this is the last layer used | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(512, 512, (3, 3)), | |
nn.ReLU(), # relu4-2 | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(512, 512, (3, 3)), | |
nn.ReLU(), # relu4-3 | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(512, 512, (3, 3)), | |
nn.ReLU(), # relu4-4 | |
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(512, 512, (3, 3)), | |
nn.ReLU(), # relu5-1 | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(512, 512, (3, 3)), | |
nn.ReLU(), # relu5-2 | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(512, 512, (3, 3)), | |
nn.ReLU(), # relu5-3 | |
nn.ReflectionPad2d((1, 1, 1, 1)), | |
nn.Conv2d(512, 512, (3, 3)), | |
nn.ReLU() # relu5-4 | |
) | |
class Net(nn.Module): | |
def __init__(self, encoder, decoder, style): | |
super(Net, self).__init__() | |
enc_layers = list(encoder.children()) | |
self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 | |
self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 | |
self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 | |
self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 | |
self.decoder = decoder | |
self.mse_loss = nn.MSELoss() | |
self.style = style | |
# fix the encoder | |
for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']: | |
for param in getattr(self, name).parameters(): | |
param.requires_grad = False | |
# extract relu1_1, relu2_1, relu3_1, relu4_1 from input image | |
def encode_with_intermediate(self, input): | |
results = [input] | |
for i in range(4): | |
func = getattr(self, 'enc_{:d}'.format(i + 1)) | |
results.append(func(results[-1])) | |
return results[1:] | |
# extract relu4_1 from input image | |
def encode(self, input): | |
for i in range(4): | |
input = getattr(self, 'enc_{:d}'.format(i + 1))(input) | |
return input | |
def calc_content_loss(self, input, target): | |
assert (input.size() == target.size()) | |
assert (target.requires_grad is False) | |
return self.mse_loss(input, target) | |
def calc_style_loss(self, input, target): | |
# ipdb.set_trace() | |
assert (input.size() == target.size()) | |
assert (target.requires_grad is False) ## first make sure which one require gradient and which one do not. | |
# print(input.requires_grad) ## True | |
input_mean, input_std = calc_mean_std(input) | |
target_mean, target_std = calc_mean_std(target) | |
if self.style == 'adain': | |
return self.mse_loss(input_mean, target_mean) + \ | |
self.mse_loss(input_std, target_std) | |
elif self.style == 'adamean': | |
return self.mse_loss(input_mean, target_mean) | |
elif self.style == 'adastd': | |
return self.mse_loss(input_std, target_std) | |
elif self.style == 'efdm': | |
B, C, W, H = input.size(0), input.size(1), input.size(2), input.size(3) | |
value_content, index_content = torch.sort(input.view(B, C, -1)) | |
value_style, index_style = torch.sort(target.view(B, C, -1)) | |
inverse_index = index_content.argsort(-1) | |
return self.mse_loss(input.view(B,C,-1), value_style.gather(-1, inverse_index)) | |
elif self.style == 'hm': | |
B, C, W, H = input.size(0), input.size(1), input.size(2), input.size(3) | |
x_view = input.view(-1, W, H) | |
image1_temp = match_histograms(np.array(x_view.detach().clone().cpu().float().transpose(0, 2)), | |
np.array(target.view(-1, W, H).detach().clone().cpu().float().transpose(0,2)), | |
multichannel=True) | |
image1_temp = torch.from_numpy(image1_temp).float().to(input.device).transpose(0, 2).view(B, C, W, H) | |
return self.mse_loss(input.reshape(B, C, -1), image1_temp.reshape(B, C, -1)) | |
else: | |
raise NotImplementedError | |
def forward(self, content, style, alpha=1.0): | |
assert 0 <= alpha <= 1 | |
# ipdb.set_trace() | |
style_feats = self.encode_with_intermediate(style) | |
content_feat = self.encode(content) | |
# print(content_feat.requires_grad) False | |
# print(style_feats[-1].requires_grad) False | |
if self.style == 'adain': | |
t = adain(content_feat, style_feats[-1]) | |
elif self.style == 'adamean': | |
t = adamean(content_feat, style_feats[-1]) | |
elif self.style == 'adastd': | |
t = adastd(content_feat, style_feats[-1]) | |
elif self.style == 'efdm': | |
t = efdm(content_feat, style_feats[-1]) | |
elif self.style == 'hm': | |
t = hm(content_feat, style_feats[-1]) | |
else: | |
raise NotImplementedError | |
t = alpha * t + (1 - alpha) * content_feat | |
g_t = self.decoder(t) | |
g_t_feats = self.encode_with_intermediate(g_t) | |
loss_c = self.calc_content_loss(g_t_feats[-1], t) ### final feature should be the same. | |
loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0]) | |
for i in range(1, 4): | |
loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i]) | |
return loss_c, loss_s | |