Spaces:
Runtime error
Runtime error
Upload neural_style.py
Browse files- neural_style.py +503 -0
neural_style.py
ADDED
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.optim as optim
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
from CaffeLoader import loadCaffemodel, ModelParallel
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
# Basic options
|
14 |
+
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg')
|
15 |
+
parser.add_argument("-style_blend_weights", default=None)
|
16 |
+
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg')
|
17 |
+
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512)
|
18 |
+
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = c", default=0)
|
19 |
+
|
20 |
+
# Optimization options
|
21 |
+
parser.add_argument("-content_weight", type=float, default=5e0)
|
22 |
+
parser.add_argument("-style_weight", type=float, default=1e2)
|
23 |
+
parser.add_argument("-normalize_weights", action='store_true')
|
24 |
+
parser.add_argument("-tv_weight", type=float, default=1e-3)
|
25 |
+
parser.add_argument("-num_iterations", type=int, default=1000)
|
26 |
+
parser.add_argument("-init", choices=['random', 'image'], default='random')
|
27 |
+
parser.add_argument("-init_image", default=None)
|
28 |
+
parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='lbfgs')
|
29 |
+
parser.add_argument("-learning_rate", type=float, default=1e0)
|
30 |
+
parser.add_argument("-lbfgs_num_correction", type=int, default=100)
|
31 |
+
|
32 |
+
# Output options
|
33 |
+
parser.add_argument("-print_iter", type=int, default=50)
|
34 |
+
parser.add_argument("-save_iter", type=int, default=100)
|
35 |
+
parser.add_argument("-output_image", default='out.png')
|
36 |
+
|
37 |
+
# Other options
|
38 |
+
parser.add_argument("-style_scale", type=float, default=1.0)
|
39 |
+
parser.add_argument("-original_colors", type=int, choices=[0, 1], default=0)
|
40 |
+
parser.add_argument("-pooling", choices=['avg', 'max'], default='max')
|
41 |
+
parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth')
|
42 |
+
parser.add_argument("-disable_check", action='store_true')
|
43 |
+
parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl', 'mkldnn', 'openmp', 'mkl,cudnn', 'cudnn,mkl'], default='nn')
|
44 |
+
parser.add_argument("-cudnn_autotune", action='store_true')
|
45 |
+
parser.add_argument("-seed", type=int, default=-1)
|
46 |
+
|
47 |
+
parser.add_argument("-content_layers", help="layers for content", default='relu4_2')
|
48 |
+
parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1')
|
49 |
+
|
50 |
+
parser.add_argument("-multidevice_strategy", default='4,7,29')
|
51 |
+
params = parser.parse_args()
|
52 |
+
|
53 |
+
|
54 |
+
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images
|
55 |
+
|
56 |
+
|
57 |
+
class TransferParams():
|
58 |
+
style_image = 'examples/inputs/seated-nude.jpg'
|
59 |
+
style_blend_weights = None
|
60 |
+
content_image = 'examples/inputs/tubingen.jpg'
|
61 |
+
image_size = 512
|
62 |
+
gpu = 0
|
63 |
+
content_weight = 5e0
|
64 |
+
style_weight = 1e2
|
65 |
+
normalize_weights = False
|
66 |
+
tv_weight = 1e-3
|
67 |
+
num_iterations = 1000
|
68 |
+
init = 'random'
|
69 |
+
init_image = None
|
70 |
+
optimizer = 'lbfgs'
|
71 |
+
learning_rate = 1e0
|
72 |
+
lbfgs_num_correction = 100
|
73 |
+
print_iter = 50
|
74 |
+
save_iter = 100
|
75 |
+
output_image = 'out.png'
|
76 |
+
log_level = 10
|
77 |
+
style_scale = 1.0
|
78 |
+
original_colors = 0
|
79 |
+
pooling = 'max'
|
80 |
+
model_file = 'models/vgg19-d01eb7cb.pth'
|
81 |
+
disable_check = False
|
82 |
+
backend = 'nn'
|
83 |
+
cudnn_autotune = False
|
84 |
+
seed = -1
|
85 |
+
content_layers = 'relu4_2'
|
86 |
+
style_layers = 'relu1_1,relu2_1,relu3_1,relu4_1,relu5_1'
|
87 |
+
multidevice_strategy = '4,7,29'
|
88 |
+
|
89 |
+
def main():
|
90 |
+
transfer(params)
|
91 |
+
|
92 |
+
def transfer(params):
|
93 |
+
dtype, multidevice, backward_device = setup_gpu()
|
94 |
+
|
95 |
+
cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, params.disable_check)
|
96 |
+
|
97 |
+
content_image = preprocess(params.content_image, params.image_size).type(dtype)
|
98 |
+
style_image_input = params.style_image.split(',')
|
99 |
+
style_image_list, ext = [], [".jpg", ".jpeg", ".png", ".tiff"]
|
100 |
+
for image in style_image_input:
|
101 |
+
if os.path.isdir(image):
|
102 |
+
images = (image + "/" + file for file in os.listdir(image)
|
103 |
+
if os.path.splitext(file)[1].lower() in ext)
|
104 |
+
style_image_list.extend(images)
|
105 |
+
else:
|
106 |
+
style_image_list.append(image)
|
107 |
+
style_images_caffe = []
|
108 |
+
for image in style_image_list:
|
109 |
+
style_size = int(params.image_size * params.style_scale)
|
110 |
+
img_caffe = preprocess(image, style_size).type(dtype)
|
111 |
+
style_images_caffe.append(img_caffe)
|
112 |
+
|
113 |
+
if params.init_image != None:
|
114 |
+
image_size = (content_image.size(2), content_image.size(3))
|
115 |
+
init_image = preprocess(params.init_image, image_size).type(dtype)
|
116 |
+
|
117 |
+
# Handle style blending weights for multiple style inputs
|
118 |
+
style_blend_weights = []
|
119 |
+
if params.style_blend_weights == None:
|
120 |
+
# Style blending not specified, so use equal weighting
|
121 |
+
for i in style_image_list:
|
122 |
+
style_blend_weights.append(1.0)
|
123 |
+
for i, blend_weights in enumerate(style_blend_weights):
|
124 |
+
style_blend_weights[i] = int(style_blend_weights[i])
|
125 |
+
else:
|
126 |
+
style_blend_weights = params.style_blend_weights.split(',')
|
127 |
+
assert len(style_blend_weights) == len(style_image_list), \
|
128 |
+
"-style_blend_weights and -style_images must have the same number of elements!"
|
129 |
+
|
130 |
+
# Normalize the style blending weights so they sum to 1
|
131 |
+
style_blend_sum = 0
|
132 |
+
for i, blend_weights in enumerate(style_blend_weights):
|
133 |
+
style_blend_weights[i] = float(style_blend_weights[i])
|
134 |
+
style_blend_sum = float(style_blend_sum) + style_blend_weights[i]
|
135 |
+
for i, blend_weights in enumerate(style_blend_weights):
|
136 |
+
style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum)
|
137 |
+
|
138 |
+
content_layers = params.content_layers.split(',')
|
139 |
+
style_layers = params.style_layers.split(',')
|
140 |
+
|
141 |
+
# Set up the network, inserting style and content loss modules
|
142 |
+
cnn = copy.deepcopy(cnn)
|
143 |
+
content_losses, style_losses, tv_losses = [], [], []
|
144 |
+
next_content_idx, next_style_idx = 1, 1
|
145 |
+
net = nn.Sequential()
|
146 |
+
c, r = 0, 0
|
147 |
+
if params.tv_weight > 0:
|
148 |
+
tv_mod = TVLoss(params.tv_weight).type(dtype)
|
149 |
+
net.add_module(str(len(net)), tv_mod)
|
150 |
+
tv_losses.append(tv_mod)
|
151 |
+
|
152 |
+
for i, layer in enumerate(list(cnn), 1):
|
153 |
+
if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers):
|
154 |
+
if isinstance(layer, nn.Conv2d):
|
155 |
+
net.add_module(str(len(net)), layer)
|
156 |
+
|
157 |
+
if layerList['C'][c] in content_layers:
|
158 |
+
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
|
159 |
+
loss_module = ContentLoss(params.content_weight)
|
160 |
+
net.add_module(str(len(net)), loss_module)
|
161 |
+
content_losses.append(loss_module)
|
162 |
+
|
163 |
+
if layerList['C'][c] in style_layers:
|
164 |
+
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
|
165 |
+
loss_module = StyleLoss(params.style_weight)
|
166 |
+
net.add_module(str(len(net)), loss_module)
|
167 |
+
style_losses.append(loss_module)
|
168 |
+
c+=1
|
169 |
+
|
170 |
+
if isinstance(layer, nn.ReLU):
|
171 |
+
net.add_module(str(len(net)), layer)
|
172 |
+
|
173 |
+
if layerList['R'][r] in content_layers:
|
174 |
+
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
|
175 |
+
loss_module = ContentLoss(params.content_weight)
|
176 |
+
net.add_module(str(len(net)), loss_module)
|
177 |
+
content_losses.append(loss_module)
|
178 |
+
next_content_idx += 1
|
179 |
+
|
180 |
+
if layerList['R'][r] in style_layers:
|
181 |
+
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
|
182 |
+
loss_module = StyleLoss(params.style_weight)
|
183 |
+
net.add_module(str(len(net)), loss_module)
|
184 |
+
style_losses.append(loss_module)
|
185 |
+
next_style_idx += 1
|
186 |
+
r+=1
|
187 |
+
|
188 |
+
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
|
189 |
+
net.add_module(str(len(net)), layer)
|
190 |
+
|
191 |
+
if multidevice:
|
192 |
+
net = setup_multi_device(net)
|
193 |
+
|
194 |
+
# Capture content targets
|
195 |
+
for i in content_losses:
|
196 |
+
i.mode = 'capture'
|
197 |
+
print("Capturing content targets")
|
198 |
+
print_torch(net, multidevice)
|
199 |
+
net(content_image)
|
200 |
+
|
201 |
+
# Capture style targets
|
202 |
+
for i in content_losses:
|
203 |
+
i.mode = 'None'
|
204 |
+
|
205 |
+
for i, image in enumerate(style_images_caffe):
|
206 |
+
print("Capturing style target " + str(i+1))
|
207 |
+
for j in style_losses:
|
208 |
+
j.mode = 'capture'
|
209 |
+
j.blend_weight = style_blend_weights[i]
|
210 |
+
net(style_images_caffe[i])
|
211 |
+
|
212 |
+
# Set all loss modules to loss mode
|
213 |
+
for i in content_losses:
|
214 |
+
i.mode = 'loss'
|
215 |
+
for i in style_losses:
|
216 |
+
i.mode = 'loss'
|
217 |
+
|
218 |
+
# Maybe normalize content and style weights
|
219 |
+
if params.normalize_weights:
|
220 |
+
normalize_weights(content_losses, style_losses)
|
221 |
+
|
222 |
+
# Freeze the network in order to prevent
|
223 |
+
# unnecessary gradient calculations
|
224 |
+
for param in net.parameters():
|
225 |
+
param.requires_grad = False
|
226 |
+
|
227 |
+
# Initialize the image
|
228 |
+
if params.seed >= 0:
|
229 |
+
torch.manual_seed(params.seed)
|
230 |
+
torch.cuda.manual_seed_all(params.seed)
|
231 |
+
torch.backends.cudnn.deterministic=True
|
232 |
+
if params.init == 'random':
|
233 |
+
B, C, H, W = content_image.size()
|
234 |
+
img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype)
|
235 |
+
elif params.init == 'image':
|
236 |
+
if params.init_image != None:
|
237 |
+
img = init_image.clone()
|
238 |
+
else:
|
239 |
+
img = content_image.clone()
|
240 |
+
img = nn.Parameter(img)
|
241 |
+
|
242 |
+
def maybe_print(t, loss):
|
243 |
+
if params.print_iter > 0 and t % params.print_iter == 0:
|
244 |
+
print("Iteration " + str(t) + " / "+ str(params.num_iterations))
|
245 |
+
for i, loss_module in enumerate(content_losses):
|
246 |
+
print(" Content " + str(i+1) + " loss: " + str(loss_module.loss.item()))
|
247 |
+
for i, loss_module in enumerate(style_losses):
|
248 |
+
print(" Style " + str(i+1) + " loss: " + str(loss_module.loss.item()))
|
249 |
+
print(" Total loss: " + str(loss.item()))
|
250 |
+
|
251 |
+
def maybe_save(t):
|
252 |
+
should_save = params.save_iter > 0 and t % params.save_iter == 0
|
253 |
+
should_save = should_save or t == params.num_iterations
|
254 |
+
if should_save:
|
255 |
+
output_filename, file_extension = os.path.splitext(params.output_image)
|
256 |
+
if t == params.num_iterations:
|
257 |
+
filename = output_filename + str(file_extension)
|
258 |
+
else:
|
259 |
+
filename = str(output_filename) + "_" + str(t) + str(file_extension)
|
260 |
+
disp = deprocess(img.clone())
|
261 |
+
|
262 |
+
# Maybe perform postprocessing for color-independent style transfer
|
263 |
+
if params.original_colors == 1:
|
264 |
+
disp = original_colors(deprocess(content_image.clone()), disp)
|
265 |
+
|
266 |
+
disp.save(str(filename))
|
267 |
+
|
268 |
+
# Function to evaluate loss and gradient. We run the net forward and
|
269 |
+
# backward to get the gradient, and sum up losses from the loss modules.
|
270 |
+
# optim.lbfgs internally handles iteration and calls this function many
|
271 |
+
# times, so we manually count the number of iterations to handle printing
|
272 |
+
# and saving intermediate results.
|
273 |
+
num_calls = [0]
|
274 |
+
def feval():
|
275 |
+
num_calls[0] += 1
|
276 |
+
optimizer.zero_grad()
|
277 |
+
net(img)
|
278 |
+
loss = 0
|
279 |
+
|
280 |
+
for mod in content_losses:
|
281 |
+
loss += mod.loss.to(backward_device)
|
282 |
+
for mod in style_losses:
|
283 |
+
loss += mod.loss.to(backward_device)
|
284 |
+
if params.tv_weight > 0:
|
285 |
+
for mod in tv_losses:
|
286 |
+
loss += mod.loss.to(backward_device)
|
287 |
+
|
288 |
+
loss.backward()
|
289 |
+
|
290 |
+
maybe_save(num_calls[0])
|
291 |
+
maybe_print(num_calls[0], loss)
|
292 |
+
|
293 |
+
return loss
|
294 |
+
|
295 |
+
optimizer, loopVal = setup_optimizer(img)
|
296 |
+
while num_calls[0] <= loopVal:
|
297 |
+
optimizer.step(feval)
|
298 |
+
|
299 |
+
|
300 |
+
# Configure the optimizer
|
301 |
+
def setup_optimizer(img):
|
302 |
+
if params.optimizer == 'lbfgs':
|
303 |
+
print("Running optimization with L-BFGS")
|
304 |
+
optim_state = {
|
305 |
+
'max_iter': params.num_iterations,
|
306 |
+
'tolerance_change': -1,
|
307 |
+
'tolerance_grad': -1,
|
308 |
+
}
|
309 |
+
if params.lbfgs_num_correction != 100:
|
310 |
+
optim_state['history_size'] = params.lbfgs_num_correction
|
311 |
+
optimizer = optim.LBFGS([img], **optim_state)
|
312 |
+
loopVal = 1
|
313 |
+
elif params.optimizer == 'adam':
|
314 |
+
print("Running optimization with ADAM")
|
315 |
+
optimizer = optim.Adam([img], lr = params.learning_rate)
|
316 |
+
loopVal = params.num_iterations - 1
|
317 |
+
return optimizer, loopVal
|
318 |
+
|
319 |
+
|
320 |
+
def setup_gpu():
|
321 |
+
def setup_cuda():
|
322 |
+
if 'cudnn' in params.backend:
|
323 |
+
torch.backends.cudnn.enabled = True
|
324 |
+
if params.cudnn_autotune:
|
325 |
+
torch.backends.cudnn.benchmark = True
|
326 |
+
else:
|
327 |
+
torch.backends.cudnn.enabled = False
|
328 |
+
|
329 |
+
def setup_cpu():
|
330 |
+
if 'mkl' in params.backend and 'mkldnn' not in params.backend:
|
331 |
+
torch.backends.mkl.enabled = True
|
332 |
+
elif 'mkldnn' in params.backend:
|
333 |
+
raise ValueError("MKL-DNN is not supported yet.")
|
334 |
+
elif 'openmp' in params.backend:
|
335 |
+
torch.backends.openmp.enabled = True
|
336 |
+
|
337 |
+
multidevice = False
|
338 |
+
if "," in str(params.gpu):
|
339 |
+
devices = params.gpu.split(',')
|
340 |
+
multidevice = True
|
341 |
+
|
342 |
+
if 'c' in str(devices[0]).lower():
|
343 |
+
backward_device = "cpu"
|
344 |
+
setup_cuda(), setup_cpu()
|
345 |
+
else:
|
346 |
+
backward_device = "cuda:" + devices[0]
|
347 |
+
setup_cuda()
|
348 |
+
dtype = torch.FloatTensor
|
349 |
+
|
350 |
+
elif "c" not in str(params.gpu).lower():
|
351 |
+
setup_cuda()
|
352 |
+
dtype, backward_device = torch.cuda.FloatTensor, "cuda:" + str(params.gpu)
|
353 |
+
else:
|
354 |
+
setup_cpu()
|
355 |
+
dtype, backward_device = torch.FloatTensor, "cpu"
|
356 |
+
return dtype, multidevice, backward_device
|
357 |
+
|
358 |
+
|
359 |
+
def setup_multi_device(net):
|
360 |
+
assert len(params.gpu.split(',')) - 1 == len(params.multidevice_strategy.split(',')), \
|
361 |
+
"The number of -multidevice_strategy layer indices minus 1, must be equal to the number of -gpu devices."
|
362 |
+
|
363 |
+
new_net = ModelParallel(net, params.gpu, params.multidevice_strategy)
|
364 |
+
return new_net
|
365 |
+
|
366 |
+
|
367 |
+
# Preprocess an image before passing it to a model.
|
368 |
+
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR,
|
369 |
+
# and subtract the mean pixel.
|
370 |
+
def preprocess(image_name, image_size):
|
371 |
+
image = Image.open(image_name).convert('RGB')
|
372 |
+
if type(image_size) is not tuple:
|
373 |
+
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
|
374 |
+
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
|
375 |
+
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
|
376 |
+
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])])
|
377 |
+
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0)
|
378 |
+
return tensor
|
379 |
+
|
380 |
+
|
381 |
+
# Undo the above preprocessing.
|
382 |
+
def deprocess(output_tensor):
|
383 |
+
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])])
|
384 |
+
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
|
385 |
+
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256
|
386 |
+
output_tensor.clamp_(0, 1)
|
387 |
+
Image2PIL = transforms.ToPILImage()
|
388 |
+
image = Image2PIL(output_tensor.cpu())
|
389 |
+
return image
|
390 |
+
|
391 |
+
|
392 |
+
# Combine the Y channel of the generated image and the UV/CbCr channels of the
|
393 |
+
# content image to perform color-independent style transfer.
|
394 |
+
def original_colors(content, generated):
|
395 |
+
content_channels = list(content.convert('YCbCr').split())
|
396 |
+
generated_channels = list(generated.convert('YCbCr').split())
|
397 |
+
content_channels[0] = generated_channels[0]
|
398 |
+
return Image.merge('YCbCr', content_channels).convert('RGB')
|
399 |
+
|
400 |
+
|
401 |
+
# Print like Lua/Torch7
|
402 |
+
def print_torch(net, multidevice):
|
403 |
+
if multidevice:
|
404 |
+
return
|
405 |
+
simplelist = ""
|
406 |
+
for i, layer in enumerate(net, 1):
|
407 |
+
simplelist = simplelist + "(" + str(i) + ") -> "
|
408 |
+
print("nn.Sequential ( \n [input -> " + simplelist + "output]")
|
409 |
+
|
410 |
+
def strip(x):
|
411 |
+
return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", "
|
412 |
+
def n():
|
413 |
+
return " (" + str(i) + "): " + "nn." + str(l).split("(", 1)[0]
|
414 |
+
|
415 |
+
for i, l in enumerate(net, 1):
|
416 |
+
if "2d" in str(l):
|
417 |
+
ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding)
|
418 |
+
if "Conv2d" in str(l):
|
419 |
+
ch = str(l.in_channels) + " -> " + str(l.out_channels)
|
420 |
+
print(n() + "(" + ch + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')'))
|
421 |
+
elif "Pool2d" in str(l):
|
422 |
+
st = st.replace(" ",' ') + st.replace(", ",')')
|
423 |
+
print(n() + "(" + ((ks).replace(",",'x' + ks, 1) + st).replace(", ",','))
|
424 |
+
else:
|
425 |
+
print(n())
|
426 |
+
print(")")
|
427 |
+
|
428 |
+
|
429 |
+
# Divide weights by channel size
|
430 |
+
def normalize_weights(content_losses, style_losses):
|
431 |
+
for n, i in enumerate(content_losses):
|
432 |
+
i.strength = i.strength / max(i.target.size())
|
433 |
+
for n, i in enumerate(style_losses):
|
434 |
+
i.strength = i.strength / max(i.target.size())
|
435 |
+
|
436 |
+
|
437 |
+
# Define an nn Module to compute content loss
|
438 |
+
class ContentLoss(nn.Module):
|
439 |
+
|
440 |
+
def __init__(self, strength):
|
441 |
+
super(ContentLoss, self).__init__()
|
442 |
+
self.strength = strength
|
443 |
+
self.crit = nn.MSELoss()
|
444 |
+
self.mode = 'None'
|
445 |
+
|
446 |
+
def forward(self, input):
|
447 |
+
if self.mode == 'loss':
|
448 |
+
self.loss = self.crit(input, self.target) * self.strength
|
449 |
+
elif self.mode == 'capture':
|
450 |
+
self.target = input.detach()
|
451 |
+
return input
|
452 |
+
|
453 |
+
|
454 |
+
class GramMatrix(nn.Module):
|
455 |
+
|
456 |
+
def forward(self, input):
|
457 |
+
B, C, H, W = input.size()
|
458 |
+
x_flat = input.view(C, H * W)
|
459 |
+
return torch.mm(x_flat, x_flat.t())
|
460 |
+
|
461 |
+
|
462 |
+
# Define an nn Module to compute style loss
|
463 |
+
class StyleLoss(nn.Module):
|
464 |
+
|
465 |
+
def __init__(self, strength):
|
466 |
+
super(StyleLoss, self).__init__()
|
467 |
+
self.target = torch.Tensor()
|
468 |
+
self.strength = strength
|
469 |
+
self.gram = GramMatrix()
|
470 |
+
self.crit = nn.MSELoss()
|
471 |
+
self.mode = 'None'
|
472 |
+
self.blend_weight = None
|
473 |
+
|
474 |
+
def forward(self, input):
|
475 |
+
self.G = self.gram(input)
|
476 |
+
self.G = self.G.div(input.nelement())
|
477 |
+
if self.mode == 'capture':
|
478 |
+
if self.blend_weight == None:
|
479 |
+
self.target = self.G.detach()
|
480 |
+
elif self.target.nelement() == 0:
|
481 |
+
self.target = self.G.detach().mul(self.blend_weight)
|
482 |
+
else:
|
483 |
+
self.target = self.target.add(self.blend_weight, self.G.detach())
|
484 |
+
elif self.mode == 'loss':
|
485 |
+
self.loss = self.strength * self.crit(self.G, self.target)
|
486 |
+
return input
|
487 |
+
|
488 |
+
|
489 |
+
class TVLoss(nn.Module):
|
490 |
+
|
491 |
+
def __init__(self, strength):
|
492 |
+
super(TVLoss, self).__init__()
|
493 |
+
self.strength = strength
|
494 |
+
|
495 |
+
def forward(self, input):
|
496 |
+
self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:]
|
497 |
+
self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1]
|
498 |
+
self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff)))
|
499 |
+
return input
|
500 |
+
|
501 |
+
|
502 |
+
if __name__ == "__main__":
|
503 |
+
main()
|