Upload 28 files
Browse files- .gitattributes +8 -0
- AUGAN.py +738 -0
- README.md +125 -0
- __pycache__/AUGAN.cpython-36.pyc +0 -0
- __pycache__/loss_utils.cpython-36.pyc +0 -0
- __pycache__/models.cpython-36.pyc +0 -0
- __pycache__/ops.cpython-36.pyc +0 -0
- __pycache__/utils.cpython-36.pyc +0 -0
- assets/augan_alderley.png +0 -0
- assets/augan_bdd.png +0 -0
- assets/augan_model.png +0 -0
- assets/augan_result.png +3 -0
- assets/augan_uncer.png +3 -0
- cc.sh +7 -0
- check.zip +3 -0
- datasets/swim/testA/GP010594_frame_000017_rgb_anon.png +3 -0
- datasets/swim/testA/GP010594_frame_000021_rgb_anon.png +3 -0
- datasets/swim/testA/GP010594_frame_000087_rgb_anon.png +3 -0
- datasets/swim/testB/GOPR0351_frame_000159_rgb_ref_anon.png +3 -0
- datasets/swim/testB/GOPR0351_frame_000161_rgb_ref_anon.png +3 -0
- datasets/swim/testB/GOPR0355_frame_000138_rgb_ref_anon.png +3 -0
- inference.py +0 -0
- loss_utils.py +51 -0
- main.py +193 -0
- models.py +178 -0
- ops.py +246 -0
- parser.py +18 -0
- requirements.txt +4 -0
- utils.py +182 -0
.gitattributes
CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/augan_result.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/augan_uncer.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
datasets/swim/testA/GP010594_frame_000017_rgb_anon.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
datasets/swim/testA/GP010594_frame_000021_rgb_anon.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
datasets/swim/testA/GP010594_frame_000087_rgb_anon.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
datasets/swim/testB/GOPR0351_frame_000159_rgb_ref_anon.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
datasets/swim/testB/GOPR0351_frame_000161_rgb_ref_anon.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
datasets/swim/testB/GOPR0355_frame_000138_rgb_ref_anon.png filter=lfs diff=lfs merge=lfs -text
|
AUGAN.py
ADDED
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
from models import generator_resnet, discriminator
|
3 |
+
from utils import *
|
4 |
+
from loss_utils import *
|
5 |
+
from ops import *
|
6 |
+
import time
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from glob import glob
|
9 |
+
|
10 |
+
|
11 |
+
class AUGAN(object):
|
12 |
+
def __init__(self, sess, args):
|
13 |
+
self.sess = sess
|
14 |
+
self.batch_size = args.batch_size
|
15 |
+
self.image_size = args.fine_size
|
16 |
+
self.input_c_dim = args.input_nc
|
17 |
+
self.output_c_dim = args.output_nc
|
18 |
+
self.L1_lambda = args.L1_lambda
|
19 |
+
self.conf_lambda = args.conf_lambda
|
20 |
+
self.dataset_dir = args.dataset_dir
|
21 |
+
self.n_d = args.n_d
|
22 |
+
self.n_scale = args.n_scale
|
23 |
+
self.ndf = args.ndf
|
24 |
+
self.load_size = args.load_size
|
25 |
+
self.fine_size = args.fine_size
|
26 |
+
self.generator = generator_resnet
|
27 |
+
self.discriminator = discriminator
|
28 |
+
if args.use_lsgan:
|
29 |
+
self.criterionGAN = mae_criterion
|
30 |
+
self.criterionGAN_list = mae_criterion_list
|
31 |
+
else:
|
32 |
+
self.criterionGAN = sce_criterion
|
33 |
+
self.criterionGAN_list = sce_criterion_list
|
34 |
+
|
35 |
+
self.use_uncertainty = args.use_uncertainty
|
36 |
+
|
37 |
+
OPTIONS = namedtuple(
|
38 |
+
"OPTIONS",
|
39 |
+
"batch_size image_size \
|
40 |
+
gf_dim df_dim output_c_dim is_training",
|
41 |
+
)
|
42 |
+
self.options = OPTIONS._make(
|
43 |
+
(
|
44 |
+
args.batch_size,
|
45 |
+
args.fine_size,
|
46 |
+
args.ngf,
|
47 |
+
args.ndf // args.n_d,
|
48 |
+
args.output_nc,
|
49 |
+
args.phase == "train",
|
50 |
+
)
|
51 |
+
)
|
52 |
+
self.save_conf = args.save_conf
|
53 |
+
self._build_model()
|
54 |
+
self.saver = tf.compat.v1.train.Saver()
|
55 |
+
self.pool = ImagePool(args.max_size)
|
56 |
+
|
57 |
+
def _build_model(self):
|
58 |
+
self.real_data = tf.compat.v1.placeholder(
|
59 |
+
tf.float32,
|
60 |
+
[
|
61 |
+
self.batch_size,
|
62 |
+
self.image_size,
|
63 |
+
self.image_size * 2,
|
64 |
+
self.input_c_dim + self.output_c_dim,
|
65 |
+
],
|
66 |
+
name="real_A_and_B_images",
|
67 |
+
)
|
68 |
+
|
69 |
+
self.real_A = self.real_data[:, :, :, : self.input_c_dim]
|
70 |
+
self.real_B = self.real_data[
|
71 |
+
:, :, :, self.input_c_dim : self.input_c_dim + self.output_c_dim
|
72 |
+
]
|
73 |
+
|
74 |
+
A_label = np.zeros([1, 1, 1, 2], dtype=np.float32)
|
75 |
+
B_label = np.zeros([1, 1, 1, 2], dtype=np.float32)
|
76 |
+
A_label[:, :, :, 0] = 1.0
|
77 |
+
B_label[:, :, :, 1] = 1.0
|
78 |
+
self.A_label = tf.convert_to_tensor(A_label)
|
79 |
+
self.B_label = tf.convert_to_tensor(B_label)
|
80 |
+
|
81 |
+
(
|
82 |
+
self.fake_B,
|
83 |
+
self.rec_realA,
|
84 |
+
self.realA_percep,
|
85 |
+
self.transA_percep,
|
86 |
+
self.pred_confA,
|
87 |
+
) = self.generator(
|
88 |
+
self.real_A, self.options, transfer=True, reuse=False, name="generatorA2B"
|
89 |
+
)
|
90 |
+
self.fake_A_, self.rec_fakeB, self.fakeB_percep, _, _ = self.generator(
|
91 |
+
self.fake_B, self.options, transfer=False, reuse=False, name="generatorB2A"
|
92 |
+
)
|
93 |
+
self.fake_A, self.rec_realB, self.realB_percep, _, _ = self.generator(
|
94 |
+
self.real_B, self.options, transfer=False, reuse=True, name="generatorB2A"
|
95 |
+
)
|
96 |
+
self.fake_B_, self.rec_fakeA, self.fakeA_percep, self.trans_fakeA_percep, _ = (
|
97 |
+
self.generator(
|
98 |
+
self.fake_A,
|
99 |
+
self.options,
|
100 |
+
transfer=True,
|
101 |
+
reuse=True,
|
102 |
+
name="generatorA2B",
|
103 |
+
)
|
104 |
+
)
|
105 |
+
|
106 |
+
self.g_adv_total = 0.0
|
107 |
+
self.g_adv = 0.0
|
108 |
+
self.g_adv_rec = 0.0
|
109 |
+
self.g_adv_recfake = 0.0
|
110 |
+
|
111 |
+
self.percep_loss = tf.reduce_mean(
|
112 |
+
tf.abs(
|
113 |
+
tf.reduce_mean(self.transA_percep, axis=3)
|
114 |
+
- tf.reduce_mean(self.fakeB_percep, axis=3)
|
115 |
+
)
|
116 |
+
) + tf.reduce_mean(
|
117 |
+
tf.abs(
|
118 |
+
tf.reduce_mean(self.realB_percep, axis=3)
|
119 |
+
- tf.reduce_mean(self.fakeA_percep, axis=3)
|
120 |
+
)
|
121 |
+
)
|
122 |
+
|
123 |
+
for i in range(self.n_d):
|
124 |
+
self.DB_fake = self.discriminator(
|
125 |
+
self.fake_B, self.options, reuse=False, name=str(i) + "_discriminatorB"
|
126 |
+
)
|
127 |
+
self.DA_fake = self.discriminator(
|
128 |
+
self.fake_A, self.options, reuse=False, name=str(i) + "_discriminatorA"
|
129 |
+
)
|
130 |
+
|
131 |
+
self.g_adv_total += self.criterionGAN_list(
|
132 |
+
self.DA_fake, get_ones_like(self.DA_fake)
|
133 |
+
) + self.criterionGAN_list(self.DB_fake, get_ones_like(self.DB_fake))
|
134 |
+
|
135 |
+
self.g_adv += self.criterionGAN_list(
|
136 |
+
self.DA_fake, get_ones_like(self.DA_fake)
|
137 |
+
) + self.criterionGAN_list(self.DB_fake, get_ones_like(self.DB_fake))
|
138 |
+
|
139 |
+
self.g_loss_a2b = (
|
140 |
+
self.criterionGAN_list(self.DB_fake, get_ones_like(self.DB_fake))
|
141 |
+
+ self.L1_lambda * abs_criterion(self.real_A, self.fake_A_)
|
142 |
+
+ self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
|
143 |
+
)
|
144 |
+
self.g_loss_b2a = (
|
145 |
+
self.criterionGAN_list(self.DA_fake, get_ones_like(self.DA_fake))
|
146 |
+
+ self.L1_lambda * abs_criterion(self.real_A, self.fake_A_)
|
147 |
+
+ self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
|
148 |
+
)
|
149 |
+
|
150 |
+
self.g_A_recon_loss = self.L1_lambda * abs_criterion(
|
151 |
+
self.rec_realA, self.real_A
|
152 |
+
)
|
153 |
+
self.g_B_recon_loss = self.L1_lambda * abs_criterion(
|
154 |
+
self.rec_realB, self.real_B
|
155 |
+
)
|
156 |
+
if self.use_uncertainty:
|
157 |
+
self.g_A_cycle_loss = self.conf_lambda * conf_criterion(
|
158 |
+
self.real_A, self.fake_A_, self.pred_confA
|
159 |
+
)
|
160 |
+
else:
|
161 |
+
self.g_A_cycle_loss = self.L1_lambda * abs_criterion(
|
162 |
+
self.real_A, self.fake_A_
|
163 |
+
)
|
164 |
+
self.g_B_cylce_loss = self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
|
165 |
+
|
166 |
+
self.g_loss = (
|
167 |
+
self.g_adv_total
|
168 |
+
+ self.g_A_recon_loss
|
169 |
+
+ self.g_B_recon_loss
|
170 |
+
+ self.g_A_cycle_loss
|
171 |
+
+ self.g_B_cylce_loss
|
172 |
+
+ self.percep_loss
|
173 |
+
)
|
174 |
+
|
175 |
+
self.g_rec_real = abs_criterion(self.rec_realA, self.real_A) + abs_criterion(
|
176 |
+
self.rec_realB, self.real_B
|
177 |
+
)
|
178 |
+
self.g_rec_cycle = abs_criterion(self.real_A, self.fake_A_) + abs_criterion(
|
179 |
+
self.real_B, self.fake_B_
|
180 |
+
)
|
181 |
+
|
182 |
+
self.fake_A_sample = tf.compat.v1.placeholder(
|
183 |
+
tf.float32,
|
184 |
+
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
|
185 |
+
name="fake_A_sample",
|
186 |
+
)
|
187 |
+
self.fake_B_sample = tf.compat.v1.placeholder(
|
188 |
+
tf.float32,
|
189 |
+
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
|
190 |
+
name="fake_B_sample",
|
191 |
+
)
|
192 |
+
self.rec_A_sample = tf.compat.v1.placeholder(
|
193 |
+
tf.float32,
|
194 |
+
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
|
195 |
+
name="rec_A_sample",
|
196 |
+
)
|
197 |
+
self.rec_B_sample = tf.compat.v1.placeholder(
|
198 |
+
tf.float32,
|
199 |
+
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
|
200 |
+
name="rec_B_sample",
|
201 |
+
)
|
202 |
+
self.rec_fakeA_sample = tf.compat.v1.placeholder(
|
203 |
+
tf.float32,
|
204 |
+
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
|
205 |
+
name="rec_fakeA_sample",
|
206 |
+
)
|
207 |
+
self.rec_fakeB_sample = tf.compat.v1.placeholder(
|
208 |
+
tf.float32,
|
209 |
+
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
|
210 |
+
name="rec_fakeB_sample",
|
211 |
+
)
|
212 |
+
|
213 |
+
self.d_loss_item = []
|
214 |
+
self.d_loss_item_rec = []
|
215 |
+
self.d_loss_item_recfake = []
|
216 |
+
|
217 |
+
for i in range(self.n_d):
|
218 |
+
self.DB_real = self.discriminator(
|
219 |
+
self.real_B, self.options, reuse=True, name=str(i) + "_discriminatorB"
|
220 |
+
)
|
221 |
+
self.DA_real = self.discriminator(
|
222 |
+
self.real_A, self.options, reuse=True, name=str(i) + "_discriminatorA"
|
223 |
+
)
|
224 |
+
self.DB_fake_sample = self.discriminator(
|
225 |
+
self.fake_B_sample,
|
226 |
+
self.options,
|
227 |
+
reuse=True,
|
228 |
+
name=str(i) + "_discriminatorB",
|
229 |
+
)
|
230 |
+
self.DA_fake_sample = self.discriminator(
|
231 |
+
self.fake_A_sample,
|
232 |
+
self.options,
|
233 |
+
reuse=True,
|
234 |
+
name=str(i) + "_discriminatorA",
|
235 |
+
)
|
236 |
+
self.db_loss_real = self.criterionGAN_list(
|
237 |
+
self.DB_real, get_ones_like(self.DB_real)
|
238 |
+
)
|
239 |
+
self.db_loss_fake = self.criterionGAN_list(
|
240 |
+
self.DB_fake_sample, get_zeros_like(self.DB_fake_sample)
|
241 |
+
)
|
242 |
+
self.db_loss = self.db_loss_real * 0.5 + self.db_loss_fake * 0.5
|
243 |
+
self.da_loss_real = self.criterionGAN_list(
|
244 |
+
self.DA_real, get_ones_like(self.DA_real)
|
245 |
+
)
|
246 |
+
self.da_loss_fake = self.criterionGAN_list(
|
247 |
+
self.DA_fake_sample, get_zeros_like(self.DA_fake_sample)
|
248 |
+
)
|
249 |
+
self.da_loss = self.da_loss_real * 0.5 + self.da_loss_fake * 0.5
|
250 |
+
self.d_loss = self.da_loss + self.db_loss
|
251 |
+
self.d_loss_item.append(self.d_loss)
|
252 |
+
|
253 |
+
self.g_loss_a2b_sum = tf.compat.v1.summary.scalar("g_loss_a2b", self.g_loss_a2b)
|
254 |
+
self.g_loss_b2a_sum = tf.compat.v1.summary.scalar("g_loss_b2a", self.g_loss_b2a)
|
255 |
+
self.g_loss_sum = tf.compat.v1.summary.scalar("g_loss", self.g_loss)
|
256 |
+
self.g_sum = tf.compat.v1.summary.merge(
|
257 |
+
[self.g_loss_a2b_sum, self.g_loss_b2a_sum, self.g_loss_sum]
|
258 |
+
)
|
259 |
+
self.db_loss_sum = tf.compat.v1.summary.scalar("db_loss", self.db_loss)
|
260 |
+
self.da_loss_sum = tf.compat.v1.summary.scalar("da_loss", self.da_loss)
|
261 |
+
self.d_loss_sum = tf.compat.v1.summary.scalar("d_loss", self.d_loss)
|
262 |
+
self.db_loss_real_sum = tf.compat.v1.summary.scalar(
|
263 |
+
"db_loss_real", self.db_loss_real
|
264 |
+
)
|
265 |
+
self.db_loss_fake_sum = tf.compat.v1.summary.scalar(
|
266 |
+
"db_loss_fake", self.db_loss_fake
|
267 |
+
)
|
268 |
+
self.da_loss_real_sum = tf.compat.v1.summary.scalar(
|
269 |
+
"da_loss_real", self.da_loss_real
|
270 |
+
)
|
271 |
+
self.da_loss_fake_sum = tf.compat.v1.summary.scalar(
|
272 |
+
"da_loss_fake", self.da_loss_fake
|
273 |
+
)
|
274 |
+
self.d_sum = tf.compat.v1.summary.merge(
|
275 |
+
[
|
276 |
+
self.da_loss_sum,
|
277 |
+
self.da_loss_real_sum,
|
278 |
+
self.da_loss_fake_sum,
|
279 |
+
self.db_loss_sum,
|
280 |
+
self.db_loss_real_sum,
|
281 |
+
self.db_loss_fake_sum,
|
282 |
+
self.d_loss_sum,
|
283 |
+
]
|
284 |
+
)
|
285 |
+
|
286 |
+
self.test_A = tf.compat.v1.placeholder(
|
287 |
+
tf.float32,
|
288 |
+
[self.batch_size, self.image_size, self.image_size * 2, self.input_c_dim],
|
289 |
+
name="test_A",
|
290 |
+
)
|
291 |
+
self.test_B = tf.compat.v1.placeholder(
|
292 |
+
tf.float32,
|
293 |
+
[self.batch_size, self.image_size, self.image_size * 2, self.output_c_dim],
|
294 |
+
name="test_B",
|
295 |
+
)
|
296 |
+
|
297 |
+
(
|
298 |
+
self.testB,
|
299 |
+
self.rec_testA,
|
300 |
+
self.testA_percep,
|
301 |
+
self.trans_testA_percep,
|
302 |
+
self.test_pred_confA,
|
303 |
+
) = self.generator(
|
304 |
+
self.test_A, self.options, transfer=True, reuse=True, name="generatorA2B"
|
305 |
+
)
|
306 |
+
self.rec_cycle_A, self.refine_testB, self.testB_percep, _, _ = self.generator(
|
307 |
+
self.testB, self.options, transfer=False, reuse=True, name="generatorB2A"
|
308 |
+
)
|
309 |
+
|
310 |
+
self.testA, self.rec_testB, _, _, _ = self.generator(
|
311 |
+
self.test_B, self.options, transfer=False, reuse=True, name="generatorB2A"
|
312 |
+
)
|
313 |
+
self.rec_cycle_B, self.refine_testA, _, _, _ = self.generator(
|
314 |
+
self.testA, self.options, True, True, name="generatorA2B"
|
315 |
+
)
|
316 |
+
|
317 |
+
t_vars = tf.compat.v1.trainable_variables()
|
318 |
+
|
319 |
+
self.g_vars = [var for var in t_vars if "generator" in var.name]
|
320 |
+
self.p_vars = [var for var in t_vars if "percep" in var.name]
|
321 |
+
self.d_vars_item = []
|
322 |
+
for i in range(self.n_d):
|
323 |
+
self.d_vars = [
|
324 |
+
var for var in t_vars if str(i) + "_discriminator" in var.name
|
325 |
+
]
|
326 |
+
self.d_vars_item.append(self.d_vars)
|
327 |
+
|
328 |
+
def train(self, args):
|
329 |
+
|
330 |
+
self.lr = tf.compat.v1.placeholder(tf.float32, None, name="learning_rate")
|
331 |
+
|
332 |
+
### generator
|
333 |
+
self.g_optim = tf.optimizers.Adam(
|
334 |
+
learning_rate=self.lr, beta_1=args.beta1
|
335 |
+
).minimize(self.g_loss, var_list=self.g_vars, tape=None)
|
336 |
+
|
337 |
+
### translation
|
338 |
+
self.d_optim_item = []
|
339 |
+
for i in range(self.n_d):
|
340 |
+
self.d_optim = tf.optimizers.Adam(
|
341 |
+
learning_rate=self.lr, beta_1=args.beta1
|
342 |
+
).minimize(self.g_loss, var_list=self.g_vars, tape=None)
|
343 |
+
self.d_optim_item.append(self.d_optim)
|
344 |
+
|
345 |
+
init_op = tf.compat.v1.global_variables_initializer()
|
346 |
+
self.sess.run(init_op)
|
347 |
+
self.writer = tf.summary.FileWriter(
|
348 |
+
os.path.join(args.checkpoint_dir, "logs"), self.sess.graph
|
349 |
+
)
|
350 |
+
|
351 |
+
counter = 1
|
352 |
+
start_time = time.time()
|
353 |
+
|
354 |
+
if args.continue_train:
|
355 |
+
if self.load(args.checkpoint_dir):
|
356 |
+
print(" [*] Load SUCCESS")
|
357 |
+
else:
|
358 |
+
print(" [!] Load failed...")
|
359 |
+
|
360 |
+
print("Training.........................")
|
361 |
+
for epoch in range(args.epoch):
|
362 |
+
dataA = glob("./datasets/{}/*.*".format(self.dataset_dir + "/trainA"))
|
363 |
+
dataB = glob("./datasets/{}/*.*".format(self.dataset_dir + "/trainB"))
|
364 |
+
if (len(dataA) == 0) or (len(dataB) == 0):
|
365 |
+
raise Exception("No files found in the dataset")
|
366 |
+
else:
|
367 |
+
print(
|
368 |
+
"Data found in the dataset. length of A: ",
|
369 |
+
len(dataA),
|
370 |
+
" B: ",
|
371 |
+
len(dataB),
|
372 |
+
)
|
373 |
+
np.random.shuffle(dataA)
|
374 |
+
np.random.shuffle(dataB)
|
375 |
+
batch_idxs = (
|
376 |
+
min(min(len(dataA), len(dataB)), args.train_size) // self.batch_size
|
377 |
+
)
|
378 |
+
lr = (
|
379 |
+
args.lr
|
380 |
+
if epoch < args.epoch_step
|
381 |
+
else args.lr * (args.epoch - epoch) / (args.epoch - args.epoch_step)
|
382 |
+
)
|
383 |
+
|
384 |
+
for idx in range(0, batch_idxs):
|
385 |
+
print("Epoch: [%2d] [%4d/%4d] " % (epoch, idx, batch_idxs))
|
386 |
+
batch_files = list(
|
387 |
+
zip(
|
388 |
+
dataA[idx * self.batch_size : (idx + 1) * self.batch_size],
|
389 |
+
dataB[idx * self.batch_size : (idx + 1) * self.batch_size],
|
390 |
+
)
|
391 |
+
)
|
392 |
+
batch_images = [
|
393 |
+
load_train_data(batch_file, args.load_size, args.fine_size)
|
394 |
+
for batch_file in batch_files
|
395 |
+
]
|
396 |
+
batch_images = np.array(batch_images).astype(np.float32)
|
397 |
+
# Update G network and record fake outputs
|
398 |
+
print("Training G network----------------------")
|
399 |
+
(
|
400 |
+
fake_A,
|
401 |
+
fake_B,
|
402 |
+
rec_A,
|
403 |
+
rec_B,
|
404 |
+
rec_fake_A,
|
405 |
+
rec_fake_B,
|
406 |
+
_,
|
407 |
+
g_loss,
|
408 |
+
gan_loss,
|
409 |
+
percep,
|
410 |
+
g_adv,
|
411 |
+
g_A_recon_loss,
|
412 |
+
g_B_recon_loss,
|
413 |
+
g_A_cycle_loss,
|
414 |
+
g_B_cycle_loss,
|
415 |
+
summary_str,
|
416 |
+
) = self.sess.run(
|
417 |
+
[
|
418 |
+
self.fake_A,
|
419 |
+
self.fake_B,
|
420 |
+
self.rec_realA,
|
421 |
+
self.rec_realB,
|
422 |
+
self.rec_fakeA,
|
423 |
+
self.rec_fakeB,
|
424 |
+
self.g_optim,
|
425 |
+
self.g_loss,
|
426 |
+
self.g_adv_total,
|
427 |
+
self.percep_loss,
|
428 |
+
self.g_adv,
|
429 |
+
self.g_A_recon_loss,
|
430 |
+
self.g_B_recon_loss,
|
431 |
+
self.g_A_cycle_loss,
|
432 |
+
self.g_B_cylce_loss,
|
433 |
+
self.g_sum,
|
434 |
+
],
|
435 |
+
feed_dict={self.real_data: batch_images, self.lr: lr},
|
436 |
+
)
|
437 |
+
self.writer.add_summary(summary_str, counter)
|
438 |
+
[fake_A, fake_B] = self.pool([fake_A, fake_B])
|
439 |
+
|
440 |
+
# Update D network
|
441 |
+
print("Training D network----------------------")
|
442 |
+
loss_print = []
|
443 |
+
for i in range(self.n_d):
|
444 |
+
_, d_loss, d_sum = self.sess.run(
|
445 |
+
[self.d_optim_item[i], self.d_loss_item[i], self.d_sum],
|
446 |
+
feed_dict={
|
447 |
+
self.real_data: batch_images,
|
448 |
+
self.fake_A_sample: fake_A,
|
449 |
+
self.fake_B_sample: fake_B,
|
450 |
+
self.lr: lr,
|
451 |
+
},
|
452 |
+
)
|
453 |
+
|
454 |
+
loss_print.append(d_loss)
|
455 |
+
|
456 |
+
counter += 1
|
457 |
+
print(
|
458 |
+
(
|
459 |
+
"Epoch: [%2d] [%4d/%4d] time: %4.4f g_loss: %4.4f gan:%4.4f adv:%4.4f g_percep:%4.4f "
|
460 |
+
% (
|
461 |
+
epoch,
|
462 |
+
idx,
|
463 |
+
batch_idxs,
|
464 |
+
time.time() - start_time,
|
465 |
+
g_loss,
|
466 |
+
gan_loss,
|
467 |
+
g_adv,
|
468 |
+
percep,
|
469 |
+
)
|
470 |
+
)
|
471 |
+
)
|
472 |
+
|
473 |
+
if np.mod(counter, args.print_freq) == 1:
|
474 |
+
self.sample_model(args.sample_dir, epoch, idx)
|
475 |
+
|
476 |
+
if np.mod(counter, args.save_freq) == 2:
|
477 |
+
self.save(args.checkpoint_dir, counter)
|
478 |
+
|
479 |
+
def save(self, checkpoint_dir, step):
|
480 |
+
model_name = "cyclegan.model"
|
481 |
+
model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
|
482 |
+
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
|
483 |
+
|
484 |
+
if not os.path.exists(checkpoint_dir):
|
485 |
+
os.makedirs(checkpoint_dir)
|
486 |
+
|
487 |
+
self.saver.save(
|
488 |
+
self.sess, os.path.join(checkpoint_dir, model_name), global_step=step
|
489 |
+
)
|
490 |
+
|
491 |
+
def load(self, checkpoint_dir):
|
492 |
+
print(" [*] Reading checkpoint...")
|
493 |
+
|
494 |
+
model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
|
495 |
+
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
|
496 |
+
|
497 |
+
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
|
498 |
+
if ckpt and ckpt.model_checkpoint_path:
|
499 |
+
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
|
500 |
+
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
|
501 |
+
return True
|
502 |
+
else:
|
503 |
+
return False
|
504 |
+
|
505 |
+
def sample_model(self, sample_dir, epoch, idx):
|
506 |
+
dataA = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testA"))
|
507 |
+
dataB = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testB"))
|
508 |
+
if (len(dataA) == 0) or (len(dataB) == 0):
|
509 |
+
raise Exception("No files found in the test directory")
|
510 |
+
np.random.shuffle(dataA)
|
511 |
+
np.random.shuffle(dataB)
|
512 |
+
batch_files = list(zip(dataA[: self.batch_size], dataB[: self.batch_size]))
|
513 |
+
sample_images = [
|
514 |
+
load_train_data(batch_file, self.load_size, self.fine_size, is_testing=True)
|
515 |
+
for batch_file in batch_files
|
516 |
+
]
|
517 |
+
sample_images = np.array(sample_images).astype(np.float32)
|
518 |
+
|
519 |
+
fake_A, fake_B = self.sess.run(
|
520 |
+
[self.fake_A, self.fake_B], feed_dict={self.real_data: sample_images}
|
521 |
+
)
|
522 |
+
real_A = sample_images[:, :, :, :3]
|
523 |
+
real_B = sample_images[:, :, :, 3:]
|
524 |
+
|
525 |
+
merge_A = np.concatenate([real_B, fake_A], axis=2)
|
526 |
+
merge_B = np.concatenate([real_A, fake_B], axis=2)
|
527 |
+
check_folder("./{}/{:02d}".format(sample_dir, epoch))
|
528 |
+
save_images(
|
529 |
+
merge_A,
|
530 |
+
[self.batch_size, 1],
|
531 |
+
"./{}/{:02d}/A_{:04d}.jpg".format(sample_dir, epoch, idx),
|
532 |
+
)
|
533 |
+
save_images(
|
534 |
+
merge_B,
|
535 |
+
[self.batch_size, 1],
|
536 |
+
"./{}/{:02d}/B_{:04d}.jpg".format(sample_dir, epoch, idx),
|
537 |
+
)
|
538 |
+
|
539 |
+
def test(self, args):
|
540 |
+
total_time = 0
|
541 |
+
|
542 |
+
init_op = tf.compat.v1.global_variables_initializer()
|
543 |
+
self.sess.run(init_op)
|
544 |
+
if args.which_direction == "AtoB":
|
545 |
+
sample_files = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testA"))
|
546 |
+
elif args.which_direction == "BtoA":
|
547 |
+
sample_files = glob("./datasets/{}/*.*".format(self.dataset_dir + "/testB"))
|
548 |
+
else:
|
549 |
+
raise Exception("--which_direction must be AtoB or BtoA")
|
550 |
+
|
551 |
+
if len(sample_files) == 0:
|
552 |
+
raise Exception("No files found in the test directory")
|
553 |
+
|
554 |
+
# print(sample_files)
|
555 |
+
|
556 |
+
if self.load(args.checkpoint_dir):
|
557 |
+
print(" [*] Load SUCCESS")
|
558 |
+
else:
|
559 |
+
print(" [!] Load failed...")
|
560 |
+
out_var, refine_var, in_var, rec_var, cycle_var, percep_var, conf_var = (
|
561 |
+
(
|
562 |
+
self.testB,
|
563 |
+
self.refine_testB,
|
564 |
+
self.test_A,
|
565 |
+
self.rec_testA,
|
566 |
+
self.rec_cycle_A,
|
567 |
+
self.testA_percep,
|
568 |
+
self.test_pred_confA,
|
569 |
+
)
|
570 |
+
if args.which_direction == "AtoB"
|
571 |
+
else (
|
572 |
+
self.testA,
|
573 |
+
self.refine_testA,
|
574 |
+
self.test_B,
|
575 |
+
self.rec_testB,
|
576 |
+
self.rec_cycle_B,
|
577 |
+
self.testB_percep,
|
578 |
+
self.test_pred_confA,
|
579 |
+
)
|
580 |
+
)
|
581 |
+
for sample_file in sample_files:
|
582 |
+
# print('Processing image: ' + sample_file)
|
583 |
+
sample_image = [load_test_data(sample_file, args.fine_size)]
|
584 |
+
start_time = time.time()
|
585 |
+
sample_image = np.array(sample_image).astype(np.float32)
|
586 |
+
image_path = os.path.join(
|
587 |
+
args.test_dir,
|
588 |
+
"{0}_{1}".format(args.which_direction, os.path.basename(sample_file)),
|
589 |
+
)
|
590 |
+
ori_path = os.path.join(
|
591 |
+
args.test_dir,
|
592 |
+
"{0}_{1}".format("ori", os.path.basename(sample_file)),
|
593 |
+
)
|
594 |
+
conf_path = os.path.join(
|
595 |
+
args.conf_dir,
|
596 |
+
"{0}_{1}".format(args.which_direction, os.path.basename(sample_file)),
|
597 |
+
)
|
598 |
+
|
599 |
+
(fake_img,) = self.sess.run([out_var], feed_dict={in_var: sample_image})
|
600 |
+
end_time = time.time()
|
601 |
+
# merge = np.concatenate([sample_image, fake_img], axis=2)
|
602 |
+
save_images(fake_img[0], [1], image_path)
|
603 |
+
save_images(sample_image[0], [1], ori_path)
|
604 |
+
# save_images(merge, [1, 1], image_path)
|
605 |
+
total_time = total_time + (end_time - start_time)
|
606 |
+
|
607 |
+
if args.save_conf:
|
608 |
+
|
609 |
+
if args.which_direction == "AtoB":
|
610 |
+
pass
|
611 |
+
else:
|
612 |
+
raise Exception(
|
613 |
+
"--conf map only can be estimated in AtoB direction"
|
614 |
+
)
|
615 |
+
|
616 |
+
conf_img = self.sess.run(conf_var, feed_dict={in_var: sample_image})
|
617 |
+
conf_img_sq = np.squeeze(conf_img)
|
618 |
+
plt.imshow(
|
619 |
+
conf_img_sq, cmap="plasma", interpolation="nearest", alpha=1.0
|
620 |
+
)
|
621 |
+
plt.savefig(conf_path)
|
622 |
+
print(
|
623 |
+
f"Average time taken to convert images: {total_time/len(sample_files)} seconds"
|
624 |
+
)
|
625 |
+
|
626 |
+
def convert(self, args, datadir="./inf_data"):
|
627 |
+
total_time = 0
|
628 |
+
|
629 |
+
init_op = tf.compat.v1.global_variables_initializer()
|
630 |
+
self.sess.run(init_op)
|
631 |
+
|
632 |
+
if self.load(args.checkpoint_dir):
|
633 |
+
print(" [*] Load SUCCESS")
|
634 |
+
else:
|
635 |
+
raise Exception("-- Cannot Load Model. Train or Add model first")
|
636 |
+
|
637 |
+
if args.which_direction == "AtoB":
|
638 |
+
sample_files = glob(datadir)
|
639 |
+
elif args.which_direction == "BtoA":
|
640 |
+
sample_files = glob(datadir)
|
641 |
+
else:
|
642 |
+
raise Exception("--which_direction must be AtoB or BtoA")
|
643 |
+
|
644 |
+
print(sample_files)
|
645 |
+
|
646 |
+
out_var, refine_var, in_var, rec_var, cycle_var, percep_var, conf_var = (
|
647 |
+
(
|
648 |
+
self.testB,
|
649 |
+
self.refine_testB,
|
650 |
+
self.test_A,
|
651 |
+
self.rec_testA,
|
652 |
+
self.rec_cycle_A,
|
653 |
+
self.testA_percep,
|
654 |
+
self.test_pred_confA,
|
655 |
+
)
|
656 |
+
if args.which_direction == "AtoB"
|
657 |
+
else (
|
658 |
+
self.testA,
|
659 |
+
self.refine_testA,
|
660 |
+
self.test_B,
|
661 |
+
self.rec_testB,
|
662 |
+
self.rec_cycle_B,
|
663 |
+
self.testB_percep,
|
664 |
+
self.test_pred_confA,
|
665 |
+
)
|
666 |
+
)
|
667 |
+
for sample_file in sample_files:
|
668 |
+
print("Processing image: " + sample_file)
|
669 |
+
sample_image = [load_test_data(sample_file, args.fine_size)]
|
670 |
+
start_time = time.time()
|
671 |
+
sample_image = np.array(sample_image).astype(np.float32)
|
672 |
+
image_path = os.path.join(
|
673 |
+
args.test_dir,
|
674 |
+
"{0}_{1}".format(args.which_direction, os.path.basename(sample_file)),
|
675 |
+
)
|
676 |
+
conf_path = os.path.join(
|
677 |
+
args.conf_dir,
|
678 |
+
"{0}_{1}".format(args.which_direction, os.path.basename(sample_file)),
|
679 |
+
)
|
680 |
+
|
681 |
+
(fake_img,) = self.sess.run([out_var], feed_dict={in_var: sample_image})
|
682 |
+
end_time = time.time()
|
683 |
+
merge = np.concatenate([sample_image, fake_img], axis=2)
|
684 |
+
save_images(merge, [1, 1], image_path)
|
685 |
+
total_time = total_time + (end_time - start_time)
|
686 |
+
print(f"Time taken to convert image: {end_time - start_time} seconds")
|
687 |
+
|
688 |
+
if args.save_conf:
|
689 |
+
|
690 |
+
if args.which_direction == "AtoB":
|
691 |
+
pass
|
692 |
+
else:
|
693 |
+
raise Exception(
|
694 |
+
"--conf map only can be estimated in AtoB direction"
|
695 |
+
)
|
696 |
+
|
697 |
+
conf_img = self.sess.run(conf_var, feed_dict={in_var: sample_image})
|
698 |
+
conf_img_sq = np.squeeze(conf_img)
|
699 |
+
plt.imshow(
|
700 |
+
conf_img_sq, cmap="plasma", interpolation="nearest", alpha=1.0
|
701 |
+
)
|
702 |
+
plt.savefig(conf_path)
|
703 |
+
print(
|
704 |
+
f"Average time taken to convert images: {total_time/len(sample_files)} seconds"
|
705 |
+
)
|
706 |
+
|
707 |
+
def convert_image(self, args, input_image_path, output_dir):
|
708 |
+
init_op = tf.compat.v1.global_variables_initializer()
|
709 |
+
if self.load(args.checkpoint_dir):
|
710 |
+
print(" [*] Load SUCCESS")
|
711 |
+
with tf.Session() as sess:
|
712 |
+
sess.run(init_op)
|
713 |
+
# Load the input image
|
714 |
+
input_image = [load_test_data(input_image_path, self.fine_size)]
|
715 |
+
input_image = np.array(input_image).astype(np.float32)
|
716 |
+
|
717 |
+
# Get the generator output
|
718 |
+
if args.which_direction == "AtoB":
|
719 |
+
out_var = self.testB
|
720 |
+
in_var = self.test_A
|
721 |
+
else:
|
722 |
+
out_var = self.testA
|
723 |
+
in_var = self.test_B
|
724 |
+
|
725 |
+
# Run the model to obtain the converted image
|
726 |
+
start_time = time.time()
|
727 |
+
converted_image = sess.run(out_var, feed_dict={in_var: input_image})
|
728 |
+
end_time = time.time()
|
729 |
+
|
730 |
+
# Save the converted image
|
731 |
+
output_image_path = os.path.join(
|
732 |
+
output_dir, os.path.basename(input_image_path)
|
733 |
+
)
|
734 |
+
merge = np.concatenate([input_image, converted_image], axis=2)
|
735 |
+
save_images(merge, [1, 1], output_image_path)
|
736 |
+
|
737 |
+
# Print the time taken
|
738 |
+
print(f"Time taken to convert image: {end_time - start_time} seconds")
|
README.md
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adverse Weather Image Translation with Asymmetric and Uncertainty-aware GAN (AU-GAN)
|
2 |
+
Official Tensorflow implementation of [Adverse Weather Image Translation with Asymmetric and Uncertainty-aware GAN](https://www.bmvc2021-virtualconference.com/assets/papers/1443.pdf) (AU-GAN)\
|
3 |
+
Jeong-gi Kwak, Youngsaeng Jin, Yuanming Li, Dongsik Yoon, Donghyeon Kim and Hanseok Ko </br>
|
4 |
+
*British Machine Vision Conference (BMVC), 2021*
|
5 |
+
</br>
|
6 |
+
|
7 |
+
## Intro
|
8 |
+
|
9 |
+
### Night → Day ([BDD100K](https://bdd-data.berkeley.edu/))
|
10 |
+
<img src="./assets/augan_bdd.png" width="800">
|
11 |
+
|
12 |
+
### Rainy night → Day ([Alderdey](https://wiki.qut.edu.au/pages/viewpage.action?pageId=181178395))
|
13 |
+
<img src="./assets/augan_alderley.png" width="800">
|
14 |
+
</br>
|
15 |
+
|
16 |
+
|
17 |
+
## Architecture
|
18 |
+
<img src="./assets/augan_model.png" width="800">
|
19 |
+
Our generator has asymmetric structure for editing day→night and night→day.
|
20 |
+
Please refer our paper for details
|
21 |
+
|
22 |
+
## **Envs**
|
23 |
+
|
24 |
+
```bash
|
25 |
+
|
26 |
+
git clone https://github.com/jgkwak95/AU-GAN.git
|
27 |
+
cd AU-GAN
|
28 |
+
|
29 |
+
# Create virtual environment
|
30 |
+
conda create -y --name augan python=3.6.7
|
31 |
+
conda activate augan
|
32 |
+
|
33 |
+
conda install tensorflow-gpu==1.14.0 # Tensorflow 1.14
|
34 |
+
pip install --no-cache-dir -r requirements.txt
|
35 |
+
|
36 |
+
```
|
37 |
+
|
38 |
+
## **Preparing datasets**
|
39 |
+
|
40 |
+
**Night → Day** </br>
|
41 |
+
[Berkeley DeepDrive dataset](https://bdd-data.berkeley.edu/) contains 100,000 high resolution images of the urban roads for autonomous driving.</br></br>
|
42 |
+
**Rainy night → Day** </br>
|
43 |
+
[Alderley dataset](https://wiki.qut.edu.au/pages/viewpage.action?pageId=181178395) consists of images of two domains,
|
44 |
+
rainy night and daytime. It was collected while driving the same route in each weather environment.</br>
|
45 |
+
</br>
|
46 |
+
Please download datasets and then construct them following [ForkGAN](https://github.com/zhengziqiang/ForkGAN)
|
47 |
+
|
48 |
+
## Pretrained Model
|
49 |
+
|
50 |
+
Download the pretrained model for BDD100K(256x512) [here](https://drive.google.com/file/d/1rvIF3yE9MwPWj0kD4IEstETyMQXYAHzr/view?usp=sharing) and unzip it to ./check/bdd_exp/bdd100k_256/
|
51 |
+
|
52 |
+
## Training
|
53 |
+
|
54 |
+
```bash
|
55 |
+
|
56 |
+
# Alderley (256x512)
|
57 |
+
python main_uncer.py --dataset_dir alderley
|
58 |
+
--phase train
|
59 |
+
--experiment_name alderley_exp
|
60 |
+
--batch_size 8
|
61 |
+
--load_size 286
|
62 |
+
--fine_size 256
|
63 |
+
--use_uncertainty True
|
64 |
+
|
65 |
+
```
|
66 |
+
|
67 |
+
```bash
|
68 |
+
|
69 |
+
# BDD100k (256x512)
|
70 |
+
python main_uncer.py --dataset_dir bdd100k
|
71 |
+
--phase train
|
72 |
+
--experiment_name bdd_exp
|
73 |
+
--batch_size 8
|
74 |
+
--load_size 286
|
75 |
+
--fine_size 256
|
76 |
+
--use_uncertainty True
|
77 |
+
|
78 |
+
```
|
79 |
+
|
80 |
+
## Test
|
81 |
+
|
82 |
+
```bash
|
83 |
+
|
84 |
+
# Alderley (256x512)
|
85 |
+
python main_uncer.py --dataset_dir alderley
|
86 |
+
--phase test
|
87 |
+
--experiment_name alderley_exp
|
88 |
+
--batch_size 1
|
89 |
+
--load_size 286
|
90 |
+
--fine_size 256
|
91 |
+
|
92 |
+
```
|
93 |
+
|
94 |
+
```bash
|
95 |
+
|
96 |
+
# BDD100k (256x512)
|
97 |
+
python main_uncer.py --dataset_dir bdd100k
|
98 |
+
--phase test
|
99 |
+
--experiment_name bdd_exp
|
100 |
+
--batch_size 1
|
101 |
+
--load_size 286
|
102 |
+
--fine_size 256
|
103 |
+
|
104 |
+
|
105 |
+
```
|
106 |
+
## Additional results
|
107 |
+
<img src="./assets/augan_result.png" width="800">
|
108 |
+
|
109 |
+
More results in [paper](https://www.bmvc2021-virtualconference.com/assets/papers/1443.pdf) and [supplementary]()
|
110 |
+
|
111 |
+
## Uncertainty map
|
112 |
+
<img src="./assets/augan_uncer.png" width="800">
|
113 |
+
|
114 |
+
## **Citation**
|
115 |
+
If our code is helpful your research, please cite our paper:
|
116 |
+
```
|
117 |
+
@article{kwak2021adverse,
|
118 |
+
title={Adverse weather image translation with asymmetric and uncertainty-aware GAN},
|
119 |
+
author={Kwak, Jeong-gi and Jin, Youngsaeng and Li, Yuanming and Yoon, Dongsik and Kim, Donghyeon and Ko, Hanseok},
|
120 |
+
journal={arXiv preprint arXiv:2112.04283},
|
121 |
+
year={2021}
|
122 |
+
}
|
123 |
+
```
|
124 |
+
## Acknowledgments
|
125 |
+
Our code is bulided upon the [ForkGAN](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123480154.pdf) implementation.
|
__pycache__/AUGAN.cpython-36.pyc
ADDED
Binary file (14.8 kB). View file
|
|
__pycache__/loss_utils.cpython-36.pyc
ADDED
Binary file (1.62 kB). View file
|
|
__pycache__/models.cpython-36.pyc
ADDED
Binary file (4.3 kB). View file
|
|
__pycache__/ops.cpython-36.pyc
ADDED
Binary file (6.78 kB). View file
|
|
__pycache__/utils.cpython-36.pyc
ADDED
Binary file (4.68 kB). View file
|
|
assets/augan_alderley.png
ADDED
assets/augan_bdd.png
ADDED
assets/augan_model.png
ADDED
assets/augan_result.png
ADDED
Git LFS Details
|
assets/augan_uncer.png
ADDED
Git LFS Details
|
cc.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python main.py --dataset_dir swim \
|
2 |
+
--phase test \
|
3 |
+
--experiment_name bdd_exp \
|
4 |
+
--batch_size 1 \
|
5 |
+
--which_direction BtoA \
|
6 |
+
--load_size 286 \
|
7 |
+
--fine_size 256
|
check.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1fc7f6d4f5f9c503bc69e1fdb454d8cc4f652d8c8966875dcd46aab83cdd0ff4
|
3 |
+
size 173513070
|
datasets/swim/testA/GP010594_frame_000017_rgb_anon.png
ADDED
Git LFS Details
|
datasets/swim/testA/GP010594_frame_000021_rgb_anon.png
ADDED
Git LFS Details
|
datasets/swim/testA/GP010594_frame_000087_rgb_anon.png
ADDED
Git LFS Details
|
datasets/swim/testB/GOPR0351_frame_000159_rgb_ref_anon.png
ADDED
Git LFS Details
|
datasets/swim/testB/GOPR0351_frame_000161_rgb_ref_anon.png
ADDED
Git LFS Details
|
datasets/swim/testB/GOPR0355_frame_000138_rgb_ref_anon.png
ADDED
Git LFS Details
|
inference.py
ADDED
File without changes
|
loss_utils.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
epsilon = 1e-7
|
4 |
+
|
5 |
+
def conf_criterion_lp(im1, im2, conf_sigma): # factorized laplacian distribution
|
6 |
+
loss = tf.abs(im1 - im2)
|
7 |
+
if conf_sigma is not None:
|
8 |
+
loss = loss * 2 / (conf_sigma + epsilon) + tf.log(conf_sigma * 2 + epsilon)
|
9 |
+
loss = tf.reduce_mean(loss)
|
10 |
+
else:
|
11 |
+
loss = tf.reduce_mean(loss)
|
12 |
+
|
13 |
+
return loss
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def conf_criterion(im1, im2, conf_sigma): # gaussian distribution
|
18 |
+
loss = tf.abs(im1 - im2)
|
19 |
+
if conf_sigma is not None:
|
20 |
+
loss = tf.math.exp(-conf_sigma) * 5 * loss + conf_sigma / 2
|
21 |
+
loss = tf.reduce_mean(loss)
|
22 |
+
else:
|
23 |
+
loss = tf.reduce_mean(loss)
|
24 |
+
|
25 |
+
return loss
|
26 |
+
|
27 |
+
|
28 |
+
def abs_criterion(in_, target):
|
29 |
+
return tf.reduce_mean(tf.abs(in_ - target))
|
30 |
+
|
31 |
+
|
32 |
+
def mae_criterion(in_, target):
|
33 |
+
return tf.reduce_mean((in_ - target) ** 2)
|
34 |
+
|
35 |
+
|
36 |
+
def sce_criterion(logits, labels):
|
37 |
+
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
|
38 |
+
|
39 |
+
|
40 |
+
def mae_criterion_list(in_, target):
|
41 |
+
loss = 0.0
|
42 |
+
for i in range(len(target)):
|
43 |
+
loss += tf.reduce_mean((in_[i] - target[i]) ** 2)
|
44 |
+
return loss / len(target)
|
45 |
+
|
46 |
+
|
47 |
+
def sce_criterion_list(logits, labels):
|
48 |
+
loss = 0.0
|
49 |
+
for i in range(len(labels)):
|
50 |
+
loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits[i], labels=labels[i]))
|
51 |
+
return loss / len(labels)
|
main.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import tensorflow as tf
|
3 |
+
import os
|
4 |
+
from utils import *
|
5 |
+
from AUGAN import AUGAN
|
6 |
+
from ops import *
|
7 |
+
import time
|
8 |
+
|
9 |
+
parser = argparse.ArgumentParser(description="")
|
10 |
+
parser.add_argument(
|
11 |
+
"--dataset_dir", dest="dataset_dir", default="bdd100k", help="path of the dataset"
|
12 |
+
)
|
13 |
+
parser.add_argument(
|
14 |
+
"--experiment_name",
|
15 |
+
dest="experiment_name",
|
16 |
+
type=str,
|
17 |
+
default="bdd_exp",
|
18 |
+
help="name of experiment",
|
19 |
+
)
|
20 |
+
parser.add_argument("--epoch", dest="epoch", type=int, default=20, help="# of epoch")
|
21 |
+
parser.add_argument(
|
22 |
+
"--epoch_step",
|
23 |
+
dest="epoch_step",
|
24 |
+
type=int,
|
25 |
+
default=10,
|
26 |
+
help="# of epoch to decay lr",
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--batch_size", dest="batch_size", type=int, default=1, help="# images in batch"
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--train_size",
|
33 |
+
dest="train_size",
|
34 |
+
type=int,
|
35 |
+
default=1e8,
|
36 |
+
help="# images used to train",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--load_size",
|
40 |
+
dest="load_size",
|
41 |
+
type=int,
|
42 |
+
default=286,
|
43 |
+
help="scale images to this size",
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--fine_size",
|
47 |
+
dest="fine_size",
|
48 |
+
type=int,
|
49 |
+
default=256,
|
50 |
+
help="then crop to this size",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--ngf",
|
54 |
+
dest="ngf",
|
55 |
+
type=int,
|
56 |
+
default=64,
|
57 |
+
help="# of gen filters in first conv layer",
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--ndf",
|
61 |
+
dest="ndf",
|
62 |
+
type=int,
|
63 |
+
default=64,
|
64 |
+
help="# of discri filters in first conv layer",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--n_d", dest="n_d", type=int, default=2, help="# of discriminators"
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--n_scale", dest="n_scale", type=int, default=2, help="# of scales"
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--gpu", dest="gpu", type=int, default=0, help="# index of gpu device"
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--input_nc", dest="input_nc", type=int, default=3, help="# of input image channels"
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--output_nc",
|
80 |
+
dest="output_nc",
|
81 |
+
type=int,
|
82 |
+
default=3,
|
83 |
+
help="# of output image channels",
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--lr", dest="lr", type=float, default=0.0002, help="initial learning rate for adam"
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"--beta1", dest="beta1", type=float, default=0.5, help="momentum term of adam"
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--which_direction", dest="which_direction", default="AtoB", help="AtoB or BtoA "
|
93 |
+
)
|
94 |
+
parser.add_argument("--phase", dest="phase", default="test", help="train, test")
|
95 |
+
parser.add_argument(
|
96 |
+
"--save_freq",
|
97 |
+
dest="save_freq",
|
98 |
+
type=int,
|
99 |
+
default=1000,
|
100 |
+
help="save a model every save_freq iterations",
|
101 |
+
)
|
102 |
+
parser.add_argument(
|
103 |
+
"--print_freq",
|
104 |
+
dest="print_freq",
|
105 |
+
type=int,
|
106 |
+
default=100,
|
107 |
+
help="print the debug information every print_freq iterations",
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--L1_lambda",
|
111 |
+
dest="L1_lambda",
|
112 |
+
type=float,
|
113 |
+
default=10.0,
|
114 |
+
help="weight on L1 term in objective",
|
115 |
+
)
|
116 |
+
parser.add_argument(
|
117 |
+
"--conf_lambda",
|
118 |
+
dest="conf_lambda",
|
119 |
+
type=float,
|
120 |
+
default=1.0,
|
121 |
+
help="weight on L1 term in objective",
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--use_resnet",
|
125 |
+
dest="use_resnet",
|
126 |
+
type=bool,
|
127 |
+
default=True,
|
128 |
+
help="generation network using reidule block",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--use_lsgan",
|
132 |
+
dest="use_lsgan",
|
133 |
+
type=bool,
|
134 |
+
default=True,
|
135 |
+
help="gan loss defined in lsgan",
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
"--use_uncertainty",
|
139 |
+
dest="use_uncertainty",
|
140 |
+
type=bool,
|
141 |
+
default=True,
|
142 |
+
help="max size of image pool, 0 means do not use image pool",
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--max_size",
|
146 |
+
dest="max_size",
|
147 |
+
type=int,
|
148 |
+
default=50,
|
149 |
+
help="max size of image pool, 0 means do not use image pool",
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--continue_train",
|
153 |
+
dest="continue_train",
|
154 |
+
type=bool,
|
155 |
+
default=False,
|
156 |
+
help="if continue training, load the latest model: 1: true, 0: false",
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"--save_conf",
|
160 |
+
dest="save_conf",
|
161 |
+
type=bool,
|
162 |
+
default=False,
|
163 |
+
help="save conf map in test phase",
|
164 |
+
)
|
165 |
+
args = parser.parse_args()
|
166 |
+
|
167 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
|
168 |
+
|
169 |
+
|
170 |
+
def main(_):
|
171 |
+
|
172 |
+
set_path(args, args.experiment_name)
|
173 |
+
|
174 |
+
tfconfig = tf.compat.v1.ConfigProto(allow_soft_placement=True)
|
175 |
+
tfconfig.gpu_options.allow_growth = True
|
176 |
+
with tf.compat.v1.Session(config=tfconfig) as sess:
|
177 |
+
model = AUGAN(sess, args)
|
178 |
+
# show_all_variables()
|
179 |
+
# model.train(args) if args.phase == 'train' \
|
180 |
+
# else model.test(args)
|
181 |
+
|
182 |
+
if args.phase == "train":
|
183 |
+
model.train(args)
|
184 |
+
elif args.phase == "test":
|
185 |
+
model.test(args)
|
186 |
+
elif args.phase == "convert":
|
187 |
+
model.convert_image(args, "inf_data/b1ca2e5d-84cf9134.jpg", "out")
|
188 |
+
else:
|
189 |
+
raise Exception("Give a phase")
|
190 |
+
|
191 |
+
|
192 |
+
if __name__ == "__main__":
|
193 |
+
tf.compat.v1.app.run()
|
models.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
from utils import *
|
3 |
+
from ops import *
|
4 |
+
import time
|
5 |
+
from glob import glob
|
6 |
+
|
7 |
+
|
8 |
+
def gaussian_noise_layer(input_layer, std):
|
9 |
+
noise = tf.random.normal(
|
10 |
+
shape=tf.shape(input_layer), mean=0.0, stddev=std, dtype=tf.float32
|
11 |
+
)
|
12 |
+
return input_layer + noise
|
13 |
+
|
14 |
+
|
15 |
+
def generator_resnet(image, options, transfer=False, reuse=False, name="generator"):
|
16 |
+
with tf.compat.v1.variable_scope(name):
|
17 |
+
if reuse:
|
18 |
+
tf.compat.v1.get_variable_scope().reuse_variables()
|
19 |
+
else:
|
20 |
+
assert tf.compat.v1.get_variable_scope().reuse is False
|
21 |
+
|
22 |
+
def residule_block_dilated(x, dim, ks=3, s=1, name="res", down=False):
|
23 |
+
if down:
|
24 |
+
dim = dim * 2
|
25 |
+
y = instance_norm(
|
26 |
+
dilated_conv2d(x, dim, ks, s, padding="SAME", name=name + "_c1"),
|
27 |
+
name + "_bn1",
|
28 |
+
)
|
29 |
+
y = tf.nn.relu(y)
|
30 |
+
y = instance_norm(
|
31 |
+
dilated_conv2d(y, dim, ks, s, padding="SAME", name=name + "_c2"),
|
32 |
+
name + "_bn2",
|
33 |
+
)
|
34 |
+
out = y + x
|
35 |
+
if down:
|
36 |
+
out = tf.nn.relu(
|
37 |
+
instance_norm(
|
38 |
+
conv2d(out, dim // 2, 3, 1, name=name + "_down_c"),
|
39 |
+
name + "_in_down",
|
40 |
+
)
|
41 |
+
)
|
42 |
+
return out
|
43 |
+
|
44 |
+
def residual_block(x_init, dim, ks=3, s=1, name="resblock", down=False):
|
45 |
+
with tf.compat.v1.variable_scope(name):
|
46 |
+
if down:
|
47 |
+
dim = dim * 2
|
48 |
+
|
49 |
+
with tf.compat.v1.variable_scope("res1"):
|
50 |
+
x = instance_norm(
|
51 |
+
conv2d(x_init, dim, ks, s, padding="SAME", name=name + "_c1"),
|
52 |
+
name + "_in1",
|
53 |
+
)
|
54 |
+
x = tf.nn.relu(x)
|
55 |
+
|
56 |
+
with tf.compat.v1.variable_scope("res2"):
|
57 |
+
|
58 |
+
x = instance_norm(
|
59 |
+
conv2d(x, dim, ks, s, padding="SAME", name=name + "_c2"),
|
60 |
+
name + "_in2",
|
61 |
+
)
|
62 |
+
|
63 |
+
out = x + x_init
|
64 |
+
|
65 |
+
if down:
|
66 |
+
out = tf.nn.relu(
|
67 |
+
instance_norm(
|
68 |
+
conv2d(out, dim // 2, 3, 1, name=name + "_down_c"),
|
69 |
+
name + "_in_down",
|
70 |
+
)
|
71 |
+
)
|
72 |
+
return out
|
73 |
+
|
74 |
+
### Encoder architecture
|
75 |
+
c0 = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
|
76 |
+
c1 = tf.nn.relu(
|
77 |
+
instance_norm(
|
78 |
+
conv2d(c0, options.gf_dim, 7, 1, padding="VALID", name="g_e1_c"),
|
79 |
+
"g_e1_bn",
|
80 |
+
)
|
81 |
+
)
|
82 |
+
c2 = tf.nn.relu(
|
83 |
+
instance_norm(
|
84 |
+
conv2d(c1, options.gf_dim * 2, 3, 2, name="g_e2_c"), "g_e2_bn"
|
85 |
+
)
|
86 |
+
)
|
87 |
+
c3 = tf.nn.relu(
|
88 |
+
instance_norm(
|
89 |
+
conv2d(c2, options.gf_dim * 4, 3, 2, name="g_e3_c"), "g_e3_bn"
|
90 |
+
)
|
91 |
+
)
|
92 |
+
r1 = residule_block_dilated(c3, options.gf_dim * 4, name="g_r1")
|
93 |
+
r2 = residule_block_dilated(r1, options.gf_dim * 4, name="g_r2")
|
94 |
+
r3 = residule_block_dilated(r2, options.gf_dim * 4, name="g_r3")
|
95 |
+
r4 = residule_block_dilated(r3, options.gf_dim * 4, name="g_r4")
|
96 |
+
# r5 = residule_block_dilated(r4, options.gf_dim * 4, name='g_r5')
|
97 |
+
|
98 |
+
if transfer:
|
99 |
+
t1 = residual_block(r4, options.gf_dim * 4, name="g_t1")
|
100 |
+
t2 = residual_block(t1, options.gf_dim * 4, name="g_t2")
|
101 |
+
t3 = residual_block(t2, options.gf_dim * 4, name="g_t3")
|
102 |
+
t4 = residual_block(t3, options.gf_dim * 4, name="g_t4")
|
103 |
+
# feature = tf.concat([r4, t4], axis=3, name='g_concat')
|
104 |
+
# down = True
|
105 |
+
feature = t4
|
106 |
+
else:
|
107 |
+
feature = r4
|
108 |
+
t4 = None
|
109 |
+
down = False
|
110 |
+
|
111 |
+
### translation decoder architecture
|
112 |
+
r6 = residule_block_dilated(feature, options.gf_dim * 4, name="g_r6")
|
113 |
+
r7 = residule_block_dilated(r6, options.gf_dim * 4, name="g_r7")
|
114 |
+
r8 = residule_block_dilated(r7, options.gf_dim * 4, name="g_r8")
|
115 |
+
r9 = residule_block_dilated(r8, options.gf_dim * 4, name="g_r9")
|
116 |
+
d1 = deconv2d(r9, options.gf_dim * 2, 3, 2, name="g_d1_dc")
|
117 |
+
d1 = tf.nn.relu(instance_norm(d1, "g_d1_bn"))
|
118 |
+
d2 = deconv2d(d1, options.gf_dim, 3, 2, name="g_d2_dc")
|
119 |
+
d2 = tf.nn.relu(instance_norm(d2, "g_d2_bn"))
|
120 |
+
d2 = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
|
121 |
+
pred = tf.nn.tanh(
|
122 |
+
conv2d(d2, options.output_c_dim, 7, 1, padding="VALID", name="g_pred_c")
|
123 |
+
)
|
124 |
+
|
125 |
+
### reconstruction decoder architecture
|
126 |
+
r5 = gaussian_noise_layer(r4, 0.02)
|
127 |
+
r6_rec = residule_block_dilated(r5, options.gf_dim * 4, name="g_r6_rec")
|
128 |
+
r6_rec = gaussian_noise_layer(r6_rec, 0.02)
|
129 |
+
r7_rec = residule_block_dilated(r6_rec, options.gf_dim * 4, name="g_r7_rec")
|
130 |
+
r8_rec = residule_block_dilated(r7_rec, options.gf_dim * 4, name="g_r8_rec")
|
131 |
+
r9_rec = residule_block_dilated(r8_rec, options.gf_dim * 4, name="g_r9_rec")
|
132 |
+
d1_rec = deconv2d(r9_rec, options.gf_dim * 2, 3, 2, name="g_d1_dc_rec")
|
133 |
+
d1_rec = tf.nn.relu(instance_norm(d1_rec, "g_d1_bn_rec"))
|
134 |
+
d2_rec = deconv2d(d1_rec, options.gf_dim, 3, 2, name="g_d2_dc_rec")
|
135 |
+
d2_rec = tf.nn.relu(instance_norm(d2_rec, "g_d2_bn_rec"))
|
136 |
+
d2_rec = tf.pad(d2_rec, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
|
137 |
+
pred_rec = tf.nn.tanh(
|
138 |
+
conv2d(
|
139 |
+
d2_rec, options.output_c_dim, 7, 1, padding="VALID", name="g_pred_c_rec"
|
140 |
+
)
|
141 |
+
)
|
142 |
+
|
143 |
+
## confidence prediction
|
144 |
+
|
145 |
+
if transfer:
|
146 |
+
|
147 |
+
d_conf = deconv2d(d1, options.gf_dim, 3, 2, name="g_d_dc_conf")
|
148 |
+
d_conf = tf.nn.relu(instance_norm(d_conf, "g_d_bn_conf"))
|
149 |
+
d_conf = tf.pad(d_conf, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
|
150 |
+
pred_conf = tf.nn.softplus(
|
151 |
+
conv2d(d_conf, 1, 7, 1, padding="VALID", name="g_pred_c_conf")
|
152 |
+
)
|
153 |
+
|
154 |
+
else:
|
155 |
+
pred_conf = None
|
156 |
+
|
157 |
+
return pred, pred_rec, r4, t4, pred_conf
|
158 |
+
|
159 |
+
|
160 |
+
def discriminator(image, options, n_scale=2, reuse=False, name="discriminator"):
|
161 |
+
images = []
|
162 |
+
for i in range(n_scale):
|
163 |
+
images.append(
|
164 |
+
tf.compat.v1.image.resize_bicubic(
|
165 |
+
image, [get_shape(image)[1] // (2**i), get_shape(image)[2] // (2**i)]
|
166 |
+
)
|
167 |
+
)
|
168 |
+
with tf.compat.v1.variable_scope(name):
|
169 |
+
if reuse:
|
170 |
+
tf.compat.v1.get_variable_scope().reuse_variables()
|
171 |
+
else:
|
172 |
+
assert tf.compat.v1.get_variable_scope().reuse is False
|
173 |
+
images = dis_down(images, 4, 2, n_scale, options.df_dim, "d_h0_conv_scale_")
|
174 |
+
images = dis_down(images, 4, 2, n_scale, options.df_dim * 2, "d_h1_conv_scale_")
|
175 |
+
images = dis_down(images, 4, 2, n_scale, options.df_dim * 4, "d_h2_conv_scale_")
|
176 |
+
images = dis_down(images, 4, 2, n_scale, options.df_dim * 8, "d_h3_conv_scale_")
|
177 |
+
images = final_conv(images, n_scale, "d_pred_scale_")
|
178 |
+
return images
|
ops.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
# import tensorflow.contrib.slim as slim
|
4 |
+
import tf_slim as slim
|
5 |
+
import math
|
6 |
+
import pprint
|
7 |
+
|
8 |
+
pp = pprint.PrettyPrinter()
|
9 |
+
get_stddev = lambda x, k_h, k_w: 1 / math.sqrt(k_w * k_h * x.get_shape()[-1])
|
10 |
+
# import tensorflow.contrib as tf_contrib
|
11 |
+
|
12 |
+
# weight_init = tf_contrib.layers.xavier_initializer()
|
13 |
+
weight_init = tf.initializers.GlorotUniform()
|
14 |
+
weight_regularizer = None
|
15 |
+
|
16 |
+
|
17 |
+
def batch_norm(x, name="batch_norm"):
|
18 |
+
# return tf.contrib.layers.batch_norm(
|
19 |
+
# x, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope=name
|
20 |
+
# )
|
21 |
+
return tf.keras.layers.BatchNormalization(
|
22 |
+
momentum=0.9, epsilon=1e-5, scale=True, name=name
|
23 |
+
)(x)
|
24 |
+
|
25 |
+
|
26 |
+
def instance_norm(input, name="instance_norm"):
|
27 |
+
with tf.compat.v1.variable_scope(name):
|
28 |
+
depth = input.get_shape()[3]
|
29 |
+
scale = tf.compat.v1.get_variable(
|
30 |
+
"scale",
|
31 |
+
[depth],
|
32 |
+
initializer=tf.keras.initializers.RandomNormal(
|
33 |
+
mean=1.0, stddev=0.02, seed=None
|
34 |
+
),
|
35 |
+
)
|
36 |
+
offset = tf.compat.v1.get_variable(
|
37 |
+
"offset", [depth], initializer=tf.constant_initializer(0.0)
|
38 |
+
)
|
39 |
+
mean, variance = tf.nn.moments(input, axes=[1, 2], keepdims=True)
|
40 |
+
epsilon = 1e-5
|
41 |
+
inv = tf.math.rsqrt(variance + epsilon)
|
42 |
+
normalized = (input - mean) * inv
|
43 |
+
return scale * normalized + offset
|
44 |
+
|
45 |
+
|
46 |
+
def conv2d(input_, output_dim, ks=4, s=2, stddev=0.02, padding="SAME", name="conv2d"):
|
47 |
+
with tf.compat.v1.variable_scope(name):
|
48 |
+
return slim.conv2d(
|
49 |
+
input_,
|
50 |
+
output_dim,
|
51 |
+
ks,
|
52 |
+
s,
|
53 |
+
padding=padding,
|
54 |
+
activation_fn=None,
|
55 |
+
weights_initializer=tf.keras.initializers.TruncatedNormal(stddev=stddev),
|
56 |
+
biases_initializer=None,
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def deconv2d(input_, output_dim, ks=4, s=2, stddev=0.02, name="deconv2d"):
|
61 |
+
with tf.compat.v1.variable_scope(name):
|
62 |
+
return slim.conv2d_transpose(
|
63 |
+
input_,
|
64 |
+
output_dim,
|
65 |
+
ks,
|
66 |
+
s,
|
67 |
+
padding="SAME",
|
68 |
+
activation_fn=None,
|
69 |
+
weights_initializer=tf.keras.initializers.TruncatedNormal(stddev=stddev),
|
70 |
+
biases_initializer=None,
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
def dilated_conv2d(
|
75 |
+
input_, output_dim, ks=3, s=2, stddev=0.02, padding="SAME", name="conv2d"
|
76 |
+
):
|
77 |
+
with tf.compat.v1.variable_scope(name):
|
78 |
+
batch, in_height, in_width, in_channels = [int(d) for d in input_.get_shape()]
|
79 |
+
filter = tf.compat.v1.get_variable(
|
80 |
+
"filter",
|
81 |
+
[ks, ks, in_channels, output_dim],
|
82 |
+
dtype=tf.float32,
|
83 |
+
initializer=tf.random_normal_initializer(0, stddev),
|
84 |
+
)
|
85 |
+
conv = tf.nn.atrous_conv2d(input_, filter, rate=s, padding=padding, name=name)
|
86 |
+
|
87 |
+
return conv
|
88 |
+
|
89 |
+
|
90 |
+
def one_step(x, ch, kernel, stride, name):
|
91 |
+
return lrelu(
|
92 |
+
instance_norm(
|
93 |
+
conv2d(x, ch, kernel, stride, name=name + "_first_c"), name + "_first_bn"
|
94 |
+
)
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def one_step_dilated(x, ch, kernel, stride, name):
|
99 |
+
return lrelu(
|
100 |
+
instance_norm(
|
101 |
+
dilated_conv2d(x, ch, kernel, stride, name=name + "_first_c"),
|
102 |
+
name + "_first_bn",
|
103 |
+
)
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
def num_steps(x, ch, kernel, stride, num_steps, name):
|
108 |
+
for i in range(num_steps):
|
109 |
+
x = lrelu(
|
110 |
+
instance_norm(
|
111 |
+
conv2d(x, ch, kernel, stride, name=name + "_c_" + str(i)),
|
112 |
+
name + "_bn_" + str(i),
|
113 |
+
)
|
114 |
+
)
|
115 |
+
return x
|
116 |
+
|
117 |
+
|
118 |
+
def one_step_noins(x, ch, kernel, stride, name):
|
119 |
+
return lrelu(conv2d(x, ch, kernel, stride, name=name + "_first_c"))
|
120 |
+
|
121 |
+
|
122 |
+
def num_steps_noins(x, ch, kernel, stride, num_steps, name):
|
123 |
+
|
124 |
+
for i in range(num_steps):
|
125 |
+
x = lrelu(conv2d(x, ch, kernel, stride, name=name + "_c_" + str(i)))
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
def dis_down(images, kernel_size, stride, n_scale, ch, name):
|
130 |
+
backpack = images[0]
|
131 |
+
for i in range(n_scale):
|
132 |
+
if i == n_scale - 1:
|
133 |
+
images[i] = num_steps(
|
134 |
+
backpack, ch, kernel_size, stride, n_scale, name + str(i)
|
135 |
+
)
|
136 |
+
else:
|
137 |
+
images[i] = one_step_dilated(
|
138 |
+
images[i + 1], ch, kernel_size, 1, name + str(i)
|
139 |
+
)
|
140 |
+
return images
|
141 |
+
|
142 |
+
|
143 |
+
def dis_down_noins(images, kernel_size, stride, n_scale, ch, name):
|
144 |
+
backpack = images[0]
|
145 |
+
for i in range(n_scale):
|
146 |
+
if i == n_scale - 1:
|
147 |
+
images[i] = num_steps_noins(
|
148 |
+
backpack, ch, kernel_size, stride, n_scale, name + str(i)
|
149 |
+
)
|
150 |
+
else:
|
151 |
+
images[i] = one_step_noins(images[i + 1], ch, kernel_size, 1, name + str(i))
|
152 |
+
return images
|
153 |
+
|
154 |
+
|
155 |
+
def final_conv(images, n_scale, name):
|
156 |
+
for i in range(n_scale):
|
157 |
+
images[i] = conv2d(images[i], 1, s=1, name=name + str(i))
|
158 |
+
return images
|
159 |
+
|
160 |
+
|
161 |
+
def lrelu(x, leak=0.2, name="lrelu"):
|
162 |
+
return tf.maximum(x, leak * x)
|
163 |
+
|
164 |
+
|
165 |
+
def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
|
166 |
+
with tf.compat.v1.variable_scope(scope or "Linear"):
|
167 |
+
matrix = tf.get_variable(
|
168 |
+
"Matrix",
|
169 |
+
[input_.get_shape()[-1], output_size],
|
170 |
+
tf.float32,
|
171 |
+
tf.random_normal_initializer(stddev=stddev),
|
172 |
+
)
|
173 |
+
bias = tf.get_variable(
|
174 |
+
"bias", [output_size], initializer=tf.constant_initializer(bias_start)
|
175 |
+
)
|
176 |
+
if with_w:
|
177 |
+
return tf.matmul(input_, matrix) + bias, matrix, bias
|
178 |
+
else:
|
179 |
+
return tf.matmul(input_, matrix) + bias
|
180 |
+
|
181 |
+
|
182 |
+
def get_ones_like(logit):
|
183 |
+
target = []
|
184 |
+
for i in range(len(logit)):
|
185 |
+
target.append(tf.ones_like(logit[i]))
|
186 |
+
return target
|
187 |
+
|
188 |
+
|
189 |
+
def get_zeros_like(logit):
|
190 |
+
target = []
|
191 |
+
for i in range(len(logit)):
|
192 |
+
target.append(tf.zeros_like(logit[i]))
|
193 |
+
return target
|
194 |
+
|
195 |
+
|
196 |
+
def conv(
|
197 |
+
x,
|
198 |
+
channels,
|
199 |
+
kernel=4,
|
200 |
+
stride=2,
|
201 |
+
pad=0,
|
202 |
+
pad_type="zero",
|
203 |
+
use_bias=True,
|
204 |
+
scope="conv_0",
|
205 |
+
):
|
206 |
+
with tf.compat.v1.variable_scope(scope):
|
207 |
+
if pad_type == "zero":
|
208 |
+
x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
|
209 |
+
if pad_type == "reflect":
|
210 |
+
x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode="REFLECT")
|
211 |
+
|
212 |
+
x = tf.layers.conv2d(
|
213 |
+
inputs=x,
|
214 |
+
filters=channels,
|
215 |
+
kernel_size=kernel,
|
216 |
+
kernel_initializer=weight_init,
|
217 |
+
kernel_regularizer=weight_regularizer,
|
218 |
+
strides=stride,
|
219 |
+
use_bias=use_bias,
|
220 |
+
)
|
221 |
+
|
222 |
+
return x
|
223 |
+
|
224 |
+
|
225 |
+
def reduce_sum(input_tensor, axis=None, keepdims=False):
|
226 |
+
try:
|
227 |
+
return tf.reduce_sum(input_tensor, axis=axis, keepdims=keepdims)
|
228 |
+
except:
|
229 |
+
return tf.reduce_sum(input_tensor, axis=axis, keep_dims=keepdims)
|
230 |
+
|
231 |
+
|
232 |
+
def get_shape(inputs, name=None):
|
233 |
+
name = "shape" if name is None else name
|
234 |
+
with tf.name_scope(name):
|
235 |
+
static_shape = inputs.get_shape().as_list()
|
236 |
+
dynamic_shape = tf.shape(inputs)
|
237 |
+
shape = []
|
238 |
+
for i, dim in enumerate(static_shape):
|
239 |
+
dim = dim if dim is not None else dynamic_shape[i]
|
240 |
+
shape.append(dim)
|
241 |
+
return shape
|
242 |
+
|
243 |
+
|
244 |
+
def show_all_variables():
|
245 |
+
model_vars = tf.trainable_variables()
|
246 |
+
slim.model_analyzer.analyze_vars(model_vars, print_info=True)
|
parser.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
|
5 |
+
with open('C:/jg/github_code/ForkGAN/bdd100k/labels/bdd100k_labels_images_train.json') as json_file:
|
6 |
+
json_data = json.load(json_file)
|
7 |
+
|
8 |
+
for item in json_data:
|
9 |
+
item_path = 'C:/jg/github_code/ForkGAN/bdd100k/images/100k/train/'+ item['name']
|
10 |
+
print(item['name'])
|
11 |
+
if item['attributes']['timeofday'] == 'daytime':
|
12 |
+
shutil.copy(item_path, 'C:/jg/github_code/ForkGAN/bdd100k/images/daytime/'+item['name'])
|
13 |
+
|
14 |
+
elif item['attributes']['timeofday'] == 'night':
|
15 |
+
shutil.copy(item_path, 'C:/jg/github_code/ForkGAN/bdd100k/images/night/'+item['name'])
|
16 |
+
|
17 |
+
else :
|
18 |
+
shutil.copy(item_path, 'C:/jg/github_code/ForkGAN/bdd100k/images/else/' + item['name'])
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pillow==6.0.0
|
2 |
+
scipy==1.1.0
|
3 |
+
numpy
|
4 |
+
matplotlib
|
utils.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import scipy.misc
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import copy
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
class ImagePool(object):
|
9 |
+
def __init__(self, maxsize=50):
|
10 |
+
self.maxsize = maxsize
|
11 |
+
self.num_img = 0
|
12 |
+
self.images = []
|
13 |
+
|
14 |
+
def __call__(self, image):
|
15 |
+
if self.maxsize <= 0:
|
16 |
+
return image
|
17 |
+
if self.num_img < self.maxsize:
|
18 |
+
self.images.append(image)
|
19 |
+
self.num_img += 1
|
20 |
+
return image
|
21 |
+
if np.random.rand() > 0.5:
|
22 |
+
idx = int(np.random.rand() * self.maxsize)
|
23 |
+
tmp1 = copy.copy(self.images[idx])[0]
|
24 |
+
self.images[idx][0] = image[0]
|
25 |
+
idx = int(np.random.rand() * self.maxsize)
|
26 |
+
tmp2 = copy.copy(self.images[idx])[1]
|
27 |
+
self.images[idx][1] = image[1]
|
28 |
+
return [tmp1, tmp2]
|
29 |
+
else:
|
30 |
+
return image
|
31 |
+
|
32 |
+
|
33 |
+
def load_test_data(image_path, fine_size=256):
|
34 |
+
img = Image.open(image_path)
|
35 |
+
img = img.resize((fine_size * 2, fine_size))
|
36 |
+
img = np.array(img)
|
37 |
+
# Normalize image to the range [-1, 1]
|
38 |
+
img = img / 127.5 - 1
|
39 |
+
|
40 |
+
return img
|
41 |
+
|
42 |
+
|
43 |
+
def check_folder(path):
|
44 |
+
if not os.path.exists(path):
|
45 |
+
os.mkdir(path)
|
46 |
+
|
47 |
+
|
48 |
+
def load_train_data(image_path, load_size=286, fine_size=256, is_testing=False):
|
49 |
+
img_A = Image.open(image_path[0])
|
50 |
+
img_B = Image.open(image_path[1])
|
51 |
+
|
52 |
+
if not is_testing:
|
53 |
+
# Resize images using PIL
|
54 |
+
img_A = img_A.resize((load_size * 2, load_size))
|
55 |
+
img_B = img_B.resize((load_size * 2, load_size))
|
56 |
+
|
57 |
+
# Random crop
|
58 |
+
h1 = int(np.ceil(np.random.uniform(1e-2, load_size - fine_size)))
|
59 |
+
w1 = int(np.ceil(np.random.uniform(1e-2, (load_size - fine_size) * 2)))
|
60 |
+
img_A = np.array(img_A.crop((w1, h1, w1 + fine_size * 2, h1 + fine_size)))
|
61 |
+
img_B = np.array(img_B.crop((w1, h1, w1 + fine_size * 2, h1 + fine_size)))
|
62 |
+
|
63 |
+
# Random horizontal flip
|
64 |
+
if np.random.random() > 0.5:
|
65 |
+
img_A = np.fliplr(img_A)
|
66 |
+
img_B = np.fliplr(img_B)
|
67 |
+
else:
|
68 |
+
# Resize images using PIL for testing
|
69 |
+
img_A = img_A.resize((fine_size * 2, fine_size))
|
70 |
+
img_B = img_B.resize((fine_size * 2, fine_size))
|
71 |
+
|
72 |
+
# Normalize images to the range [-1, 1]
|
73 |
+
img_A = img_A / 127.5 - 1.0
|
74 |
+
img_B = img_B / 127.5 - 1.0
|
75 |
+
|
76 |
+
# Concatenate images along the channel axis
|
77 |
+
img_AB = np.concatenate((img_A, img_B), axis=2)
|
78 |
+
|
79 |
+
return img_AB
|
80 |
+
|
81 |
+
|
82 |
+
# -----------------------------
|
83 |
+
|
84 |
+
|
85 |
+
def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale=False):
|
86 |
+
return transform(
|
87 |
+
load_image(image_path, is_grayscale), image_size, is_crop, resize_w
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
def save_images(images, size, image_path):
|
92 |
+
return imsave(images, size, image_path)
|
93 |
+
|
94 |
+
|
95 |
+
def load_image(path, is_grayscale=False):
|
96 |
+
if is_grayscale:
|
97 |
+
return np.array(Image.open(path).convert("L")).astype(np.float)
|
98 |
+
else:
|
99 |
+
return np.array(Image.open(path).convert("RGB")).astype(np.float)
|
100 |
+
|
101 |
+
|
102 |
+
def merge_images(images, size):
|
103 |
+
return inverse_transform(images)
|
104 |
+
|
105 |
+
|
106 |
+
def merge(images, size):
|
107 |
+
h, w = images.shape[1], images.shape[2]
|
108 |
+
img = np.zeros((h * size[0], w * size[1], 3))
|
109 |
+
for idx, image in enumerate(images):
|
110 |
+
i = idx % size[1]
|
111 |
+
j = idx // size[1]
|
112 |
+
img[j * h : j * h + h, i * w : i * w + w, :] = image
|
113 |
+
|
114 |
+
return img
|
115 |
+
|
116 |
+
|
117 |
+
def imsave(image, size, path):
|
118 |
+
# Convert images to uint8 format and adjust the range
|
119 |
+
image = ((image + 1.0) * 127.5).astype(np.uint8)
|
120 |
+
|
121 |
+
# Merge images
|
122 |
+
# merged_image = merge(images, size).astype(np.uint8)
|
123 |
+
|
124 |
+
# Create a PIL Image from the numpy array
|
125 |
+
pil_image = Image.fromarray(image)
|
126 |
+
|
127 |
+
# Save the image using PIL
|
128 |
+
pil_image.save(path)
|
129 |
+
|
130 |
+
return None
|
131 |
+
|
132 |
+
|
133 |
+
def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64):
|
134 |
+
if crop_w is None:
|
135 |
+
crop_w = crop_h
|
136 |
+
h, w = x.shape[:2]
|
137 |
+
j = int(round((h - crop_h) / 2.0))
|
138 |
+
i = int(round((w - crop_w) / 2.0))
|
139 |
+
|
140 |
+
# Use PIL for resizing
|
141 |
+
cropped_image = Image.fromarray(x[j : j + crop_h, i : i + crop_w].astype(np.uint8))
|
142 |
+
cropped_image = cropped_image.resize((resize_w, resize_h))
|
143 |
+
|
144 |
+
return np.array(cropped_image) / 127.5 - 1.0
|
145 |
+
|
146 |
+
|
147 |
+
def transform(image, npx=64, is_crop=True, resize_w=64):
|
148 |
+
# npx: # of pixels width/height of image
|
149 |
+
if is_crop:
|
150 |
+
cropped_image = center_crop(image, npx, resize_w=resize_w)
|
151 |
+
else:
|
152 |
+
cropped_image = image
|
153 |
+
return np.array(cropped_image) / 127.5 - 1.0
|
154 |
+
|
155 |
+
|
156 |
+
def inverse_transform(images):
|
157 |
+
return (images + 1.0) / 2.0
|
158 |
+
|
159 |
+
|
160 |
+
def norm_img(img):
|
161 |
+
img = img / np.linalg.norm(img)
|
162 |
+
img = (img * 2.0) - 1.0
|
163 |
+
|
164 |
+
return img
|
165 |
+
|
166 |
+
|
167 |
+
def set_path(args, experiment_name):
|
168 |
+
args.checkpoint_dir = f"./check/{experiment_name}"
|
169 |
+
args.sample_dir = f"./check/{experiment_name}/sample"
|
170 |
+
if args.which_direction == "AtoB":
|
171 |
+
args.test_dir = f"./check/{experiment_name}/testa2b"
|
172 |
+
else:
|
173 |
+
args.test_dir = f"./check/{experiment_name}/testb2a"
|
174 |
+
args.conf_dir = f"./check/{experiment_name}/conf"
|
175 |
+
if not os.path.exists(args.checkpoint_dir):
|
176 |
+
os.makedirs(args.checkpoint_dir)
|
177 |
+
if not os.path.exists(args.sample_dir):
|
178 |
+
os.makedirs(args.sample_dir)
|
179 |
+
if not os.path.exists(args.test_dir):
|
180 |
+
os.makedirs(args.test_dir)
|
181 |
+
if not os.path.exists(args.conf_dir):
|
182 |
+
os.makedirs(args.conf_dir)
|