qninhdt commited on
Commit
89ce6b3
·
verified ·
1 Parent(s): 504a41e

Upload 28 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/augan_result.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/augan_uncer.png filter=lfs diff=lfs merge=lfs -text
38
+ datasets/swim/testA/GP010594_frame_000017_rgb_anon.png filter=lfs diff=lfs merge=lfs -text
39
+ datasets/swim/testA/GP010594_frame_000021_rgb_anon.png filter=lfs diff=lfs merge=lfs -text
40
+ datasets/swim/testA/GP010594_frame_000087_rgb_anon.png filter=lfs diff=lfs merge=lfs -text
41
+ datasets/swim/testB/GOPR0351_frame_000159_rgb_ref_anon.png filter=lfs diff=lfs merge=lfs -text
42
+ datasets/swim/testB/GOPR0351_frame_000161_rgb_ref_anon.png filter=lfs diff=lfs merge=lfs -text
43
+ datasets/swim/testB/GOPR0355_frame_000138_rgb_ref_anon.png filter=lfs diff=lfs merge=lfs -text
AUGAN.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from models import generator_resnet, discriminator
3
+ from utils import *
4
+ from loss_utils import *
5
+ from ops import *
6
+ import time
7
+ import matplotlib.pyplot as plt
8
+ from glob import glob
9
+
10
+
11
+ class AUGAN(object):
12
+ def __init__(self, sess, args):
13
+ self.sess = sess
14
+ self.batch_size = args.batch_size
15
+ self.image_size = args.fine_size
16
+ self.input_c_dim = args.input_nc
17
+ self.output_c_dim = args.output_nc
18
+ self.L1_lambda = args.L1_lambda
19
+ self.conf_lambda = args.conf_lambda
20
+ self.dataset_dir = args.dataset_dir
21
+ self.n_d = args.n_d
22
+ self.n_scale = args.n_scale
23
+ self.ndf = args.ndf
24
+ self.load_size = args.load_size
25
+ self.fine_size = args.fine_size
26
+ self.generator = generator_resnet
27
+ self.discriminator = discriminator
28
+ if args.use_lsgan:
29
+ self.criterionGAN = mae_criterion
30
+ self.criterionGAN_list = mae_criterion_list
31
+ else:
32
+ self.criterionGAN = sce_criterion
33
+ self.criterionGAN_list = sce_criterion_list
34
+
35
+ self.use_uncertainty = args.use_uncertainty
36
+
37
+ OPTIONS = namedtuple(
38
+ "OPTIONS",
39
+ "batch_size image_size \
40
+ gf_dim df_dim output_c_dim is_training",
41
+ )
42
+ self.options = OPTIONS._make(
43
+ (
44
+ args.batch_size,
45
+ args.fine_size,
46
+ args.ngf,
47
+ args.ndf // args.n_d,
48
+ args.output_nc,
49
+ args.phase == "train",
50
+ )
51
+ )
52
+ self.save_conf = args.save_conf
53
+ self._build_model()
54
+ self.saver = tf.compat.v1.train.Saver()
55
+ self.pool = ImagePool(args.max_size)
56
+
57
+ def _build_model(self):
58
+ self.real_data = tf.compat.v1.placeholder(
59
+ tf.float32,
60
+ [
61
+ self.batch_size,
62
+ self.image_size,
63
+ self.image_size * 2,
64
+ self.input_c_dim + self.output_c_dim,
65
+ ],
66
+ name="real_A_and_B_images",
67
+ )
68
+
69
+ self.real_A = self.real_data[:, :, :, : self.input_c_dim]
70
+ self.real_B = self.real_data[
71
+ :, :, :, self.input_c_dim : self.input_c_dim + self.output_c_dim
72
+ ]
73
+
74
+ A_label = np.zeros([1, 1, 1, 2], dtype=np.float32)
75
+ B_label = np.zeros([1, 1, 1, 2], dtype=np.float32)
76
+ A_label[:, :, :, 0] = 1.0
77
+ B_label[:, :, :, 1] = 1.0
78
+ self.A_label = tf.convert_to_tensor(A_label)
79
+ self.B_label = tf.convert_to_tensor(B_label)
80
+
81
+ (
82
+ self.fake_B,
83
+ self.rec_realA,
84
+ self.realA_percep,
85
+ self.transA_percep,
86
+ self.pred_confA,
87
+ ) = self.generator(
88
+ self.real_A, self.options, transfer=True, reuse=False, name="generatorA2B"
89
+ )
90
+ self.fake_A_, self.rec_fakeB, self.fakeB_percep, _, _ = self.generator(
91
+ self.fake_B, self.options, transfer=False, reuse=False, name="generatorB2A"
92
+ )
93
+ self.fake_A, self.rec_realB, self.realB_percep, _, _ = self.generator(
94
+ self.real_B, self.options, transfer=False, reuse=True, name="generatorB2A"
95
+ )
96
+ self.fake_B_, self.rec_fakeA, self.fakeA_percep, self.trans_fakeA_percep, _ = (
97
+ self.generator(
98
+ self.fake_A,
99
+ self.options,
100
+ transfer=True,
101
+ reuse=True,
102
+ name="generatorA2B",
103
+ )
104
+ )
105
+
106
+ self.g_adv_total = 0.0
107
+ self.g_adv = 0.0
108
+ self.g_adv_rec = 0.0
109
+ self.g_adv_recfake = 0.0
110
+
111
+ self.percep_loss = tf.reduce_mean(
112
+ tf.abs(
113
+ tf.reduce_mean(self.transA_percep, axis=3)
114
+ - tf.reduce_mean(self.fakeB_percep, axis=3)
115
+ )
116
+ ) + tf.reduce_mean(
117
+ tf.abs(
118
+ tf.reduce_mean(self.realB_percep, axis=3)
119
+ - tf.reduce_mean(self.fakeA_percep, axis=3)
120
+ )
121
+ )
122
+
123
+ for i in range(self.n_d):
124
+ self.DB_fake = self.discriminator(
125
+ self.fake_B, self.options, reuse=False, name=str(i) + "_discriminatorB"
126
+ )
127
+ self.DA_fake = self.discriminator(
128
+ self.fake_A, self.options, reuse=False, name=str(i) + "_discriminatorA"
129
+ )
130
+
131
+ self.g_adv_total += self.criterionGAN_list(
132
+ self.DA_fake, get_ones_like(self.DA_fake)
133
+ ) + self.criterionGAN_list(self.DB_fake, get_ones_like(self.DB_fake))
134
+
135
+ self.g_adv += self.criterionGAN_list(
136
+ self.DA_fake, get_ones_like(self.DA_fake)
137
+ ) + self.criterionGAN_list(self.DB_fake, get_ones_like(self.DB_fake))
138
+
139
+ self.g_loss_a2b = (
140
+ self.criterionGAN_list(self.DB_fake, get_ones_like(self.DB_fake))
141
+ + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_)
142
+ + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
143
+ )
144
+ self.g_loss_b2a = (
145
+ self.criterionGAN_list(self.DA_fake, get_ones_like(self.DA_fake))
146
+ + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_)
147
+ + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
148
+ )
149
+
150
+ self.g_A_recon_loss = self.L1_lambda * abs_criterion(
151
+ self.rec_realA, self.real_A
152
+ )
153
+ self.g_B_recon_loss = self.L1_lambda * abs_criterion(
154
+ self.rec_realB, self.real_B
155
+ )
156
+ if self.use_uncertainty:
157
+ self.g_A_cycle_loss = self.conf_lambda * conf_criterion(
158
+ self.real_A, self.fake_A_, self.pred_confA
159
+ )
160
+ else:
161
+ self.g_A_cycle_loss = self.L1_lambda * abs_criterion(
162
+ self.real_A, self.fake_A_
163
+ )
164
+ self.g_B_cylce_loss = self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
165
+
166
+ self.g_loss = (
167
+ self.g_adv_total
168
+ + self.g_A_recon_loss
169
+ + self.g_B_recon_loss
170
+ + self.g_A_cycle_loss
171
+ + self.g_B_cylce_loss
172
+ + self.percep_loss
173
+ )
174
+
175
+ self.g_rec_real = abs_criterion(self.rec_realA, self.real_A) + abs_criterion(
176
+ self.rec_realB, self.real_B
177
+ )
178
+ self.g_rec_cycle = abs_criterion(self.real_A, self.fake_A_) + abs_criterion(
179
+ self.real_B, self.fake_B_
180
+ )
181
+
182
+ self.fake_A_sample = tf.compat.v1.placeholder(
183
+ tf.float32,
184
+ [self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
185
+ name="fake_A_sample",
186
+ )
187
+ self.fake_B_sample = tf.compat.v1.placeholder(
188
+ tf.float32,
189
+ [self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
190
+ name="fake_B_sample",
191
+ )
192
+ self.rec_A_sample = tf.compat.v1.placeholder(
193
+ tf.float32,
194
+ [self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
195
+ name="rec_A_sample",
196
+ )
197
+ self.rec_B_sample = tf.compat.v1.placeholder(
198
+ tf.float32,
199
+ [self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
200
+ name="rec_B_sample",
201
+ )
202
+ self.rec_fakeA_sample = tf.compat.v1.placeholder(
203
+ tf.float32,
204
+ [self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
205
+ name="rec_fakeA_sample",
206
+ )
207
+ self.rec_fakeB_sample = tf.compat.v1.placeholder(
208
+ tf.float32,
209
+ [self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
210
+ name="rec_fakeB_sample",
211
+ )
212
+
213
+ self.d_loss_item = []
214
+ self.d_loss_item_rec = []
215
+ self.d_loss_item_recfake = []
216
+
217
+ for i in range(self.n_d):
218
+ self.DB_real = self.discriminator(
219
+ self.real_B, self.options, reuse=True, name=str(i) + "_discriminatorB"
220
+ )
221
+ self.DA_real = self.discriminator(
222
+ self.real_A, self.options, reuse=True, name=str(i) + "_discriminatorA"
223
+ )
224
+ self.DB_fake_sample = self.discriminator(
225
+ self.fake_B_sample,
226
+ self.options,
227
+ reuse=True,
228
+ name=str(i) + "_discriminatorB",
229
+ )
230
+ self.DA_fake_sample = self.discriminator(
231
+ self.fake_A_sample,
232
+ self.options,
233
+ reuse=True,
234
+ name=str(i) + "_discriminatorA",
235
+ )
236
+ self.db_loss_real = self.criterionGAN_list(
237
+ self.DB_real, get_ones_like(self.DB_real)
238
+ )
239
+ self.db_loss_fake = self.criterionGAN_list(
240
+ self.DB_fake_sample, get_zeros_like(self.DB_fake_sample)
241
+ )
242
+ self.db_loss = self.db_loss_real * 0.5 + self.db_loss_fake * 0.5
243
+ self.da_loss_real = self.criterionGAN_list(
244
+ self.DA_real, get_ones_like(self.DA_real)
245
+ )
246
+ self.da_loss_fake = self.criterionGAN_list(
247
+ self.DA_fake_sample, get_zeros_like(self.DA_fake_sample)
248
+ )
249
+ self.da_loss = self.da_loss_real * 0.5 + self.da_loss_fake * 0.5
250
+ self.d_loss = self.da_loss + self.db_loss
251
+ self.d_loss_item.append(self.d_loss)
252
+
253
+ self.g_loss_a2b_sum = tf.compat.v1.summary.scalar("g_loss_a2b", self.g_loss_a2b)
254
+ self.g_loss_b2a_sum = tf.compat.v1.summary.scalar("g_loss_b2a", self.g_loss_b2a)
255
+ self.g_loss_sum = tf.compat.v1.summary.scalar("g_loss", self.g_loss)
256
+ self.g_sum = tf.compat.v1.summary.merge(
257
+ [self.g_loss_a2b_sum, self.g_loss_b2a_sum, self.g_loss_sum]
258
+ )
259
+ self.db_loss_sum = tf.compat.v1.summary.scalar("db_loss", self.db_loss)
260
+ self.da_loss_sum = tf.compat.v1.summary.scalar("da_loss", self.da_loss)
261
+ self.d_loss_sum = tf.compat.v1.summary.scalar("d_loss", self.d_loss)
262
+ self.db_loss_real_sum = tf.compat.v1.summary.scalar(
263
+ "db_loss_real", self.db_loss_real
264
+ )
265
+ self.db_loss_fake_sum = tf.compat.v1.summary.scalar(
266
+ "db_loss_fake", self.db_loss_fake
267
+ )
268
+ self.da_loss_real_sum = tf.compat.v1.summary.scalar(
269
+ "da_loss_real", self.da_loss_real
270
+ )
271
+ self.da_loss_fake_sum = tf.compat.v1.summary.scalar(
272
+ "da_loss_fake", self.da_loss_fake
273
+ )
274
+ self.d_sum = tf.compat.v1.summary.merge(
275
+ [
276
+ self.da_loss_sum,
277
+ self.da_loss_real_sum,
278
+ self.da_loss_fake_sum,
279
+ self.db_loss_sum,
280
+ self.db_loss_real_sum,
281
+ self.db_loss_fake_sum,
282
+ self.d_loss_sum,
283
+ ]
284
+ )
285
+
286
+ self.test_A = tf.compat.v1.placeholder(
287
+ tf.float32,
288
+ [self.batch_size, self.image_size, self.image_size * 2, self.input_c_dim],
289
+ name="test_A",
290
+ )
291
+ self.test_B = tf.compat.v1.placeholder(
292
+ tf.float32,
293
+ [self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
294
+ name="test_B",
295
+ )
296
+
297
+ (
298
+ self.testB,
299
+ self.rec_testA,
300
+ self.testA_percep,
301
+ self.trans_testA_percep,
302
+ self.test_pred_confA,
303
+ ) = self.generator(
304
+ self.test_A, self.options, transfer=True, reuse=True, name="generatorA2B"
305
+ )
306
+ self.rec_cycle_A, self.refine_testB, self.testB_percep, _, _ = self.generator(
307
+ self.testB, self.options, transfer=False, reuse=True, name="generatorB2A"
308
+ )
309
+
310
+ self.testA, self.rec_testB, _, _, _ = self.generator(
311
+ self.test_B, self.options, transfer=False, reuse=True, name="generatorB2A"
312
+ )
313
+ self.rec_cycle_B, self.refine_testA, _, _, _ = self.generator(
314
+ self.testA, self.options, True, True, name="generatorA2B"
315
+ )
316
+
317
+ t_vars = tf.compat.v1.trainable_variables()
318
+
319
+ self.g_vars = [var for var in t_vars if "generator" in var.name]
320
+ self.p_vars = [var for var in t_vars if "percep" in var.name]
321
+ self.d_vars_item = []
322
+ for i in range(self.n_d):
323
+ self.d_vars = [
324
+ var for var in t_vars if str(i) + "_discriminator" in var.name
325
+ ]
326
+ self.d_vars_item.append(self.d_vars)
327
+
328
+ def train(self, args):
329
+
330
+ self.lr = tf.compat.v1.placeholder(tf.float32, None, name="learning_rate")
331
+
332
+ ### generator
333
+ self.g_optim = tf.optimizers.Adam(
334
+ learning_rate=self.lr, beta_1=args.beta1
335
+ ).minimize(self.g_loss, var_list=self.g_vars, tape=None)
336
+
337
+ ### translation
338
+ self.d_optim_item = []
339
+ for i in range(self.n_d):
340
+ self.d_optim = tf.optimizers.Adam(
341
+ learning_rate=self.lr, beta_1=args.beta1
342
+ ).minimize(self.g_loss, var_list=self.g_vars, tape=None)
343
+ self.d_optim_item.append(self.d_optim)
344
+
345
+ init_op = tf.compat.v1.global_variables_initializer()
346
+ self.sess.run(init_op)
347
+ self.writer = tf.summary.FileWriter(
348
+ os.path.join(args.checkpoint_dir, "logs"), self.sess.graph
349
+ )
350
+
351
+ counter = 1
352
+ start_time = time.time()
353
+
354
+ if args.continue_train:
355
+ if self.load(args.checkpoint_dir):
356
+ print(" [*] Load SUCCESS")
357
+ else:
358
+ print(" [!] Load failed...")
359
+
360
+ print("Training.........................")
361
+ for epoch in range(args.epoch):
362
+ dataA = glob("./datasets/{}/*.*".format(self.dataset_dir + "/trainA"))
363
+ dataB = glob("./datasets/{}/*.*".format(self.dataset_dir + "/trainB"))
364
+ if (len(dataA) == 0) or (len(dataB) == 0):
365
+ raise Exception("No files found in the dataset")
366
+ else:
367
+ print(
368
+ "Data found in the dataset. length of A: ",
369
+ len(dataA),
370
+ " B: ",
371
+ len(dataB),
372
+ )
373
+ np.random.shuffle(dataA)
374
+ np.random.shuffle(dataB)
375
+ batch_idxs = (
376
+ min(min(len(dataA), len(dataB)), args.train_size) // self.batch_size
377
+ )
378
+ lr = (
379
+ args.lr
380
+ if epoch < args.epoch_step
381
+ else args.lr * (args.epoch - epoch) / (args.epoch - args.epoch_step)
382
+ )
383
+
384
+ for idx in range(0, batch_idxs):
385
+ print("Epoch: [%2d] [%4d/%4d] " % (epoch, idx, batch_idxs))
386
+ batch_files = list(
387
+ zip(
388
+ dataA[idx * self.batch_size : (idx + 1) * self.batch_size],
389
+ dataB[idx * self.batch_size : (idx + 1) * self.batch_size],
390
+ )
391
+ )
392
+ batch_images = [
393
+ load_train_data(batch_file, args.load_size, args.fine_size)
394
+ for batch_file in batch_files
395
+ ]
396
+ batch_images = np.array(batch_images).astype(np.float32)
397
+ # Update G network and record fake outputs
398
+ print("Training G network----------------------")
399
+ (
400
+ fake_A,
401
+ fake_B,
402
+ rec_A,
403
+ rec_B,
404
+ rec_fake_A,
405
+ rec_fake_B,
406
+ _,
407
+ g_loss,
408
+ gan_loss,
409
+ percep,
410
+ g_adv,
411
+ g_A_recon_loss,
412
+ g_B_recon_loss,
413
+ g_A_cycle_loss,
414
+ g_B_cycle_loss,
415
+ summary_str,
416
+ ) = self.sess.run(
417
+ [
418
+ self.fake_A,
419
+ self.fake_B,
420
+ self.rec_realA,
421
+ self.rec_realB,
422
+ self.rec_fakeA,
423
+ self.rec_fakeB,
424
+ self.g_optim,
425
+ self.g_loss,
426
+ self.g_adv_total,
427
+ self.percep_loss,
428
+ self.g_adv,
429
+ self.g_A_recon_loss,
430
+ self.g_B_recon_loss,
431
+ self.g_A_cycle_loss,
432
+ self.g_B_cylce_loss,
433
+ self.g_sum,
434
+ ],
435
+ feed_dict={self.real_data: batch_images, self.lr: lr},
436
+ )
437
+ self.writer.add_summary(summary_str, counter)
438
+ [fake_A, fake_B] = self.pool([fake_A, fake_B])
439
+
440
+ # Update D network
441
+ print("Training D network----------------------")
442
+ loss_print = []
443
+ for i in range(self.n_d):
444
+ _, d_loss, d_sum = self.sess.run(
445
+ [self.d_optim_item[i], self.d_loss_item[i], self.d_sum],
446
+ feed_dict={
447
+ self.real_data: batch_images,
448
+ self.fake_A_sample: fake_A,
449
+ self.fake_B_sample: fake_B,
450
+ self.lr: lr,
451
+ },
452
+ )
453
+
454
+ loss_print.append(d_loss)
455
+
456
+ counter += 1
457
+ print(
458
+ (
459
+ "Epoch: [%2d] [%4d/%4d] time: %4.4f g_loss: %4.4f gan:%4.4f adv:%4.4f g_percep:%4.4f "
460
+ % (
461
+ epoch,
462
+ idx,
463
+ batch_idxs,
464
+ time.time() - start_time,
465
+ g_loss,
466
+ gan_loss,
467
+ g_adv,
468
+ percep,
469
+ )
470
+ )
471
+ )
472
+
473
+ if np.mod(counter, args.print_freq) == 1:
474
+ self.sample_model(args.sample_dir, epoch, idx)
475
+
476
+ if np.mod(counter, args.save_freq) == 2:
477
+ self.save(args.checkpoint_dir, counter)
478
+
479
+ def save(self, checkpoint_dir, step):
480
+ model_name = "cyclegan.model"
481
+ model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
482
+ checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
483
+
484
+ if not os.path.exists(checkpoint_dir):
485
+ os.makedirs(checkpoint_dir)
486
+
487
+ self.saver.save(
488
+ self.sess, os.path.join(checkpoint_dir, model_name), global_step=step
489
+ )
490
+
491
+ def load(self, checkpoint_dir):
492
+ print(" [*] Reading checkpoint...")
493
+
494
+ model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
495
+ checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
496
+
497
+ ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
498
+ if ckpt and ckpt.model_checkpoint_path:
499
+ ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
500
+ self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
501
+ return True
502
+ else:
503
+ return False
504
+
505
+ def sample_model(self, sample_dir, epoch, idx):
506
+ dataA = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testA"))
507
+ dataB = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testB"))
508
+ if (len(dataA) == 0) or (len(dataB) == 0):
509
+ raise Exception("No files found in the test directory")
510
+ np.random.shuffle(dataA)
511
+ np.random.shuffle(dataB)
512
+ batch_files = list(zip(dataA[: self.batch_size], dataB[: self.batch_size]))
513
+ sample_images = [
514
+ load_train_data(batch_file, self.load_size, self.fine_size, is_testing=True)
515
+ for batch_file in batch_files
516
+ ]
517
+ sample_images = np.array(sample_images).astype(np.float32)
518
+
519
+ fake_A, fake_B = self.sess.run(
520
+ [self.fake_A, self.fake_B], feed_dict={self.real_data: sample_images}
521
+ )
522
+ real_A = sample_images[:, :, :, :3]
523
+ real_B = sample_images[:, :, :, 3:]
524
+
525
+ merge_A = np.concatenate([real_B, fake_A], axis=2)
526
+ merge_B = np.concatenate([real_A, fake_B], axis=2)
527
+ check_folder("./{}/{:02d}".format(sample_dir, epoch))
528
+ save_images(
529
+ merge_A,
530
+ [self.batch_size, 1],
531
+ "./{}/{:02d}/A_{:04d}.jpg".format(sample_dir, epoch, idx),
532
+ )
533
+ save_images(
534
+ merge_B,
535
+ [self.batch_size, 1],
536
+ "./{}/{:02d}/B_{:04d}.jpg".format(sample_dir, epoch, idx),
537
+ )
538
+
539
+ def test(self, args):
540
+ total_time = 0
541
+
542
+ init_op = tf.compat.v1.global_variables_initializer()
543
+ self.sess.run(init_op)
544
+ if args.which_direction == "AtoB":
545
+ sample_files = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testA"))
546
+ elif args.which_direction == "BtoA":
547
+ sample_files = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testB"))
548
+ else:
549
+ raise Exception("--which_direction must be AtoB or BtoA")
550
+
551
+ if len(sample_files) == 0:
552
+ raise Exception("No files found in the test directory")
553
+
554
+ # print(sample_files)
555
+
556
+ if self.load(args.checkpoint_dir):
557
+ print(" [*] Load SUCCESS")
558
+ else:
559
+ print(" [!] Load failed...")
560
+ out_var, refine_var, in_var, rec_var, cycle_var, percep_var, conf_var = (
561
+ (
562
+ self.testB,
563
+ self.refine_testB,
564
+ self.test_A,
565
+ self.rec_testA,
566
+ self.rec_cycle_A,
567
+ self.testA_percep,
568
+ self.test_pred_confA,
569
+ )
570
+ if args.which_direction == "AtoB"
571
+ else (
572
+ self.testA,
573
+ self.refine_testA,
574
+ self.test_B,
575
+ self.rec_testB,
576
+ self.rec_cycle_B,
577
+ self.testB_percep,
578
+ self.test_pred_confA,
579
+ )
580
+ )
581
+ for sample_file in sample_files:
582
+ # print('Processing image: ' + sample_file)
583
+ sample_image = [load_test_data(sample_file, args.fine_size)]
584
+ start_time = time.time()
585
+ sample_image = np.array(sample_image).astype(np.float32)
586
+ image_path = os.path.join(
587
+ args.test_dir,
588
+ "{0}_{1}".format(args.which_direction, os.path.basename(sample_file)),
589
+ )
590
+ ori_path = os.path.join(
591
+ args.test_dir,
592
+ "{0}_{1}".format("ori", os.path.basename(sample_file)),
593
+ )
594
+ conf_path = os.path.join(
595
+ args.conf_dir,
596
+ "{0}_{1}".format(args.which_direction, os.path.basename(sample_file)),
597
+ )
598
+
599
+ (fake_img,) = self.sess.run([out_var], feed_dict={in_var: sample_image})
600
+ end_time = time.time()
601
+ # merge = np.concatenate([sample_image, fake_img], axis=2)
602
+ save_images(fake_img[0], [1], image_path)
603
+ save_images(sample_image[0], [1], ori_path)
604
+ # save_images(merge, [1, 1], image_path)
605
+ total_time = total_time + (end_time - start_time)
606
+
607
+ if args.save_conf:
608
+
609
+ if args.which_direction == "AtoB":
610
+ pass
611
+ else:
612
+ raise Exception(
613
+ "--conf map only can be estimated in AtoB direction"
614
+ )
615
+
616
+ conf_img = self.sess.run(conf_var, feed_dict={in_var: sample_image})
617
+ conf_img_sq = np.squeeze(conf_img)
618
+ plt.imshow(
619
+ conf_img_sq, cmap="plasma", interpolation="nearest", alpha=1.0
620
+ )
621
+ plt.savefig(conf_path)
622
+ print(
623
+ f"Average time taken to convert images: {total_time/len(sample_files)} seconds"
624
+ )
625
+
626
+ def convert(self, args, datadir="./inf_data"):
627
+ total_time = 0
628
+
629
+ init_op = tf.compat.v1.global_variables_initializer()
630
+ self.sess.run(init_op)
631
+
632
+ if self.load(args.checkpoint_dir):
633
+ print(" [*] Load SUCCESS")
634
+ else:
635
+ raise Exception("-- Cannot Load Model. Train or Add model first")
636
+
637
+ if args.which_direction == "AtoB":
638
+ sample_files = glob(datadir)
639
+ elif args.which_direction == "BtoA":
640
+ sample_files = glob(datadir)
641
+ else:
642
+ raise Exception("--which_direction must be AtoB or BtoA")
643
+
644
+ print(sample_files)
645
+
646
+ out_var, refine_var, in_var, rec_var, cycle_var, percep_var, conf_var = (
647
+ (
648
+ self.testB,
649
+ self.refine_testB,
650
+ self.test_A,
651
+ self.rec_testA,
652
+ self.rec_cycle_A,
653
+ self.testA_percep,
654
+ self.test_pred_confA,
655
+ )
656
+ if args.which_direction == "AtoB"
657
+ else (
658
+ self.testA,
659
+ self.refine_testA,
660
+ self.test_B,
661
+ self.rec_testB,
662
+ self.rec_cycle_B,
663
+ self.testB_percep,
664
+ self.test_pred_confA,
665
+ )
666
+ )
667
+ for sample_file in sample_files:
668
+ print("Processing image: " + sample_file)
669
+ sample_image = [load_test_data(sample_file, args.fine_size)]
670
+ start_time = time.time()
671
+ sample_image = np.array(sample_image).astype(np.float32)
672
+ image_path = os.path.join(
673
+ args.test_dir,
674
+ "{0}_{1}".format(args.which_direction, os.path.basename(sample_file)),
675
+ )
676
+ conf_path = os.path.join(
677
+ args.conf_dir,
678
+ "{0}_{1}".format(args.which_direction, os.path.basename(sample_file)),
679
+ )
680
+
681
+ (fake_img,) = self.sess.run([out_var], feed_dict={in_var: sample_image})
682
+ end_time = time.time()
683
+ merge = np.concatenate([sample_image, fake_img], axis=2)
684
+ save_images(merge, [1, 1], image_path)
685
+ total_time = total_time + (end_time - start_time)
686
+ print(f"Time taken to convert image: {end_time - start_time} seconds")
687
+
688
+ if args.save_conf:
689
+
690
+ if args.which_direction == "AtoB":
691
+ pass
692
+ else:
693
+ raise Exception(
694
+ "--conf map only can be estimated in AtoB direction"
695
+ )
696
+
697
+ conf_img = self.sess.run(conf_var, feed_dict={in_var: sample_image})
698
+ conf_img_sq = np.squeeze(conf_img)
699
+ plt.imshow(
700
+ conf_img_sq, cmap="plasma", interpolation="nearest", alpha=1.0
701
+ )
702
+ plt.savefig(conf_path)
703
+ print(
704
+ f"Average time taken to convert images: {total_time/len(sample_files)} seconds"
705
+ )
706
+
707
+ def convert_image(self, args, input_image_path, output_dir):
708
+ init_op = tf.compat.v1.global_variables_initializer()
709
+ if self.load(args.checkpoint_dir):
710
+ print(" [*] Load SUCCESS")
711
+ with tf.Session() as sess:
712
+ sess.run(init_op)
713
+ # Load the input image
714
+ input_image = [load_test_data(input_image_path, self.fine_size)]
715
+ input_image = np.array(input_image).astype(np.float32)
716
+
717
+ # Get the generator output
718
+ if args.which_direction == "AtoB":
719
+ out_var = self.testB
720
+ in_var = self.test_A
721
+ else:
722
+ out_var = self.testA
723
+ in_var = self.test_B
724
+
725
+ # Run the model to obtain the converted image
726
+ start_time = time.time()
727
+ converted_image = sess.run(out_var, feed_dict={in_var: input_image})
728
+ end_time = time.time()
729
+
730
+ # Save the converted image
731
+ output_image_path = os.path.join(
732
+ output_dir, os.path.basename(input_image_path)
733
+ )
734
+ merge = np.concatenate([input_image, converted_image], axis=2)
735
+ save_images(merge, [1, 1], output_image_path)
736
+
737
+ # Print the time taken
738
+ print(f"Time taken to convert image: {end_time - start_time} seconds")
README.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adverse Weather Image Translation with Asymmetric and Uncertainty-aware GAN (AU-GAN)
2
+ Official Tensorflow implementation of [Adverse Weather Image Translation with Asymmetric and Uncertainty-aware GAN](https://www.bmvc2021-virtualconference.com/assets/papers/1443.pdf) (AU-GAN)\
3
+ Jeong-gi Kwak, Youngsaeng Jin, Yuanming Li, Dongsik Yoon, Donghyeon Kim and Hanseok Ko </br>
4
+ *British Machine Vision Conference (BMVC), 2021*
5
+ </br>
6
+
7
+ ## Intro
8
+
9
+ ### Night &rarr; Day ([BDD100K](https://bdd-data.berkeley.edu/))
10
+ <img src="./assets/augan_bdd.png" width="800">
11
+
12
+ ### Rainy night &rarr; Day ([Alderdey](https://wiki.qut.edu.au/pages/viewpage.action?pageId=181178395))
13
+ <img src="./assets/augan_alderley.png" width="800">
14
+ </br>
15
+
16
+
17
+ ## Architecture
18
+ <img src="./assets/augan_model.png" width="800">
19
+ Our generator has asymmetric structure for editing day&rarr;night and night&rarr;day.
20
+ Please refer our paper for details
21
+
22
+ ## **Envs**
23
+
24
+ ```bash
25
+
26
+ git clone https://github.com/jgkwak95/AU-GAN.git
27
+ cd AU-GAN
28
+
29
+ # Create virtual environment
30
+ conda create -y --name augan python=3.6.7
31
+ conda activate augan
32
+
33
+ conda install tensorflow-gpu==1.14.0 # Tensorflow 1.14
34
+ pip install --no-cache-dir -r requirements.txt
35
+
36
+ ```
37
+
38
+ ## **Preparing datasets**
39
+
40
+ **Night &rarr; Day** </br>
41
+ [Berkeley DeepDrive dataset](https://bdd-data.berkeley.edu/) contains 100,000 high resolution images of the urban roads for autonomous driving.</br></br>
42
+ **Rainy night &rarr; Day** </br>
43
+ [Alderley dataset](https://wiki.qut.edu.au/pages/viewpage.action?pageId=181178395) consists of images of two domains,
44
+ rainy night and daytime. It was collected while driving the same route in each weather environment.</br>
45
+ </br>
46
+ Please download datasets and then construct them following [ForkGAN](https://github.com/zhengziqiang/ForkGAN)
47
+
48
+ ## Pretrained Model
49
+
50
+ Download the pretrained model for BDD100K(256x512) [here](https://drive.google.com/file/d/1rvIF3yE9MwPWj0kD4IEstETyMQXYAHzr/view?usp=sharing) and unzip it to ./check/bdd_exp/bdd100k_256/
51
+
52
+ ## Training
53
+
54
+ ```bash
55
+
56
+ # Alderley (256x512)
57
+ python main_uncer.py --dataset_dir alderley
58
+ --phase train
59
+ --experiment_name alderley_exp
60
+ --batch_size 8
61
+ --load_size 286
62
+ --fine_size 256
63
+ --use_uncertainty True
64
+
65
+ ```
66
+
67
+ ```bash
68
+
69
+ # BDD100k (256x512)
70
+ python main_uncer.py --dataset_dir bdd100k
71
+ --phase train
72
+ --experiment_name bdd_exp
73
+ --batch_size 8
74
+ --load_size 286
75
+ --fine_size 256
76
+ --use_uncertainty True
77
+
78
+ ```
79
+
80
+ ## Test
81
+
82
+ ```bash
83
+
84
+ # Alderley (256x512)
85
+ python main_uncer.py --dataset_dir alderley
86
+ --phase test
87
+ --experiment_name alderley_exp
88
+ --batch_size 1
89
+ --load_size 286
90
+ --fine_size 256
91
+
92
+ ```
93
+
94
+ ```bash
95
+
96
+ # BDD100k (256x512)
97
+ python main_uncer.py --dataset_dir bdd100k
98
+ --phase test
99
+ --experiment_name bdd_exp
100
+ --batch_size 1
101
+ --load_size 286
102
+ --fine_size 256
103
+
104
+
105
+ ```
106
+ ## Additional results
107
+ <img src="./assets/augan_result.png" width="800">
108
+
109
+ More results in [paper](https://www.bmvc2021-virtualconference.com/assets/papers/1443.pdf) and [supplementary]()
110
+
111
+ ## Uncertainty map
112
+ <img src="./assets/augan_uncer.png" width="800">
113
+
114
+ ## **Citation**
115
+ If our code is helpful your research, please cite our paper:
116
+ ```
117
+ @article{kwak2021adverse,
118
+ title={Adverse weather image translation with asymmetric and uncertainty-aware GAN},
119
+ author={Kwak, Jeong-gi and Jin, Youngsaeng and Li, Yuanming and Yoon, Dongsik and Kim, Donghyeon and Ko, Hanseok},
120
+ journal={arXiv preprint arXiv:2112.04283},
121
+ year={2021}
122
+ }
123
+ ```
124
+ ## Acknowledgments
125
+ Our code is bulided upon the [ForkGAN](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123480154.pdf) implementation.
__pycache__/AUGAN.cpython-36.pyc ADDED
Binary file (14.8 kB). View file
 
__pycache__/loss_utils.cpython-36.pyc ADDED
Binary file (1.62 kB). View file
 
__pycache__/models.cpython-36.pyc ADDED
Binary file (4.3 kB). View file
 
__pycache__/ops.cpython-36.pyc ADDED
Binary file (6.78 kB). View file
 
__pycache__/utils.cpython-36.pyc ADDED
Binary file (4.68 kB). View file
 
assets/augan_alderley.png ADDED
assets/augan_bdd.png ADDED
assets/augan_model.png ADDED
assets/augan_result.png ADDED

Git LFS Details

  • SHA256: 896e866d3e3883df451964b53436267c751ea52c77d0780a195a4b093504197f
  • Pointer size: 132 Bytes
  • Size of remote file: 7.59 MB
assets/augan_uncer.png ADDED

Git LFS Details

  • SHA256: 175b60024e17eabcb2c481b7a069811b84815fc4b0d502f19d4fc69c23d135f8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.02 MB
cc.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ python main.py --dataset_dir swim \
2
+ --phase test \
3
+ --experiment_name bdd_exp \
4
+ --batch_size 1 \
5
+ --which_direction BtoA \
6
+ --load_size 286 \
7
+ --fine_size 256
check.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fc7f6d4f5f9c503bc69e1fdb454d8cc4f652d8c8966875dcd46aab83cdd0ff4
3
+ size 173513070
datasets/swim/testA/GP010594_frame_000017_rgb_anon.png ADDED

Git LFS Details

  • SHA256: a92a10b1852c3e2ad652830f26999e2871f703bf8a4ddb3953eacc30ea392ca2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.38 MB
datasets/swim/testA/GP010594_frame_000021_rgb_anon.png ADDED

Git LFS Details

  • SHA256: b95aeb7cf482a097f44ca074dd184089adb721825693dfd59c2e19454e16f730
  • Pointer size: 132 Bytes
  • Size of remote file: 1.46 MB
datasets/swim/testA/GP010594_frame_000087_rgb_anon.png ADDED

Git LFS Details

  • SHA256: a7c5162e2915fff348a3f4e8c15186f1abbfdb42b2e94687e478a4894cf72666
  • Pointer size: 132 Bytes
  • Size of remote file: 2.15 MB
datasets/swim/testB/GOPR0351_frame_000159_rgb_ref_anon.png ADDED

Git LFS Details

  • SHA256: 67a5f90e9a07e6c2600c9815fcbd971a4c1b8a922347341d295af9fb68cde1d1
  • Pointer size: 132 Bytes
  • Size of remote file: 2.02 MB
datasets/swim/testB/GOPR0351_frame_000161_rgb_ref_anon.png ADDED

Git LFS Details

  • SHA256: eefd9cd2cc6ab1299912d2422e252bb1d63fd7823fbefa7167c588ada74a3f2c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.93 MB
datasets/swim/testB/GOPR0355_frame_000138_rgb_ref_anon.png ADDED

Git LFS Details

  • SHA256: 8d813c52a0cc39243c706f7528bf6a6974f53f65b4b14471f1afb67b95bb3049
  • Pointer size: 132 Bytes
  • Size of remote file: 1.88 MB
inference.py ADDED
File without changes
loss_utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ epsilon = 1e-7
4
+
5
+ def conf_criterion_lp(im1, im2, conf_sigma): # factorized laplacian distribution
6
+ loss = tf.abs(im1 - im2)
7
+ if conf_sigma is not None:
8
+ loss = loss * 2 / (conf_sigma + epsilon) + tf.log(conf_sigma * 2 + epsilon)
9
+ loss = tf.reduce_mean(loss)
10
+ else:
11
+ loss = tf.reduce_mean(loss)
12
+
13
+ return loss
14
+
15
+
16
+
17
+ def conf_criterion(im1, im2, conf_sigma): # gaussian distribution
18
+ loss = tf.abs(im1 - im2)
19
+ if conf_sigma is not None:
20
+ loss = tf.math.exp(-conf_sigma) * 5 * loss + conf_sigma / 2
21
+ loss = tf.reduce_mean(loss)
22
+ else:
23
+ loss = tf.reduce_mean(loss)
24
+
25
+ return loss
26
+
27
+
28
+ def abs_criterion(in_, target):
29
+ return tf.reduce_mean(tf.abs(in_ - target))
30
+
31
+
32
+ def mae_criterion(in_, target):
33
+ return tf.reduce_mean((in_ - target) ** 2)
34
+
35
+
36
+ def sce_criterion(logits, labels):
37
+ return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
38
+
39
+
40
+ def mae_criterion_list(in_, target):
41
+ loss = 0.0
42
+ for i in range(len(target)):
43
+ loss += tf.reduce_mean((in_[i] - target[i]) ** 2)
44
+ return loss / len(target)
45
+
46
+
47
+ def sce_criterion_list(logits, labels):
48
+ loss = 0.0
49
+ for i in range(len(labels)):
50
+ loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits[i], labels=labels[i]))
51
+ return loss / len(labels)
main.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import tensorflow as tf
3
+ import os
4
+ from utils import *
5
+ from AUGAN import AUGAN
6
+ from ops import *
7
+ import time
8
+
9
+ parser = argparse.ArgumentParser(description="")
10
+ parser.add_argument(
11
+ "--dataset_dir", dest="dataset_dir", default="bdd100k", help="path of the dataset"
12
+ )
13
+ parser.add_argument(
14
+ "--experiment_name",
15
+ dest="experiment_name",
16
+ type=str,
17
+ default="bdd_exp",
18
+ help="name of experiment",
19
+ )
20
+ parser.add_argument("--epoch", dest="epoch", type=int, default=20, help="# of epoch")
21
+ parser.add_argument(
22
+ "--epoch_step",
23
+ dest="epoch_step",
24
+ type=int,
25
+ default=10,
26
+ help="# of epoch to decay lr",
27
+ )
28
+ parser.add_argument(
29
+ "--batch_size", dest="batch_size", type=int, default=1, help="# images in batch"
30
+ )
31
+ parser.add_argument(
32
+ "--train_size",
33
+ dest="train_size",
34
+ type=int,
35
+ default=1e8,
36
+ help="# images used to train",
37
+ )
38
+ parser.add_argument(
39
+ "--load_size",
40
+ dest="load_size",
41
+ type=int,
42
+ default=286,
43
+ help="scale images to this size",
44
+ )
45
+ parser.add_argument(
46
+ "--fine_size",
47
+ dest="fine_size",
48
+ type=int,
49
+ default=256,
50
+ help="then crop to this size",
51
+ )
52
+ parser.add_argument(
53
+ "--ngf",
54
+ dest="ngf",
55
+ type=int,
56
+ default=64,
57
+ help="# of gen filters in first conv layer",
58
+ )
59
+ parser.add_argument(
60
+ "--ndf",
61
+ dest="ndf",
62
+ type=int,
63
+ default=64,
64
+ help="# of discri filters in first conv layer",
65
+ )
66
+ parser.add_argument(
67
+ "--n_d", dest="n_d", type=int, default=2, help="# of discriminators"
68
+ )
69
+ parser.add_argument(
70
+ "--n_scale", dest="n_scale", type=int, default=2, help="# of scales"
71
+ )
72
+ parser.add_argument(
73
+ "--gpu", dest="gpu", type=int, default=0, help="# index of gpu device"
74
+ )
75
+ parser.add_argument(
76
+ "--input_nc", dest="input_nc", type=int, default=3, help="# of input image channels"
77
+ )
78
+ parser.add_argument(
79
+ "--output_nc",
80
+ dest="output_nc",
81
+ type=int,
82
+ default=3,
83
+ help="# of output image channels",
84
+ )
85
+ parser.add_argument(
86
+ "--lr", dest="lr", type=float, default=0.0002, help="initial learning rate for adam"
87
+ )
88
+ parser.add_argument(
89
+ "--beta1", dest="beta1", type=float, default=0.5, help="momentum term of adam"
90
+ )
91
+ parser.add_argument(
92
+ "--which_direction", dest="which_direction", default="AtoB", help="AtoB or BtoA "
93
+ )
94
+ parser.add_argument("--phase", dest="phase", default="test", help="train, test")
95
+ parser.add_argument(
96
+ "--save_freq",
97
+ dest="save_freq",
98
+ type=int,
99
+ default=1000,
100
+ help="save a model every save_freq iterations",
101
+ )
102
+ parser.add_argument(
103
+ "--print_freq",
104
+ dest="print_freq",
105
+ type=int,
106
+ default=100,
107
+ help="print the debug information every print_freq iterations",
108
+ )
109
+ parser.add_argument(
110
+ "--L1_lambda",
111
+ dest="L1_lambda",
112
+ type=float,
113
+ default=10.0,
114
+ help="weight on L1 term in objective",
115
+ )
116
+ parser.add_argument(
117
+ "--conf_lambda",
118
+ dest="conf_lambda",
119
+ type=float,
120
+ default=1.0,
121
+ help="weight on L1 term in objective",
122
+ )
123
+ parser.add_argument(
124
+ "--use_resnet",
125
+ dest="use_resnet",
126
+ type=bool,
127
+ default=True,
128
+ help="generation network using reidule block",
129
+ )
130
+ parser.add_argument(
131
+ "--use_lsgan",
132
+ dest="use_lsgan",
133
+ type=bool,
134
+ default=True,
135
+ help="gan loss defined in lsgan",
136
+ )
137
+ parser.add_argument(
138
+ "--use_uncertainty",
139
+ dest="use_uncertainty",
140
+ type=bool,
141
+ default=True,
142
+ help="max size of image pool, 0 means do not use image pool",
143
+ )
144
+ parser.add_argument(
145
+ "--max_size",
146
+ dest="max_size",
147
+ type=int,
148
+ default=50,
149
+ help="max size of image pool, 0 means do not use image pool",
150
+ )
151
+ parser.add_argument(
152
+ "--continue_train",
153
+ dest="continue_train",
154
+ type=bool,
155
+ default=False,
156
+ help="if continue training, load the latest model: 1: true, 0: false",
157
+ )
158
+ parser.add_argument(
159
+ "--save_conf",
160
+ dest="save_conf",
161
+ type=bool,
162
+ default=False,
163
+ help="save conf map in test phase",
164
+ )
165
+ args = parser.parse_args()
166
+
167
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
168
+
169
+
170
+ def main(_):
171
+
172
+ set_path(args, args.experiment_name)
173
+
174
+ tfconfig = tf.compat.v1.ConfigProto(allow_soft_placement=True)
175
+ tfconfig.gpu_options.allow_growth = True
176
+ with tf.compat.v1.Session(config=tfconfig) as sess:
177
+ model = AUGAN(sess, args)
178
+ # show_all_variables()
179
+ # model.train(args) if args.phase == 'train' \
180
+ # else model.test(args)
181
+
182
+ if args.phase == "train":
183
+ model.train(args)
184
+ elif args.phase == "test":
185
+ model.test(args)
186
+ elif args.phase == "convert":
187
+ model.convert_image(args, "inf_data/b1ca2e5d-84cf9134.jpg", "out")
188
+ else:
189
+ raise Exception("Give a phase")
190
+
191
+
192
+ if __name__ == "__main__":
193
+ tf.compat.v1.app.run()
models.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from utils import *
3
+ from ops import *
4
+ import time
5
+ from glob import glob
6
+
7
+
8
+ def gaussian_noise_layer(input_layer, std):
9
+ noise = tf.random.normal(
10
+ shape=tf.shape(input_layer), mean=0.0, stddev=std, dtype=tf.float32
11
+ )
12
+ return input_layer + noise
13
+
14
+
15
+ def generator_resnet(image, options, transfer=False, reuse=False, name="generator"):
16
+ with tf.compat.v1.variable_scope(name):
17
+ if reuse:
18
+ tf.compat.v1.get_variable_scope().reuse_variables()
19
+ else:
20
+ assert tf.compat.v1.get_variable_scope().reuse is False
21
+
22
+ def residule_block_dilated(x, dim, ks=3, s=1, name="res", down=False):
23
+ if down:
24
+ dim = dim * 2
25
+ y = instance_norm(
26
+ dilated_conv2d(x, dim, ks, s, padding="SAME", name=name + "_c1"),
27
+ name + "_bn1",
28
+ )
29
+ y = tf.nn.relu(y)
30
+ y = instance_norm(
31
+ dilated_conv2d(y, dim, ks, s, padding="SAME", name=name + "_c2"),
32
+ name + "_bn2",
33
+ )
34
+ out = y + x
35
+ if down:
36
+ out = tf.nn.relu(
37
+ instance_norm(
38
+ conv2d(out, dim // 2, 3, 1, name=name + "_down_c"),
39
+ name + "_in_down",
40
+ )
41
+ )
42
+ return out
43
+
44
+ def residual_block(x_init, dim, ks=3, s=1, name="resblock", down=False):
45
+ with tf.compat.v1.variable_scope(name):
46
+ if down:
47
+ dim = dim * 2
48
+
49
+ with tf.compat.v1.variable_scope("res1"):
50
+ x = instance_norm(
51
+ conv2d(x_init, dim, ks, s, padding="SAME", name=name + "_c1"),
52
+ name + "_in1",
53
+ )
54
+ x = tf.nn.relu(x)
55
+
56
+ with tf.compat.v1.variable_scope("res2"):
57
+
58
+ x = instance_norm(
59
+ conv2d(x, dim, ks, s, padding="SAME", name=name + "_c2"),
60
+ name + "_in2",
61
+ )
62
+
63
+ out = x + x_init
64
+
65
+ if down:
66
+ out = tf.nn.relu(
67
+ instance_norm(
68
+ conv2d(out, dim // 2, 3, 1, name=name + "_down_c"),
69
+ name + "_in_down",
70
+ )
71
+ )
72
+ return out
73
+
74
+ ### Encoder architecture
75
+ c0 = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
76
+ c1 = tf.nn.relu(
77
+ instance_norm(
78
+ conv2d(c0, options.gf_dim, 7, 1, padding="VALID", name="g_e1_c"),
79
+ "g_e1_bn",
80
+ )
81
+ )
82
+ c2 = tf.nn.relu(
83
+ instance_norm(
84
+ conv2d(c1, options.gf_dim * 2, 3, 2, name="g_e2_c"), "g_e2_bn"
85
+ )
86
+ )
87
+ c3 = tf.nn.relu(
88
+ instance_norm(
89
+ conv2d(c2, options.gf_dim * 4, 3, 2, name="g_e3_c"), "g_e3_bn"
90
+ )
91
+ )
92
+ r1 = residule_block_dilated(c3, options.gf_dim * 4, name="g_r1")
93
+ r2 = residule_block_dilated(r1, options.gf_dim * 4, name="g_r2")
94
+ r3 = residule_block_dilated(r2, options.gf_dim * 4, name="g_r3")
95
+ r4 = residule_block_dilated(r3, options.gf_dim * 4, name="g_r4")
96
+ # r5 = residule_block_dilated(r4, options.gf_dim * 4, name='g_r5')
97
+
98
+ if transfer:
99
+ t1 = residual_block(r4, options.gf_dim * 4, name="g_t1")
100
+ t2 = residual_block(t1, options.gf_dim * 4, name="g_t2")
101
+ t3 = residual_block(t2, options.gf_dim * 4, name="g_t3")
102
+ t4 = residual_block(t3, options.gf_dim * 4, name="g_t4")
103
+ # feature = tf.concat([r4, t4], axis=3, name='g_concat')
104
+ # down = True
105
+ feature = t4
106
+ else:
107
+ feature = r4
108
+ t4 = None
109
+ down = False
110
+
111
+ ### translation decoder architecture
112
+ r6 = residule_block_dilated(feature, options.gf_dim * 4, name="g_r6")
113
+ r7 = residule_block_dilated(r6, options.gf_dim * 4, name="g_r7")
114
+ r8 = residule_block_dilated(r7, options.gf_dim * 4, name="g_r8")
115
+ r9 = residule_block_dilated(r8, options.gf_dim * 4, name="g_r9")
116
+ d1 = deconv2d(r9, options.gf_dim * 2, 3, 2, name="g_d1_dc")
117
+ d1 = tf.nn.relu(instance_norm(d1, "g_d1_bn"))
118
+ d2 = deconv2d(d1, options.gf_dim, 3, 2, name="g_d2_dc")
119
+ d2 = tf.nn.relu(instance_norm(d2, "g_d2_bn"))
120
+ d2 = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
121
+ pred = tf.nn.tanh(
122
+ conv2d(d2, options.output_c_dim, 7, 1, padding="VALID", name="g_pred_c")
123
+ )
124
+
125
+ ### reconstruction decoder architecture
126
+ r5 = gaussian_noise_layer(r4, 0.02)
127
+ r6_rec = residule_block_dilated(r5, options.gf_dim * 4, name="g_r6_rec")
128
+ r6_rec = gaussian_noise_layer(r6_rec, 0.02)
129
+ r7_rec = residule_block_dilated(r6_rec, options.gf_dim * 4, name="g_r7_rec")
130
+ r8_rec = residule_block_dilated(r7_rec, options.gf_dim * 4, name="g_r8_rec")
131
+ r9_rec = residule_block_dilated(r8_rec, options.gf_dim * 4, name="g_r9_rec")
132
+ d1_rec = deconv2d(r9_rec, options.gf_dim * 2, 3, 2, name="g_d1_dc_rec")
133
+ d1_rec = tf.nn.relu(instance_norm(d1_rec, "g_d1_bn_rec"))
134
+ d2_rec = deconv2d(d1_rec, options.gf_dim, 3, 2, name="g_d2_dc_rec")
135
+ d2_rec = tf.nn.relu(instance_norm(d2_rec, "g_d2_bn_rec"))
136
+ d2_rec = tf.pad(d2_rec, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
137
+ pred_rec = tf.nn.tanh(
138
+ conv2d(
139
+ d2_rec, options.output_c_dim, 7, 1, padding="VALID", name="g_pred_c_rec"
140
+ )
141
+ )
142
+
143
+ ## confidence prediction
144
+
145
+ if transfer:
146
+
147
+ d_conf = deconv2d(d1, options.gf_dim, 3, 2, name="g_d_dc_conf")
148
+ d_conf = tf.nn.relu(instance_norm(d_conf, "g_d_bn_conf"))
149
+ d_conf = tf.pad(d_conf, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
150
+ pred_conf = tf.nn.softplus(
151
+ conv2d(d_conf, 1, 7, 1, padding="VALID", name="g_pred_c_conf")
152
+ )
153
+
154
+ else:
155
+ pred_conf = None
156
+
157
+ return pred, pred_rec, r4, t4, pred_conf
158
+
159
+
160
+ def discriminator(image, options, n_scale=2, reuse=False, name="discriminator"):
161
+ images = []
162
+ for i in range(n_scale):
163
+ images.append(
164
+ tf.compat.v1.image.resize_bicubic(
165
+ image, [get_shape(image)[1] // (2**i), get_shape(image)[2] // (2**i)]
166
+ )
167
+ )
168
+ with tf.compat.v1.variable_scope(name):
169
+ if reuse:
170
+ tf.compat.v1.get_variable_scope().reuse_variables()
171
+ else:
172
+ assert tf.compat.v1.get_variable_scope().reuse is False
173
+ images = dis_down(images, 4, 2, n_scale, options.df_dim, "d_h0_conv_scale_")
174
+ images = dis_down(images, 4, 2, n_scale, options.df_dim * 2, "d_h1_conv_scale_")
175
+ images = dis_down(images, 4, 2, n_scale, options.df_dim * 4, "d_h2_conv_scale_")
176
+ images = dis_down(images, 4, 2, n_scale, options.df_dim * 8, "d_h3_conv_scale_")
177
+ images = final_conv(images, n_scale, "d_pred_scale_")
178
+ return images
ops.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ # import tensorflow.contrib.slim as slim
4
+ import tf_slim as slim
5
+ import math
6
+ import pprint
7
+
8
+ pp = pprint.PrettyPrinter()
9
+ get_stddev = lambda x, k_h, k_w: 1 / math.sqrt(k_w * k_h * x.get_shape()[-1])
10
+ # import tensorflow.contrib as tf_contrib
11
+
12
+ # weight_init = tf_contrib.layers.xavier_initializer()
13
+ weight_init = tf.initializers.GlorotUniform()
14
+ weight_regularizer = None
15
+
16
+
17
+ def batch_norm(x, name="batch_norm"):
18
+ # return tf.contrib.layers.batch_norm(
19
+ # x, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope=name
20
+ # )
21
+ return tf.keras.layers.BatchNormalization(
22
+ momentum=0.9, epsilon=1e-5, scale=True, name=name
23
+ )(x)
24
+
25
+
26
+ def instance_norm(input, name="instance_norm"):
27
+ with tf.compat.v1.variable_scope(name):
28
+ depth = input.get_shape()[3]
29
+ scale = tf.compat.v1.get_variable(
30
+ "scale",
31
+ [depth],
32
+ initializer=tf.keras.initializers.RandomNormal(
33
+ mean=1.0, stddev=0.02, seed=None
34
+ ),
35
+ )
36
+ offset = tf.compat.v1.get_variable(
37
+ "offset", [depth], initializer=tf.constant_initializer(0.0)
38
+ )
39
+ mean, variance = tf.nn.moments(input, axes=[1, 2], keepdims=True)
40
+ epsilon = 1e-5
41
+ inv = tf.math.rsqrt(variance + epsilon)
42
+ normalized = (input - mean) * inv
43
+ return scale * normalized + offset
44
+
45
+
46
+ def conv2d(input_, output_dim, ks=4, s=2, stddev=0.02, padding="SAME", name="conv2d"):
47
+ with tf.compat.v1.variable_scope(name):
48
+ return slim.conv2d(
49
+ input_,
50
+ output_dim,
51
+ ks,
52
+ s,
53
+ padding=padding,
54
+ activation_fn=None,
55
+ weights_initializer=tf.keras.initializers.TruncatedNormal(stddev=stddev),
56
+ biases_initializer=None,
57
+ )
58
+
59
+
60
+ def deconv2d(input_, output_dim, ks=4, s=2, stddev=0.02, name="deconv2d"):
61
+ with tf.compat.v1.variable_scope(name):
62
+ return slim.conv2d_transpose(
63
+ input_,
64
+ output_dim,
65
+ ks,
66
+ s,
67
+ padding="SAME",
68
+ activation_fn=None,
69
+ weights_initializer=tf.keras.initializers.TruncatedNormal(stddev=stddev),
70
+ biases_initializer=None,
71
+ )
72
+
73
+
74
+ def dilated_conv2d(
75
+ input_, output_dim, ks=3, s=2, stddev=0.02, padding="SAME", name="conv2d"
76
+ ):
77
+ with tf.compat.v1.variable_scope(name):
78
+ batch, in_height, in_width, in_channels = [int(d) for d in input_.get_shape()]
79
+ filter = tf.compat.v1.get_variable(
80
+ "filter",
81
+ [ks, ks, in_channels, output_dim],
82
+ dtype=tf.float32,
83
+ initializer=tf.random_normal_initializer(0, stddev),
84
+ )
85
+ conv = tf.nn.atrous_conv2d(input_, filter, rate=s, padding=padding, name=name)
86
+
87
+ return conv
88
+
89
+
90
+ def one_step(x, ch, kernel, stride, name):
91
+ return lrelu(
92
+ instance_norm(
93
+ conv2d(x, ch, kernel, stride, name=name + "_first_c"), name + "_first_bn"
94
+ )
95
+ )
96
+
97
+
98
+ def one_step_dilated(x, ch, kernel, stride, name):
99
+ return lrelu(
100
+ instance_norm(
101
+ dilated_conv2d(x, ch, kernel, stride, name=name + "_first_c"),
102
+ name + "_first_bn",
103
+ )
104
+ )
105
+
106
+
107
+ def num_steps(x, ch, kernel, stride, num_steps, name):
108
+ for i in range(num_steps):
109
+ x = lrelu(
110
+ instance_norm(
111
+ conv2d(x, ch, kernel, stride, name=name + "_c_" + str(i)),
112
+ name + "_bn_" + str(i),
113
+ )
114
+ )
115
+ return x
116
+
117
+
118
+ def one_step_noins(x, ch, kernel, stride, name):
119
+ return lrelu(conv2d(x, ch, kernel, stride, name=name + "_first_c"))
120
+
121
+
122
+ def num_steps_noins(x, ch, kernel, stride, num_steps, name):
123
+
124
+ for i in range(num_steps):
125
+ x = lrelu(conv2d(x, ch, kernel, stride, name=name + "_c_" + str(i)))
126
+ return x
127
+
128
+
129
+ def dis_down(images, kernel_size, stride, n_scale, ch, name):
130
+ backpack = images[0]
131
+ for i in range(n_scale):
132
+ if i == n_scale - 1:
133
+ images[i] = num_steps(
134
+ backpack, ch, kernel_size, stride, n_scale, name + str(i)
135
+ )
136
+ else:
137
+ images[i] = one_step_dilated(
138
+ images[i + 1], ch, kernel_size, 1, name + str(i)
139
+ )
140
+ return images
141
+
142
+
143
+ def dis_down_noins(images, kernel_size, stride, n_scale, ch, name):
144
+ backpack = images[0]
145
+ for i in range(n_scale):
146
+ if i == n_scale - 1:
147
+ images[i] = num_steps_noins(
148
+ backpack, ch, kernel_size, stride, n_scale, name + str(i)
149
+ )
150
+ else:
151
+ images[i] = one_step_noins(images[i + 1], ch, kernel_size, 1, name + str(i))
152
+ return images
153
+
154
+
155
+ def final_conv(images, n_scale, name):
156
+ for i in range(n_scale):
157
+ images[i] = conv2d(images[i], 1, s=1, name=name + str(i))
158
+ return images
159
+
160
+
161
+ def lrelu(x, leak=0.2, name="lrelu"):
162
+ return tf.maximum(x, leak * x)
163
+
164
+
165
+ def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
166
+ with tf.compat.v1.variable_scope(scope or "Linear"):
167
+ matrix = tf.get_variable(
168
+ "Matrix",
169
+ [input_.get_shape()[-1], output_size],
170
+ tf.float32,
171
+ tf.random_normal_initializer(stddev=stddev),
172
+ )
173
+ bias = tf.get_variable(
174
+ "bias", [output_size], initializer=tf.constant_initializer(bias_start)
175
+ )
176
+ if with_w:
177
+ return tf.matmul(input_, matrix) + bias, matrix, bias
178
+ else:
179
+ return tf.matmul(input_, matrix) + bias
180
+
181
+
182
+ def get_ones_like(logit):
183
+ target = []
184
+ for i in range(len(logit)):
185
+ target.append(tf.ones_like(logit[i]))
186
+ return target
187
+
188
+
189
+ def get_zeros_like(logit):
190
+ target = []
191
+ for i in range(len(logit)):
192
+ target.append(tf.zeros_like(logit[i]))
193
+ return target
194
+
195
+
196
+ def conv(
197
+ x,
198
+ channels,
199
+ kernel=4,
200
+ stride=2,
201
+ pad=0,
202
+ pad_type="zero",
203
+ use_bias=True,
204
+ scope="conv_0",
205
+ ):
206
+ with tf.compat.v1.variable_scope(scope):
207
+ if pad_type == "zero":
208
+ x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
209
+ if pad_type == "reflect":
210
+ x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode="REFLECT")
211
+
212
+ x = tf.layers.conv2d(
213
+ inputs=x,
214
+ filters=channels,
215
+ kernel_size=kernel,
216
+ kernel_initializer=weight_init,
217
+ kernel_regularizer=weight_regularizer,
218
+ strides=stride,
219
+ use_bias=use_bias,
220
+ )
221
+
222
+ return x
223
+
224
+
225
+ def reduce_sum(input_tensor, axis=None, keepdims=False):
226
+ try:
227
+ return tf.reduce_sum(input_tensor, axis=axis, keepdims=keepdims)
228
+ except:
229
+ return tf.reduce_sum(input_tensor, axis=axis, keep_dims=keepdims)
230
+
231
+
232
+ def get_shape(inputs, name=None):
233
+ name = "shape" if name is None else name
234
+ with tf.name_scope(name):
235
+ static_shape = inputs.get_shape().as_list()
236
+ dynamic_shape = tf.shape(inputs)
237
+ shape = []
238
+ for i, dim in enumerate(static_shape):
239
+ dim = dim if dim is not None else dynamic_shape[i]
240
+ shape.append(dim)
241
+ return shape
242
+
243
+
244
+ def show_all_variables():
245
+ model_vars = tf.trainable_variables()
246
+ slim.model_analyzer.analyze_vars(model_vars, print_info=True)
parser.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import shutil
3
+
4
+
5
+ with open('C:/jg/github_code/ForkGAN/bdd100k/labels/bdd100k_labels_images_train.json') as json_file:
6
+ json_data = json.load(json_file)
7
+
8
+ for item in json_data:
9
+ item_path = 'C:/jg/github_code/ForkGAN/bdd100k/images/100k/train/'+ item['name']
10
+ print(item['name'])
11
+ if item['attributes']['timeofday'] == 'daytime':
12
+ shutil.copy(item_path, 'C:/jg/github_code/ForkGAN/bdd100k/images/daytime/'+item['name'])
13
+
14
+ elif item['attributes']['timeofday'] == 'night':
15
+ shutil.copy(item_path, 'C:/jg/github_code/ForkGAN/bdd100k/images/night/'+item['name'])
16
+
17
+ else :
18
+ shutil.copy(item_path, 'C:/jg/github_code/ForkGAN/bdd100k/images/else/' + item['name'])
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pillow==6.0.0
2
+ scipy==1.1.0
3
+ numpy
4
+ matplotlib
utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import scipy.misc
2
+ from PIL import Image
3
+ import numpy as np
4
+ import copy
5
+ import os
6
+
7
+
8
+ class ImagePool(object):
9
+ def __init__(self, maxsize=50):
10
+ self.maxsize = maxsize
11
+ self.num_img = 0
12
+ self.images = []
13
+
14
+ def __call__(self, image):
15
+ if self.maxsize <= 0:
16
+ return image
17
+ if self.num_img < self.maxsize:
18
+ self.images.append(image)
19
+ self.num_img += 1
20
+ return image
21
+ if np.random.rand() > 0.5:
22
+ idx = int(np.random.rand() * self.maxsize)
23
+ tmp1 = copy.copy(self.images[idx])[0]
24
+ self.images[idx][0] = image[0]
25
+ idx = int(np.random.rand() * self.maxsize)
26
+ tmp2 = copy.copy(self.images[idx])[1]
27
+ self.images[idx][1] = image[1]
28
+ return [tmp1, tmp2]
29
+ else:
30
+ return image
31
+
32
+
33
+ def load_test_data(image_path, fine_size=256):
34
+ img = Image.open(image_path)
35
+ img = img.resize((fine_size * 2, fine_size))
36
+ img = np.array(img)
37
+ # Normalize image to the range [-1, 1]
38
+ img = img / 127.5 - 1
39
+
40
+ return img
41
+
42
+
43
+ def check_folder(path):
44
+ if not os.path.exists(path):
45
+ os.mkdir(path)
46
+
47
+
48
+ def load_train_data(image_path, load_size=286, fine_size=256, is_testing=False):
49
+ img_A = Image.open(image_path[0])
50
+ img_B = Image.open(image_path[1])
51
+
52
+ if not is_testing:
53
+ # Resize images using PIL
54
+ img_A = img_A.resize((load_size * 2, load_size))
55
+ img_B = img_B.resize((load_size * 2, load_size))
56
+
57
+ # Random crop
58
+ h1 = int(np.ceil(np.random.uniform(1e-2, load_size - fine_size)))
59
+ w1 = int(np.ceil(np.random.uniform(1e-2, (load_size - fine_size) * 2)))
60
+ img_A = np.array(img_A.crop((w1, h1, w1 + fine_size * 2, h1 + fine_size)))
61
+ img_B = np.array(img_B.crop((w1, h1, w1 + fine_size * 2, h1 + fine_size)))
62
+
63
+ # Random horizontal flip
64
+ if np.random.random() > 0.5:
65
+ img_A = np.fliplr(img_A)
66
+ img_B = np.fliplr(img_B)
67
+ else:
68
+ # Resize images using PIL for testing
69
+ img_A = img_A.resize((fine_size * 2, fine_size))
70
+ img_B = img_B.resize((fine_size * 2, fine_size))
71
+
72
+ # Normalize images to the range [-1, 1]
73
+ img_A = img_A / 127.5 - 1.0
74
+ img_B = img_B / 127.5 - 1.0
75
+
76
+ # Concatenate images along the channel axis
77
+ img_AB = np.concatenate((img_A, img_B), axis=2)
78
+
79
+ return img_AB
80
+
81
+
82
+ # -----------------------------
83
+
84
+
85
+ def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale=False):
86
+ return transform(
87
+ load_image(image_path, is_grayscale), image_size, is_crop, resize_w
88
+ )
89
+
90
+
91
+ def save_images(images, size, image_path):
92
+ return imsave(images, size, image_path)
93
+
94
+
95
+ def load_image(path, is_grayscale=False):
96
+ if is_grayscale:
97
+ return np.array(Image.open(path).convert("L")).astype(np.float)
98
+ else:
99
+ return np.array(Image.open(path).convert("RGB")).astype(np.float)
100
+
101
+
102
+ def merge_images(images, size):
103
+ return inverse_transform(images)
104
+
105
+
106
+ def merge(images, size):
107
+ h, w = images.shape[1], images.shape[2]
108
+ img = np.zeros((h * size[0], w * size[1], 3))
109
+ for idx, image in enumerate(images):
110
+ i = idx % size[1]
111
+ j = idx // size[1]
112
+ img[j * h : j * h + h, i * w : i * w + w, :] = image
113
+
114
+ return img
115
+
116
+
117
+ def imsave(image, size, path):
118
+ # Convert images to uint8 format and adjust the range
119
+ image = ((image + 1.0) * 127.5).astype(np.uint8)
120
+
121
+ # Merge images
122
+ # merged_image = merge(images, size).astype(np.uint8)
123
+
124
+ # Create a PIL Image from the numpy array
125
+ pil_image = Image.fromarray(image)
126
+
127
+ # Save the image using PIL
128
+ pil_image.save(path)
129
+
130
+ return None
131
+
132
+
133
+ def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64):
134
+ if crop_w is None:
135
+ crop_w = crop_h
136
+ h, w = x.shape[:2]
137
+ j = int(round((h - crop_h) / 2.0))
138
+ i = int(round((w - crop_w) / 2.0))
139
+
140
+ # Use PIL for resizing
141
+ cropped_image = Image.fromarray(x[j : j + crop_h, i : i + crop_w].astype(np.uint8))
142
+ cropped_image = cropped_image.resize((resize_w, resize_h))
143
+
144
+ return np.array(cropped_image) / 127.5 - 1.0
145
+
146
+
147
+ def transform(image, npx=64, is_crop=True, resize_w=64):
148
+ # npx: # of pixels width/height of image
149
+ if is_crop:
150
+ cropped_image = center_crop(image, npx, resize_w=resize_w)
151
+ else:
152
+ cropped_image = image
153
+ return np.array(cropped_image) / 127.5 - 1.0
154
+
155
+
156
+ def inverse_transform(images):
157
+ return (images + 1.0) / 2.0
158
+
159
+
160
+ def norm_img(img):
161
+ img = img / np.linalg.norm(img)
162
+ img = (img * 2.0) - 1.0
163
+
164
+ return img
165
+
166
+
167
+ def set_path(args, experiment_name):
168
+ args.checkpoint_dir = f"./check/{experiment_name}"
169
+ args.sample_dir = f"./check/{experiment_name}/sample"
170
+ if args.which_direction == "AtoB":
171
+ args.test_dir = f"./check/{experiment_name}/testa2b"
172
+ else:
173
+ args.test_dir = f"./check/{experiment_name}/testb2a"
174
+ args.conf_dir = f"./check/{experiment_name}/conf"
175
+ if not os.path.exists(args.checkpoint_dir):
176
+ os.makedirs(args.checkpoint_dir)
177
+ if not os.path.exists(args.sample_dir):
178
+ os.makedirs(args.sample_dir)
179
+ if not os.path.exists(args.test_dir):
180
+ os.makedirs(args.test_dir)
181
+ if not os.path.exists(args.conf_dir):
182
+ os.makedirs(args.conf_dir)