Spaces:
Runtime error
Runtime error
biubiubiiu
commited on
Commit
•
c9b624b
1
Parent(s):
7abb31f
add EFDM
Browse files- .gitattributes +2 -0
- app.py +130 -0
- config.toml +8 -0
- examples/content/einstein.jpeg +3 -0
- examples/content/granatum.jpg +3 -0
- examples/content/paris.jpeg +3 -0
- examples/content/sailboat.jpg +3 -0
- examples/style/flowers_in_a_turquoise_vase.jpg +3 -0
- examples/style/polasticot2.jpeg +3 -0
- examples/style/sketch.png +3 -0
- examples/style/vangogh.jpeg +3 -0
- function.py +112 -0
- net.py +198 -0
- test.py +252 -0
.gitattributes
CHANGED
@@ -25,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.pth.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
examples/** filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import toml
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from torch import nn
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
import net
|
9 |
+
from function import *
|
10 |
+
|
11 |
+
cfg = toml.load("config.toml") # static variables
|
12 |
+
|
13 |
+
# Setup device
|
14 |
+
if torch.cuda.is_available() and cfg["use_cuda"]:
|
15 |
+
device = torch.device("cuda")
|
16 |
+
else:
|
17 |
+
device = torch.device("cpu")
|
18 |
+
|
19 |
+
# Load pretrained models
|
20 |
+
decoder = net.decoder
|
21 |
+
vgg = net.vgg
|
22 |
+
|
23 |
+
decoder.eval()
|
24 |
+
vgg.eval()
|
25 |
+
|
26 |
+
decoder.load_state_dict(torch.load(cfg["decoder_weight"]))
|
27 |
+
vgg.load_state_dict(torch.load(cfg["vgg_weight"]))
|
28 |
+
vgg = nn.Sequential(*list(vgg.children())[:31])
|
29 |
+
|
30 |
+
vgg = vgg.to(device)
|
31 |
+
decoder = decoder.to(device)
|
32 |
+
|
33 |
+
|
34 |
+
def transform(img, size, crop):
|
35 |
+
transform_list = []
|
36 |
+
if size > 0:
|
37 |
+
transform_list.append(transforms.Resize(size))
|
38 |
+
if crop:
|
39 |
+
transform_list.append(transforms.CenterCrop(size))
|
40 |
+
transform_list.append(transforms.ToTensor())
|
41 |
+
transform = transforms.Compose(transform_list)
|
42 |
+
return transform(img)
|
43 |
+
|
44 |
+
|
45 |
+
@torch.inference_mode()
|
46 |
+
def style_transfer(content, style, style_type, alpha, keep_resolution):
|
47 |
+
"""Stylize function"""
|
48 |
+
style_type = style_type.lower()
|
49 |
+
|
50 |
+
# Step 1: convert image to PyTorch Tensor
|
51 |
+
if keep_resolution:
|
52 |
+
style = style.resize(content.size, Image.ANTIALIAS)
|
53 |
+
|
54 |
+
if style_type == "efdm" and not keep_resolution:
|
55 |
+
content = transform(content, cfg["content_size"], cfg["crop"])
|
56 |
+
style = transform(style, cfg["style_size"], cfg["crop"])
|
57 |
+
else:
|
58 |
+
content = transform(content, -1, False)
|
59 |
+
style = transform(style, -1, False)
|
60 |
+
|
61 |
+
content = content.to(device).unsqueeze(0)
|
62 |
+
style = style.to(device).unsqueeze(0)
|
63 |
+
|
64 |
+
# Step 2: extract content feature and style feature
|
65 |
+
content_feat = vgg(content)
|
66 |
+
style_feat = vgg(style)
|
67 |
+
|
68 |
+
# Step 3: perform style transfer
|
69 |
+
transfer = {
|
70 |
+
"adain": adaptive_instance_normalization,
|
71 |
+
"adamean": adaptive_mean_normalization,
|
72 |
+
"adastd": adaptive_std_normalization,
|
73 |
+
"efdm": exact_feature_distribution_matching,
|
74 |
+
"hm": histogram_matching,
|
75 |
+
}[style_type]
|
76 |
+
feat = transfer(content_feat, style_feat)
|
77 |
+
|
78 |
+
# Step 4: content-style trade-off
|
79 |
+
feat = feat * alpha + content_feat * (1 - alpha)
|
80 |
+
|
81 |
+
# Step 5: decode to image
|
82 |
+
output = decoder(feat).cpu().squeeze(0).clamp_(0, 1)
|
83 |
+
output = transforms.ToPILImage()(output)
|
84 |
+
|
85 |
+
torch.cuda.ipc_collect()
|
86 |
+
torch.cuda.empty_cache()
|
87 |
+
|
88 |
+
return output
|
89 |
+
|
90 |
+
|
91 |
+
# Add image examples
|
92 |
+
example_img_pairs = {
|
93 |
+
"examples/content/sailboat.jpg": "examples/style/sketch.png",
|
94 |
+
"examples/content/granatum.jpg": "examples/style/flowers_in_a_turquoise_vase.jpg",
|
95 |
+
"examples/content/einstein.jpeg": "examples/style/polasticot2.jpeg",
|
96 |
+
"examples/content/paris.jpeg": "examples/style/vangogh.jpeg",
|
97 |
+
}
|
98 |
+
|
99 |
+
# Customize interface
|
100 |
+
title = "Style Transfer with EFDM"
|
101 |
+
description = """
|
102 |
+
Gradio demo for neural style transfer using exact feature distribution matching
|
103 |
+
"""
|
104 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.07740'>Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization</a></p>"
|
105 |
+
content_input = gr.inputs.Image(label="Content Image", source="upload", type="pil")
|
106 |
+
style_input = gr.inputs.Image(label="Style Image", source="upload", type="pil")
|
107 |
+
style_type = gr.inputs.Radio(
|
108 |
+
["EFDM", "AdaIN", "AdaMean", "AdaStd", "HM"], label="Method"
|
109 |
+
)
|
110 |
+
alpha_selector = gr.inputs.Slider(
|
111 |
+
minimum=0.0, maximum=1.0, step=0.01, default=1.0, label="Content-Style trade-off"
|
112 |
+
)
|
113 |
+
keep_resolution = gr.inputs.Checkbox(
|
114 |
+
default=False, label="Keep content image resolution"
|
115 |
+
)
|
116 |
+
|
117 |
+
iface = gr.Interface(
|
118 |
+
fn=style_transfer,
|
119 |
+
inputs=[content_input, style_input, style_type, alpha_selector, keep_resolution],
|
120 |
+
outputs=["image"],
|
121 |
+
title=title,
|
122 |
+
description=description,
|
123 |
+
article=article,
|
124 |
+
theme="huggingface",
|
125 |
+
examples=[
|
126 |
+
[content, style, "EFDM", 1.0, False]
|
127 |
+
for content, style in example_img_pairs.items()
|
128 |
+
],
|
129 |
+
)
|
130 |
+
iface.launch(debug=False, enable_queue=True)
|
config.toml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use_cuda = true
|
2 |
+
|
3 |
+
content_size = 512
|
4 |
+
style_size = 512
|
5 |
+
crop = true
|
6 |
+
|
7 |
+
vgg_weight = "pretrained/vgg_normalised.pth"
|
8 |
+
decoder_weight = "pretrained/efdm_decoder_iter_160000.pth.tar"
|
examples/content/einstein.jpeg
ADDED
Git LFS Details
|
examples/content/granatum.jpg
ADDED
Git LFS Details
|
examples/content/paris.jpeg
ADDED
Git LFS Details
|
examples/content/sailboat.jpg
ADDED
Git LFS Details
|
examples/style/flowers_in_a_turquoise_vase.jpg
ADDED
Git LFS Details
|
examples/style/polasticot2.jpeg
ADDED
Git LFS Details
|
examples/style/sketch.png
ADDED
Git LFS Details
|
examples/style/vangogh.jpeg
ADDED
Git LFS Details
|
function.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from skimage.exposure import match_histograms
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def calc_mean_std(feat, eps=1e-5):
|
6 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
7 |
+
size = feat.size()
|
8 |
+
assert (len(size) == 4)
|
9 |
+
N, C = size[:2]
|
10 |
+
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
11 |
+
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
12 |
+
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
13 |
+
return feat_mean, feat_std
|
14 |
+
|
15 |
+
|
16 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
17 |
+
assert (content_feat.size()[:2] == style_feat.size()[:2])
|
18 |
+
size = content_feat.size()
|
19 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
20 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
21 |
+
|
22 |
+
normalized_feat = (content_feat - content_mean.expand(
|
23 |
+
size)) / content_std.expand(size)
|
24 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
25 |
+
|
26 |
+
## AdaMean
|
27 |
+
def adaptive_mean_normalization(content_feat, style_feat):
|
28 |
+
assert (content_feat.size()[:2] == style_feat.size()[:2])
|
29 |
+
size = content_feat.size()
|
30 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
31 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
32 |
+
|
33 |
+
normalized_feat = (content_feat - content_mean.expand(
|
34 |
+
size))
|
35 |
+
return normalized_feat + style_mean.expand(size)
|
36 |
+
|
37 |
+
## AdaStd
|
38 |
+
def adaptive_std_normalization(content_feat, style_feat):
|
39 |
+
assert (content_feat.size()[:2] == style_feat.size()[:2])
|
40 |
+
size = content_feat.size()
|
41 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
42 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
43 |
+
|
44 |
+
normalized_feat = (content_feat) / content_std.expand(size)
|
45 |
+
return normalized_feat * style_std.expand(size)
|
46 |
+
|
47 |
+
## EFDM
|
48 |
+
def exact_feature_distribution_matching(content_feat, style_feat):
|
49 |
+
assert (content_feat.size() == style_feat.size())
|
50 |
+
B, C, W, H = content_feat.size(0), content_feat.size(1), content_feat.size(2), content_feat.size(3)
|
51 |
+
value_content, index_content = torch.sort(content_feat.view(B,C,-1)) # sort conduct a deep copy here.
|
52 |
+
value_style, _ = torch.sort(style_feat.view(B,C,-1)) # sort conduct a deep copy here.
|
53 |
+
inverse_index = index_content.argsort(-1)
|
54 |
+
new_content = content_feat.view(B,C,-1) + (value_style.gather(-1, inverse_index) - content_feat.view(B,C,-1).detach())
|
55 |
+
|
56 |
+
return new_content.view(B, C, W, H)
|
57 |
+
|
58 |
+
## HM
|
59 |
+
def histogram_matching(content_feat, style_feat):
|
60 |
+
assert (content_feat.size() == style_feat.size())
|
61 |
+
B, C, W, H = content_feat.size(0), content_feat.size(1), content_feat.size(2), content_feat.size(3)
|
62 |
+
x_view = content_feat.view(-1, W,H)
|
63 |
+
image1_temp = match_histograms(np.array(x_view.detach().clone().cpu().float().transpose(0, 2)),
|
64 |
+
np.array(style_feat.view(-1, W, H).detach().clone().cpu().float().transpose(0, 2)),
|
65 |
+
multichannel=True)
|
66 |
+
image1_temp = torch.from_numpy(image1_temp).float().to(content_feat.device).transpose(0, 2).view(B, C, W, H)
|
67 |
+
return content_feat + (image1_temp - content_feat).detach()
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
def _calc_feat_flatten_mean_std(feat):
|
72 |
+
# takes 3D feat (C, H, W), return mean and std of array within channels
|
73 |
+
assert (feat.size()[0] == 3)
|
74 |
+
assert (isinstance(feat, torch.FloatTensor))
|
75 |
+
feat_flatten = feat.view(3, -1)
|
76 |
+
mean = feat_flatten.mean(dim=-1, keepdim=True)
|
77 |
+
std = feat_flatten.std(dim=-1, keepdim=True)
|
78 |
+
return feat_flatten, mean, std
|
79 |
+
|
80 |
+
|
81 |
+
def _mat_sqrt(x):
|
82 |
+
U, D, V = torch.svd(x)
|
83 |
+
return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
|
84 |
+
|
85 |
+
|
86 |
+
def coral(source, target):
|
87 |
+
# assume both source and target are 3D array (C, H, W)
|
88 |
+
# Note: flatten -> f
|
89 |
+
|
90 |
+
source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
|
91 |
+
source_f_norm = (source_f - source_f_mean.expand_as(
|
92 |
+
source_f)) / source_f_std.expand_as(source_f)
|
93 |
+
source_f_cov_eye = \
|
94 |
+
torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
|
95 |
+
|
96 |
+
target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
|
97 |
+
target_f_norm = (target_f - target_f_mean.expand_as(
|
98 |
+
target_f)) / target_f_std.expand_as(target_f)
|
99 |
+
target_f_cov_eye = \
|
100 |
+
torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
|
101 |
+
|
102 |
+
source_f_norm_transfer = torch.mm(
|
103 |
+
_mat_sqrt(target_f_cov_eye),
|
104 |
+
torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
|
105 |
+
source_f_norm)
|
106 |
+
)
|
107 |
+
|
108 |
+
source_f_transfer = source_f_norm_transfer * \
|
109 |
+
target_f_std.expand_as(source_f_norm) + \
|
110 |
+
target_f_mean.expand_as(source_f_norm)
|
111 |
+
|
112 |
+
return source_f_transfer.view(source.size())
|
net.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
from function import adaptive_mean_normalization as adamean
|
4 |
+
from function import adaptive_std_normalization as adastd
|
5 |
+
from function import adaptive_instance_normalization as adain
|
6 |
+
from function import exact_feature_distribution_matching as efdm
|
7 |
+
from function import histogram_matching as hm
|
8 |
+
|
9 |
+
from function import calc_mean_std
|
10 |
+
# import ipdb
|
11 |
+
from skimage.exposure import match_histograms
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
decoder = nn.Sequential(
|
15 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
16 |
+
nn.Conv2d(512, 256, (3, 3)),
|
17 |
+
nn.ReLU(),
|
18 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
19 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
20 |
+
nn.Conv2d(256, 256, (3, 3)),
|
21 |
+
nn.ReLU(),
|
22 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
23 |
+
nn.Conv2d(256, 256, (3, 3)),
|
24 |
+
nn.ReLU(),
|
25 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
26 |
+
nn.Conv2d(256, 256, (3, 3)),
|
27 |
+
nn.ReLU(),
|
28 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
29 |
+
nn.Conv2d(256, 128, (3, 3)),
|
30 |
+
nn.ReLU(),
|
31 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
32 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
33 |
+
nn.Conv2d(128, 128, (3, 3)),
|
34 |
+
nn.ReLU(),
|
35 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
36 |
+
nn.Conv2d(128, 64, (3, 3)),
|
37 |
+
nn.ReLU(),
|
38 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
39 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
40 |
+
nn.Conv2d(64, 64, (3, 3)),
|
41 |
+
nn.ReLU(),
|
42 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
43 |
+
nn.Conv2d(64, 3, (3, 3)),
|
44 |
+
)
|
45 |
+
|
46 |
+
vgg = nn.Sequential(
|
47 |
+
nn.Conv2d(3, 3, (1, 1)),
|
48 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
49 |
+
nn.Conv2d(3, 64, (3, 3)),
|
50 |
+
nn.ReLU(), # relu1-1
|
51 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
52 |
+
nn.Conv2d(64, 64, (3, 3)),
|
53 |
+
nn.ReLU(), # relu1-2
|
54 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
55 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
56 |
+
nn.Conv2d(64, 128, (3, 3)),
|
57 |
+
nn.ReLU(), # relu2-1
|
58 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
59 |
+
nn.Conv2d(128, 128, (3, 3)),
|
60 |
+
nn.ReLU(), # relu2-2
|
61 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
62 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
63 |
+
nn.Conv2d(128, 256, (3, 3)),
|
64 |
+
nn.ReLU(), # relu3-1
|
65 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
66 |
+
nn.Conv2d(256, 256, (3, 3)),
|
67 |
+
nn.ReLU(), # relu3-2
|
68 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
69 |
+
nn.Conv2d(256, 256, (3, 3)),
|
70 |
+
nn.ReLU(), # relu3-3
|
71 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
72 |
+
nn.Conv2d(256, 256, (3, 3)),
|
73 |
+
nn.ReLU(), # relu3-4
|
74 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
75 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
76 |
+
nn.Conv2d(256, 512, (3, 3)),
|
77 |
+
nn.ReLU(), # relu4-1, this is the last layer used
|
78 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
79 |
+
nn.Conv2d(512, 512, (3, 3)),
|
80 |
+
nn.ReLU(), # relu4-2
|
81 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
82 |
+
nn.Conv2d(512, 512, (3, 3)),
|
83 |
+
nn.ReLU(), # relu4-3
|
84 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
85 |
+
nn.Conv2d(512, 512, (3, 3)),
|
86 |
+
nn.ReLU(), # relu4-4
|
87 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
88 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
89 |
+
nn.Conv2d(512, 512, (3, 3)),
|
90 |
+
nn.ReLU(), # relu5-1
|
91 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
92 |
+
nn.Conv2d(512, 512, (3, 3)),
|
93 |
+
nn.ReLU(), # relu5-2
|
94 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
95 |
+
nn.Conv2d(512, 512, (3, 3)),
|
96 |
+
nn.ReLU(), # relu5-3
|
97 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
98 |
+
nn.Conv2d(512, 512, (3, 3)),
|
99 |
+
nn.ReLU() # relu5-4
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
class Net(nn.Module):
|
104 |
+
def __init__(self, encoder, decoder, style):
|
105 |
+
super(Net, self).__init__()
|
106 |
+
enc_layers = list(encoder.children())
|
107 |
+
self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
|
108 |
+
self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
|
109 |
+
self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
|
110 |
+
self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
|
111 |
+
self.decoder = decoder
|
112 |
+
self.mse_loss = nn.MSELoss()
|
113 |
+
self.style = style
|
114 |
+
|
115 |
+
# fix the encoder
|
116 |
+
for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
|
117 |
+
for param in getattr(self, name).parameters():
|
118 |
+
param.requires_grad = False
|
119 |
+
|
120 |
+
# extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
|
121 |
+
def encode_with_intermediate(self, input):
|
122 |
+
results = [input]
|
123 |
+
for i in range(4):
|
124 |
+
func = getattr(self, 'enc_{:d}'.format(i + 1))
|
125 |
+
results.append(func(results[-1]))
|
126 |
+
return results[1:]
|
127 |
+
|
128 |
+
# extract relu4_1 from input image
|
129 |
+
def encode(self, input):
|
130 |
+
for i in range(4):
|
131 |
+
input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
|
132 |
+
return input
|
133 |
+
|
134 |
+
def calc_content_loss(self, input, target):
|
135 |
+
assert (input.size() == target.size())
|
136 |
+
assert (target.requires_grad is False)
|
137 |
+
return self.mse_loss(input, target)
|
138 |
+
|
139 |
+
def calc_style_loss(self, input, target):
|
140 |
+
# ipdb.set_trace()
|
141 |
+
assert (input.size() == target.size())
|
142 |
+
assert (target.requires_grad is False) ## first make sure which one require gradient and which one do not.
|
143 |
+
# print(input.requires_grad) ## True
|
144 |
+
input_mean, input_std = calc_mean_std(input)
|
145 |
+
target_mean, target_std = calc_mean_std(target)
|
146 |
+
if self.style == 'adain':
|
147 |
+
return self.mse_loss(input_mean, target_mean) + \
|
148 |
+
self.mse_loss(input_std, target_std)
|
149 |
+
elif self.style == 'adamean':
|
150 |
+
return self.mse_loss(input_mean, target_mean)
|
151 |
+
elif self.style == 'adastd':
|
152 |
+
return self.mse_loss(input_std, target_std)
|
153 |
+
elif self.style == 'efdm':
|
154 |
+
B, C, W, H = input.size(0), input.size(1), input.size(2), input.size(3)
|
155 |
+
value_content, index_content = torch.sort(input.view(B, C, -1))
|
156 |
+
value_style, index_style = torch.sort(target.view(B, C, -1))
|
157 |
+
inverse_index = index_content.argsort(-1)
|
158 |
+
return self.mse_loss(input.view(B,C,-1), value_style.gather(-1, inverse_index))
|
159 |
+
elif self.style == 'hm':
|
160 |
+
B, C, W, H = input.size(0), input.size(1), input.size(2), input.size(3)
|
161 |
+
x_view = input.view(-1, W, H)
|
162 |
+
image1_temp = match_histograms(np.array(x_view.detach().clone().cpu().float().transpose(0, 2)),
|
163 |
+
np.array(target.view(-1, W, H).detach().clone().cpu().float().transpose(0,2)),
|
164 |
+
multichannel=True)
|
165 |
+
image1_temp = torch.from_numpy(image1_temp).float().to(input.device).transpose(0, 2).view(B, C, W, H)
|
166 |
+
return self.mse_loss(input.reshape(B, C, -1), image1_temp.reshape(B, C, -1))
|
167 |
+
else:
|
168 |
+
raise NotImplementedError
|
169 |
+
|
170 |
+
def forward(self, content, style, alpha=1.0):
|
171 |
+
assert 0 <= alpha <= 1
|
172 |
+
# ipdb.set_trace()
|
173 |
+
style_feats = self.encode_with_intermediate(style)
|
174 |
+
content_feat = self.encode(content)
|
175 |
+
# print(content_feat.requires_grad) False
|
176 |
+
# print(style_feats[-1].requires_grad) False
|
177 |
+
if self.style == 'adain':
|
178 |
+
t = adain(content_feat, style_feats[-1])
|
179 |
+
elif self.style == 'adamean':
|
180 |
+
t = adamean(content_feat, style_feats[-1])
|
181 |
+
elif self.style == 'adastd':
|
182 |
+
t = adastd(content_feat, style_feats[-1])
|
183 |
+
elif self.style == 'efdm':
|
184 |
+
t = efdm(content_feat, style_feats[-1])
|
185 |
+
elif self.style == 'hm':
|
186 |
+
t = hm(content_feat, style_feats[-1])
|
187 |
+
else:
|
188 |
+
raise NotImplementedError
|
189 |
+
t = alpha * t + (1 - alpha) * content_feat
|
190 |
+
|
191 |
+
g_t = self.decoder(t)
|
192 |
+
g_t_feats = self.encode_with_intermediate(g_t)
|
193 |
+
|
194 |
+
loss_c = self.calc_content_loss(g_t_feats[-1], t) ### final feature should be the same.
|
195 |
+
loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
|
196 |
+
for i in range(1, 4):
|
197 |
+
loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
|
198 |
+
return loss_c, loss_s
|
test.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import transforms
|
8 |
+
from torchvision.utils import save_image
|
9 |
+
import time
|
10 |
+
import net
|
11 |
+
from function import adaptive_instance_normalization, coral
|
12 |
+
from function import adaptive_mean_normalization
|
13 |
+
from function import adaptive_std_normalization
|
14 |
+
from function import exact_feature_distribution_matching, histogram_matching
|
15 |
+
|
16 |
+
def test_transform(size, crop):
|
17 |
+
transform_list = []
|
18 |
+
if size != 0:
|
19 |
+
transform_list.append(transforms.Resize(size))
|
20 |
+
if crop:
|
21 |
+
transform_list.append(transforms.CenterCrop(size))
|
22 |
+
transform_list.append(transforms.ToTensor())
|
23 |
+
transform = transforms.Compose(transform_list)
|
24 |
+
return transform
|
25 |
+
|
26 |
+
|
27 |
+
def style_transfer(vgg, decoder, content, style, alpha=1.0,
|
28 |
+
interpolation_weights=None, style_type='adain'):
|
29 |
+
assert (0.0 <= alpha <= 1.0)
|
30 |
+
content_f = vgg(content)
|
31 |
+
style_f = vgg(style)
|
32 |
+
if interpolation_weights:
|
33 |
+
_, C, H, W = content_f.size()
|
34 |
+
feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
|
35 |
+
if style_type == 'adain':
|
36 |
+
base_feat = adaptive_instance_normalization(content_f, style_f)
|
37 |
+
elif style_type == 'adamean':
|
38 |
+
base_feat = adaptive_mean_normalization(content_f, style_f)
|
39 |
+
elif style_type == 'adastd':
|
40 |
+
base_feat = adaptive_std_normalization(content_f, style_f)
|
41 |
+
elif style_type == 'efdm':
|
42 |
+
base_feat = exact_feature_distribution_matching(content_f, style_f)
|
43 |
+
elif style_type == 'hm':
|
44 |
+
feat = histogram_matching(content_f, style_f)
|
45 |
+
else:
|
46 |
+
raise NotImplementedError
|
47 |
+
for i, w in enumerate(interpolation_weights):
|
48 |
+
feat = feat + w * base_feat[i:i + 1]
|
49 |
+
content_f = content_f[0:1]
|
50 |
+
else:
|
51 |
+
if style_type == 'adain':
|
52 |
+
feat = adaptive_instance_normalization(content_f, style_f)
|
53 |
+
elif style_type == 'adamean':
|
54 |
+
feat = adaptive_mean_normalization(content_f, style_f)
|
55 |
+
elif style_type == 'adastd':
|
56 |
+
feat = adaptive_std_normalization(content_f, style_f)
|
57 |
+
elif style_type == 'efdm':
|
58 |
+
feat = exact_feature_distribution_matching(content_f, style_f)
|
59 |
+
elif style_type == 'hm':
|
60 |
+
feat = histogram_matching(content_f, style_f)
|
61 |
+
else:
|
62 |
+
raise NotImplementedError
|
63 |
+
feat = feat * alpha + content_f * (1 - alpha)
|
64 |
+
return decoder(feat)
|
65 |
+
|
66 |
+
|
67 |
+
parser = argparse.ArgumentParser()
|
68 |
+
# Basic options
|
69 |
+
parser.add_argument('--content', type=str,
|
70 |
+
help='File path to the content image')
|
71 |
+
parser.add_argument('--content_dir', type=str,
|
72 |
+
help='Directory path to a batch of content images')
|
73 |
+
parser.add_argument('--style', type=str,
|
74 |
+
help='File path to the style image, or multiple style \
|
75 |
+
images separated by commas if you want to do style \
|
76 |
+
interpolation or spatial control')
|
77 |
+
parser.add_argument('--style_dir', type=str,
|
78 |
+
help='Directory path to a batch of style images')
|
79 |
+
parser.add_argument('--vgg', type=str, default='pretrained/vgg_normalised.pth')
|
80 |
+
parser.add_argument('--decoder', type=str, default='pretrained/efdm_decoder_iter_160000.pth.tar')
|
81 |
+
parser.add_argument('--style_type', type=str, default='adain', help='adain | adamean | adastd | efdm')
|
82 |
+
parser.add_argument('--test_style_type', type=str, default='', help='adain | adamean | adastd | efdm')
|
83 |
+
# Additional options
|
84 |
+
parser.add_argument('--content_size', type=int, default=512,
|
85 |
+
help='New (minimum) size for the content image, \
|
86 |
+
keeping the original size if set to 0')
|
87 |
+
parser.add_argument('--style_size', type=int, default=512,
|
88 |
+
help='New (minimum) size for the style image, \
|
89 |
+
keeping the original size if set to 0')
|
90 |
+
parser.add_argument('--crop', action='store_true',
|
91 |
+
help='do center crop to create squared image')
|
92 |
+
parser.add_argument('--save_ext', default='.jpg',
|
93 |
+
help='The extension name of the output image')
|
94 |
+
parser.add_argument('--output', type=str, default='output',
|
95 |
+
help='Directory to save the output image(s)')
|
96 |
+
parser.add_argument('--photo', action='store_true',
|
97 |
+
help='apply on the photo style transfer')
|
98 |
+
# Advanced options
|
99 |
+
parser.add_argument('--preserve_color', action='store_true',
|
100 |
+
help='If specified, preserve color of the content image')
|
101 |
+
parser.add_argument('--alpha', type=float, default=1.0,
|
102 |
+
help='The weight that controls the degree of \
|
103 |
+
stylization. Should be between 0 and 1')
|
104 |
+
parser.add_argument(
|
105 |
+
'--style_interpolation_weights', type=str, default='',
|
106 |
+
help='The weight for blending the style of multiple style images')
|
107 |
+
|
108 |
+
args = parser.parse_args()
|
109 |
+
if not args.test_style_type:
|
110 |
+
args.test_style_type = args.style_type
|
111 |
+
|
112 |
+
print('Note: the style type: %s and the pre-trained model: %s should be consistent' % (args.style_type, args.decoder))
|
113 |
+
print('The test style type is:', args.test_style_type)
|
114 |
+
|
115 |
+
do_interpolation = False
|
116 |
+
|
117 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
118 |
+
|
119 |
+
output_dir = Path(args.output + '_' + args.style_type + '_' + args.test_style_type)
|
120 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
121 |
+
|
122 |
+
# Either --content or --contentDir should be given.
|
123 |
+
assert (args.content or args.content_dir)
|
124 |
+
if args.content:
|
125 |
+
content_paths = [Path(args.content)]
|
126 |
+
else:
|
127 |
+
content_dir = Path(args.content_dir)
|
128 |
+
content_paths = [f for f in content_dir.glob('*')]
|
129 |
+
|
130 |
+
# Either --style or --styleDir should be given.
|
131 |
+
assert (args.style or args.style_dir)
|
132 |
+
if args.style:
|
133 |
+
style_paths = args.style.split(',')
|
134 |
+
if len(style_paths) == 1:
|
135 |
+
style_paths = [Path(args.style)]
|
136 |
+
else:
|
137 |
+
do_interpolation = True
|
138 |
+
# assert (args.style_interpolation_weights != ''), \
|
139 |
+
# 'Please specify interpolation weights'
|
140 |
+
# weights = [int(i) for i in args.style_interpolation_weights.split(',')]
|
141 |
+
# interpolation_weights = [w / sum(weights) for w in weights]
|
142 |
+
else:
|
143 |
+
style_dir = Path(args.style_dir)
|
144 |
+
style_paths = [f for f in style_dir.glob('*')]
|
145 |
+
|
146 |
+
decoder = net.decoder
|
147 |
+
vgg = net.vgg
|
148 |
+
|
149 |
+
decoder.eval()
|
150 |
+
vgg.eval()
|
151 |
+
|
152 |
+
decoder.load_state_dict(torch.load(args.decoder))
|
153 |
+
vgg.load_state_dict(torch.load(args.vgg))
|
154 |
+
vgg = nn.Sequential(*list(vgg.children())[:31])
|
155 |
+
|
156 |
+
vgg.to(device)
|
157 |
+
decoder.to(device)
|
158 |
+
|
159 |
+
content_tf = test_transform(args.content_size, args.crop)
|
160 |
+
style_tf = test_transform(args.style_size, args.crop)
|
161 |
+
|
162 |
+
timer = []
|
163 |
+
for content_path in content_paths:
|
164 |
+
if do_interpolation:
|
165 |
+
# one content image, 4 style image
|
166 |
+
style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths])
|
167 |
+
content = content_tf(Image.open(str(content_path))) \
|
168 |
+
.unsqueeze(0).expand_as(style)
|
169 |
+
style = style.to(device)
|
170 |
+
content = content.to(device)
|
171 |
+
list = []
|
172 |
+
steps = [1, 0.75, 0.5, 0.25, 0]
|
173 |
+
for i in steps:
|
174 |
+
for j in steps:
|
175 |
+
list.append([i*j, i*(1-j), (1-i)*j, (1-i)*(1-j)])
|
176 |
+
count = 1
|
177 |
+
for interpolation_weights in list:
|
178 |
+
with torch.no_grad():
|
179 |
+
output = style_transfer(vgg, decoder, content, style,
|
180 |
+
args.alpha, interpolation_weights, style_type=args.test_style_type)
|
181 |
+
output = output.cpu()
|
182 |
+
output_name = output_dir / '{:s}_interpolate_{:s}_{:s}'.format(
|
183 |
+
content_path.stem, str(count), args.save_ext)
|
184 |
+
save_image(output, str(output_name))
|
185 |
+
count+=1
|
186 |
+
|
187 |
+
#### content & style trade-off.
|
188 |
+
# alpha = [0.0, 0.25, 0.5, 0.75, 1.0]
|
189 |
+
# for style_path in style_paths:
|
190 |
+
# content = content_tf(Image.open(str(content_path)))
|
191 |
+
# style = style_tf(Image.open(str(style_path)))
|
192 |
+
# if args.preserve_color:
|
193 |
+
# style = coral(style, content)
|
194 |
+
# style = style.to(device).unsqueeze(0)
|
195 |
+
# content = content.to(device).unsqueeze(0)
|
196 |
+
# ## replace the style image with Gaussian noise
|
197 |
+
# # style.normal_(0,1)
|
198 |
+
# # style = torch.rand(style.size()).to(device)
|
199 |
+
# ### for paired images.
|
200 |
+
# if args.photo:
|
201 |
+
# if content_path.stem[2:] == style_path.stem[3:]:
|
202 |
+
# for sample_alpha in alpha:
|
203 |
+
# with torch.no_grad():
|
204 |
+
# output = style_transfer(vgg, decoder, content, style,
|
205 |
+
# sample_alpha, style_type=args.test_style_type)
|
206 |
+
# output = output.cpu()
|
207 |
+
# output_name = output_dir / '{:s}_stylized_{:s}{:s}{:s}'.format(
|
208 |
+
# content_path.stem, style_path.stem, str(sample_alpha), args.save_ext)
|
209 |
+
# save_image(output, str(output_name))
|
210 |
+
# else:
|
211 |
+
# for sample_alpha in alpha:
|
212 |
+
# with torch.no_grad():
|
213 |
+
# output = style_transfer(vgg, decoder, content, style,
|
214 |
+
# sample_alpha, style_type=args.test_style_type)
|
215 |
+
# output = output.cpu()
|
216 |
+
# output_name = output_dir / '{:s}_stylized_{:s}{:s}{:s}'.format(
|
217 |
+
# content_path.stem, style_path.stem, str(sample_alpha), args.save_ext)
|
218 |
+
# save_image(output, str(output_name))
|
219 |
+
else: # process one content and one style
|
220 |
+
for style_path in style_paths:
|
221 |
+
content = content_tf(Image.open(str(content_path)))
|
222 |
+
style = style_tf(Image.open(str(style_path)))
|
223 |
+
if args.preserve_color:
|
224 |
+
style = coral(style, content)
|
225 |
+
style = style.to(device).unsqueeze(0)
|
226 |
+
content = content.to(device).unsqueeze(0)
|
227 |
+
## replace the style image with Gaussian noise
|
228 |
+
# style.normal_(0,1)
|
229 |
+
# style = torch.rand(style.size()).to(device)
|
230 |
+
### for paired images.
|
231 |
+
if args.photo:
|
232 |
+
if content_path.stem[2:] == style_path.stem[3:]:
|
233 |
+
with torch.no_grad():
|
234 |
+
start_time = time.time()
|
235 |
+
output = style_transfer(vgg, decoder, content, style,
|
236 |
+
args.alpha, style_type=args.test_style_type)
|
237 |
+
timer.append(time.time() - start_time)
|
238 |
+
print(timer)
|
239 |
+
|
240 |
+
output = output.cpu()
|
241 |
+
output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format(
|
242 |
+
content_path.stem, style_path.stem, args.save_ext)
|
243 |
+
save_image(output, str(output_name))
|
244 |
+
else:
|
245 |
+
with torch.no_grad():
|
246 |
+
output = style_transfer(vgg, decoder, content, style,
|
247 |
+
args.alpha, style_type=args.test_style_type)
|
248 |
+
output = output.cpu()
|
249 |
+
output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format(
|
250 |
+
content_path.stem, style_path.stem, args.save_ext)
|
251 |
+
save_image(output, str(output_name))
|
252 |
+
print(torch.FloatTensor(timer).mean())
|