hylee commited on
Commit
7a7f105
1 Parent(s): 7afae4c
Files changed (8) hide show
  1. app.py +88 -0
  2. packages.txt +2 -0
  3. requirements.txt +4 -0
  4. ugatit/UGATIT.py +665 -0
  5. ugatit/main.py +106 -0
  6. ugatit/ops.py +345 -0
  7. ugatit/utils.py +80 -0
  8. ugatit_test.py +372 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+ import argparse
5
+ import functools
6
+ import os
7
+ import pathlib
8
+ import sys
9
+ from typing import Callable
10
+
11
+
12
+ import gradio as gr
13
+ import huggingface_hub
14
+ import numpy as np
15
+ import PIL.Image
16
+
17
+ from io import BytesIO
18
+
19
+
20
+ ORIGINAL_REPO_URL = 'https://github.com/taki0112/UGATIT'
21
+ TITLE = 'taki0112/UGATIT'
22
+ DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}.
23
+
24
+ """
25
+ ARTICLE = """
26
+
27
+ """
28
+
29
+ def parse_args() -> argparse.Namespace:
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument('--device', type=str, default='cpu')
32
+ parser.add_argument('--theme', type=str)
33
+ parser.add_argument('--live', action='store_true')
34
+ parser.add_argument('--share', action='store_true')
35
+ parser.add_argument('--port', type=int)
36
+ parser.add_argument('--disable-queue',
37
+ dest='enable_queue',
38
+ action='store_false')
39
+ parser.add_argument('--allow-flagging', type=str, default='never')
40
+ parser.add_argument('--allow-screenshot', action='store_true')
41
+ return parser.parse_args()
42
+
43
+
44
+
45
+ def run(
46
+ image
47
+ ) -> tuple[PIL.Image.Image]:
48
+
49
+
50
+ return PIL.Image.open(image.name)
51
+
52
+
53
+ def main():
54
+ gr.close_all()
55
+
56
+ args = parse_args()
57
+
58
+ func = functools.partial(run)
59
+ func = functools.update_wrapper(func, run)
60
+
61
+
62
+ gr.Interface(
63
+ func,
64
+ [
65
+ gr.inputs.Image(type='file', label='Input Image'),
66
+ ],
67
+ [
68
+ gr.outputs.Image(
69
+ type='pil',
70
+ label='Result'),
71
+ ],
72
+ #examples=examples,
73
+ theme=args.theme,
74
+ title=TITLE,
75
+ description=DESCRIPTION,
76
+ article=ARTICLE,
77
+ allow_screenshot=args.allow_screenshot,
78
+ allow_flagging=args.allow_flagging,
79
+ live=args.live,
80
+ ).launch(
81
+ enable_queue=args.enable_queue,
82
+ server_port=args.port,
83
+ share=args.share,
84
+ )
85
+
86
+
87
+ if __name__ == '__main__':
88
+ main()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ opencv-python-headless==4.5.5.62
2
+ Pillow==9.0.1
3
+ scipy==1.7.3
4
+ tensorflow-gpu==1.14.0
ugatit/UGATIT.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ugatit.ops import *
2
+ from utils import *
3
+ from glob import glob
4
+ import time
5
+ from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
6
+ import numpy as np
7
+
8
+ class UGATIT(object) :
9
+ def __init__(self, sess, args):
10
+ self.light = args.light
11
+
12
+ if self.light :
13
+ self.model_name = 'UGATIT_light'
14
+ else :
15
+ self.model_name = 'UGATIT'
16
+
17
+ self.sess = sess
18
+ self.phase = args.phase
19
+ self.checkpoint_dir = args.checkpoint_dir
20
+ self.result_dir = args.result_dir
21
+ self.log_dir = args.log_dir
22
+ self.dataset_name = args.dataset
23
+ self.augment_flag = args.augment_flag
24
+
25
+ self.epoch = args.epoch
26
+ self.iteration = args.iteration
27
+ self.decay_flag = args.decay_flag
28
+ self.decay_epoch = args.decay_epoch
29
+
30
+ self.gan_type = args.gan_type
31
+
32
+ self.batch_size = args.batch_size
33
+ self.print_freq = args.print_freq
34
+ self.save_freq = args.save_freq
35
+
36
+ self.init_lr = args.lr
37
+ self.ch = args.ch
38
+
39
+ """ Weight """
40
+ self.adv_weight = args.adv_weight
41
+ self.cycle_weight = args.cycle_weight
42
+ self.identity_weight = args.identity_weight
43
+ self.cam_weight = args.cam_weight
44
+ self.ld = args.GP_ld
45
+ self.smoothing = args.smoothing
46
+
47
+ """ Generator """
48
+ self.n_res = args.n_res
49
+
50
+ """ Discriminator """
51
+ self.n_dis = args.n_dis
52
+ self.n_critic = args.n_critic
53
+ self.sn = args.sn
54
+
55
+ self.img_size = args.img_size
56
+ self.img_ch = args.img_ch
57
+
58
+
59
+ self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
60
+ check_folder(self.sample_dir)
61
+
62
+ # self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size
63
+ self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
64
+ self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB'))
65
+ self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset))
66
+
67
+ print()
68
+
69
+ print("##### Information #####")
70
+ print("# light : ", self.light)
71
+ print("# gan type : ", self.gan_type)
72
+ print("# dataset : ", self.dataset_name)
73
+ print("# max dataset number : ", self.dataset_num)
74
+ print("# batch_size : ", self.batch_size)
75
+ print("# epoch : ", self.epoch)
76
+ print("# iteration per epoch : ", self.iteration)
77
+ print("# smoothing : ", self.smoothing)
78
+
79
+ print()
80
+
81
+ print("##### Generator #####")
82
+ print("# residual blocks : ", self.n_res)
83
+
84
+ print()
85
+
86
+ print("##### Discriminator #####")
87
+ print("# discriminator layer : ", self.n_dis)
88
+ print("# the number of critic : ", self.n_critic)
89
+ print("# spectral normalization : ", self.sn)
90
+
91
+ print()
92
+
93
+ print("##### Weight #####")
94
+ print("# adv_weight : ", self.adv_weight)
95
+ print("# cycle_weight : ", self.cycle_weight)
96
+ print("# identity_weight : ", self.identity_weight)
97
+ print("# cam_weight : ", self.cam_weight)
98
+
99
+ ##################################################################################
100
+ # Generator
101
+ ##################################################################################
102
+
103
+ def generator(self, x_init, reuse=False, scope="generator"):
104
+ channel = self.ch
105
+ with tf.variable_scope(scope, reuse=reuse) :
106
+ x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv')
107
+ x = instance_norm(x, scope='ins_norm')
108
+ x = relu(x)
109
+
110
+ # Down-Sampling
111
+ for i in range(2) :
112
+ x = conv(x, channel*2, kernel=3, stride=2, pad=1, pad_type='reflect', scope='conv_'+str(i))
113
+ x = instance_norm(x, scope='ins_norm_'+str(i))
114
+ x = relu(x)
115
+
116
+ channel = channel * 2
117
+
118
+ # Down-Sampling Bottleneck
119
+ for i in range(self.n_res):
120
+ x = resblock(x, channel, scope='resblock_' + str(i))
121
+
122
+
123
+ # Class Activation Map
124
+ cam_x = global_avg_pooling(x)
125
+ cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, scope='CAM_logit')
126
+ x_gap = tf.multiply(x, cam_x_weight)
127
+
128
+ cam_x = global_max_pooling(x)
129
+ cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, reuse=True, scope='CAM_logit')
130
+ x_gmp = tf.multiply(x, cam_x_weight)
131
+
132
+
133
+ cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
134
+ x = tf.concat([x_gap, x_gmp], axis=-1)
135
+
136
+ x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
137
+ x = relu(x)
138
+
139
+ heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
140
+
141
+ # Gamma, Beta block
142
+ gamma, beta = self.MLP(x, reuse=reuse)
143
+
144
+ # Up-Sampling Bottleneck
145
+ for i in range(self.n_res):
146
+ x = adaptive_ins_layer_resblock(x, channel, gamma, beta, smoothing=self.smoothing, scope='adaptive_resblock' + str(i))
147
+
148
+ # Up-Sampling
149
+ for i in range(2) :
150
+ x = up_sample(x, scale_factor=2)
151
+ x = conv(x, channel//2, kernel=3, stride=1, pad=1, pad_type='reflect', scope='up_conv_'+str(i))
152
+ x = layer_instance_norm(x, scope='layer_ins_norm_'+str(i))
153
+ x = relu(x)
154
+
155
+ channel = channel // 2
156
+
157
+
158
+ x = conv(x, channels=3, kernel=7, stride=1, pad=3, pad_type='reflect', scope='G_logit')
159
+ x = tanh(x)
160
+
161
+ return x, cam_logit, heatmap
162
+
163
+ def MLP(self, x, use_bias=True, reuse=False, scope='MLP'):
164
+ channel = self.ch * self.n_res
165
+
166
+ if self.light :
167
+ x = global_avg_pooling(x)
168
+
169
+ with tf.variable_scope(scope, reuse=reuse):
170
+ for i in range(2) :
171
+ x = fully_connected(x, channel, use_bias, scope='linear_' + str(i))
172
+ x = relu(x)
173
+
174
+
175
+ gamma = fully_connected(x, channel, use_bias, scope='gamma')
176
+ beta = fully_connected(x, channel, use_bias, scope='beta')
177
+
178
+ gamma = tf.reshape(gamma, shape=[self.batch_size, 1, 1, channel])
179
+ beta = tf.reshape(beta, shape=[self.batch_size, 1, 1, channel])
180
+
181
+ return gamma, beta
182
+
183
+ ##################################################################################
184
+ # Discriminator
185
+ ##################################################################################
186
+
187
+ def discriminator(self, x_init, reuse=False, scope="discriminator"):
188
+ D_logit = []
189
+ D_CAM_logit = []
190
+ with tf.variable_scope(scope, reuse=reuse) :
191
+ local_x, local_cam, local_heatmap = self.discriminator_local(x_init, reuse=reuse, scope='local')
192
+ global_x, global_cam, global_heatmap = self.discriminator_global(x_init, reuse=reuse, scope='global')
193
+
194
+ D_logit.extend([local_x, global_x])
195
+ D_CAM_logit.extend([local_cam, global_cam])
196
+
197
+ return D_logit, D_CAM_logit, local_heatmap, global_heatmap
198
+
199
+ def discriminator_global(self, x_init, reuse=False, scope='discriminator_global'):
200
+ with tf.variable_scope(scope, reuse=reuse):
201
+ channel = self.ch
202
+ x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0')
203
+ x = lrelu(x, 0.2)
204
+
205
+ for i in range(1, self.n_dis - 1):
206
+ x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i))
207
+ x = lrelu(x, 0.2)
208
+
209
+ channel = channel * 2
210
+
211
+ x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last')
212
+ x = lrelu(x, 0.2)
213
+
214
+ channel = channel * 2
215
+
216
+ cam_x = global_avg_pooling(x)
217
+ cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit')
218
+ x_gap = tf.multiply(x, cam_x_weight)
219
+
220
+ cam_x = global_max_pooling(x)
221
+ cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit')
222
+ x_gmp = tf.multiply(x, cam_x_weight)
223
+
224
+ cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
225
+ x = tf.concat([x_gap, x_gmp], axis=-1)
226
+
227
+ x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
228
+ x = lrelu(x, 0.2)
229
+
230
+ heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
231
+
232
+
233
+ x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit')
234
+
235
+ return x, cam_logit, heatmap
236
+
237
+ def discriminator_local(self, x_init, reuse=False, scope='discriminator_local'):
238
+ with tf.variable_scope(scope, reuse=reuse) :
239
+ channel = self.ch
240
+ x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0')
241
+ x = lrelu(x, 0.2)
242
+
243
+ for i in range(1, self.n_dis - 2 - 1):
244
+ x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i))
245
+ x = lrelu(x, 0.2)
246
+
247
+ channel = channel * 2
248
+
249
+ x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last')
250
+ x = lrelu(x, 0.2)
251
+
252
+ channel = channel * 2
253
+
254
+ cam_x = global_avg_pooling(x)
255
+ cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit')
256
+ x_gap = tf.multiply(x, cam_x_weight)
257
+
258
+ cam_x = global_max_pooling(x)
259
+ cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit')
260
+ x_gmp = tf.multiply(x, cam_x_weight)
261
+
262
+ cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
263
+ x = tf.concat([x_gap, x_gmp], axis=-1)
264
+
265
+ x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
266
+ x = lrelu(x, 0.2)
267
+
268
+ heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
269
+
270
+ x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit')
271
+
272
+ return x, cam_logit, heatmap
273
+
274
+ ##################################################################################
275
+ # Model
276
+ ##################################################################################
277
+
278
+ def generate_a2b(self, x_A, reuse=False):
279
+ out, cam, _ = self.generator(x_A, reuse=reuse, scope="generator_B")
280
+
281
+ return out, cam
282
+
283
+ def generate_b2a(self, x_B, reuse=False):
284
+ out, cam, _ = self.generator(x_B, reuse=reuse, scope="generator_A")
285
+
286
+ return out, cam
287
+
288
+ def discriminate_real(self, x_A, x_B):
289
+ real_A_logit, real_A_cam_logit, _, _ = self.discriminator(x_A, scope="discriminator_A")
290
+ real_B_logit, real_B_cam_logit, _, _ = self.discriminator(x_B, scope="discriminator_B")
291
+
292
+ return real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit
293
+
294
+ def discriminate_fake(self, x_ba, x_ab):
295
+ fake_A_logit, fake_A_cam_logit, _, _ = self.discriminator(x_ba, reuse=True, scope="discriminator_A")
296
+ fake_B_logit, fake_B_cam_logit, _, _ = self.discriminator(x_ab, reuse=True, scope="discriminator_B")
297
+
298
+ return fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit
299
+
300
+ def gradient_panalty(self, real, fake, scope="discriminator_A"):
301
+ if self.gan_type.__contains__('dragan'):
302
+ eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.)
303
+ _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
304
+ x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
305
+
306
+ fake = real + 0.5 * x_std * eps
307
+
308
+ alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
309
+ interpolated = real + alpha * (fake - real)
310
+
311
+ logit, cam_logit, _, _ = self.discriminator(interpolated, reuse=True, scope=scope)
312
+
313
+
314
+ GP = []
315
+ cam_GP = []
316
+
317
+ for i in range(2) :
318
+ grad = tf.gradients(logit[i], interpolated)[0] # gradient of D(interpolated)
319
+ grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
320
+
321
+ # WGAN - LP
322
+ if self.gan_type == 'wgan-lp' :
323
+ GP.append(self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))))
324
+
325
+ elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
326
+ GP.append(self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)))
327
+
328
+ for i in range(2) :
329
+ grad = tf.gradients(cam_logit[i], interpolated)[0] # gradient of D(interpolated)
330
+ grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
331
+
332
+ # WGAN - LP
333
+ if self.gan_type == 'wgan-lp' :
334
+ cam_GP.append(self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))))
335
+
336
+ elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
337
+ cam_GP.append(self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)))
338
+
339
+
340
+ return sum(GP), sum(cam_GP)
341
+
342
+ def build_model(self):
343
+ if self.phase == 'train' :
344
+ self.lr = tf.placeholder(tf.float32, name='learning_rate')
345
+
346
+
347
+ """ Input Image"""
348
+ Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag)
349
+
350
+ trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
351
+ trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
352
+
353
+
354
+ gpu_device = '/gpu:0'
355
+ trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))
356
+ trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))
357
+
358
+
359
+ trainA_iterator = trainA.make_one_shot_iterator()
360
+ trainB_iterator = trainB.make_one_shot_iterator()
361
+
362
+ self.domain_A = trainA_iterator.get_next()
363
+ self.domain_B = trainB_iterator.get_next()
364
+
365
+ """ Define Generator, Discriminator """
366
+ x_ab, cam_ab = self.generate_a2b(self.domain_A) # real a
367
+ x_ba, cam_ba = self.generate_b2a(self.domain_B) # real b
368
+
369
+ x_aba, _ = self.generate_b2a(x_ab, reuse=True) # real b
370
+ x_bab, _ = self.generate_a2b(x_ba, reuse=True) # real a
371
+
372
+ x_aa, cam_aa = self.generate_b2a(self.domain_A, reuse=True) # fake b
373
+ x_bb, cam_bb = self.generate_a2b(self.domain_B, reuse=True) # fake a
374
+
375
+ real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real(self.domain_A, self.domain_B)
376
+ fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake(x_ba, x_ab)
377
+
378
+
379
+ """ Define Loss """
380
+ if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :
381
+ GP_A, GP_CAM_A = self.gradient_panalty(real=self.domain_A, fake=x_ba, scope="discriminator_A")
382
+ GP_B, GP_CAM_B = self.gradient_panalty(real=self.domain_B, fake=x_ab, scope="discriminator_B")
383
+ else :
384
+ GP_A, GP_CAM_A = 0, 0
385
+ GP_B, GP_CAM_B = 0, 0
386
+
387
+ G_ad_loss_A = (generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, fake_A_cam_logit))
388
+ G_ad_loss_B = (generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, fake_B_cam_logit))
389
+
390
+ D_ad_loss_A = (discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) + discriminator_loss(self.gan_type, real_A_cam_logit, fake_A_cam_logit) + GP_A + GP_CAM_A)
391
+ D_ad_loss_B = (discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) + discriminator_loss(self.gan_type, real_B_cam_logit, fake_B_cam_logit) + GP_B + GP_CAM_B)
392
+
393
+ reconstruction_A = L1_loss(x_aba, self.domain_A) # reconstruction
394
+ reconstruction_B = L1_loss(x_bab, self.domain_B) # reconstruction
395
+
396
+ identity_A = L1_loss(x_aa, self.domain_A)
397
+ identity_B = L1_loss(x_bb, self.domain_B)
398
+
399
+ cam_A = cam_loss(source=cam_ba, non_source=cam_aa)
400
+ cam_B = cam_loss(source=cam_ab, non_source=cam_bb)
401
+
402
+ Generator_A_gan = self.adv_weight * G_ad_loss_A
403
+ Generator_A_cycle = self.cycle_weight * reconstruction_B
404
+ Generator_A_identity = self.identity_weight * identity_A
405
+ Generator_A_cam = self.cam_weight * cam_A
406
+
407
+
408
+ Generator_B_gan = self.adv_weight * G_ad_loss_B
409
+ Generator_B_cycle = self.cycle_weight * reconstruction_A
410
+ Generator_B_identity = self.identity_weight * identity_B
411
+ Generator_B_cam = self.cam_weight * cam_B
412
+
413
+
414
+ Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam
415
+ Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam
416
+
417
+
418
+ Discriminator_A_loss = self.adv_weight * D_ad_loss_A
419
+ Discriminator_B_loss = self.adv_weight * D_ad_loss_B
420
+
421
+ self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss('generator')
422
+ self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss('discriminator')
423
+
424
+
425
+ """ Result Image """
426
+ self.fake_A = x_ba
427
+ self.fake_B = x_ab
428
+
429
+ self.real_A = self.domain_A
430
+ self.real_B = self.domain_B
431
+
432
+
433
+ """ Training """
434
+ t_vars = tf.trainable_variables()
435
+ G_vars = [var for var in t_vars if 'generator' in var.name]
436
+ D_vars = [var for var in t_vars if 'discriminator' in var.name]
437
+
438
+ self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
439
+ self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)
440
+
441
+
442
+ """" Summary """
443
+ self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
444
+ self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)
445
+
446
+ self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
447
+ self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan)
448
+ self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle)
449
+ self.G_A_identity = tf.summary.scalar("G_A_identity", Generator_A_identity)
450
+ self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam)
451
+
452
+ self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
453
+ self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan)
454
+ self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle)
455
+ self.G_B_identity = tf.summary.scalar("G_B_identity", Generator_B_identity)
456
+ self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam)
457
+
458
+ self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
459
+ self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)
460
+
461
+ self.rho_var = []
462
+ for var in tf.trainable_variables():
463
+ if 'rho' in var.name:
464
+ self.rho_var.append(tf.summary.histogram(var.name, var))
465
+ self.rho_var.append(tf.summary.scalar(var.name + "_min", tf.reduce_min(var)))
466
+ self.rho_var.append(tf.summary.scalar(var.name + "_max", tf.reduce_max(var)))
467
+ self.rho_var.append(tf.summary.scalar(var.name + "_mean", tf.reduce_mean(var)))
468
+
469
+ g_summary_list = [self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity, self.G_A_cam,
470
+ self.G_B_loss, self.G_B_gan, self.G_B_cycle, self.G_B_identity, self.G_B_cam,
471
+ self.all_G_loss]
472
+
473
+ g_summary_list.extend(self.rho_var)
474
+ d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss]
475
+
476
+ self.G_loss = tf.summary.merge(g_summary_list)
477
+ self.D_loss = tf.summary.merge(d_summary_list)
478
+
479
+ else :
480
+ """ Test """
481
+ self.test_domain_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A')
482
+ self.test_domain_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B')
483
+
484
+
485
+ self.test_fake_B, _ = self.generate_a2b(self.test_domain_A)
486
+ self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
487
+
488
+
489
+ def train(self):
490
+ # initialize all variables
491
+ tf.global_variables_initializer().run()
492
+
493
+ # saver to save model
494
+ self.saver = tf.train.Saver()
495
+
496
+ # summary writer
497
+ self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
498
+
499
+
500
+ # restore check-point if it exits
501
+ could_load, checkpoint_counter = self.load(self.checkpoint_dir)
502
+ if could_load:
503
+ start_epoch = (int)(checkpoint_counter / self.iteration)
504
+ start_batch_id = checkpoint_counter - start_epoch * self.iteration
505
+ counter = checkpoint_counter
506
+ print(" [*] Load SUCCESS")
507
+ else:
508
+ start_epoch = 0
509
+ start_batch_id = 0
510
+ counter = 1
511
+ print(" [!] Load failed...")
512
+
513
+ # loop for epoch
514
+ start_time = time.time()
515
+ past_g_loss = -1.
516
+ lr = self.init_lr
517
+ for epoch in range(start_epoch, self.epoch):
518
+ # lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch)
519
+ if self.decay_flag :
520
+ #lr = self.init_lr * pow(0.5, epoch // self.decay_epoch)
521
+ lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch)
522
+ for idx in range(start_batch_id, self.iteration):
523
+ train_feed_dict = {
524
+ self.lr : lr
525
+ }
526
+
527
+ # Update D
528
+ _, d_loss, summary_str = self.sess.run([self.D_optim,
529
+ self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict)
530
+ self.writer.add_summary(summary_str, counter)
531
+
532
+ # Update G
533
+ g_loss = None
534
+ if (counter - 1) % self.n_critic == 0 :
535
+ batch_A_images, batch_B_images, fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.real_A, self.real_B,
536
+ self.fake_A, self.fake_B,
537
+ self.G_optim,
538
+ self.Generator_loss, self.G_loss], feed_dict = train_feed_dict)
539
+ self.writer.add_summary(summary_str, counter)
540
+ past_g_loss = g_loss
541
+
542
+ # display training status
543
+ counter += 1
544
+ if g_loss == None :
545
+ g_loss = past_g_loss
546
+ print("Epoch: [%2d] [%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
547
+
548
+ if np.mod(idx+1, self.print_freq) == 0 :
549
+ save_images(batch_A_images, [self.batch_size, 1],
550
+ './{}/real_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
551
+ # save_images(batch_B_images, [self.batch_size, 1],
552
+ # './{}/real_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
553
+
554
+ # save_images(fake_A, [self.batch_size, 1],
555
+ # './{}/fake_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
556
+ save_images(fake_B, [self.batch_size, 1],
557
+ './{}/fake_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
558
+
559
+ if np.mod(idx + 1, self.save_freq) == 0:
560
+ self.save(self.checkpoint_dir, counter)
561
+
562
+
563
+
564
+ # After an epoch, start_batch_id is set to zero
565
+ # non-zero value is only for the first epoch after loading pre-trained model
566
+ start_batch_id = 0
567
+
568
+ # save model for final step
569
+ self.save(self.checkpoint_dir, counter)
570
+
571
+ @property
572
+ def model_dir(self):
573
+ n_res = str(self.n_res) + 'resblock'
574
+ n_dis = str(self.n_dis) + 'dis'
575
+
576
+ if self.smoothing :
577
+ smoothing = '_smoothing'
578
+ else :
579
+ smoothing = ''
580
+
581
+ if self.sn :
582
+ sn = '_sn'
583
+ else :
584
+ sn = ''
585
+
586
+ return "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}{}{}".format(self.model_name, self.dataset_name,
587
+ self.gan_type, n_res, n_dis,
588
+ self.n_critic,
589
+ self.adv_weight, self.cycle_weight, self.identity_weight, self.cam_weight, sn, smoothing)
590
+
591
+ def save(self, checkpoint_dir, step):
592
+ checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
593
+
594
+ if not os.path.exists(checkpoint_dir):
595
+ os.makedirs(checkpoint_dir)
596
+
597
+ self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
598
+
599
+ def load(self, checkpoint_dir):
600
+ print(" [*] Reading checkpoints...")
601
+ checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
602
+
603
+ ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
604
+ if ckpt and ckpt.model_checkpoint_path:
605
+ ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
606
+ self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
607
+ counter = int(ckpt_name.split('-')[-1])
608
+ print(" [*] Success to read {}".format(ckpt_name))
609
+ return True, counter
610
+ else:
611
+ print(" [*] Failed to find a checkpoint")
612
+ return False, 0
613
+
614
+ def test(self):
615
+ tf.global_variables_initializer().run()
616
+ test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA'))
617
+ test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB'))
618
+
619
+ self.saver = tf.train.Saver()
620
+ could_load, checkpoint_counter = self.load(self.checkpoint_dir)
621
+ self.result_dir = os.path.join(self.result_dir, self.model_dir)
622
+ check_folder(self.result_dir)
623
+
624
+ if could_load :
625
+ print(" [*] Load SUCCESS")
626
+ else :
627
+ print(" [!] Load failed...")
628
+
629
+ # write html for visual comparison
630
+ index_path = os.path.join(self.result_dir, 'index.html')
631
+ index = open(index_path, 'w')
632
+ index.write("<html><body><table><tr>")
633
+ index.write("<th>name</th><th>input</th><th>output</th></tr>")
634
+
635
+ for sample_file in test_A_files : # A -> B
636
+ print('Processing A image: ' + sample_file)
637
+ sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
638
+ image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
639
+
640
+ fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_domain_A : sample_image})
641
+ save_images(fake_img, [1, 1], image_path)
642
+
643
+ index.write("<td>%s</td>" % os.path.basename(image_path))
644
+
645
+ index.write("<td><img src='%s' width='%d' height='%d'></td>" % (sample_file if os.path.isabs(sample_file) else (
646
+ '../..' + os.path.sep + sample_file), self.img_size, self.img_size))
647
+ index.write("<td><img src='%s' width='%d' height='%d'></td>" % (image_path if os.path.isabs(image_path) else (
648
+ '../..' + os.path.sep + image_path), self.img_size, self.img_size))
649
+ index.write("</tr>")
650
+
651
+ for sample_file in test_B_files : # B -> A
652
+ print('Processing B image: ' + sample_file)
653
+ sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
654
+ image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
655
+
656
+ fake_img = self.sess.run(self.test_fake_A, feed_dict = {self.test_domain_B : sample_image})
657
+
658
+ save_images(fake_img, [1, 1], image_path)
659
+ index.write("<td>%s</td>" % os.path.basename(image_path))
660
+ index.write("<td><img src='%s' width='%d' height='%d'></td>" % (sample_file if os.path.isabs(sample_file) else (
661
+ '../..' + os.path.sep + sample_file), self.img_size, self.img_size))
662
+ index.write("<td><img src='%s' width='%d' height='%d'></td>" % (image_path if os.path.isabs(image_path) else (
663
+ '../..' + os.path.sep + image_path), self.img_size, self.img_size))
664
+ index.write("</tr>")
665
+ index.close()
ugatit/main.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ugatit.UGATIT import UGATIT
2
+ import argparse
3
+ from ugatit.utils import *
4
+
5
+ """parsing and configuration"""
6
+
7
+ def parse_args():
8
+ desc = "Tensorflow implementation of U-GAT-IT"
9
+ parser = argparse.ArgumentParser(description=desc)
10
+ parser.add_argument('--phase', type=str, default='train', help='[train / test]')
11
+ parser.add_argument('--light', type=str2bool, default=False, help='[U-GAT-IT full version / U-GAT-IT light version]')
12
+ parser.add_argument('--dataset', type=str, default='selfie2anime', help='dataset_name')
13
+
14
+ parser.add_argument('--epoch', type=int, default=100, help='The number of epochs to run')
15
+ parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
16
+ parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
17
+ parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq')
18
+ parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq')
19
+ parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag')
20
+ parser.add_argument('--decay_epoch', type=int, default=50, help='decay epoch')
21
+
22
+ parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
23
+ parser.add_argument('--GP_ld', type=int, default=10, help='The gradient penalty lambda')
24
+ parser.add_argument('--adv_weight', type=int, default=1, help='Weight about GAN')
25
+ parser.add_argument('--cycle_weight', type=int, default=10, help='Weight about Cycle')
26
+ parser.add_argument('--identity_weight', type=int, default=10, help='Weight about Identity')
27
+ parser.add_argument('--cam_weight', type=int, default=1000, help='Weight about CAM')
28
+ parser.add_argument('--gan_type', type=str, default='lsgan', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]')
29
+
30
+ parser.add_argument('--smoothing', type=str2bool, default=True, help='AdaLIN smoothing effect')
31
+
32
+ parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
33
+ parser.add_argument('--n_res', type=int, default=4, help='The number of resblock')
34
+ parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')
35
+ parser.add_argument('--n_critic', type=int, default=1, help='The number of critic')
36
+ parser.add_argument('--sn', type=str2bool, default=True, help='using spectral norm')
37
+
38
+ parser.add_argument('--img_size', type=int, default=256, help='The size of image')
39
+ parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
40
+ parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not')
41
+
42
+ parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
43
+ help='Directory name to save the checkpoints')
44
+ parser.add_argument('--result_dir', type=str, default='results',
45
+ help='Directory name to save the generated images')
46
+ parser.add_argument('--log_dir', type=str, default='logs',
47
+ help='Directory name to save training logs')
48
+ parser.add_argument('--sample_dir', type=str, default='samples',
49
+ help='Directory name to save the samples on training')
50
+
51
+ return check_args(parser.parse_args())
52
+
53
+ """checking arguments"""
54
+ def check_args(args):
55
+ # --checkpoint_dir
56
+ check_folder(args.checkpoint_dir)
57
+
58
+ # --result_dir
59
+ check_folder(args.result_dir)
60
+
61
+ # --result_dir
62
+ check_folder(args.log_dir)
63
+
64
+ # --sample_dir
65
+ check_folder(args.sample_dir)
66
+
67
+ # --epoch
68
+ try:
69
+ assert args.epoch >= 1
70
+ except:
71
+ print('number of epochs must be larger than or equal to one')
72
+
73
+ # --batch_size
74
+ try:
75
+ assert args.batch_size >= 1
76
+ except:
77
+ print('batch size must be larger than or equal to one')
78
+ return args
79
+
80
+ """main"""
81
+ def main():
82
+ # parse arguments
83
+ args = parse_args()
84
+ if args is None:
85
+ exit()
86
+
87
+ # open session
88
+ with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
89
+ gan = UGATIT(sess, args)
90
+
91
+ # build graph
92
+ gan.build_model()
93
+
94
+ # show network architecture
95
+ show_all_variables()
96
+
97
+ if args.phase == 'train' :
98
+ gan.train()
99
+ print(" [*] Training finished!")
100
+
101
+ if args.phase == 'test' :
102
+ gan.test()
103
+ print(" [*] Test finished!")
104
+
105
+ if __name__ == '__main__':
106
+ main()
ugatit/ops.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow.contrib as tf_contrib
3
+
4
+ # Xavier : tf_contrib.layers.xavier_initializer()
5
+ # He : tf_contrib.layers.variance_scaling_initializer()
6
+ # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
7
+ # l2_decay : tf_contrib.layers.l2_regularizer(0.0001)
8
+
9
+ weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
10
+ weight_regularizer = tf_contrib.layers.l2_regularizer(scale=0.0001)
11
+
12
+ ##################################################################################
13
+ # Layer
14
+ ##################################################################################
15
+
16
+ def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
17
+ with tf.variable_scope(scope):
18
+ if pad > 0 :
19
+ if (kernel - stride) % 2 == 0:
20
+ pad_top = pad
21
+ pad_bottom = pad
22
+ pad_left = pad
23
+ pad_right = pad
24
+
25
+ else:
26
+ pad_top = pad
27
+ pad_bottom = kernel - stride - pad_top
28
+ pad_left = pad
29
+ pad_right = kernel - stride - pad_left
30
+
31
+ if pad_type == 'zero':
32
+ x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
33
+ if pad_type == 'reflect':
34
+ x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
35
+
36
+ if sn :
37
+ w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
38
+ regularizer=weight_regularizer)
39
+ x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
40
+ strides=[1, stride, stride, 1], padding='VALID')
41
+ if use_bias :
42
+ bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
43
+ x = tf.nn.bias_add(x, bias)
44
+
45
+ else :
46
+ x = tf.layers.conv2d(inputs=x, filters=channels,
47
+ kernel_size=kernel, kernel_initializer=weight_init,
48
+ kernel_regularizer=weight_regularizer,
49
+ strides=stride, use_bias=use_bias)
50
+
51
+
52
+ return x
53
+
54
+ def fully_connected_with_w(x, use_bias=True, sn=False, reuse=False, scope='linear'):
55
+ with tf.variable_scope(scope, reuse=reuse):
56
+ x = flatten(x)
57
+ bias = 0.0
58
+ shape = x.get_shape().as_list()
59
+ channels = shape[-1]
60
+
61
+ w = tf.get_variable("kernel", [channels, 1], tf.float32,
62
+ initializer=weight_init, regularizer=weight_regularizer)
63
+
64
+ if sn :
65
+ w = spectral_norm(w)
66
+
67
+ if use_bias :
68
+ bias = tf.get_variable("bias", [1],
69
+ initializer=tf.constant_initializer(0.0))
70
+
71
+ x = tf.matmul(x, w) + bias
72
+ else :
73
+ x = tf.matmul(x, w)
74
+
75
+ if use_bias :
76
+ weights = tf.gather(tf.transpose(tf.nn.bias_add(w, bias)), 0)
77
+ else :
78
+ weights = tf.gather(tf.transpose(w), 0)
79
+
80
+ return x, weights
81
+
82
+ def fully_connected(x, units, use_bias=True, sn=False, scope='linear'):
83
+ with tf.variable_scope(scope):
84
+ x = flatten(x)
85
+ shape = x.get_shape().as_list()
86
+ channels = shape[-1]
87
+
88
+ if sn:
89
+ w = tf.get_variable("kernel", [channels, units], tf.float32,
90
+ initializer=weight_init, regularizer=weight_regularizer)
91
+ if use_bias:
92
+ bias = tf.get_variable("bias", [units],
93
+ initializer=tf.constant_initializer(0.0))
94
+
95
+ x = tf.matmul(x, spectral_norm(w)) + bias
96
+ else:
97
+ x = tf.matmul(x, spectral_norm(w))
98
+
99
+ else :
100
+ x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias)
101
+
102
+ return x
103
+
104
+ def flatten(x) :
105
+ return tf.layers.flatten(x)
106
+
107
+ ##################################################################################
108
+ # Residual-block
109
+ ##################################################################################
110
+
111
+ def resblock(x_init, channels, use_bias=True, scope='resblock_0'):
112
+ with tf.variable_scope(scope):
113
+ with tf.variable_scope('res1'):
114
+ x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
115
+ x = instance_norm(x)
116
+ x = relu(x)
117
+
118
+ with tf.variable_scope('res2'):
119
+ x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
120
+ x = instance_norm(x)
121
+
122
+ return x + x_init
123
+
124
+ def adaptive_ins_layer_resblock(x_init, channels, gamma, beta, use_bias=True, smoothing=True, scope='adaptive_resblock') :
125
+ with tf.variable_scope(scope):
126
+ with tf.variable_scope('res1'):
127
+ x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
128
+ x = adaptive_instance_layer_norm(x, gamma, beta, smoothing)
129
+ x = relu(x)
130
+
131
+ with tf.variable_scope('res2'):
132
+ x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
133
+ x = adaptive_instance_layer_norm(x, gamma, beta, smoothing)
134
+
135
+ return x + x_init
136
+
137
+
138
+ ##################################################################################
139
+ # Sampling
140
+ ##################################################################################
141
+
142
+ def up_sample(x, scale_factor=2):
143
+ _, h, w, _ = x.get_shape().as_list()
144
+ new_size = [h * scale_factor, w * scale_factor]
145
+ return tf.image.resize_nearest_neighbor(x, size=new_size)
146
+
147
+
148
+ def global_avg_pooling(x):
149
+ gap = tf.reduce_mean(x, axis=[1, 2])
150
+ return gap
151
+
152
+ def global_max_pooling(x):
153
+ gmp = tf.reduce_max(x, axis=[1, 2])
154
+ return gmp
155
+
156
+ ##################################################################################
157
+ # Activation function
158
+ ##################################################################################
159
+
160
+ def lrelu(x, alpha=0.01):
161
+ # pytorch alpha is 0.01
162
+ return tf.nn.leaky_relu(x, alpha)
163
+
164
+
165
+ def relu(x):
166
+ return tf.nn.relu(x)
167
+
168
+
169
+ def tanh(x):
170
+ return tf.tanh(x)
171
+
172
+ def sigmoid(x) :
173
+ return tf.sigmoid(x)
174
+
175
+ ##################################################################################
176
+ # Normalization function
177
+ ##################################################################################
178
+
179
+ def adaptive_instance_layer_norm(x, gamma, beta, smoothing=True, scope='instance_layer_norm') :
180
+ with tf.variable_scope(scope):
181
+ ch = x.shape[-1]
182
+ eps = 1e-5
183
+
184
+ ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
185
+ x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps))
186
+
187
+ ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True)
188
+ x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps))
189
+
190
+ rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(1.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0))
191
+
192
+ if smoothing :
193
+ rho = tf.clip_by_value(rho - tf.constant(0.1), 0.0, 1.0)
194
+
195
+ x_hat = rho * x_ins + (1 - rho) * x_ln
196
+
197
+
198
+ x_hat = x_hat * gamma + beta
199
+
200
+ return x_hat
201
+
202
+ def instance_norm(x, scope='instance_norm'):
203
+ return tf_contrib.layers.instance_norm(x,
204
+ epsilon=1e-05,
205
+ center=True, scale=True,
206
+ scope=scope)
207
+
208
+ def layer_norm(x, scope='layer_norm') :
209
+ return tf_contrib.layers.layer_norm(x,
210
+ center=True, scale=True,
211
+ scope=scope)
212
+
213
+ def layer_instance_norm(x, scope='layer_instance_norm') :
214
+ with tf.variable_scope(scope):
215
+ ch = x.shape[-1]
216
+ eps = 1e-5
217
+
218
+ ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
219
+ x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps))
220
+
221
+ ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True)
222
+ x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps))
223
+
224
+ rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(0.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0))
225
+
226
+ gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0))
227
+ beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0))
228
+
229
+ x_hat = rho * x_ins + (1 - rho) * x_ln
230
+
231
+ x_hat = x_hat * gamma + beta
232
+
233
+ return x_hat
234
+
235
+ def spectral_norm(w, iteration=1):
236
+ w_shape = w.shape.as_list()
237
+ w = tf.reshape(w, [-1, w_shape[-1]])
238
+
239
+ u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
240
+
241
+ u_hat = u
242
+ v_hat = None
243
+ for i in range(iteration):
244
+ """
245
+ power iteration
246
+ Usually iteration = 1 will be enough
247
+ """
248
+ v_ = tf.matmul(u_hat, tf.transpose(w))
249
+ v_hat = tf.nn.l2_normalize(v_)
250
+
251
+ u_ = tf.matmul(v_hat, w)
252
+ u_hat = tf.nn.l2_normalize(u_)
253
+
254
+ u_hat = tf.stop_gradient(u_hat)
255
+ v_hat = tf.stop_gradient(v_hat)
256
+
257
+ sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
258
+
259
+ with tf.control_dependencies([u.assign(u_hat)]):
260
+ w_norm = w / sigma
261
+ w_norm = tf.reshape(w_norm, w_shape)
262
+
263
+
264
+ return w_norm
265
+
266
+ ##################################################################################
267
+ # Loss function
268
+ ##################################################################################
269
+
270
+ def L1_loss(x, y):
271
+ loss = tf.reduce_mean(tf.abs(x - y))
272
+
273
+ return loss
274
+
275
+ def cam_loss(source, non_source) :
276
+
277
+ identity_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(source), logits=source))
278
+ non_identity_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(non_source), logits=non_source))
279
+
280
+ loss = identity_loss + non_identity_loss
281
+
282
+ return loss
283
+
284
+ def regularization_loss(scope_name) :
285
+ """
286
+ If you want to use "Regularization"
287
+ g_loss += regularization_loss('generator')
288
+ d_loss += regularization_loss('discriminator')
289
+ """
290
+ collection_regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
291
+
292
+ loss = []
293
+ for item in collection_regularization :
294
+ if scope_name in item.name :
295
+ loss.append(item)
296
+
297
+ return tf.reduce_sum(loss)
298
+
299
+
300
+ def discriminator_loss(loss_func, real, fake):
301
+ loss = []
302
+ real_loss = 0
303
+ fake_loss = 0
304
+
305
+ for i in range(2) :
306
+ if loss_func.__contains__('wgan') :
307
+ real_loss = -tf.reduce_mean(real[i])
308
+ fake_loss = tf.reduce_mean(fake[i])
309
+
310
+ if loss_func == 'lsgan' :
311
+ real_loss = tf.reduce_mean(tf.squared_difference(real[i], 1.0))
312
+ fake_loss = tf.reduce_mean(tf.square(fake[i]))
313
+
314
+ if loss_func == 'gan' or loss_func == 'dragan' :
315
+ real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real[i]), logits=real[i]))
316
+ fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake[i]), logits=fake[i]))
317
+
318
+ if loss_func == 'hinge' :
319
+ real_loss = tf.reduce_mean(relu(1.0 - real[i]))
320
+ fake_loss = tf.reduce_mean(relu(1.0 + fake[i]))
321
+
322
+ loss.append(real_loss + fake_loss)
323
+
324
+ return sum(loss)
325
+
326
+ def generator_loss(loss_func, fake):
327
+ loss = []
328
+ fake_loss = 0
329
+
330
+ for i in range(2) :
331
+ if loss_func.__contains__('wgan') :
332
+ fake_loss = -tf.reduce_mean(fake[i])
333
+
334
+ if loss_func == 'lsgan' :
335
+ fake_loss = tf.reduce_mean(tf.squared_difference(fake[i], 1.0))
336
+
337
+ if loss_func == 'gan' or loss_func == 'dragan' :
338
+ fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake[i]), logits=fake[i]))
339
+
340
+ if loss_func == 'hinge' :
341
+ fake_loss = -tf.reduce_mean(fake[i])
342
+
343
+ loss.append(fake_loss)
344
+
345
+ return sum(loss)
ugatit/utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.contrib import slim
3
+ import cv2
4
+ import os, random
5
+ import numpy as np
6
+
7
+ class ImageData:
8
+
9
+ def __init__(self, load_size, channels, augment_flag):
10
+ self.load_size = load_size
11
+ self.channels = channels
12
+ self.augment_flag = augment_flag
13
+
14
+ def image_processing(self, filename):
15
+ x = tf.read_file(filename)
16
+ x_decode = tf.image.decode_jpeg(x, channels=self.channels)
17
+ img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
18
+ img = tf.cast(img, tf.float32) / 127.5 - 1
19
+
20
+ if self.augment_flag :
21
+ augment_size = self.load_size + (30 if self.load_size == 256 else 15)
22
+ p = random.random()
23
+ if p > 0.5:
24
+ img = augmentation(img, augment_size)
25
+
26
+ return img
27
+
28
+ def load_test_data(image_path, size=256):
29
+ img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
30
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
31
+
32
+ img = cv2.resize(img, dsize=(size, size))
33
+
34
+ img = np.expand_dims(img, axis=0)
35
+ img = img/127.5 - 1
36
+
37
+ return img
38
+
39
+ def augmentation(image, augment_size):
40
+ seed = random.randint(0, 2 ** 31 - 1)
41
+ ori_image_shape = tf.shape(image)
42
+ image = tf.image.random_flip_left_right(image, seed=seed)
43
+ image = tf.image.resize_images(image, [augment_size, augment_size])
44
+ image = tf.random_crop(image, ori_image_shape, seed=seed)
45
+ return image
46
+
47
+ def save_images(images, size, image_path):
48
+ return imsave(inverse_transform(images), size, image_path)
49
+
50
+ def inverse_transform(images):
51
+ return ((images+1.) / 2) * 255.0
52
+
53
+
54
+ def imsave(images, size, path):
55
+ images = merge(images, size)
56
+ images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR)
57
+
58
+ return cv2.imwrite(path, images)
59
+
60
+ def merge(images, size):
61
+ h, w = images.shape[1], images.shape[2]
62
+ img = np.zeros((h * size[0], w * size[1], 3))
63
+ for idx, image in enumerate(images):
64
+ i = idx % size[1]
65
+ j = idx // size[1]
66
+ img[h*j:h*(j+1), w*i:w*(i+1), :] = image
67
+
68
+ return img
69
+
70
+ def show_all_variables():
71
+ model_vars = tf.trainable_variables()
72
+ slim.model_analyzer.analyze_vars(model_vars, print_info=True)
73
+
74
+ def check_folder(log_dir):
75
+ if not os.path.exists(log_dir):
76
+ os.makedirs(log_dir)
77
+ return log_dir
78
+
79
+ def str2bool(x):
80
+ return x.lower() in ('true')
ugatit_test.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ugatit.ops import *
2
+ from ugatit.utils import *
3
+ from glob import glob
4
+ import time
5
+ from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
6
+ import numpy as np
7
+ from ugatit.utils import *
8
+
9
+ class UgatitTest:
10
+
11
+ def __init__(self, sess):
12
+ self.light = False
13
+
14
+ if self.light:
15
+ self.model_name = 'UGATIT_light'
16
+ else:
17
+ self.model_name = 'UGATIT'
18
+
19
+ self.sess = sess
20
+ self.phase = 'test'
21
+ self.checkpoint_dir = '/home/hylee/cartoon/UGATIT/checkpoint'
22
+ self.result_dir = 'results'
23
+ self.log_dir = 'logs'
24
+ self.dataset_name = 'selfie2anime'
25
+ self.augment_flag = True
26
+
27
+ self.epoch = 100
28
+ self.iteration = 10000
29
+ self.decay_flag = True
30
+ self.decay_epoch = 50
31
+
32
+ self.gan_type = 'lsgan'
33
+
34
+ self.batch_size = 1
35
+ self.print_freq = 1000
36
+ self.save_freq = 1000
37
+
38
+ self.init_lr = 0.0001
39
+ self.ch = 64
40
+
41
+ """ Weight """
42
+ self.adv_weight = 1
43
+ self.cycle_weight = 10
44
+ self.identity_weight = 10
45
+ self.cam_weight = 1000
46
+ self.ld = 10
47
+ self.smoothing = True
48
+
49
+ """ Generator """
50
+ self.n_res = 4
51
+
52
+ """ Discriminator """
53
+ self.n_dis = 6
54
+ self.n_critic = 1
55
+ self.sn = True
56
+
57
+ self.img_size = 256
58
+ self.img_ch = 3
59
+
60
+ self.sample_dir = os.path.join('/home/hylee/cartoon/UGATIT/samples', self.model_dir)
61
+ check_folder(self.sample_dir)
62
+
63
+ # self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size
64
+ self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
65
+ self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB'))
66
+ self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset))
67
+
68
+ print()
69
+
70
+ print("##### Information #####")
71
+ print("# light : ", self.light)
72
+ print("# gan type : ", self.gan_type)
73
+ print("# dataset : ", self.dataset_name)
74
+ print("# max dataset number : ", self.dataset_num)
75
+ print("# batch_size : ", self.batch_size)
76
+ print("# epoch : ", self.epoch)
77
+ print("# iteration per epoch : ", self.iteration)
78
+ print("# smoothing : ", self.smoothing)
79
+
80
+ print()
81
+
82
+ print("##### Generator #####")
83
+ print("# residual blocks : ", self.n_res)
84
+
85
+ print()
86
+
87
+ print("##### Discriminator #####")
88
+ print("# discriminator layer : ", self.n_dis)
89
+ print("# the number of critic : ", self.n_critic)
90
+ print("# spectral normalization : ", self.sn)
91
+
92
+ print()
93
+
94
+ print("##### Weight #####")
95
+ print("# adv_weight : ", self.adv_weight)
96
+ print("# cycle_weight : ", self.cycle_weight)
97
+ print("# identity_weight : ", self.identity_weight)
98
+ print("# cam_weight : ", self.cam_weight)
99
+
100
+ ##################################################################################
101
+ # Generator
102
+ ##################################################################################
103
+
104
+ def generator(self, x_init, reuse=False, scope="generator"):
105
+ channel = self.ch
106
+ with tf.variable_scope(scope, reuse=reuse) :
107
+ x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv')
108
+ x = instance_norm(x, scope='ins_norm')
109
+ x = relu(x)
110
+
111
+ # Down-Sampling
112
+ for i in range(2) :
113
+ x = conv(x, channel*2, kernel=3, stride=2, pad=1, pad_type='reflect', scope='conv_'+str(i))
114
+ x = instance_norm(x, scope='ins_norm_'+str(i))
115
+ x = relu(x)
116
+
117
+ channel = channel * 2
118
+
119
+ # Down-Sampling Bottleneck
120
+ for i in range(self.n_res):
121
+ x = resblock(x, channel, scope='resblock_' + str(i))
122
+
123
+
124
+ # Class Activation Map
125
+ cam_x = global_avg_pooling(x)
126
+ cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, scope='CAM_logit')
127
+ x_gap = tf.multiply(x, cam_x_weight)
128
+
129
+ cam_x = global_max_pooling(x)
130
+ cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, reuse=True, scope='CAM_logit')
131
+ x_gmp = tf.multiply(x, cam_x_weight)
132
+
133
+
134
+ cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
135
+ x = tf.concat([x_gap, x_gmp], axis=-1)
136
+
137
+ x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
138
+ x = relu(x)
139
+
140
+ heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
141
+
142
+ # Gamma, Beta block
143
+ gamma, beta = self.MLP(x, reuse=reuse)
144
+
145
+ # Up-Sampling Bottleneck
146
+ for i in range(self.n_res):
147
+ x = adaptive_ins_layer_resblock(x, channel, gamma, beta, smoothing=self.smoothing, scope='adaptive_resblock' + str(i))
148
+
149
+ # Up-Sampling
150
+ for i in range(2) :
151
+ x = up_sample(x, scale_factor=2)
152
+ x = conv(x, channel//2, kernel=3, stride=1, pad=1, pad_type='reflect', scope='up_conv_'+str(i))
153
+ x = layer_instance_norm(x, scope='layer_ins_norm_'+str(i))
154
+ x = relu(x)
155
+
156
+ channel = channel // 2
157
+
158
+
159
+ x = conv(x, channels=3, kernel=7, stride=1, pad=3, pad_type='reflect', scope='G_logit')
160
+ x = tanh(x)
161
+
162
+ return x, cam_logit, heatmap
163
+
164
+ def MLP(self, x, use_bias=True, reuse=False, scope='MLP'):
165
+ channel = self.ch * self.n_res
166
+
167
+ if self.light :
168
+ x = global_avg_pooling(x)
169
+
170
+ with tf.variable_scope(scope, reuse=reuse):
171
+ for i in range(2) :
172
+ x = fully_connected(x, channel, use_bias, scope='linear_' + str(i))
173
+ x = relu(x)
174
+
175
+
176
+ gamma = fully_connected(x, channel, use_bias, scope='gamma')
177
+ beta = fully_connected(x, channel, use_bias, scope='beta')
178
+
179
+ gamma = tf.reshape(gamma, shape=[self.batch_size, 1, 1, channel])
180
+ beta = tf.reshape(beta, shape=[self.batch_size, 1, 1, channel])
181
+
182
+ return gamma, beta
183
+
184
+ ##################################################################################
185
+ # Discriminator
186
+ ##################################################################################
187
+
188
+ def discriminator(self, x_init, reuse=False, scope="discriminator"):
189
+ D_logit = []
190
+ D_CAM_logit = []
191
+ with tf.variable_scope(scope, reuse=reuse) :
192
+ local_x, local_cam, local_heatmap = self.discriminator_local(x_init, reuse=reuse, scope='local')
193
+ global_x, global_cam, global_heatmap = self.discriminator_global(x_init, reuse=reuse, scope='global')
194
+
195
+ D_logit.extend([local_x, global_x])
196
+ D_CAM_logit.extend([local_cam, global_cam])
197
+
198
+ return D_logit, D_CAM_logit, local_heatmap, global_heatmap
199
+
200
+ def discriminator_global(self, x_init, reuse=False, scope='discriminator_global'):
201
+ with tf.variable_scope(scope, reuse=reuse):
202
+ channel = self.ch
203
+ x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0')
204
+ x = lrelu(x, 0.2)
205
+
206
+ for i in range(1, self.n_dis - 1):
207
+ x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i))
208
+ x = lrelu(x, 0.2)
209
+
210
+ channel = channel * 2
211
+
212
+ x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last')
213
+ x = lrelu(x, 0.2)
214
+
215
+ channel = channel * 2
216
+
217
+ cam_x = global_avg_pooling(x)
218
+ cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit')
219
+ x_gap = tf.multiply(x, cam_x_weight)
220
+
221
+ cam_x = global_max_pooling(x)
222
+ cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit')
223
+ x_gmp = tf.multiply(x, cam_x_weight)
224
+
225
+ cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
226
+ x = tf.concat([x_gap, x_gmp], axis=-1)
227
+
228
+ x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
229
+ x = lrelu(x, 0.2)
230
+
231
+ heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
232
+
233
+
234
+ x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit')
235
+
236
+ return x, cam_logit, heatmap
237
+
238
+ def discriminator_local(self, x_init, reuse=False, scope='discriminator_local'):
239
+ with tf.variable_scope(scope, reuse=reuse) :
240
+ channel = self.ch
241
+ x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0')
242
+ x = lrelu(x, 0.2)
243
+
244
+ for i in range(1, self.n_dis - 2 - 1):
245
+ x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i))
246
+ x = lrelu(x, 0.2)
247
+
248
+ channel = channel * 2
249
+
250
+ x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last')
251
+ x = lrelu(x, 0.2)
252
+
253
+ channel = channel * 2
254
+
255
+ cam_x = global_avg_pooling(x)
256
+ cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit')
257
+ x_gap = tf.multiply(x, cam_x_weight)
258
+
259
+ cam_x = global_max_pooling(x)
260
+ cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit')
261
+ x_gmp = tf.multiply(x, cam_x_weight)
262
+
263
+ cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
264
+ x = tf.concat([x_gap, x_gmp], axis=-1)
265
+
266
+ x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
267
+ x = lrelu(x, 0.2)
268
+
269
+ heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
270
+
271
+ x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit')
272
+
273
+ return x, cam_logit, heatmap
274
+
275
+ def generate_a2b(self, x_A, reuse=False):
276
+ out, cam, _ = self.generator(x_A, reuse=reuse, scope="generator_B")
277
+
278
+ return out, cam
279
+
280
+ def generate_b2a(self, x_B, reuse=False):
281
+ out, cam, _ = self.generator(x_B, reuse=reuse, scope="generator_A")
282
+
283
+ return out, cam
284
+ def build_model(self):
285
+ self.test_domain_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A')
286
+ self.test_domain_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B')
287
+
288
+ self.test_fake_B, _ = self.generate_a2b(self.test_domain_A)
289
+ self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
290
+
291
+ @property
292
+ def model_dir(self):
293
+ n_res = str(self.n_res) + 'resblock'
294
+ n_dis = str(self.n_dis) + 'dis'
295
+
296
+ if self.smoothing:
297
+ smoothing = '_smoothing'
298
+ else:
299
+ smoothing = ''
300
+
301
+ if self.sn:
302
+ sn = '_sn'
303
+ else:
304
+ sn = ''
305
+
306
+ return "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}{}{}".format(self.model_name, self.dataset_name,
307
+ self.gan_type, n_res, n_dis,
308
+ self.n_critic,
309
+ self.adv_weight, self.cycle_weight, self.identity_weight,
310
+ self.cam_weight, sn, smoothing)
311
+
312
+ def load(self, checkpoint_dir):
313
+ print(" [*] Reading checkpoints...")
314
+ checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
315
+
316
+ ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
317
+ if ckpt and ckpt.model_checkpoint_path:
318
+ ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
319
+ self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
320
+ counter = int(ckpt_name.split('-')[-1])
321
+ print(" [*] Success to read {}".format(ckpt_name))
322
+ return True, counter
323
+ else:
324
+ print(" [*] Failed to find a checkpoint")
325
+ return False, 0
326
+
327
+ def loadModel(self):
328
+ tf.global_variables_initializer().run(session=self.sess)
329
+
330
+ self.saver = tf.train.Saver()
331
+ could_load, checkpoint_counter = self.load(self.checkpoint_dir)
332
+ self.result_dir = os.path.join(self.result_dir, self.model_dir)
333
+ check_folder(self.result_dir)
334
+
335
+ if could_load:
336
+ print(" [*] Load SUCCESS")
337
+ else:
338
+ print(" [!] Load failed...")
339
+
340
+ def test(self, sample_file):
341
+ # A -> B
342
+ print('Processing A image: ' + sample_file)
343
+ sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
344
+ image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
345
+
346
+ fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_domain_A : sample_image})
347
+ save_images(fake_img, [1, 1], image_path)
348
+
349
+ return image_path
350
+
351
+
352
+ gan = None
353
+ def main_test(img_path):
354
+ # open session
355
+ sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
356
+ global gan
357
+ if gan is None:
358
+ gan = UgatitTest(sess)
359
+ # build graph
360
+ gan.build_model()
361
+ # show network architecture
362
+ show_all_variables()
363
+
364
+ gan.loadModel()
365
+
366
+ result = gan.test(img_path)
367
+ print(" [*] Test finished!")
368
+ print(result)
369
+ return os.path.abspath(result)
370
+
371
+ if __name__ == '__main__':
372
+ main_test('/home/hylee/cartoon/myp2c/imgs/src/im4.jpg')