init
Browse files- app.py +88 -0
- packages.txt +2 -0
- requirements.txt +4 -0
- ugatit/UGATIT.py +665 -0
- ugatit/main.py +106 -0
- ugatit/ops.py +345 -0
- ugatit/utils.py +80 -0
- ugatit_test.py +372 -0
app.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
import argparse
|
5 |
+
import functools
|
6 |
+
import os
|
7 |
+
import pathlib
|
8 |
+
import sys
|
9 |
+
from typing import Callable
|
10 |
+
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
import huggingface_hub
|
14 |
+
import numpy as np
|
15 |
+
import PIL.Image
|
16 |
+
|
17 |
+
from io import BytesIO
|
18 |
+
|
19 |
+
|
20 |
+
ORIGINAL_REPO_URL = 'https://github.com/taki0112/UGATIT'
|
21 |
+
TITLE = 'taki0112/UGATIT'
|
22 |
+
DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}.
|
23 |
+
|
24 |
+
"""
|
25 |
+
ARTICLE = """
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
def parse_args() -> argparse.Namespace:
|
30 |
+
parser = argparse.ArgumentParser()
|
31 |
+
parser.add_argument('--device', type=str, default='cpu')
|
32 |
+
parser.add_argument('--theme', type=str)
|
33 |
+
parser.add_argument('--live', action='store_true')
|
34 |
+
parser.add_argument('--share', action='store_true')
|
35 |
+
parser.add_argument('--port', type=int)
|
36 |
+
parser.add_argument('--disable-queue',
|
37 |
+
dest='enable_queue',
|
38 |
+
action='store_false')
|
39 |
+
parser.add_argument('--allow-flagging', type=str, default='never')
|
40 |
+
parser.add_argument('--allow-screenshot', action='store_true')
|
41 |
+
return parser.parse_args()
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def run(
|
46 |
+
image
|
47 |
+
) -> tuple[PIL.Image.Image]:
|
48 |
+
|
49 |
+
|
50 |
+
return PIL.Image.open(image.name)
|
51 |
+
|
52 |
+
|
53 |
+
def main():
|
54 |
+
gr.close_all()
|
55 |
+
|
56 |
+
args = parse_args()
|
57 |
+
|
58 |
+
func = functools.partial(run)
|
59 |
+
func = functools.update_wrapper(func, run)
|
60 |
+
|
61 |
+
|
62 |
+
gr.Interface(
|
63 |
+
func,
|
64 |
+
[
|
65 |
+
gr.inputs.Image(type='file', label='Input Image'),
|
66 |
+
],
|
67 |
+
[
|
68 |
+
gr.outputs.Image(
|
69 |
+
type='pil',
|
70 |
+
label='Result'),
|
71 |
+
],
|
72 |
+
#examples=examples,
|
73 |
+
theme=args.theme,
|
74 |
+
title=TITLE,
|
75 |
+
description=DESCRIPTION,
|
76 |
+
article=ARTICLE,
|
77 |
+
allow_screenshot=args.allow_screenshot,
|
78 |
+
allow_flagging=args.allow_flagging,
|
79 |
+
live=args.live,
|
80 |
+
).launch(
|
81 |
+
enable_queue=args.enable_queue,
|
82 |
+
server_port=args.port,
|
83 |
+
share=args.share,
|
84 |
+
)
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == '__main__':
|
88 |
+
main()
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python-headless==4.5.5.62
|
2 |
+
Pillow==9.0.1
|
3 |
+
scipy==1.7.3
|
4 |
+
tensorflow-gpu==1.14.0
|
ugatit/UGATIT.py
ADDED
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ugatit.ops import *
|
2 |
+
from utils import *
|
3 |
+
from glob import glob
|
4 |
+
import time
|
5 |
+
from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
class UGATIT(object) :
|
9 |
+
def __init__(self, sess, args):
|
10 |
+
self.light = args.light
|
11 |
+
|
12 |
+
if self.light :
|
13 |
+
self.model_name = 'UGATIT_light'
|
14 |
+
else :
|
15 |
+
self.model_name = 'UGATIT'
|
16 |
+
|
17 |
+
self.sess = sess
|
18 |
+
self.phase = args.phase
|
19 |
+
self.checkpoint_dir = args.checkpoint_dir
|
20 |
+
self.result_dir = args.result_dir
|
21 |
+
self.log_dir = args.log_dir
|
22 |
+
self.dataset_name = args.dataset
|
23 |
+
self.augment_flag = args.augment_flag
|
24 |
+
|
25 |
+
self.epoch = args.epoch
|
26 |
+
self.iteration = args.iteration
|
27 |
+
self.decay_flag = args.decay_flag
|
28 |
+
self.decay_epoch = args.decay_epoch
|
29 |
+
|
30 |
+
self.gan_type = args.gan_type
|
31 |
+
|
32 |
+
self.batch_size = args.batch_size
|
33 |
+
self.print_freq = args.print_freq
|
34 |
+
self.save_freq = args.save_freq
|
35 |
+
|
36 |
+
self.init_lr = args.lr
|
37 |
+
self.ch = args.ch
|
38 |
+
|
39 |
+
""" Weight """
|
40 |
+
self.adv_weight = args.adv_weight
|
41 |
+
self.cycle_weight = args.cycle_weight
|
42 |
+
self.identity_weight = args.identity_weight
|
43 |
+
self.cam_weight = args.cam_weight
|
44 |
+
self.ld = args.GP_ld
|
45 |
+
self.smoothing = args.smoothing
|
46 |
+
|
47 |
+
""" Generator """
|
48 |
+
self.n_res = args.n_res
|
49 |
+
|
50 |
+
""" Discriminator """
|
51 |
+
self.n_dis = args.n_dis
|
52 |
+
self.n_critic = args.n_critic
|
53 |
+
self.sn = args.sn
|
54 |
+
|
55 |
+
self.img_size = args.img_size
|
56 |
+
self.img_ch = args.img_ch
|
57 |
+
|
58 |
+
|
59 |
+
self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
|
60 |
+
check_folder(self.sample_dir)
|
61 |
+
|
62 |
+
# self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size
|
63 |
+
self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
|
64 |
+
self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB'))
|
65 |
+
self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset))
|
66 |
+
|
67 |
+
print()
|
68 |
+
|
69 |
+
print("##### Information #####")
|
70 |
+
print("# light : ", self.light)
|
71 |
+
print("# gan type : ", self.gan_type)
|
72 |
+
print("# dataset : ", self.dataset_name)
|
73 |
+
print("# max dataset number : ", self.dataset_num)
|
74 |
+
print("# batch_size : ", self.batch_size)
|
75 |
+
print("# epoch : ", self.epoch)
|
76 |
+
print("# iteration per epoch : ", self.iteration)
|
77 |
+
print("# smoothing : ", self.smoothing)
|
78 |
+
|
79 |
+
print()
|
80 |
+
|
81 |
+
print("##### Generator #####")
|
82 |
+
print("# residual blocks : ", self.n_res)
|
83 |
+
|
84 |
+
print()
|
85 |
+
|
86 |
+
print("##### Discriminator #####")
|
87 |
+
print("# discriminator layer : ", self.n_dis)
|
88 |
+
print("# the number of critic : ", self.n_critic)
|
89 |
+
print("# spectral normalization : ", self.sn)
|
90 |
+
|
91 |
+
print()
|
92 |
+
|
93 |
+
print("##### Weight #####")
|
94 |
+
print("# adv_weight : ", self.adv_weight)
|
95 |
+
print("# cycle_weight : ", self.cycle_weight)
|
96 |
+
print("# identity_weight : ", self.identity_weight)
|
97 |
+
print("# cam_weight : ", self.cam_weight)
|
98 |
+
|
99 |
+
##################################################################################
|
100 |
+
# Generator
|
101 |
+
##################################################################################
|
102 |
+
|
103 |
+
def generator(self, x_init, reuse=False, scope="generator"):
|
104 |
+
channel = self.ch
|
105 |
+
with tf.variable_scope(scope, reuse=reuse) :
|
106 |
+
x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv')
|
107 |
+
x = instance_norm(x, scope='ins_norm')
|
108 |
+
x = relu(x)
|
109 |
+
|
110 |
+
# Down-Sampling
|
111 |
+
for i in range(2) :
|
112 |
+
x = conv(x, channel*2, kernel=3, stride=2, pad=1, pad_type='reflect', scope='conv_'+str(i))
|
113 |
+
x = instance_norm(x, scope='ins_norm_'+str(i))
|
114 |
+
x = relu(x)
|
115 |
+
|
116 |
+
channel = channel * 2
|
117 |
+
|
118 |
+
# Down-Sampling Bottleneck
|
119 |
+
for i in range(self.n_res):
|
120 |
+
x = resblock(x, channel, scope='resblock_' + str(i))
|
121 |
+
|
122 |
+
|
123 |
+
# Class Activation Map
|
124 |
+
cam_x = global_avg_pooling(x)
|
125 |
+
cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, scope='CAM_logit')
|
126 |
+
x_gap = tf.multiply(x, cam_x_weight)
|
127 |
+
|
128 |
+
cam_x = global_max_pooling(x)
|
129 |
+
cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, reuse=True, scope='CAM_logit')
|
130 |
+
x_gmp = tf.multiply(x, cam_x_weight)
|
131 |
+
|
132 |
+
|
133 |
+
cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
|
134 |
+
x = tf.concat([x_gap, x_gmp], axis=-1)
|
135 |
+
|
136 |
+
x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
|
137 |
+
x = relu(x)
|
138 |
+
|
139 |
+
heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
|
140 |
+
|
141 |
+
# Gamma, Beta block
|
142 |
+
gamma, beta = self.MLP(x, reuse=reuse)
|
143 |
+
|
144 |
+
# Up-Sampling Bottleneck
|
145 |
+
for i in range(self.n_res):
|
146 |
+
x = adaptive_ins_layer_resblock(x, channel, gamma, beta, smoothing=self.smoothing, scope='adaptive_resblock' + str(i))
|
147 |
+
|
148 |
+
# Up-Sampling
|
149 |
+
for i in range(2) :
|
150 |
+
x = up_sample(x, scale_factor=2)
|
151 |
+
x = conv(x, channel//2, kernel=3, stride=1, pad=1, pad_type='reflect', scope='up_conv_'+str(i))
|
152 |
+
x = layer_instance_norm(x, scope='layer_ins_norm_'+str(i))
|
153 |
+
x = relu(x)
|
154 |
+
|
155 |
+
channel = channel // 2
|
156 |
+
|
157 |
+
|
158 |
+
x = conv(x, channels=3, kernel=7, stride=1, pad=3, pad_type='reflect', scope='G_logit')
|
159 |
+
x = tanh(x)
|
160 |
+
|
161 |
+
return x, cam_logit, heatmap
|
162 |
+
|
163 |
+
def MLP(self, x, use_bias=True, reuse=False, scope='MLP'):
|
164 |
+
channel = self.ch * self.n_res
|
165 |
+
|
166 |
+
if self.light :
|
167 |
+
x = global_avg_pooling(x)
|
168 |
+
|
169 |
+
with tf.variable_scope(scope, reuse=reuse):
|
170 |
+
for i in range(2) :
|
171 |
+
x = fully_connected(x, channel, use_bias, scope='linear_' + str(i))
|
172 |
+
x = relu(x)
|
173 |
+
|
174 |
+
|
175 |
+
gamma = fully_connected(x, channel, use_bias, scope='gamma')
|
176 |
+
beta = fully_connected(x, channel, use_bias, scope='beta')
|
177 |
+
|
178 |
+
gamma = tf.reshape(gamma, shape=[self.batch_size, 1, 1, channel])
|
179 |
+
beta = tf.reshape(beta, shape=[self.batch_size, 1, 1, channel])
|
180 |
+
|
181 |
+
return gamma, beta
|
182 |
+
|
183 |
+
##################################################################################
|
184 |
+
# Discriminator
|
185 |
+
##################################################################################
|
186 |
+
|
187 |
+
def discriminator(self, x_init, reuse=False, scope="discriminator"):
|
188 |
+
D_logit = []
|
189 |
+
D_CAM_logit = []
|
190 |
+
with tf.variable_scope(scope, reuse=reuse) :
|
191 |
+
local_x, local_cam, local_heatmap = self.discriminator_local(x_init, reuse=reuse, scope='local')
|
192 |
+
global_x, global_cam, global_heatmap = self.discriminator_global(x_init, reuse=reuse, scope='global')
|
193 |
+
|
194 |
+
D_logit.extend([local_x, global_x])
|
195 |
+
D_CAM_logit.extend([local_cam, global_cam])
|
196 |
+
|
197 |
+
return D_logit, D_CAM_logit, local_heatmap, global_heatmap
|
198 |
+
|
199 |
+
def discriminator_global(self, x_init, reuse=False, scope='discriminator_global'):
|
200 |
+
with tf.variable_scope(scope, reuse=reuse):
|
201 |
+
channel = self.ch
|
202 |
+
x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0')
|
203 |
+
x = lrelu(x, 0.2)
|
204 |
+
|
205 |
+
for i in range(1, self.n_dis - 1):
|
206 |
+
x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i))
|
207 |
+
x = lrelu(x, 0.2)
|
208 |
+
|
209 |
+
channel = channel * 2
|
210 |
+
|
211 |
+
x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last')
|
212 |
+
x = lrelu(x, 0.2)
|
213 |
+
|
214 |
+
channel = channel * 2
|
215 |
+
|
216 |
+
cam_x = global_avg_pooling(x)
|
217 |
+
cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit')
|
218 |
+
x_gap = tf.multiply(x, cam_x_weight)
|
219 |
+
|
220 |
+
cam_x = global_max_pooling(x)
|
221 |
+
cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit')
|
222 |
+
x_gmp = tf.multiply(x, cam_x_weight)
|
223 |
+
|
224 |
+
cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
|
225 |
+
x = tf.concat([x_gap, x_gmp], axis=-1)
|
226 |
+
|
227 |
+
x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
|
228 |
+
x = lrelu(x, 0.2)
|
229 |
+
|
230 |
+
heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
|
231 |
+
|
232 |
+
|
233 |
+
x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit')
|
234 |
+
|
235 |
+
return x, cam_logit, heatmap
|
236 |
+
|
237 |
+
def discriminator_local(self, x_init, reuse=False, scope='discriminator_local'):
|
238 |
+
with tf.variable_scope(scope, reuse=reuse) :
|
239 |
+
channel = self.ch
|
240 |
+
x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0')
|
241 |
+
x = lrelu(x, 0.2)
|
242 |
+
|
243 |
+
for i in range(1, self.n_dis - 2 - 1):
|
244 |
+
x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i))
|
245 |
+
x = lrelu(x, 0.2)
|
246 |
+
|
247 |
+
channel = channel * 2
|
248 |
+
|
249 |
+
x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last')
|
250 |
+
x = lrelu(x, 0.2)
|
251 |
+
|
252 |
+
channel = channel * 2
|
253 |
+
|
254 |
+
cam_x = global_avg_pooling(x)
|
255 |
+
cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit')
|
256 |
+
x_gap = tf.multiply(x, cam_x_weight)
|
257 |
+
|
258 |
+
cam_x = global_max_pooling(x)
|
259 |
+
cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit')
|
260 |
+
x_gmp = tf.multiply(x, cam_x_weight)
|
261 |
+
|
262 |
+
cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
|
263 |
+
x = tf.concat([x_gap, x_gmp], axis=-1)
|
264 |
+
|
265 |
+
x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
|
266 |
+
x = lrelu(x, 0.2)
|
267 |
+
|
268 |
+
heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
|
269 |
+
|
270 |
+
x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit')
|
271 |
+
|
272 |
+
return x, cam_logit, heatmap
|
273 |
+
|
274 |
+
##################################################################################
|
275 |
+
# Model
|
276 |
+
##################################################################################
|
277 |
+
|
278 |
+
def generate_a2b(self, x_A, reuse=False):
|
279 |
+
out, cam, _ = self.generator(x_A, reuse=reuse, scope="generator_B")
|
280 |
+
|
281 |
+
return out, cam
|
282 |
+
|
283 |
+
def generate_b2a(self, x_B, reuse=False):
|
284 |
+
out, cam, _ = self.generator(x_B, reuse=reuse, scope="generator_A")
|
285 |
+
|
286 |
+
return out, cam
|
287 |
+
|
288 |
+
def discriminate_real(self, x_A, x_B):
|
289 |
+
real_A_logit, real_A_cam_logit, _, _ = self.discriminator(x_A, scope="discriminator_A")
|
290 |
+
real_B_logit, real_B_cam_logit, _, _ = self.discriminator(x_B, scope="discriminator_B")
|
291 |
+
|
292 |
+
return real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit
|
293 |
+
|
294 |
+
def discriminate_fake(self, x_ba, x_ab):
|
295 |
+
fake_A_logit, fake_A_cam_logit, _, _ = self.discriminator(x_ba, reuse=True, scope="discriminator_A")
|
296 |
+
fake_B_logit, fake_B_cam_logit, _, _ = self.discriminator(x_ab, reuse=True, scope="discriminator_B")
|
297 |
+
|
298 |
+
return fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit
|
299 |
+
|
300 |
+
def gradient_panalty(self, real, fake, scope="discriminator_A"):
|
301 |
+
if self.gan_type.__contains__('dragan'):
|
302 |
+
eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.)
|
303 |
+
_, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
|
304 |
+
x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
|
305 |
+
|
306 |
+
fake = real + 0.5 * x_std * eps
|
307 |
+
|
308 |
+
alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
|
309 |
+
interpolated = real + alpha * (fake - real)
|
310 |
+
|
311 |
+
logit, cam_logit, _, _ = self.discriminator(interpolated, reuse=True, scope=scope)
|
312 |
+
|
313 |
+
|
314 |
+
GP = []
|
315 |
+
cam_GP = []
|
316 |
+
|
317 |
+
for i in range(2) :
|
318 |
+
grad = tf.gradients(logit[i], interpolated)[0] # gradient of D(interpolated)
|
319 |
+
grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
|
320 |
+
|
321 |
+
# WGAN - LP
|
322 |
+
if self.gan_type == 'wgan-lp' :
|
323 |
+
GP.append(self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))))
|
324 |
+
|
325 |
+
elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
|
326 |
+
GP.append(self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)))
|
327 |
+
|
328 |
+
for i in range(2) :
|
329 |
+
grad = tf.gradients(cam_logit[i], interpolated)[0] # gradient of D(interpolated)
|
330 |
+
grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
|
331 |
+
|
332 |
+
# WGAN - LP
|
333 |
+
if self.gan_type == 'wgan-lp' :
|
334 |
+
cam_GP.append(self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))))
|
335 |
+
|
336 |
+
elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
|
337 |
+
cam_GP.append(self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)))
|
338 |
+
|
339 |
+
|
340 |
+
return sum(GP), sum(cam_GP)
|
341 |
+
|
342 |
+
def build_model(self):
|
343 |
+
if self.phase == 'train' :
|
344 |
+
self.lr = tf.placeholder(tf.float32, name='learning_rate')
|
345 |
+
|
346 |
+
|
347 |
+
""" Input Image"""
|
348 |
+
Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag)
|
349 |
+
|
350 |
+
trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
|
351 |
+
trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
|
352 |
+
|
353 |
+
|
354 |
+
gpu_device = '/gpu:0'
|
355 |
+
trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))
|
356 |
+
trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))
|
357 |
+
|
358 |
+
|
359 |
+
trainA_iterator = trainA.make_one_shot_iterator()
|
360 |
+
trainB_iterator = trainB.make_one_shot_iterator()
|
361 |
+
|
362 |
+
self.domain_A = trainA_iterator.get_next()
|
363 |
+
self.domain_B = trainB_iterator.get_next()
|
364 |
+
|
365 |
+
""" Define Generator, Discriminator """
|
366 |
+
x_ab, cam_ab = self.generate_a2b(self.domain_A) # real a
|
367 |
+
x_ba, cam_ba = self.generate_b2a(self.domain_B) # real b
|
368 |
+
|
369 |
+
x_aba, _ = self.generate_b2a(x_ab, reuse=True) # real b
|
370 |
+
x_bab, _ = self.generate_a2b(x_ba, reuse=True) # real a
|
371 |
+
|
372 |
+
x_aa, cam_aa = self.generate_b2a(self.domain_A, reuse=True) # fake b
|
373 |
+
x_bb, cam_bb = self.generate_a2b(self.domain_B, reuse=True) # fake a
|
374 |
+
|
375 |
+
real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real(self.domain_A, self.domain_B)
|
376 |
+
fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake(x_ba, x_ab)
|
377 |
+
|
378 |
+
|
379 |
+
""" Define Loss """
|
380 |
+
if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :
|
381 |
+
GP_A, GP_CAM_A = self.gradient_panalty(real=self.domain_A, fake=x_ba, scope="discriminator_A")
|
382 |
+
GP_B, GP_CAM_B = self.gradient_panalty(real=self.domain_B, fake=x_ab, scope="discriminator_B")
|
383 |
+
else :
|
384 |
+
GP_A, GP_CAM_A = 0, 0
|
385 |
+
GP_B, GP_CAM_B = 0, 0
|
386 |
+
|
387 |
+
G_ad_loss_A = (generator_loss(self.gan_type, fake_A_logit) + generator_loss(self.gan_type, fake_A_cam_logit))
|
388 |
+
G_ad_loss_B = (generator_loss(self.gan_type, fake_B_logit) + generator_loss(self.gan_type, fake_B_cam_logit))
|
389 |
+
|
390 |
+
D_ad_loss_A = (discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) + discriminator_loss(self.gan_type, real_A_cam_logit, fake_A_cam_logit) + GP_A + GP_CAM_A)
|
391 |
+
D_ad_loss_B = (discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) + discriminator_loss(self.gan_type, real_B_cam_logit, fake_B_cam_logit) + GP_B + GP_CAM_B)
|
392 |
+
|
393 |
+
reconstruction_A = L1_loss(x_aba, self.domain_A) # reconstruction
|
394 |
+
reconstruction_B = L1_loss(x_bab, self.domain_B) # reconstruction
|
395 |
+
|
396 |
+
identity_A = L1_loss(x_aa, self.domain_A)
|
397 |
+
identity_B = L1_loss(x_bb, self.domain_B)
|
398 |
+
|
399 |
+
cam_A = cam_loss(source=cam_ba, non_source=cam_aa)
|
400 |
+
cam_B = cam_loss(source=cam_ab, non_source=cam_bb)
|
401 |
+
|
402 |
+
Generator_A_gan = self.adv_weight * G_ad_loss_A
|
403 |
+
Generator_A_cycle = self.cycle_weight * reconstruction_B
|
404 |
+
Generator_A_identity = self.identity_weight * identity_A
|
405 |
+
Generator_A_cam = self.cam_weight * cam_A
|
406 |
+
|
407 |
+
|
408 |
+
Generator_B_gan = self.adv_weight * G_ad_loss_B
|
409 |
+
Generator_B_cycle = self.cycle_weight * reconstruction_A
|
410 |
+
Generator_B_identity = self.identity_weight * identity_B
|
411 |
+
Generator_B_cam = self.cam_weight * cam_B
|
412 |
+
|
413 |
+
|
414 |
+
Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam
|
415 |
+
Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam
|
416 |
+
|
417 |
+
|
418 |
+
Discriminator_A_loss = self.adv_weight * D_ad_loss_A
|
419 |
+
Discriminator_B_loss = self.adv_weight * D_ad_loss_B
|
420 |
+
|
421 |
+
self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss('generator')
|
422 |
+
self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss('discriminator')
|
423 |
+
|
424 |
+
|
425 |
+
""" Result Image """
|
426 |
+
self.fake_A = x_ba
|
427 |
+
self.fake_B = x_ab
|
428 |
+
|
429 |
+
self.real_A = self.domain_A
|
430 |
+
self.real_B = self.domain_B
|
431 |
+
|
432 |
+
|
433 |
+
""" Training """
|
434 |
+
t_vars = tf.trainable_variables()
|
435 |
+
G_vars = [var for var in t_vars if 'generator' in var.name]
|
436 |
+
D_vars = [var for var in t_vars if 'discriminator' in var.name]
|
437 |
+
|
438 |
+
self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
|
439 |
+
self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)
|
440 |
+
|
441 |
+
|
442 |
+
"""" Summary """
|
443 |
+
self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
|
444 |
+
self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)
|
445 |
+
|
446 |
+
self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
|
447 |
+
self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan)
|
448 |
+
self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle)
|
449 |
+
self.G_A_identity = tf.summary.scalar("G_A_identity", Generator_A_identity)
|
450 |
+
self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam)
|
451 |
+
|
452 |
+
self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
|
453 |
+
self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan)
|
454 |
+
self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle)
|
455 |
+
self.G_B_identity = tf.summary.scalar("G_B_identity", Generator_B_identity)
|
456 |
+
self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam)
|
457 |
+
|
458 |
+
self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
|
459 |
+
self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)
|
460 |
+
|
461 |
+
self.rho_var = []
|
462 |
+
for var in tf.trainable_variables():
|
463 |
+
if 'rho' in var.name:
|
464 |
+
self.rho_var.append(tf.summary.histogram(var.name, var))
|
465 |
+
self.rho_var.append(tf.summary.scalar(var.name + "_min", tf.reduce_min(var)))
|
466 |
+
self.rho_var.append(tf.summary.scalar(var.name + "_max", tf.reduce_max(var)))
|
467 |
+
self.rho_var.append(tf.summary.scalar(var.name + "_mean", tf.reduce_mean(var)))
|
468 |
+
|
469 |
+
g_summary_list = [self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity, self.G_A_cam,
|
470 |
+
self.G_B_loss, self.G_B_gan, self.G_B_cycle, self.G_B_identity, self.G_B_cam,
|
471 |
+
self.all_G_loss]
|
472 |
+
|
473 |
+
g_summary_list.extend(self.rho_var)
|
474 |
+
d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss]
|
475 |
+
|
476 |
+
self.G_loss = tf.summary.merge(g_summary_list)
|
477 |
+
self.D_loss = tf.summary.merge(d_summary_list)
|
478 |
+
|
479 |
+
else :
|
480 |
+
""" Test """
|
481 |
+
self.test_domain_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A')
|
482 |
+
self.test_domain_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B')
|
483 |
+
|
484 |
+
|
485 |
+
self.test_fake_B, _ = self.generate_a2b(self.test_domain_A)
|
486 |
+
self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
|
487 |
+
|
488 |
+
|
489 |
+
def train(self):
|
490 |
+
# initialize all variables
|
491 |
+
tf.global_variables_initializer().run()
|
492 |
+
|
493 |
+
# saver to save model
|
494 |
+
self.saver = tf.train.Saver()
|
495 |
+
|
496 |
+
# summary writer
|
497 |
+
self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
|
498 |
+
|
499 |
+
|
500 |
+
# restore check-point if it exits
|
501 |
+
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
|
502 |
+
if could_load:
|
503 |
+
start_epoch = (int)(checkpoint_counter / self.iteration)
|
504 |
+
start_batch_id = checkpoint_counter - start_epoch * self.iteration
|
505 |
+
counter = checkpoint_counter
|
506 |
+
print(" [*] Load SUCCESS")
|
507 |
+
else:
|
508 |
+
start_epoch = 0
|
509 |
+
start_batch_id = 0
|
510 |
+
counter = 1
|
511 |
+
print(" [!] Load failed...")
|
512 |
+
|
513 |
+
# loop for epoch
|
514 |
+
start_time = time.time()
|
515 |
+
past_g_loss = -1.
|
516 |
+
lr = self.init_lr
|
517 |
+
for epoch in range(start_epoch, self.epoch):
|
518 |
+
# lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch)
|
519 |
+
if self.decay_flag :
|
520 |
+
#lr = self.init_lr * pow(0.5, epoch // self.decay_epoch)
|
521 |
+
lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch)
|
522 |
+
for idx in range(start_batch_id, self.iteration):
|
523 |
+
train_feed_dict = {
|
524 |
+
self.lr : lr
|
525 |
+
}
|
526 |
+
|
527 |
+
# Update D
|
528 |
+
_, d_loss, summary_str = self.sess.run([self.D_optim,
|
529 |
+
self.Discriminator_loss, self.D_loss], feed_dict = train_feed_dict)
|
530 |
+
self.writer.add_summary(summary_str, counter)
|
531 |
+
|
532 |
+
# Update G
|
533 |
+
g_loss = None
|
534 |
+
if (counter - 1) % self.n_critic == 0 :
|
535 |
+
batch_A_images, batch_B_images, fake_A, fake_B, _, g_loss, summary_str = self.sess.run([self.real_A, self.real_B,
|
536 |
+
self.fake_A, self.fake_B,
|
537 |
+
self.G_optim,
|
538 |
+
self.Generator_loss, self.G_loss], feed_dict = train_feed_dict)
|
539 |
+
self.writer.add_summary(summary_str, counter)
|
540 |
+
past_g_loss = g_loss
|
541 |
+
|
542 |
+
# display training status
|
543 |
+
counter += 1
|
544 |
+
if g_loss == None :
|
545 |
+
g_loss = past_g_loss
|
546 |
+
print("Epoch: [%2d] [%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
|
547 |
+
|
548 |
+
if np.mod(idx+1, self.print_freq) == 0 :
|
549 |
+
save_images(batch_A_images, [self.batch_size, 1],
|
550 |
+
'./{}/real_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
|
551 |
+
# save_images(batch_B_images, [self.batch_size, 1],
|
552 |
+
# './{}/real_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
|
553 |
+
|
554 |
+
# save_images(fake_A, [self.batch_size, 1],
|
555 |
+
# './{}/fake_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
|
556 |
+
save_images(fake_B, [self.batch_size, 1],
|
557 |
+
'./{}/fake_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
|
558 |
+
|
559 |
+
if np.mod(idx + 1, self.save_freq) == 0:
|
560 |
+
self.save(self.checkpoint_dir, counter)
|
561 |
+
|
562 |
+
|
563 |
+
|
564 |
+
# After an epoch, start_batch_id is set to zero
|
565 |
+
# non-zero value is only for the first epoch after loading pre-trained model
|
566 |
+
start_batch_id = 0
|
567 |
+
|
568 |
+
# save model for final step
|
569 |
+
self.save(self.checkpoint_dir, counter)
|
570 |
+
|
571 |
+
@property
|
572 |
+
def model_dir(self):
|
573 |
+
n_res = str(self.n_res) + 'resblock'
|
574 |
+
n_dis = str(self.n_dis) + 'dis'
|
575 |
+
|
576 |
+
if self.smoothing :
|
577 |
+
smoothing = '_smoothing'
|
578 |
+
else :
|
579 |
+
smoothing = ''
|
580 |
+
|
581 |
+
if self.sn :
|
582 |
+
sn = '_sn'
|
583 |
+
else :
|
584 |
+
sn = ''
|
585 |
+
|
586 |
+
return "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}{}{}".format(self.model_name, self.dataset_name,
|
587 |
+
self.gan_type, n_res, n_dis,
|
588 |
+
self.n_critic,
|
589 |
+
self.adv_weight, self.cycle_weight, self.identity_weight, self.cam_weight, sn, smoothing)
|
590 |
+
|
591 |
+
def save(self, checkpoint_dir, step):
|
592 |
+
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
|
593 |
+
|
594 |
+
if not os.path.exists(checkpoint_dir):
|
595 |
+
os.makedirs(checkpoint_dir)
|
596 |
+
|
597 |
+
self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
|
598 |
+
|
599 |
+
def load(self, checkpoint_dir):
|
600 |
+
print(" [*] Reading checkpoints...")
|
601 |
+
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
|
602 |
+
|
603 |
+
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
|
604 |
+
if ckpt and ckpt.model_checkpoint_path:
|
605 |
+
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
|
606 |
+
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
|
607 |
+
counter = int(ckpt_name.split('-')[-1])
|
608 |
+
print(" [*] Success to read {}".format(ckpt_name))
|
609 |
+
return True, counter
|
610 |
+
else:
|
611 |
+
print(" [*] Failed to find a checkpoint")
|
612 |
+
return False, 0
|
613 |
+
|
614 |
+
def test(self):
|
615 |
+
tf.global_variables_initializer().run()
|
616 |
+
test_A_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testA'))
|
617 |
+
test_B_files = glob('./dataset/{}/*.*'.format(self.dataset_name + '/testB'))
|
618 |
+
|
619 |
+
self.saver = tf.train.Saver()
|
620 |
+
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
|
621 |
+
self.result_dir = os.path.join(self.result_dir, self.model_dir)
|
622 |
+
check_folder(self.result_dir)
|
623 |
+
|
624 |
+
if could_load :
|
625 |
+
print(" [*] Load SUCCESS")
|
626 |
+
else :
|
627 |
+
print(" [!] Load failed...")
|
628 |
+
|
629 |
+
# write html for visual comparison
|
630 |
+
index_path = os.path.join(self.result_dir, 'index.html')
|
631 |
+
index = open(index_path, 'w')
|
632 |
+
index.write("<html><body><table><tr>")
|
633 |
+
index.write("<th>name</th><th>input</th><th>output</th></tr>")
|
634 |
+
|
635 |
+
for sample_file in test_A_files : # A -> B
|
636 |
+
print('Processing A image: ' + sample_file)
|
637 |
+
sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
|
638 |
+
image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
|
639 |
+
|
640 |
+
fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_domain_A : sample_image})
|
641 |
+
save_images(fake_img, [1, 1], image_path)
|
642 |
+
|
643 |
+
index.write("<td>%s</td>" % os.path.basename(image_path))
|
644 |
+
|
645 |
+
index.write("<td><img src='%s' width='%d' height='%d'></td>" % (sample_file if os.path.isabs(sample_file) else (
|
646 |
+
'../..' + os.path.sep + sample_file), self.img_size, self.img_size))
|
647 |
+
index.write("<td><img src='%s' width='%d' height='%d'></td>" % (image_path if os.path.isabs(image_path) else (
|
648 |
+
'../..' + os.path.sep + image_path), self.img_size, self.img_size))
|
649 |
+
index.write("</tr>")
|
650 |
+
|
651 |
+
for sample_file in test_B_files : # B -> A
|
652 |
+
print('Processing B image: ' + sample_file)
|
653 |
+
sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
|
654 |
+
image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
|
655 |
+
|
656 |
+
fake_img = self.sess.run(self.test_fake_A, feed_dict = {self.test_domain_B : sample_image})
|
657 |
+
|
658 |
+
save_images(fake_img, [1, 1], image_path)
|
659 |
+
index.write("<td>%s</td>" % os.path.basename(image_path))
|
660 |
+
index.write("<td><img src='%s' width='%d' height='%d'></td>" % (sample_file if os.path.isabs(sample_file) else (
|
661 |
+
'../..' + os.path.sep + sample_file), self.img_size, self.img_size))
|
662 |
+
index.write("<td><img src='%s' width='%d' height='%d'></td>" % (image_path if os.path.isabs(image_path) else (
|
663 |
+
'../..' + os.path.sep + image_path), self.img_size, self.img_size))
|
664 |
+
index.write("</tr>")
|
665 |
+
index.close()
|
ugatit/main.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ugatit.UGATIT import UGATIT
|
2 |
+
import argparse
|
3 |
+
from ugatit.utils import *
|
4 |
+
|
5 |
+
"""parsing and configuration"""
|
6 |
+
|
7 |
+
def parse_args():
|
8 |
+
desc = "Tensorflow implementation of U-GAT-IT"
|
9 |
+
parser = argparse.ArgumentParser(description=desc)
|
10 |
+
parser.add_argument('--phase', type=str, default='train', help='[train / test]')
|
11 |
+
parser.add_argument('--light', type=str2bool, default=False, help='[U-GAT-IT full version / U-GAT-IT light version]')
|
12 |
+
parser.add_argument('--dataset', type=str, default='selfie2anime', help='dataset_name')
|
13 |
+
|
14 |
+
parser.add_argument('--epoch', type=int, default=100, help='The number of epochs to run')
|
15 |
+
parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
|
16 |
+
parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
|
17 |
+
parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq')
|
18 |
+
parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq')
|
19 |
+
parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag')
|
20 |
+
parser.add_argument('--decay_epoch', type=int, default=50, help='decay epoch')
|
21 |
+
|
22 |
+
parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
|
23 |
+
parser.add_argument('--GP_ld', type=int, default=10, help='The gradient penalty lambda')
|
24 |
+
parser.add_argument('--adv_weight', type=int, default=1, help='Weight about GAN')
|
25 |
+
parser.add_argument('--cycle_weight', type=int, default=10, help='Weight about Cycle')
|
26 |
+
parser.add_argument('--identity_weight', type=int, default=10, help='Weight about Identity')
|
27 |
+
parser.add_argument('--cam_weight', type=int, default=1000, help='Weight about CAM')
|
28 |
+
parser.add_argument('--gan_type', type=str, default='lsgan', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]')
|
29 |
+
|
30 |
+
parser.add_argument('--smoothing', type=str2bool, default=True, help='AdaLIN smoothing effect')
|
31 |
+
|
32 |
+
parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
|
33 |
+
parser.add_argument('--n_res', type=int, default=4, help='The number of resblock')
|
34 |
+
parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')
|
35 |
+
parser.add_argument('--n_critic', type=int, default=1, help='The number of critic')
|
36 |
+
parser.add_argument('--sn', type=str2bool, default=True, help='using spectral norm')
|
37 |
+
|
38 |
+
parser.add_argument('--img_size', type=int, default=256, help='The size of image')
|
39 |
+
parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
|
40 |
+
parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not')
|
41 |
+
|
42 |
+
parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
|
43 |
+
help='Directory name to save the checkpoints')
|
44 |
+
parser.add_argument('--result_dir', type=str, default='results',
|
45 |
+
help='Directory name to save the generated images')
|
46 |
+
parser.add_argument('--log_dir', type=str, default='logs',
|
47 |
+
help='Directory name to save training logs')
|
48 |
+
parser.add_argument('--sample_dir', type=str, default='samples',
|
49 |
+
help='Directory name to save the samples on training')
|
50 |
+
|
51 |
+
return check_args(parser.parse_args())
|
52 |
+
|
53 |
+
"""checking arguments"""
|
54 |
+
def check_args(args):
|
55 |
+
# --checkpoint_dir
|
56 |
+
check_folder(args.checkpoint_dir)
|
57 |
+
|
58 |
+
# --result_dir
|
59 |
+
check_folder(args.result_dir)
|
60 |
+
|
61 |
+
# --result_dir
|
62 |
+
check_folder(args.log_dir)
|
63 |
+
|
64 |
+
# --sample_dir
|
65 |
+
check_folder(args.sample_dir)
|
66 |
+
|
67 |
+
# --epoch
|
68 |
+
try:
|
69 |
+
assert args.epoch >= 1
|
70 |
+
except:
|
71 |
+
print('number of epochs must be larger than or equal to one')
|
72 |
+
|
73 |
+
# --batch_size
|
74 |
+
try:
|
75 |
+
assert args.batch_size >= 1
|
76 |
+
except:
|
77 |
+
print('batch size must be larger than or equal to one')
|
78 |
+
return args
|
79 |
+
|
80 |
+
"""main"""
|
81 |
+
def main():
|
82 |
+
# parse arguments
|
83 |
+
args = parse_args()
|
84 |
+
if args is None:
|
85 |
+
exit()
|
86 |
+
|
87 |
+
# open session
|
88 |
+
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
|
89 |
+
gan = UGATIT(sess, args)
|
90 |
+
|
91 |
+
# build graph
|
92 |
+
gan.build_model()
|
93 |
+
|
94 |
+
# show network architecture
|
95 |
+
show_all_variables()
|
96 |
+
|
97 |
+
if args.phase == 'train' :
|
98 |
+
gan.train()
|
99 |
+
print(" [*] Training finished!")
|
100 |
+
|
101 |
+
if args.phase == 'test' :
|
102 |
+
gan.test()
|
103 |
+
print(" [*] Test finished!")
|
104 |
+
|
105 |
+
if __name__ == '__main__':
|
106 |
+
main()
|
ugatit/ops.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import tensorflow.contrib as tf_contrib
|
3 |
+
|
4 |
+
# Xavier : tf_contrib.layers.xavier_initializer()
|
5 |
+
# He : tf_contrib.layers.variance_scaling_initializer()
|
6 |
+
# Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
|
7 |
+
# l2_decay : tf_contrib.layers.l2_regularizer(0.0001)
|
8 |
+
|
9 |
+
weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
|
10 |
+
weight_regularizer = tf_contrib.layers.l2_regularizer(scale=0.0001)
|
11 |
+
|
12 |
+
##################################################################################
|
13 |
+
# Layer
|
14 |
+
##################################################################################
|
15 |
+
|
16 |
+
def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
|
17 |
+
with tf.variable_scope(scope):
|
18 |
+
if pad > 0 :
|
19 |
+
if (kernel - stride) % 2 == 0:
|
20 |
+
pad_top = pad
|
21 |
+
pad_bottom = pad
|
22 |
+
pad_left = pad
|
23 |
+
pad_right = pad
|
24 |
+
|
25 |
+
else:
|
26 |
+
pad_top = pad
|
27 |
+
pad_bottom = kernel - stride - pad_top
|
28 |
+
pad_left = pad
|
29 |
+
pad_right = kernel - stride - pad_left
|
30 |
+
|
31 |
+
if pad_type == 'zero':
|
32 |
+
x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
|
33 |
+
if pad_type == 'reflect':
|
34 |
+
x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')
|
35 |
+
|
36 |
+
if sn :
|
37 |
+
w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
|
38 |
+
regularizer=weight_regularizer)
|
39 |
+
x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
|
40 |
+
strides=[1, stride, stride, 1], padding='VALID')
|
41 |
+
if use_bias :
|
42 |
+
bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
|
43 |
+
x = tf.nn.bias_add(x, bias)
|
44 |
+
|
45 |
+
else :
|
46 |
+
x = tf.layers.conv2d(inputs=x, filters=channels,
|
47 |
+
kernel_size=kernel, kernel_initializer=weight_init,
|
48 |
+
kernel_regularizer=weight_regularizer,
|
49 |
+
strides=stride, use_bias=use_bias)
|
50 |
+
|
51 |
+
|
52 |
+
return x
|
53 |
+
|
54 |
+
def fully_connected_with_w(x, use_bias=True, sn=False, reuse=False, scope='linear'):
|
55 |
+
with tf.variable_scope(scope, reuse=reuse):
|
56 |
+
x = flatten(x)
|
57 |
+
bias = 0.0
|
58 |
+
shape = x.get_shape().as_list()
|
59 |
+
channels = shape[-1]
|
60 |
+
|
61 |
+
w = tf.get_variable("kernel", [channels, 1], tf.float32,
|
62 |
+
initializer=weight_init, regularizer=weight_regularizer)
|
63 |
+
|
64 |
+
if sn :
|
65 |
+
w = spectral_norm(w)
|
66 |
+
|
67 |
+
if use_bias :
|
68 |
+
bias = tf.get_variable("bias", [1],
|
69 |
+
initializer=tf.constant_initializer(0.0))
|
70 |
+
|
71 |
+
x = tf.matmul(x, w) + bias
|
72 |
+
else :
|
73 |
+
x = tf.matmul(x, w)
|
74 |
+
|
75 |
+
if use_bias :
|
76 |
+
weights = tf.gather(tf.transpose(tf.nn.bias_add(w, bias)), 0)
|
77 |
+
else :
|
78 |
+
weights = tf.gather(tf.transpose(w), 0)
|
79 |
+
|
80 |
+
return x, weights
|
81 |
+
|
82 |
+
def fully_connected(x, units, use_bias=True, sn=False, scope='linear'):
|
83 |
+
with tf.variable_scope(scope):
|
84 |
+
x = flatten(x)
|
85 |
+
shape = x.get_shape().as_list()
|
86 |
+
channels = shape[-1]
|
87 |
+
|
88 |
+
if sn:
|
89 |
+
w = tf.get_variable("kernel", [channels, units], tf.float32,
|
90 |
+
initializer=weight_init, regularizer=weight_regularizer)
|
91 |
+
if use_bias:
|
92 |
+
bias = tf.get_variable("bias", [units],
|
93 |
+
initializer=tf.constant_initializer(0.0))
|
94 |
+
|
95 |
+
x = tf.matmul(x, spectral_norm(w)) + bias
|
96 |
+
else:
|
97 |
+
x = tf.matmul(x, spectral_norm(w))
|
98 |
+
|
99 |
+
else :
|
100 |
+
x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias)
|
101 |
+
|
102 |
+
return x
|
103 |
+
|
104 |
+
def flatten(x) :
|
105 |
+
return tf.layers.flatten(x)
|
106 |
+
|
107 |
+
##################################################################################
|
108 |
+
# Residual-block
|
109 |
+
##################################################################################
|
110 |
+
|
111 |
+
def resblock(x_init, channels, use_bias=True, scope='resblock_0'):
|
112 |
+
with tf.variable_scope(scope):
|
113 |
+
with tf.variable_scope('res1'):
|
114 |
+
x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
|
115 |
+
x = instance_norm(x)
|
116 |
+
x = relu(x)
|
117 |
+
|
118 |
+
with tf.variable_scope('res2'):
|
119 |
+
x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
|
120 |
+
x = instance_norm(x)
|
121 |
+
|
122 |
+
return x + x_init
|
123 |
+
|
124 |
+
def adaptive_ins_layer_resblock(x_init, channels, gamma, beta, use_bias=True, smoothing=True, scope='adaptive_resblock') :
|
125 |
+
with tf.variable_scope(scope):
|
126 |
+
with tf.variable_scope('res1'):
|
127 |
+
x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
|
128 |
+
x = adaptive_instance_layer_norm(x, gamma, beta, smoothing)
|
129 |
+
x = relu(x)
|
130 |
+
|
131 |
+
with tf.variable_scope('res2'):
|
132 |
+
x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
|
133 |
+
x = adaptive_instance_layer_norm(x, gamma, beta, smoothing)
|
134 |
+
|
135 |
+
return x + x_init
|
136 |
+
|
137 |
+
|
138 |
+
##################################################################################
|
139 |
+
# Sampling
|
140 |
+
##################################################################################
|
141 |
+
|
142 |
+
def up_sample(x, scale_factor=2):
|
143 |
+
_, h, w, _ = x.get_shape().as_list()
|
144 |
+
new_size = [h * scale_factor, w * scale_factor]
|
145 |
+
return tf.image.resize_nearest_neighbor(x, size=new_size)
|
146 |
+
|
147 |
+
|
148 |
+
def global_avg_pooling(x):
|
149 |
+
gap = tf.reduce_mean(x, axis=[1, 2])
|
150 |
+
return gap
|
151 |
+
|
152 |
+
def global_max_pooling(x):
|
153 |
+
gmp = tf.reduce_max(x, axis=[1, 2])
|
154 |
+
return gmp
|
155 |
+
|
156 |
+
##################################################################################
|
157 |
+
# Activation function
|
158 |
+
##################################################################################
|
159 |
+
|
160 |
+
def lrelu(x, alpha=0.01):
|
161 |
+
# pytorch alpha is 0.01
|
162 |
+
return tf.nn.leaky_relu(x, alpha)
|
163 |
+
|
164 |
+
|
165 |
+
def relu(x):
|
166 |
+
return tf.nn.relu(x)
|
167 |
+
|
168 |
+
|
169 |
+
def tanh(x):
|
170 |
+
return tf.tanh(x)
|
171 |
+
|
172 |
+
def sigmoid(x) :
|
173 |
+
return tf.sigmoid(x)
|
174 |
+
|
175 |
+
##################################################################################
|
176 |
+
# Normalization function
|
177 |
+
##################################################################################
|
178 |
+
|
179 |
+
def adaptive_instance_layer_norm(x, gamma, beta, smoothing=True, scope='instance_layer_norm') :
|
180 |
+
with tf.variable_scope(scope):
|
181 |
+
ch = x.shape[-1]
|
182 |
+
eps = 1e-5
|
183 |
+
|
184 |
+
ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
|
185 |
+
x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps))
|
186 |
+
|
187 |
+
ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True)
|
188 |
+
x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps))
|
189 |
+
|
190 |
+
rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(1.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0))
|
191 |
+
|
192 |
+
if smoothing :
|
193 |
+
rho = tf.clip_by_value(rho - tf.constant(0.1), 0.0, 1.0)
|
194 |
+
|
195 |
+
x_hat = rho * x_ins + (1 - rho) * x_ln
|
196 |
+
|
197 |
+
|
198 |
+
x_hat = x_hat * gamma + beta
|
199 |
+
|
200 |
+
return x_hat
|
201 |
+
|
202 |
+
def instance_norm(x, scope='instance_norm'):
|
203 |
+
return tf_contrib.layers.instance_norm(x,
|
204 |
+
epsilon=1e-05,
|
205 |
+
center=True, scale=True,
|
206 |
+
scope=scope)
|
207 |
+
|
208 |
+
def layer_norm(x, scope='layer_norm') :
|
209 |
+
return tf_contrib.layers.layer_norm(x,
|
210 |
+
center=True, scale=True,
|
211 |
+
scope=scope)
|
212 |
+
|
213 |
+
def layer_instance_norm(x, scope='layer_instance_norm') :
|
214 |
+
with tf.variable_scope(scope):
|
215 |
+
ch = x.shape[-1]
|
216 |
+
eps = 1e-5
|
217 |
+
|
218 |
+
ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
|
219 |
+
x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps))
|
220 |
+
|
221 |
+
ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3], keep_dims=True)
|
222 |
+
x_ln = (x - ln_mean) / (tf.sqrt(ln_sigma + eps))
|
223 |
+
|
224 |
+
rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(0.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0))
|
225 |
+
|
226 |
+
gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0))
|
227 |
+
beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0))
|
228 |
+
|
229 |
+
x_hat = rho * x_ins + (1 - rho) * x_ln
|
230 |
+
|
231 |
+
x_hat = x_hat * gamma + beta
|
232 |
+
|
233 |
+
return x_hat
|
234 |
+
|
235 |
+
def spectral_norm(w, iteration=1):
|
236 |
+
w_shape = w.shape.as_list()
|
237 |
+
w = tf.reshape(w, [-1, w_shape[-1]])
|
238 |
+
|
239 |
+
u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
|
240 |
+
|
241 |
+
u_hat = u
|
242 |
+
v_hat = None
|
243 |
+
for i in range(iteration):
|
244 |
+
"""
|
245 |
+
power iteration
|
246 |
+
Usually iteration = 1 will be enough
|
247 |
+
"""
|
248 |
+
v_ = tf.matmul(u_hat, tf.transpose(w))
|
249 |
+
v_hat = tf.nn.l2_normalize(v_)
|
250 |
+
|
251 |
+
u_ = tf.matmul(v_hat, w)
|
252 |
+
u_hat = tf.nn.l2_normalize(u_)
|
253 |
+
|
254 |
+
u_hat = tf.stop_gradient(u_hat)
|
255 |
+
v_hat = tf.stop_gradient(v_hat)
|
256 |
+
|
257 |
+
sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
|
258 |
+
|
259 |
+
with tf.control_dependencies([u.assign(u_hat)]):
|
260 |
+
w_norm = w / sigma
|
261 |
+
w_norm = tf.reshape(w_norm, w_shape)
|
262 |
+
|
263 |
+
|
264 |
+
return w_norm
|
265 |
+
|
266 |
+
##################################################################################
|
267 |
+
# Loss function
|
268 |
+
##################################################################################
|
269 |
+
|
270 |
+
def L1_loss(x, y):
|
271 |
+
loss = tf.reduce_mean(tf.abs(x - y))
|
272 |
+
|
273 |
+
return loss
|
274 |
+
|
275 |
+
def cam_loss(source, non_source) :
|
276 |
+
|
277 |
+
identity_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(source), logits=source))
|
278 |
+
non_identity_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(non_source), logits=non_source))
|
279 |
+
|
280 |
+
loss = identity_loss + non_identity_loss
|
281 |
+
|
282 |
+
return loss
|
283 |
+
|
284 |
+
def regularization_loss(scope_name) :
|
285 |
+
"""
|
286 |
+
If you want to use "Regularization"
|
287 |
+
g_loss += regularization_loss('generator')
|
288 |
+
d_loss += regularization_loss('discriminator')
|
289 |
+
"""
|
290 |
+
collection_regularization = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
291 |
+
|
292 |
+
loss = []
|
293 |
+
for item in collection_regularization :
|
294 |
+
if scope_name in item.name :
|
295 |
+
loss.append(item)
|
296 |
+
|
297 |
+
return tf.reduce_sum(loss)
|
298 |
+
|
299 |
+
|
300 |
+
def discriminator_loss(loss_func, real, fake):
|
301 |
+
loss = []
|
302 |
+
real_loss = 0
|
303 |
+
fake_loss = 0
|
304 |
+
|
305 |
+
for i in range(2) :
|
306 |
+
if loss_func.__contains__('wgan') :
|
307 |
+
real_loss = -tf.reduce_mean(real[i])
|
308 |
+
fake_loss = tf.reduce_mean(fake[i])
|
309 |
+
|
310 |
+
if loss_func == 'lsgan' :
|
311 |
+
real_loss = tf.reduce_mean(tf.squared_difference(real[i], 1.0))
|
312 |
+
fake_loss = tf.reduce_mean(tf.square(fake[i]))
|
313 |
+
|
314 |
+
if loss_func == 'gan' or loss_func == 'dragan' :
|
315 |
+
real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real[i]), logits=real[i]))
|
316 |
+
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake[i]), logits=fake[i]))
|
317 |
+
|
318 |
+
if loss_func == 'hinge' :
|
319 |
+
real_loss = tf.reduce_mean(relu(1.0 - real[i]))
|
320 |
+
fake_loss = tf.reduce_mean(relu(1.0 + fake[i]))
|
321 |
+
|
322 |
+
loss.append(real_loss + fake_loss)
|
323 |
+
|
324 |
+
return sum(loss)
|
325 |
+
|
326 |
+
def generator_loss(loss_func, fake):
|
327 |
+
loss = []
|
328 |
+
fake_loss = 0
|
329 |
+
|
330 |
+
for i in range(2) :
|
331 |
+
if loss_func.__contains__('wgan') :
|
332 |
+
fake_loss = -tf.reduce_mean(fake[i])
|
333 |
+
|
334 |
+
if loss_func == 'lsgan' :
|
335 |
+
fake_loss = tf.reduce_mean(tf.squared_difference(fake[i], 1.0))
|
336 |
+
|
337 |
+
if loss_func == 'gan' or loss_func == 'dragan' :
|
338 |
+
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake[i]), logits=fake[i]))
|
339 |
+
|
340 |
+
if loss_func == 'hinge' :
|
341 |
+
fake_loss = -tf.reduce_mean(fake[i])
|
342 |
+
|
343 |
+
loss.append(fake_loss)
|
344 |
+
|
345 |
+
return sum(loss)
|
ugatit/utils.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.contrib import slim
|
3 |
+
import cv2
|
4 |
+
import os, random
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class ImageData:
|
8 |
+
|
9 |
+
def __init__(self, load_size, channels, augment_flag):
|
10 |
+
self.load_size = load_size
|
11 |
+
self.channels = channels
|
12 |
+
self.augment_flag = augment_flag
|
13 |
+
|
14 |
+
def image_processing(self, filename):
|
15 |
+
x = tf.read_file(filename)
|
16 |
+
x_decode = tf.image.decode_jpeg(x, channels=self.channels)
|
17 |
+
img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
|
18 |
+
img = tf.cast(img, tf.float32) / 127.5 - 1
|
19 |
+
|
20 |
+
if self.augment_flag :
|
21 |
+
augment_size = self.load_size + (30 if self.load_size == 256 else 15)
|
22 |
+
p = random.random()
|
23 |
+
if p > 0.5:
|
24 |
+
img = augmentation(img, augment_size)
|
25 |
+
|
26 |
+
return img
|
27 |
+
|
28 |
+
def load_test_data(image_path, size=256):
|
29 |
+
img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
|
30 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
31 |
+
|
32 |
+
img = cv2.resize(img, dsize=(size, size))
|
33 |
+
|
34 |
+
img = np.expand_dims(img, axis=0)
|
35 |
+
img = img/127.5 - 1
|
36 |
+
|
37 |
+
return img
|
38 |
+
|
39 |
+
def augmentation(image, augment_size):
|
40 |
+
seed = random.randint(0, 2 ** 31 - 1)
|
41 |
+
ori_image_shape = tf.shape(image)
|
42 |
+
image = tf.image.random_flip_left_right(image, seed=seed)
|
43 |
+
image = tf.image.resize_images(image, [augment_size, augment_size])
|
44 |
+
image = tf.random_crop(image, ori_image_shape, seed=seed)
|
45 |
+
return image
|
46 |
+
|
47 |
+
def save_images(images, size, image_path):
|
48 |
+
return imsave(inverse_transform(images), size, image_path)
|
49 |
+
|
50 |
+
def inverse_transform(images):
|
51 |
+
return ((images+1.) / 2) * 255.0
|
52 |
+
|
53 |
+
|
54 |
+
def imsave(images, size, path):
|
55 |
+
images = merge(images, size)
|
56 |
+
images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR)
|
57 |
+
|
58 |
+
return cv2.imwrite(path, images)
|
59 |
+
|
60 |
+
def merge(images, size):
|
61 |
+
h, w = images.shape[1], images.shape[2]
|
62 |
+
img = np.zeros((h * size[0], w * size[1], 3))
|
63 |
+
for idx, image in enumerate(images):
|
64 |
+
i = idx % size[1]
|
65 |
+
j = idx // size[1]
|
66 |
+
img[h*j:h*(j+1), w*i:w*(i+1), :] = image
|
67 |
+
|
68 |
+
return img
|
69 |
+
|
70 |
+
def show_all_variables():
|
71 |
+
model_vars = tf.trainable_variables()
|
72 |
+
slim.model_analyzer.analyze_vars(model_vars, print_info=True)
|
73 |
+
|
74 |
+
def check_folder(log_dir):
|
75 |
+
if not os.path.exists(log_dir):
|
76 |
+
os.makedirs(log_dir)
|
77 |
+
return log_dir
|
78 |
+
|
79 |
+
def str2bool(x):
|
80 |
+
return x.lower() in ('true')
|
ugatit_test.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ugatit.ops import *
|
2 |
+
from ugatit.utils import *
|
3 |
+
from glob import glob
|
4 |
+
import time
|
5 |
+
from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
|
6 |
+
import numpy as np
|
7 |
+
from ugatit.utils import *
|
8 |
+
|
9 |
+
class UgatitTest:
|
10 |
+
|
11 |
+
def __init__(self, sess):
|
12 |
+
self.light = False
|
13 |
+
|
14 |
+
if self.light:
|
15 |
+
self.model_name = 'UGATIT_light'
|
16 |
+
else:
|
17 |
+
self.model_name = 'UGATIT'
|
18 |
+
|
19 |
+
self.sess = sess
|
20 |
+
self.phase = 'test'
|
21 |
+
self.checkpoint_dir = '/home/hylee/cartoon/UGATIT/checkpoint'
|
22 |
+
self.result_dir = 'results'
|
23 |
+
self.log_dir = 'logs'
|
24 |
+
self.dataset_name = 'selfie2anime'
|
25 |
+
self.augment_flag = True
|
26 |
+
|
27 |
+
self.epoch = 100
|
28 |
+
self.iteration = 10000
|
29 |
+
self.decay_flag = True
|
30 |
+
self.decay_epoch = 50
|
31 |
+
|
32 |
+
self.gan_type = 'lsgan'
|
33 |
+
|
34 |
+
self.batch_size = 1
|
35 |
+
self.print_freq = 1000
|
36 |
+
self.save_freq = 1000
|
37 |
+
|
38 |
+
self.init_lr = 0.0001
|
39 |
+
self.ch = 64
|
40 |
+
|
41 |
+
""" Weight """
|
42 |
+
self.adv_weight = 1
|
43 |
+
self.cycle_weight = 10
|
44 |
+
self.identity_weight = 10
|
45 |
+
self.cam_weight = 1000
|
46 |
+
self.ld = 10
|
47 |
+
self.smoothing = True
|
48 |
+
|
49 |
+
""" Generator """
|
50 |
+
self.n_res = 4
|
51 |
+
|
52 |
+
""" Discriminator """
|
53 |
+
self.n_dis = 6
|
54 |
+
self.n_critic = 1
|
55 |
+
self.sn = True
|
56 |
+
|
57 |
+
self.img_size = 256
|
58 |
+
self.img_ch = 3
|
59 |
+
|
60 |
+
self.sample_dir = os.path.join('/home/hylee/cartoon/UGATIT/samples', self.model_dir)
|
61 |
+
check_folder(self.sample_dir)
|
62 |
+
|
63 |
+
# self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size
|
64 |
+
self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
|
65 |
+
self.trainB_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainB'))
|
66 |
+
self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset))
|
67 |
+
|
68 |
+
print()
|
69 |
+
|
70 |
+
print("##### Information #####")
|
71 |
+
print("# light : ", self.light)
|
72 |
+
print("# gan type : ", self.gan_type)
|
73 |
+
print("# dataset : ", self.dataset_name)
|
74 |
+
print("# max dataset number : ", self.dataset_num)
|
75 |
+
print("# batch_size : ", self.batch_size)
|
76 |
+
print("# epoch : ", self.epoch)
|
77 |
+
print("# iteration per epoch : ", self.iteration)
|
78 |
+
print("# smoothing : ", self.smoothing)
|
79 |
+
|
80 |
+
print()
|
81 |
+
|
82 |
+
print("##### Generator #####")
|
83 |
+
print("# residual blocks : ", self.n_res)
|
84 |
+
|
85 |
+
print()
|
86 |
+
|
87 |
+
print("##### Discriminator #####")
|
88 |
+
print("# discriminator layer : ", self.n_dis)
|
89 |
+
print("# the number of critic : ", self.n_critic)
|
90 |
+
print("# spectral normalization : ", self.sn)
|
91 |
+
|
92 |
+
print()
|
93 |
+
|
94 |
+
print("##### Weight #####")
|
95 |
+
print("# adv_weight : ", self.adv_weight)
|
96 |
+
print("# cycle_weight : ", self.cycle_weight)
|
97 |
+
print("# identity_weight : ", self.identity_weight)
|
98 |
+
print("# cam_weight : ", self.cam_weight)
|
99 |
+
|
100 |
+
##################################################################################
|
101 |
+
# Generator
|
102 |
+
##################################################################################
|
103 |
+
|
104 |
+
def generator(self, x_init, reuse=False, scope="generator"):
|
105 |
+
channel = self.ch
|
106 |
+
with tf.variable_scope(scope, reuse=reuse) :
|
107 |
+
x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', scope='conv')
|
108 |
+
x = instance_norm(x, scope='ins_norm')
|
109 |
+
x = relu(x)
|
110 |
+
|
111 |
+
# Down-Sampling
|
112 |
+
for i in range(2) :
|
113 |
+
x = conv(x, channel*2, kernel=3, stride=2, pad=1, pad_type='reflect', scope='conv_'+str(i))
|
114 |
+
x = instance_norm(x, scope='ins_norm_'+str(i))
|
115 |
+
x = relu(x)
|
116 |
+
|
117 |
+
channel = channel * 2
|
118 |
+
|
119 |
+
# Down-Sampling Bottleneck
|
120 |
+
for i in range(self.n_res):
|
121 |
+
x = resblock(x, channel, scope='resblock_' + str(i))
|
122 |
+
|
123 |
+
|
124 |
+
# Class Activation Map
|
125 |
+
cam_x = global_avg_pooling(x)
|
126 |
+
cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, scope='CAM_logit')
|
127 |
+
x_gap = tf.multiply(x, cam_x_weight)
|
128 |
+
|
129 |
+
cam_x = global_max_pooling(x)
|
130 |
+
cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, reuse=True, scope='CAM_logit')
|
131 |
+
x_gmp = tf.multiply(x, cam_x_weight)
|
132 |
+
|
133 |
+
|
134 |
+
cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
|
135 |
+
x = tf.concat([x_gap, x_gmp], axis=-1)
|
136 |
+
|
137 |
+
x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
|
138 |
+
x = relu(x)
|
139 |
+
|
140 |
+
heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
|
141 |
+
|
142 |
+
# Gamma, Beta block
|
143 |
+
gamma, beta = self.MLP(x, reuse=reuse)
|
144 |
+
|
145 |
+
# Up-Sampling Bottleneck
|
146 |
+
for i in range(self.n_res):
|
147 |
+
x = adaptive_ins_layer_resblock(x, channel, gamma, beta, smoothing=self.smoothing, scope='adaptive_resblock' + str(i))
|
148 |
+
|
149 |
+
# Up-Sampling
|
150 |
+
for i in range(2) :
|
151 |
+
x = up_sample(x, scale_factor=2)
|
152 |
+
x = conv(x, channel//2, kernel=3, stride=1, pad=1, pad_type='reflect', scope='up_conv_'+str(i))
|
153 |
+
x = layer_instance_norm(x, scope='layer_ins_norm_'+str(i))
|
154 |
+
x = relu(x)
|
155 |
+
|
156 |
+
channel = channel // 2
|
157 |
+
|
158 |
+
|
159 |
+
x = conv(x, channels=3, kernel=7, stride=1, pad=3, pad_type='reflect', scope='G_logit')
|
160 |
+
x = tanh(x)
|
161 |
+
|
162 |
+
return x, cam_logit, heatmap
|
163 |
+
|
164 |
+
def MLP(self, x, use_bias=True, reuse=False, scope='MLP'):
|
165 |
+
channel = self.ch * self.n_res
|
166 |
+
|
167 |
+
if self.light :
|
168 |
+
x = global_avg_pooling(x)
|
169 |
+
|
170 |
+
with tf.variable_scope(scope, reuse=reuse):
|
171 |
+
for i in range(2) :
|
172 |
+
x = fully_connected(x, channel, use_bias, scope='linear_' + str(i))
|
173 |
+
x = relu(x)
|
174 |
+
|
175 |
+
|
176 |
+
gamma = fully_connected(x, channel, use_bias, scope='gamma')
|
177 |
+
beta = fully_connected(x, channel, use_bias, scope='beta')
|
178 |
+
|
179 |
+
gamma = tf.reshape(gamma, shape=[self.batch_size, 1, 1, channel])
|
180 |
+
beta = tf.reshape(beta, shape=[self.batch_size, 1, 1, channel])
|
181 |
+
|
182 |
+
return gamma, beta
|
183 |
+
|
184 |
+
##################################################################################
|
185 |
+
# Discriminator
|
186 |
+
##################################################################################
|
187 |
+
|
188 |
+
def discriminator(self, x_init, reuse=False, scope="discriminator"):
|
189 |
+
D_logit = []
|
190 |
+
D_CAM_logit = []
|
191 |
+
with tf.variable_scope(scope, reuse=reuse) :
|
192 |
+
local_x, local_cam, local_heatmap = self.discriminator_local(x_init, reuse=reuse, scope='local')
|
193 |
+
global_x, global_cam, global_heatmap = self.discriminator_global(x_init, reuse=reuse, scope='global')
|
194 |
+
|
195 |
+
D_logit.extend([local_x, global_x])
|
196 |
+
D_CAM_logit.extend([local_cam, global_cam])
|
197 |
+
|
198 |
+
return D_logit, D_CAM_logit, local_heatmap, global_heatmap
|
199 |
+
|
200 |
+
def discriminator_global(self, x_init, reuse=False, scope='discriminator_global'):
|
201 |
+
with tf.variable_scope(scope, reuse=reuse):
|
202 |
+
channel = self.ch
|
203 |
+
x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0')
|
204 |
+
x = lrelu(x, 0.2)
|
205 |
+
|
206 |
+
for i in range(1, self.n_dis - 1):
|
207 |
+
x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i))
|
208 |
+
x = lrelu(x, 0.2)
|
209 |
+
|
210 |
+
channel = channel * 2
|
211 |
+
|
212 |
+
x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last')
|
213 |
+
x = lrelu(x, 0.2)
|
214 |
+
|
215 |
+
channel = channel * 2
|
216 |
+
|
217 |
+
cam_x = global_avg_pooling(x)
|
218 |
+
cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit')
|
219 |
+
x_gap = tf.multiply(x, cam_x_weight)
|
220 |
+
|
221 |
+
cam_x = global_max_pooling(x)
|
222 |
+
cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit')
|
223 |
+
x_gmp = tf.multiply(x, cam_x_weight)
|
224 |
+
|
225 |
+
cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
|
226 |
+
x = tf.concat([x_gap, x_gmp], axis=-1)
|
227 |
+
|
228 |
+
x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
|
229 |
+
x = lrelu(x, 0.2)
|
230 |
+
|
231 |
+
heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
|
232 |
+
|
233 |
+
|
234 |
+
x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit')
|
235 |
+
|
236 |
+
return x, cam_logit, heatmap
|
237 |
+
|
238 |
+
def discriminator_local(self, x_init, reuse=False, scope='discriminator_local'):
|
239 |
+
with tf.variable_scope(scope, reuse=reuse) :
|
240 |
+
channel = self.ch
|
241 |
+
x = conv(x_init, channel, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_0')
|
242 |
+
x = lrelu(x, 0.2)
|
243 |
+
|
244 |
+
for i in range(1, self.n_dis - 2 - 1):
|
245 |
+
x = conv(x, channel * 2, kernel=4, stride=2, pad=1, pad_type='reflect', sn=self.sn, scope='conv_' + str(i))
|
246 |
+
x = lrelu(x, 0.2)
|
247 |
+
|
248 |
+
channel = channel * 2
|
249 |
+
|
250 |
+
x = conv(x, channel * 2, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='conv_last')
|
251 |
+
x = lrelu(x, 0.2)
|
252 |
+
|
253 |
+
channel = channel * 2
|
254 |
+
|
255 |
+
cam_x = global_avg_pooling(x)
|
256 |
+
cam_gap_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, scope='CAM_logit')
|
257 |
+
x_gap = tf.multiply(x, cam_x_weight)
|
258 |
+
|
259 |
+
cam_x = global_max_pooling(x)
|
260 |
+
cam_gmp_logit, cam_x_weight = fully_connected_with_w(cam_x, sn=self.sn, reuse=True, scope='CAM_logit')
|
261 |
+
x_gmp = tf.multiply(x, cam_x_weight)
|
262 |
+
|
263 |
+
cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
|
264 |
+
x = tf.concat([x_gap, x_gmp], axis=-1)
|
265 |
+
|
266 |
+
x = conv(x, channel, kernel=1, stride=1, scope='conv_1x1')
|
267 |
+
x = lrelu(x, 0.2)
|
268 |
+
|
269 |
+
heatmap = tf.squeeze(tf.reduce_sum(x, axis=-1))
|
270 |
+
|
271 |
+
x = conv(x, channels=1, kernel=4, stride=1, pad=1, pad_type='reflect', sn=self.sn, scope='D_logit')
|
272 |
+
|
273 |
+
return x, cam_logit, heatmap
|
274 |
+
|
275 |
+
def generate_a2b(self, x_A, reuse=False):
|
276 |
+
out, cam, _ = self.generator(x_A, reuse=reuse, scope="generator_B")
|
277 |
+
|
278 |
+
return out, cam
|
279 |
+
|
280 |
+
def generate_b2a(self, x_B, reuse=False):
|
281 |
+
out, cam, _ = self.generator(x_B, reuse=reuse, scope="generator_A")
|
282 |
+
|
283 |
+
return out, cam
|
284 |
+
def build_model(self):
|
285 |
+
self.test_domain_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_A')
|
286 |
+
self.test_domain_B = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_domain_B')
|
287 |
+
|
288 |
+
self.test_fake_B, _ = self.generate_a2b(self.test_domain_A)
|
289 |
+
self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
|
290 |
+
|
291 |
+
@property
|
292 |
+
def model_dir(self):
|
293 |
+
n_res = str(self.n_res) + 'resblock'
|
294 |
+
n_dis = str(self.n_dis) + 'dis'
|
295 |
+
|
296 |
+
if self.smoothing:
|
297 |
+
smoothing = '_smoothing'
|
298 |
+
else:
|
299 |
+
smoothing = ''
|
300 |
+
|
301 |
+
if self.sn:
|
302 |
+
sn = '_sn'
|
303 |
+
else:
|
304 |
+
sn = ''
|
305 |
+
|
306 |
+
return "{}_{}_{}_{}_{}_{}_{}_{}_{}_{}{}{}".format(self.model_name, self.dataset_name,
|
307 |
+
self.gan_type, n_res, n_dis,
|
308 |
+
self.n_critic,
|
309 |
+
self.adv_weight, self.cycle_weight, self.identity_weight,
|
310 |
+
self.cam_weight, sn, smoothing)
|
311 |
+
|
312 |
+
def load(self, checkpoint_dir):
|
313 |
+
print(" [*] Reading checkpoints...")
|
314 |
+
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
|
315 |
+
|
316 |
+
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
|
317 |
+
if ckpt and ckpt.model_checkpoint_path:
|
318 |
+
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
|
319 |
+
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
|
320 |
+
counter = int(ckpt_name.split('-')[-1])
|
321 |
+
print(" [*] Success to read {}".format(ckpt_name))
|
322 |
+
return True, counter
|
323 |
+
else:
|
324 |
+
print(" [*] Failed to find a checkpoint")
|
325 |
+
return False, 0
|
326 |
+
|
327 |
+
def loadModel(self):
|
328 |
+
tf.global_variables_initializer().run(session=self.sess)
|
329 |
+
|
330 |
+
self.saver = tf.train.Saver()
|
331 |
+
could_load, checkpoint_counter = self.load(self.checkpoint_dir)
|
332 |
+
self.result_dir = os.path.join(self.result_dir, self.model_dir)
|
333 |
+
check_folder(self.result_dir)
|
334 |
+
|
335 |
+
if could_load:
|
336 |
+
print(" [*] Load SUCCESS")
|
337 |
+
else:
|
338 |
+
print(" [!] Load failed...")
|
339 |
+
|
340 |
+
def test(self, sample_file):
|
341 |
+
# A -> B
|
342 |
+
print('Processing A image: ' + sample_file)
|
343 |
+
sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
|
344 |
+
image_path = os.path.join(self.result_dir,'{0}'.format(os.path.basename(sample_file)))
|
345 |
+
|
346 |
+
fake_img = self.sess.run(self.test_fake_B, feed_dict = {self.test_domain_A : sample_image})
|
347 |
+
save_images(fake_img, [1, 1], image_path)
|
348 |
+
|
349 |
+
return image_path
|
350 |
+
|
351 |
+
|
352 |
+
gan = None
|
353 |
+
def main_test(img_path):
|
354 |
+
# open session
|
355 |
+
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
|
356 |
+
global gan
|
357 |
+
if gan is None:
|
358 |
+
gan = UgatitTest(sess)
|
359 |
+
# build graph
|
360 |
+
gan.build_model()
|
361 |
+
# show network architecture
|
362 |
+
show_all_variables()
|
363 |
+
|
364 |
+
gan.loadModel()
|
365 |
+
|
366 |
+
result = gan.test(img_path)
|
367 |
+
print(" [*] Test finished!")
|
368 |
+
print(result)
|
369 |
+
return os.path.abspath(result)
|
370 |
+
|
371 |
+
if __name__ == '__main__':
|
372 |
+
main_test('/home/hylee/cartoon/myp2c/imgs/src/im4.jpg')
|