biubiubiiu commited on
Commit
c9b624b
1 Parent(s): 7abb31f
.gitattributes CHANGED
@@ -25,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.pth.tar filter=lfs diff=lfs merge=lfs -text
29
+ examples/** filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import toml
3
+ import torch
4
+ from PIL import Image
5
+ from torch import nn
6
+ from torchvision import transforms
7
+
8
+ import net
9
+ from function import *
10
+
11
+ cfg = toml.load("config.toml") # static variables
12
+
13
+ # Setup device
14
+ if torch.cuda.is_available() and cfg["use_cuda"]:
15
+ device = torch.device("cuda")
16
+ else:
17
+ device = torch.device("cpu")
18
+
19
+ # Load pretrained models
20
+ decoder = net.decoder
21
+ vgg = net.vgg
22
+
23
+ decoder.eval()
24
+ vgg.eval()
25
+
26
+ decoder.load_state_dict(torch.load(cfg["decoder_weight"]))
27
+ vgg.load_state_dict(torch.load(cfg["vgg_weight"]))
28
+ vgg = nn.Sequential(*list(vgg.children())[:31])
29
+
30
+ vgg = vgg.to(device)
31
+ decoder = decoder.to(device)
32
+
33
+
34
+ def transform(img, size, crop):
35
+ transform_list = []
36
+ if size > 0:
37
+ transform_list.append(transforms.Resize(size))
38
+ if crop:
39
+ transform_list.append(transforms.CenterCrop(size))
40
+ transform_list.append(transforms.ToTensor())
41
+ transform = transforms.Compose(transform_list)
42
+ return transform(img)
43
+
44
+
45
+ @torch.inference_mode()
46
+ def style_transfer(content, style, style_type, alpha, keep_resolution):
47
+ """Stylize function"""
48
+ style_type = style_type.lower()
49
+
50
+ # Step 1: convert image to PyTorch Tensor
51
+ if keep_resolution:
52
+ style = style.resize(content.size, Image.ANTIALIAS)
53
+
54
+ if style_type == "efdm" and not keep_resolution:
55
+ content = transform(content, cfg["content_size"], cfg["crop"])
56
+ style = transform(style, cfg["style_size"], cfg["crop"])
57
+ else:
58
+ content = transform(content, -1, False)
59
+ style = transform(style, -1, False)
60
+
61
+ content = content.to(device).unsqueeze(0)
62
+ style = style.to(device).unsqueeze(0)
63
+
64
+ # Step 2: extract content feature and style feature
65
+ content_feat = vgg(content)
66
+ style_feat = vgg(style)
67
+
68
+ # Step 3: perform style transfer
69
+ transfer = {
70
+ "adain": adaptive_instance_normalization,
71
+ "adamean": adaptive_mean_normalization,
72
+ "adastd": adaptive_std_normalization,
73
+ "efdm": exact_feature_distribution_matching,
74
+ "hm": histogram_matching,
75
+ }[style_type]
76
+ feat = transfer(content_feat, style_feat)
77
+
78
+ # Step 4: content-style trade-off
79
+ feat = feat * alpha + content_feat * (1 - alpha)
80
+
81
+ # Step 5: decode to image
82
+ output = decoder(feat).cpu().squeeze(0).clamp_(0, 1)
83
+ output = transforms.ToPILImage()(output)
84
+
85
+ torch.cuda.ipc_collect()
86
+ torch.cuda.empty_cache()
87
+
88
+ return output
89
+
90
+
91
+ # Add image examples
92
+ example_img_pairs = {
93
+ "examples/content/sailboat.jpg": "examples/style/sketch.png",
94
+ "examples/content/granatum.jpg": "examples/style/flowers_in_a_turquoise_vase.jpg",
95
+ "examples/content/einstein.jpeg": "examples/style/polasticot2.jpeg",
96
+ "examples/content/paris.jpeg": "examples/style/vangogh.jpeg",
97
+ }
98
+
99
+ # Customize interface
100
+ title = "Style Transfer with EFDM"
101
+ description = """
102
+ Gradio demo for neural style transfer using exact feature distribution matching
103
+ """
104
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.07740'>Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization</a></p>"
105
+ content_input = gr.inputs.Image(label="Content Image", source="upload", type="pil")
106
+ style_input = gr.inputs.Image(label="Style Image", source="upload", type="pil")
107
+ style_type = gr.inputs.Radio(
108
+ ["EFDM", "AdaIN", "AdaMean", "AdaStd", "HM"], label="Method"
109
+ )
110
+ alpha_selector = gr.inputs.Slider(
111
+ minimum=0.0, maximum=1.0, step=0.01, default=1.0, label="Content-Style trade-off"
112
+ )
113
+ keep_resolution = gr.inputs.Checkbox(
114
+ default=False, label="Keep content image resolution"
115
+ )
116
+
117
+ iface = gr.Interface(
118
+ fn=style_transfer,
119
+ inputs=[content_input, style_input, style_type, alpha_selector, keep_resolution],
120
+ outputs=["image"],
121
+ title=title,
122
+ description=description,
123
+ article=article,
124
+ theme="huggingface",
125
+ examples=[
126
+ [content, style, "EFDM", 1.0, False]
127
+ for content, style in example_img_pairs.items()
128
+ ],
129
+ )
130
+ iface.launch(debug=False, enable_queue=True)
config.toml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ use_cuda = true
2
+
3
+ content_size = 512
4
+ style_size = 512
5
+ crop = true
6
+
7
+ vgg_weight = "pretrained/vgg_normalised.pth"
8
+ decoder_weight = "pretrained/efdm_decoder_iter_160000.pth.tar"
examples/content/einstein.jpeg ADDED

Git LFS Details

  • SHA256: d03664a496d87caf687a20b2aa7a75cbe0ae2c2fa354a5947aa063a5a143ccb4
  • Pointer size: 131 Bytes
  • Size of remote file: 303 kB
examples/content/granatum.jpg ADDED

Git LFS Details

  • SHA256: 7cd0d627b15c09f373aa613d013fb2d8ae6bd20bcce0d98aa31b963e6bcca495
  • Pointer size: 131 Bytes
  • Size of remote file: 166 kB
examples/content/paris.jpeg ADDED

Git LFS Details

  • SHA256: f1a97a8989510c41006f3733e219b4a2819016e9219598ebf1363c13731fb3b0
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB
examples/content/sailboat.jpg ADDED

Git LFS Details

  • SHA256: 7c381a9366dd134524c3130887d2530c4a9c563f23825e86d698414f86d5270a
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB
examples/style/flowers_in_a_turquoise_vase.jpg ADDED

Git LFS Details

  • SHA256: 79fb1043df54e24253dfc82d9a5f7c1fdc34f83f3a4430c5b1dcbf7a2748c4f9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
examples/style/polasticot2.jpeg ADDED

Git LFS Details

  • SHA256: dc5b26049bbb33f5a8125fe2c8cad1a5c59f5e16ea301582d0b0b08ad055f9f6
  • Pointer size: 131 Bytes
  • Size of remote file: 688 kB
examples/style/sketch.png ADDED

Git LFS Details

  • SHA256: 15ad03557b213c98e8a0dd5806961aeb53c02f1abe0c0f2b9e66e433d5858819
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
examples/style/vangogh.jpeg ADDED

Git LFS Details

  • SHA256: b36c5349e156840781e0d84cb25097193a6a4970c04c0aba87f2e7baa102af71
  • Pointer size: 130 Bytes
  • Size of remote file: 45.9 kB
function.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from skimage.exposure import match_histograms
3
+ import numpy as np
4
+
5
+ def calc_mean_std(feat, eps=1e-5):
6
+ # eps is a small value added to the variance to avoid divide-by-zero.
7
+ size = feat.size()
8
+ assert (len(size) == 4)
9
+ N, C = size[:2]
10
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
11
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
12
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
13
+ return feat_mean, feat_std
14
+
15
+
16
+ def adaptive_instance_normalization(content_feat, style_feat):
17
+ assert (content_feat.size()[:2] == style_feat.size()[:2])
18
+ size = content_feat.size()
19
+ style_mean, style_std = calc_mean_std(style_feat)
20
+ content_mean, content_std = calc_mean_std(content_feat)
21
+
22
+ normalized_feat = (content_feat - content_mean.expand(
23
+ size)) / content_std.expand(size)
24
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
25
+
26
+ ## AdaMean
27
+ def adaptive_mean_normalization(content_feat, style_feat):
28
+ assert (content_feat.size()[:2] == style_feat.size()[:2])
29
+ size = content_feat.size()
30
+ style_mean, style_std = calc_mean_std(style_feat)
31
+ content_mean, content_std = calc_mean_std(content_feat)
32
+
33
+ normalized_feat = (content_feat - content_mean.expand(
34
+ size))
35
+ return normalized_feat + style_mean.expand(size)
36
+
37
+ ## AdaStd
38
+ def adaptive_std_normalization(content_feat, style_feat):
39
+ assert (content_feat.size()[:2] == style_feat.size()[:2])
40
+ size = content_feat.size()
41
+ style_mean, style_std = calc_mean_std(style_feat)
42
+ content_mean, content_std = calc_mean_std(content_feat)
43
+
44
+ normalized_feat = (content_feat) / content_std.expand(size)
45
+ return normalized_feat * style_std.expand(size)
46
+
47
+ ## EFDM
48
+ def exact_feature_distribution_matching(content_feat, style_feat):
49
+ assert (content_feat.size() == style_feat.size())
50
+ B, C, W, H = content_feat.size(0), content_feat.size(1), content_feat.size(2), content_feat.size(3)
51
+ value_content, index_content = torch.sort(content_feat.view(B,C,-1)) # sort conduct a deep copy here.
52
+ value_style, _ = torch.sort(style_feat.view(B,C,-1)) # sort conduct a deep copy here.
53
+ inverse_index = index_content.argsort(-1)
54
+ new_content = content_feat.view(B,C,-1) + (value_style.gather(-1, inverse_index) - content_feat.view(B,C,-1).detach())
55
+
56
+ return new_content.view(B, C, W, H)
57
+
58
+ ## HM
59
+ def histogram_matching(content_feat, style_feat):
60
+ assert (content_feat.size() == style_feat.size())
61
+ B, C, W, H = content_feat.size(0), content_feat.size(1), content_feat.size(2), content_feat.size(3)
62
+ x_view = content_feat.view(-1, W,H)
63
+ image1_temp = match_histograms(np.array(x_view.detach().clone().cpu().float().transpose(0, 2)),
64
+ np.array(style_feat.view(-1, W, H).detach().clone().cpu().float().transpose(0, 2)),
65
+ multichannel=True)
66
+ image1_temp = torch.from_numpy(image1_temp).float().to(content_feat.device).transpose(0, 2).view(B, C, W, H)
67
+ return content_feat + (image1_temp - content_feat).detach()
68
+
69
+
70
+
71
+ def _calc_feat_flatten_mean_std(feat):
72
+ # takes 3D feat (C, H, W), return mean and std of array within channels
73
+ assert (feat.size()[0] == 3)
74
+ assert (isinstance(feat, torch.FloatTensor))
75
+ feat_flatten = feat.view(3, -1)
76
+ mean = feat_flatten.mean(dim=-1, keepdim=True)
77
+ std = feat_flatten.std(dim=-1, keepdim=True)
78
+ return feat_flatten, mean, std
79
+
80
+
81
+ def _mat_sqrt(x):
82
+ U, D, V = torch.svd(x)
83
+ return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
84
+
85
+
86
+ def coral(source, target):
87
+ # assume both source and target are 3D array (C, H, W)
88
+ # Note: flatten -> f
89
+
90
+ source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
91
+ source_f_norm = (source_f - source_f_mean.expand_as(
92
+ source_f)) / source_f_std.expand_as(source_f)
93
+ source_f_cov_eye = \
94
+ torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
95
+
96
+ target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
97
+ target_f_norm = (target_f - target_f_mean.expand_as(
98
+ target_f)) / target_f_std.expand_as(target_f)
99
+ target_f_cov_eye = \
100
+ torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
101
+
102
+ source_f_norm_transfer = torch.mm(
103
+ _mat_sqrt(target_f_cov_eye),
104
+ torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
105
+ source_f_norm)
106
+ )
107
+
108
+ source_f_transfer = source_f_norm_transfer * \
109
+ target_f_std.expand_as(source_f_norm) + \
110
+ target_f_mean.expand_as(source_f_norm)
111
+
112
+ return source_f_transfer.view(source.size())
net.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from function import adaptive_mean_normalization as adamean
4
+ from function import adaptive_std_normalization as adastd
5
+ from function import adaptive_instance_normalization as adain
6
+ from function import exact_feature_distribution_matching as efdm
7
+ from function import histogram_matching as hm
8
+
9
+ from function import calc_mean_std
10
+ # import ipdb
11
+ from skimage.exposure import match_histograms
12
+ import numpy as np
13
+
14
+ decoder = nn.Sequential(
15
+ nn.ReflectionPad2d((1, 1, 1, 1)),
16
+ nn.Conv2d(512, 256, (3, 3)),
17
+ nn.ReLU(),
18
+ nn.Upsample(scale_factor=2, mode='nearest'),
19
+ nn.ReflectionPad2d((1, 1, 1, 1)),
20
+ nn.Conv2d(256, 256, (3, 3)),
21
+ nn.ReLU(),
22
+ nn.ReflectionPad2d((1, 1, 1, 1)),
23
+ nn.Conv2d(256, 256, (3, 3)),
24
+ nn.ReLU(),
25
+ nn.ReflectionPad2d((1, 1, 1, 1)),
26
+ nn.Conv2d(256, 256, (3, 3)),
27
+ nn.ReLU(),
28
+ nn.ReflectionPad2d((1, 1, 1, 1)),
29
+ nn.Conv2d(256, 128, (3, 3)),
30
+ nn.ReLU(),
31
+ nn.Upsample(scale_factor=2, mode='nearest'),
32
+ nn.ReflectionPad2d((1, 1, 1, 1)),
33
+ nn.Conv2d(128, 128, (3, 3)),
34
+ nn.ReLU(),
35
+ nn.ReflectionPad2d((1, 1, 1, 1)),
36
+ nn.Conv2d(128, 64, (3, 3)),
37
+ nn.ReLU(),
38
+ nn.Upsample(scale_factor=2, mode='nearest'),
39
+ nn.ReflectionPad2d((1, 1, 1, 1)),
40
+ nn.Conv2d(64, 64, (3, 3)),
41
+ nn.ReLU(),
42
+ nn.ReflectionPad2d((1, 1, 1, 1)),
43
+ nn.Conv2d(64, 3, (3, 3)),
44
+ )
45
+
46
+ vgg = nn.Sequential(
47
+ nn.Conv2d(3, 3, (1, 1)),
48
+ nn.ReflectionPad2d((1, 1, 1, 1)),
49
+ nn.Conv2d(3, 64, (3, 3)),
50
+ nn.ReLU(), # relu1-1
51
+ nn.ReflectionPad2d((1, 1, 1, 1)),
52
+ nn.Conv2d(64, 64, (3, 3)),
53
+ nn.ReLU(), # relu1-2
54
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
55
+ nn.ReflectionPad2d((1, 1, 1, 1)),
56
+ nn.Conv2d(64, 128, (3, 3)),
57
+ nn.ReLU(), # relu2-1
58
+ nn.ReflectionPad2d((1, 1, 1, 1)),
59
+ nn.Conv2d(128, 128, (3, 3)),
60
+ nn.ReLU(), # relu2-2
61
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
62
+ nn.ReflectionPad2d((1, 1, 1, 1)),
63
+ nn.Conv2d(128, 256, (3, 3)),
64
+ nn.ReLU(), # relu3-1
65
+ nn.ReflectionPad2d((1, 1, 1, 1)),
66
+ nn.Conv2d(256, 256, (3, 3)),
67
+ nn.ReLU(), # relu3-2
68
+ nn.ReflectionPad2d((1, 1, 1, 1)),
69
+ nn.Conv2d(256, 256, (3, 3)),
70
+ nn.ReLU(), # relu3-3
71
+ nn.ReflectionPad2d((1, 1, 1, 1)),
72
+ nn.Conv2d(256, 256, (3, 3)),
73
+ nn.ReLU(), # relu3-4
74
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
75
+ nn.ReflectionPad2d((1, 1, 1, 1)),
76
+ nn.Conv2d(256, 512, (3, 3)),
77
+ nn.ReLU(), # relu4-1, this is the last layer used
78
+ nn.ReflectionPad2d((1, 1, 1, 1)),
79
+ nn.Conv2d(512, 512, (3, 3)),
80
+ nn.ReLU(), # relu4-2
81
+ nn.ReflectionPad2d((1, 1, 1, 1)),
82
+ nn.Conv2d(512, 512, (3, 3)),
83
+ nn.ReLU(), # relu4-3
84
+ nn.ReflectionPad2d((1, 1, 1, 1)),
85
+ nn.Conv2d(512, 512, (3, 3)),
86
+ nn.ReLU(), # relu4-4
87
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
88
+ nn.ReflectionPad2d((1, 1, 1, 1)),
89
+ nn.Conv2d(512, 512, (3, 3)),
90
+ nn.ReLU(), # relu5-1
91
+ nn.ReflectionPad2d((1, 1, 1, 1)),
92
+ nn.Conv2d(512, 512, (3, 3)),
93
+ nn.ReLU(), # relu5-2
94
+ nn.ReflectionPad2d((1, 1, 1, 1)),
95
+ nn.Conv2d(512, 512, (3, 3)),
96
+ nn.ReLU(), # relu5-3
97
+ nn.ReflectionPad2d((1, 1, 1, 1)),
98
+ nn.Conv2d(512, 512, (3, 3)),
99
+ nn.ReLU() # relu5-4
100
+ )
101
+
102
+
103
+ class Net(nn.Module):
104
+ def __init__(self, encoder, decoder, style):
105
+ super(Net, self).__init__()
106
+ enc_layers = list(encoder.children())
107
+ self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
108
+ self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
109
+ self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
110
+ self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
111
+ self.decoder = decoder
112
+ self.mse_loss = nn.MSELoss()
113
+ self.style = style
114
+
115
+ # fix the encoder
116
+ for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
117
+ for param in getattr(self, name).parameters():
118
+ param.requires_grad = False
119
+
120
+ # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
121
+ def encode_with_intermediate(self, input):
122
+ results = [input]
123
+ for i in range(4):
124
+ func = getattr(self, 'enc_{:d}'.format(i + 1))
125
+ results.append(func(results[-1]))
126
+ return results[1:]
127
+
128
+ # extract relu4_1 from input image
129
+ def encode(self, input):
130
+ for i in range(4):
131
+ input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
132
+ return input
133
+
134
+ def calc_content_loss(self, input, target):
135
+ assert (input.size() == target.size())
136
+ assert (target.requires_grad is False)
137
+ return self.mse_loss(input, target)
138
+
139
+ def calc_style_loss(self, input, target):
140
+ # ipdb.set_trace()
141
+ assert (input.size() == target.size())
142
+ assert (target.requires_grad is False) ## first make sure which one require gradient and which one do not.
143
+ # print(input.requires_grad) ## True
144
+ input_mean, input_std = calc_mean_std(input)
145
+ target_mean, target_std = calc_mean_std(target)
146
+ if self.style == 'adain':
147
+ return self.mse_loss(input_mean, target_mean) + \
148
+ self.mse_loss(input_std, target_std)
149
+ elif self.style == 'adamean':
150
+ return self.mse_loss(input_mean, target_mean)
151
+ elif self.style == 'adastd':
152
+ return self.mse_loss(input_std, target_std)
153
+ elif self.style == 'efdm':
154
+ B, C, W, H = input.size(0), input.size(1), input.size(2), input.size(3)
155
+ value_content, index_content = torch.sort(input.view(B, C, -1))
156
+ value_style, index_style = torch.sort(target.view(B, C, -1))
157
+ inverse_index = index_content.argsort(-1)
158
+ return self.mse_loss(input.view(B,C,-1), value_style.gather(-1, inverse_index))
159
+ elif self.style == 'hm':
160
+ B, C, W, H = input.size(0), input.size(1), input.size(2), input.size(3)
161
+ x_view = input.view(-1, W, H)
162
+ image1_temp = match_histograms(np.array(x_view.detach().clone().cpu().float().transpose(0, 2)),
163
+ np.array(target.view(-1, W, H).detach().clone().cpu().float().transpose(0,2)),
164
+ multichannel=True)
165
+ image1_temp = torch.from_numpy(image1_temp).float().to(input.device).transpose(0, 2).view(B, C, W, H)
166
+ return self.mse_loss(input.reshape(B, C, -1), image1_temp.reshape(B, C, -1))
167
+ else:
168
+ raise NotImplementedError
169
+
170
+ def forward(self, content, style, alpha=1.0):
171
+ assert 0 <= alpha <= 1
172
+ # ipdb.set_trace()
173
+ style_feats = self.encode_with_intermediate(style)
174
+ content_feat = self.encode(content)
175
+ # print(content_feat.requires_grad) False
176
+ # print(style_feats[-1].requires_grad) False
177
+ if self.style == 'adain':
178
+ t = adain(content_feat, style_feats[-1])
179
+ elif self.style == 'adamean':
180
+ t = adamean(content_feat, style_feats[-1])
181
+ elif self.style == 'adastd':
182
+ t = adastd(content_feat, style_feats[-1])
183
+ elif self.style == 'efdm':
184
+ t = efdm(content_feat, style_feats[-1])
185
+ elif self.style == 'hm':
186
+ t = hm(content_feat, style_feats[-1])
187
+ else:
188
+ raise NotImplementedError
189
+ t = alpha * t + (1 - alpha) * content_feat
190
+
191
+ g_t = self.decoder(t)
192
+ g_t_feats = self.encode_with_intermediate(g_t)
193
+
194
+ loss_c = self.calc_content_loss(g_t_feats[-1], t) ### final feature should be the same.
195
+ loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
196
+ for i in range(1, 4):
197
+ loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
198
+ return loss_c, loss_s
test.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from torchvision.utils import save_image
9
+ import time
10
+ import net
11
+ from function import adaptive_instance_normalization, coral
12
+ from function import adaptive_mean_normalization
13
+ from function import adaptive_std_normalization
14
+ from function import exact_feature_distribution_matching, histogram_matching
15
+
16
+ def test_transform(size, crop):
17
+ transform_list = []
18
+ if size != 0:
19
+ transform_list.append(transforms.Resize(size))
20
+ if crop:
21
+ transform_list.append(transforms.CenterCrop(size))
22
+ transform_list.append(transforms.ToTensor())
23
+ transform = transforms.Compose(transform_list)
24
+ return transform
25
+
26
+
27
+ def style_transfer(vgg, decoder, content, style, alpha=1.0,
28
+ interpolation_weights=None, style_type='adain'):
29
+ assert (0.0 <= alpha <= 1.0)
30
+ content_f = vgg(content)
31
+ style_f = vgg(style)
32
+ if interpolation_weights:
33
+ _, C, H, W = content_f.size()
34
+ feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
35
+ if style_type == 'adain':
36
+ base_feat = adaptive_instance_normalization(content_f, style_f)
37
+ elif style_type == 'adamean':
38
+ base_feat = adaptive_mean_normalization(content_f, style_f)
39
+ elif style_type == 'adastd':
40
+ base_feat = adaptive_std_normalization(content_f, style_f)
41
+ elif style_type == 'efdm':
42
+ base_feat = exact_feature_distribution_matching(content_f, style_f)
43
+ elif style_type == 'hm':
44
+ feat = histogram_matching(content_f, style_f)
45
+ else:
46
+ raise NotImplementedError
47
+ for i, w in enumerate(interpolation_weights):
48
+ feat = feat + w * base_feat[i:i + 1]
49
+ content_f = content_f[0:1]
50
+ else:
51
+ if style_type == 'adain':
52
+ feat = adaptive_instance_normalization(content_f, style_f)
53
+ elif style_type == 'adamean':
54
+ feat = adaptive_mean_normalization(content_f, style_f)
55
+ elif style_type == 'adastd':
56
+ feat = adaptive_std_normalization(content_f, style_f)
57
+ elif style_type == 'efdm':
58
+ feat = exact_feature_distribution_matching(content_f, style_f)
59
+ elif style_type == 'hm':
60
+ feat = histogram_matching(content_f, style_f)
61
+ else:
62
+ raise NotImplementedError
63
+ feat = feat * alpha + content_f * (1 - alpha)
64
+ return decoder(feat)
65
+
66
+
67
+ parser = argparse.ArgumentParser()
68
+ # Basic options
69
+ parser.add_argument('--content', type=str,
70
+ help='File path to the content image')
71
+ parser.add_argument('--content_dir', type=str,
72
+ help='Directory path to a batch of content images')
73
+ parser.add_argument('--style', type=str,
74
+ help='File path to the style image, or multiple style \
75
+ images separated by commas if you want to do style \
76
+ interpolation or spatial control')
77
+ parser.add_argument('--style_dir', type=str,
78
+ help='Directory path to a batch of style images')
79
+ parser.add_argument('--vgg', type=str, default='pretrained/vgg_normalised.pth')
80
+ parser.add_argument('--decoder', type=str, default='pretrained/efdm_decoder_iter_160000.pth.tar')
81
+ parser.add_argument('--style_type', type=str, default='adain', help='adain | adamean | adastd | efdm')
82
+ parser.add_argument('--test_style_type', type=str, default='', help='adain | adamean | adastd | efdm')
83
+ # Additional options
84
+ parser.add_argument('--content_size', type=int, default=512,
85
+ help='New (minimum) size for the content image, \
86
+ keeping the original size if set to 0')
87
+ parser.add_argument('--style_size', type=int, default=512,
88
+ help='New (minimum) size for the style image, \
89
+ keeping the original size if set to 0')
90
+ parser.add_argument('--crop', action='store_true',
91
+ help='do center crop to create squared image')
92
+ parser.add_argument('--save_ext', default='.jpg',
93
+ help='The extension name of the output image')
94
+ parser.add_argument('--output', type=str, default='output',
95
+ help='Directory to save the output image(s)')
96
+ parser.add_argument('--photo', action='store_true',
97
+ help='apply on the photo style transfer')
98
+ # Advanced options
99
+ parser.add_argument('--preserve_color', action='store_true',
100
+ help='If specified, preserve color of the content image')
101
+ parser.add_argument('--alpha', type=float, default=1.0,
102
+ help='The weight that controls the degree of \
103
+ stylization. Should be between 0 and 1')
104
+ parser.add_argument(
105
+ '--style_interpolation_weights', type=str, default='',
106
+ help='The weight for blending the style of multiple style images')
107
+
108
+ args = parser.parse_args()
109
+ if not args.test_style_type:
110
+ args.test_style_type = args.style_type
111
+
112
+ print('Note: the style type: %s and the pre-trained model: %s should be consistent' % (args.style_type, args.decoder))
113
+ print('The test style type is:', args.test_style_type)
114
+
115
+ do_interpolation = False
116
+
117
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
118
+
119
+ output_dir = Path(args.output + '_' + args.style_type + '_' + args.test_style_type)
120
+ output_dir.mkdir(exist_ok=True, parents=True)
121
+
122
+ # Either --content or --contentDir should be given.
123
+ assert (args.content or args.content_dir)
124
+ if args.content:
125
+ content_paths = [Path(args.content)]
126
+ else:
127
+ content_dir = Path(args.content_dir)
128
+ content_paths = [f for f in content_dir.glob('*')]
129
+
130
+ # Either --style or --styleDir should be given.
131
+ assert (args.style or args.style_dir)
132
+ if args.style:
133
+ style_paths = args.style.split(',')
134
+ if len(style_paths) == 1:
135
+ style_paths = [Path(args.style)]
136
+ else:
137
+ do_interpolation = True
138
+ # assert (args.style_interpolation_weights != ''), \
139
+ # 'Please specify interpolation weights'
140
+ # weights = [int(i) for i in args.style_interpolation_weights.split(',')]
141
+ # interpolation_weights = [w / sum(weights) for w in weights]
142
+ else:
143
+ style_dir = Path(args.style_dir)
144
+ style_paths = [f for f in style_dir.glob('*')]
145
+
146
+ decoder = net.decoder
147
+ vgg = net.vgg
148
+
149
+ decoder.eval()
150
+ vgg.eval()
151
+
152
+ decoder.load_state_dict(torch.load(args.decoder))
153
+ vgg.load_state_dict(torch.load(args.vgg))
154
+ vgg = nn.Sequential(*list(vgg.children())[:31])
155
+
156
+ vgg.to(device)
157
+ decoder.to(device)
158
+
159
+ content_tf = test_transform(args.content_size, args.crop)
160
+ style_tf = test_transform(args.style_size, args.crop)
161
+
162
+ timer = []
163
+ for content_path in content_paths:
164
+ if do_interpolation:
165
+ # one content image, 4 style image
166
+ style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths])
167
+ content = content_tf(Image.open(str(content_path))) \
168
+ .unsqueeze(0).expand_as(style)
169
+ style = style.to(device)
170
+ content = content.to(device)
171
+ list = []
172
+ steps = [1, 0.75, 0.5, 0.25, 0]
173
+ for i in steps:
174
+ for j in steps:
175
+ list.append([i*j, i*(1-j), (1-i)*j, (1-i)*(1-j)])
176
+ count = 1
177
+ for interpolation_weights in list:
178
+ with torch.no_grad():
179
+ output = style_transfer(vgg, decoder, content, style,
180
+ args.alpha, interpolation_weights, style_type=args.test_style_type)
181
+ output = output.cpu()
182
+ output_name = output_dir / '{:s}_interpolate_{:s}_{:s}'.format(
183
+ content_path.stem, str(count), args.save_ext)
184
+ save_image(output, str(output_name))
185
+ count+=1
186
+
187
+ #### content & style trade-off.
188
+ # alpha = [0.0, 0.25, 0.5, 0.75, 1.0]
189
+ # for style_path in style_paths:
190
+ # content = content_tf(Image.open(str(content_path)))
191
+ # style = style_tf(Image.open(str(style_path)))
192
+ # if args.preserve_color:
193
+ # style = coral(style, content)
194
+ # style = style.to(device).unsqueeze(0)
195
+ # content = content.to(device).unsqueeze(0)
196
+ # ## replace the style image with Gaussian noise
197
+ # # style.normal_(0,1)
198
+ # # style = torch.rand(style.size()).to(device)
199
+ # ### for paired images.
200
+ # if args.photo:
201
+ # if content_path.stem[2:] == style_path.stem[3:]:
202
+ # for sample_alpha in alpha:
203
+ # with torch.no_grad():
204
+ # output = style_transfer(vgg, decoder, content, style,
205
+ # sample_alpha, style_type=args.test_style_type)
206
+ # output = output.cpu()
207
+ # output_name = output_dir / '{:s}_stylized_{:s}{:s}{:s}'.format(
208
+ # content_path.stem, style_path.stem, str(sample_alpha), args.save_ext)
209
+ # save_image(output, str(output_name))
210
+ # else:
211
+ # for sample_alpha in alpha:
212
+ # with torch.no_grad():
213
+ # output = style_transfer(vgg, decoder, content, style,
214
+ # sample_alpha, style_type=args.test_style_type)
215
+ # output = output.cpu()
216
+ # output_name = output_dir / '{:s}_stylized_{:s}{:s}{:s}'.format(
217
+ # content_path.stem, style_path.stem, str(sample_alpha), args.save_ext)
218
+ # save_image(output, str(output_name))
219
+ else: # process one content and one style
220
+ for style_path in style_paths:
221
+ content = content_tf(Image.open(str(content_path)))
222
+ style = style_tf(Image.open(str(style_path)))
223
+ if args.preserve_color:
224
+ style = coral(style, content)
225
+ style = style.to(device).unsqueeze(0)
226
+ content = content.to(device).unsqueeze(0)
227
+ ## replace the style image with Gaussian noise
228
+ # style.normal_(0,1)
229
+ # style = torch.rand(style.size()).to(device)
230
+ ### for paired images.
231
+ if args.photo:
232
+ if content_path.stem[2:] == style_path.stem[3:]:
233
+ with torch.no_grad():
234
+ start_time = time.time()
235
+ output = style_transfer(vgg, decoder, content, style,
236
+ args.alpha, style_type=args.test_style_type)
237
+ timer.append(time.time() - start_time)
238
+ print(timer)
239
+
240
+ output = output.cpu()
241
+ output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format(
242
+ content_path.stem, style_path.stem, args.save_ext)
243
+ save_image(output, str(output_name))
244
+ else:
245
+ with torch.no_grad():
246
+ output = style_transfer(vgg, decoder, content, style,
247
+ args.alpha, style_type=args.test_style_type)
248
+ output = output.cpu()
249
+ output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format(
250
+ content_path.stem, style_path.stem, args.save_ext)
251
+ save_image(output, str(output_name))
252
+ print(torch.FloatTensor(timer).mean())