soiz1 commited on
Commit
55438d7
·
verified ·
1 Parent(s): 893a056

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ model
2
+ logs
3
+ data
4
+ load_model.py
README.md CHANGED
@@ -1,12 +1,85 @@
1
- ---
2
- title: Watermark Removal
3
- emoji: 🐨
4
- colorFrom: purple
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.49.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Chimzuruoke Okafor
2
+
3
+ # Watermark-Removal
4
+
5
+ <p>
6
+ <a href="https://pepy.tech/project/prompttools" target="_blank"><img src="https://static.pepy.tech/badge/prompttools" alt="Total Downloads"/></a>
7
+ <a href="https://github.com/hegelai/watermark-removal"><img src="https://img.shields.io/github/stars/zuruoke/watermark-removal" /></a>
8
+ <a href="https://twitter.com/zuruoke_okafor"><img src="https://img.shields.io/twitter/follow/Zuruoke_Okafor?style=social"></a>
9
+ </p>
10
+
11
+ ![version](https://img.shields.io/badge/version-v1.0.0-green.svg?style=plastic)
12
+ ![pytorch](https://img.shields.io/badge/tensorflow-v1.15.0-green.svg?style=plastic)
13
+ ![license](https://img.shields.io/badge/license-CC_BY--NC-green.svg?style=plastic)
14
+
15
+ <a href="https://coff.ee/zuruokeokafor" target="_blank">
16
+ <img src="https://www.buymeacoffee.com/assets/img/custom_images/orange_img.png"
17
+ alt="Buy Me A Coffee"
18
+ style="height: 41px;width:174px;box-shadow:0px 3px 2px rgba(190,190,190,0.5);" />
19
+ </a>
20
+
21
+ An open source project that uses a machine learning based image inpainting methodology to remove watermark from images which is totally indistinguishable from the ground truth version of the image.
22
+
23
+ This project was inspired by the [Contextual Attention](https://arxiv.org/abs/1801.07892) (CVPR 2018) and [Gated Convolution](https://arxiv.org/abs/1806.03589) (ICCV 2019 Oral).
24
+
25
+ And also a shoutout to [Chu-Tak Li](https://chutakcode.wixsite.com/website) for his [Medium article series](https://towardsdatascience.com/10-papers-you-must-read-for-deep-image-inpainting-2e41c589ced0) that really gave me a deep insight into the image inpainting papers stated above
26
+
27
+ <img src="https://user-images.githubusercontent.com/51057490/140277713-c7d6e2b9-db62-4793-823a-25ed0c4e2771.png" width="45%"/> <img src="https://user-images.githubusercontent.com/51057490/140277781-5b5218bb-9044-4ec9-a349-eea93bc56d4a.png" width="45%"/> <img src="https://user-images.githubusercontent.com/51057490/140277929-3f187647-0e63-4bcb-b9f1-472f7558aae5.jpeg" width="45%"/> <img src="https://user-images.githubusercontent.com/51057490/140277957-6ddb7dec-25c8-42f1-8e39-be491d4f2248.png" width="45%"/> <img src="https://user-images.githubusercontent.com/51057490/140277983-265a1c9e-6093-4154-8252-838baca21c41.jpeg" width="45%" /> <img src="https://user-images.githubusercontent.com/51057490/140278002-56c4ae3d-6bfb-4ba3-aa02-7bd28474bfdf.png" width="45%" /> <img src="https://user-images.githubusercontent.com/51057490/140278030-d2a962ce-3722-43f1-b1bd-0ffde2aa7026.jpeg" width="45%" /> <img src="https://user-images.githubusercontent.com/51057490/140278040-10e401d7-4b7d-4d81-91fe-e9f01ef4ce7f.png" width="45%" /> <img src="https://user-images.githubusercontent.com/51057490/140278017-34862de0-86eb-40f0-b04b-7dc02fe38a77.jpeg" width="45%" /> <img src="https://user-images.githubusercontent.com/51057490/140278011-e0ae9ed0-e4ed-44ed-a9ac-28eb8456797a.png" width="45%" />
28
+
29
+ ## Run
30
+
31
+ - use [Google colab](https://research.google.com/colaboratory/)
32
+
33
+ - First of all, clone this repo
34
+
35
+ !git clone https://github.com/zuruoke/watermark-removal
36
+
37
+ - Change Directory to the repo
38
+
39
+ !cd watermark-removal
40
+
41
+ - Since Google Colab uses the latest Tensorflow 2x version and this project uses 1.15.0, downgrade to Tensorflow 1.15.0 version and restart the runtime, (`although the new version of Google Colab does not need you to restart the runtime`).
42
+
43
+ !pip install tensorflow==1.15.0
44
+
45
+ - Install tensorflow toolkit [neuralgym](https://github.com/JiahuiYu/neuralgym).
46
+
47
+ !pip install git+https://github.com/JiahuiYu/neuralgym
48
+
49
+ - Download the model dirs using this [link](https://drive.google.com/drive/folders/1xRV4EdjJuAfsX9pQme6XeoFznKXG0ptJ?usp=sharing) and put it under `model/` (rename `checkpoint.txt` to `checkpoint` because sometimes google drive automatically adds .txt after download)
50
+
51
+ And you're all Set!!
52
+
53
+ - Now remove the watermark on the image by runing the `main.py` file
54
+
55
+ !python main.py --image path-to-input-image --output path-to-output-image --checkpoint_dir model/ --watermark_type istock
56
+
57
+ ## Citing
58
+
59
+ ```
60
+ @article{yu2018generative,
61
+ title={Generative Image Inpainting with Contextual Attention},
62
+ author={Yu, Jiahui and Lin, Zhe and Yang, Jimei and Shen, Xiaohui and Lu, Xin and Huang, Thomas S},
63
+ journal={arXiv preprint arXiv:1801.07892},
64
+ year={2018}
65
+ }
66
+
67
+ @article{yu2018free,
68
+ title={Free-Form Image Inpainting with Gated Convolution},
69
+ author={Yu, Jiahui and Lin, Zhe and Yang, Jimei and Shen, Xiaohui and Lu, Xin and Huang, Thomas S},
70
+ journal={arXiv preprint arXiv:1806.03589},
71
+ year={2018}
72
+ }
73
+ ```
74
+
75
+ <p align="center">
76
+ <a href="https://star-history.com/#zuruoke/watermark-removal">
77
+ <picture>
78
+ <source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=zuruoke/watermark-removal&type=Date&theme=dark" />
79
+ <source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=zuruoke/watermark-removal&type=Date" />
80
+ <img alt="Star History Chart" src="https://api.star-history.com/svg?repos=zuruoke/watermark-removal&type=Date" />
81
+ </picture>
82
+ </a>
83
+ </p>
84
+
85
+ ## © Chimzuruoke Okafor
batch_test.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import argparse
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import tensorflow as tf
8
+ import neuralgym as ng
9
+
10
+ from inpaint_model import InpaintCAModel
11
+
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument(
15
+ '--flist', default='', type=str,
16
+ help='The filenames of image to be processed: input, mask, output.')
17
+ parser.add_argument(
18
+ '--image_height', default=-1, type=int,
19
+ help='The height of images should be defined, otherwise batch mode is not'
20
+ ' supported.')
21
+ parser.add_argument(
22
+ '--image_width', default=-1, type=int,
23
+ help='The width of images should be defined, otherwise batch mode is not'
24
+ ' supported.')
25
+ parser.add_argument(
26
+ '--checkpoint_dir', default='', type=str,
27
+ help='The directory of tensorflow checkpoint.')
28
+
29
+
30
+ if __name__ == "__main__":
31
+ FLAGS = ng.Config('inpaint.yml')
32
+ ng.get_gpus(1)
33
+ # os.environ['CUDA_VISIBLE_DEVICES'] =''
34
+ args = parser.parse_args()
35
+
36
+ sess_config = tf.ConfigProto()
37
+ sess_config.gpu_options.allow_growth = True
38
+ sess = tf.Session(config=sess_config)
39
+
40
+ model = InpaintCAModel()
41
+ input_image_ph = tf.placeholder(
42
+ tf.float32, shape=(1, args.image_height, args.image_width*2, 3))
43
+ output = model.build_server_graph(FLAGS, input_image_ph)
44
+ output = (output + 1.) * 127.5
45
+ output = tf.reverse(output, [-1])
46
+ output = tf.saturate_cast(output, tf.uint8)
47
+ vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
48
+ assign_ops = []
49
+ for var in vars_list:
50
+ vname = var.name
51
+ from_name = vname
52
+ var_value = tf.contrib.framework.load_variable(
53
+ args.checkpoint_dir, from_name)
54
+ assign_ops.append(tf.assign(var, var_value))
55
+ sess.run(assign_ops)
56
+ print('Model loaded.')
57
+
58
+ with open(args.flist, 'r') as f:
59
+ lines = f.read().splitlines()
60
+ t = time.time()
61
+ for line in lines:
62
+ # for i in range(100):
63
+ image, mask, out = line.split()
64
+ base = os.path.basename(mask)
65
+
66
+ image = cv2.imread(image)
67
+ mask = cv2.imread(mask)
68
+ image = cv2.resize(image, (args.image_width, args.image_height))
69
+ mask = cv2.resize(mask, (args.image_width, args.image_height))
70
+ # cv2.imwrite(out, image*(1-mask/255.) + mask)
71
+ # # continue
72
+ # image = np.zeros((128, 256, 3))
73
+ # mask = np.zeros((128, 256, 3))
74
+
75
+ assert image.shape == mask.shape
76
+
77
+ h, w, _ = image.shape
78
+ grid = 4
79
+ image = image[:h//grid*grid, :w//grid*grid, :]
80
+ mask = mask[:h//grid*grid, :w//grid*grid, :]
81
+ print('Shape of image: {}'.format(image.shape))
82
+
83
+ image = np.expand_dims(image, 0)
84
+ mask = np.expand_dims(mask, 0)
85
+ input_image = np.concatenate([image, mask], axis=2)
86
+
87
+ # load pretrained model
88
+ result = sess.run(output, feed_dict={input_image_ph: input_image})
89
+ print('Processed: {}'.format(out))
90
+ cv2.imwrite(out, result[0][:, :, ::-1])
91
+
92
+ print('Time total: {}'.format(time.time() - t))
guided_batch_test.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import argparse
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import tensorflow as tf
8
+ import neuralgym as ng
9
+
10
+ from inpaint_model import InpaintCAModel
11
+
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument(
15
+ '--flist', default='', type=str,
16
+ help='The filenames of image to be processed: input, mask, output.')
17
+ parser.add_argument(
18
+ '--image_height', default=-1, type=int,
19
+ help='The height of images should be defined, otherwise batch mode is not'
20
+ ' supported.')
21
+ parser.add_argument(
22
+ '--image_width', default=-1, type=int,
23
+ help='The width of images should be defined, otherwise batch mode is not'
24
+ ' supported.')
25
+ parser.add_argument(
26
+ '--checkpoint_dir', default='', type=str,
27
+ help='The directory of tensorflow checkpoint.')
28
+
29
+
30
+ if __name__ == "__main__":
31
+ ng.get_gpus(1)
32
+ # os.environ['CUDA_VISIBLE_DEVICES'] =''
33
+ args = parser.parse_args()
34
+
35
+ sess_config = tf.ConfigProto()
36
+ sess_config.gpu_options.allow_growth = True
37
+ sess = tf.Session(config=sess_config)
38
+
39
+ model = InpaintCAModel()
40
+ input_image_ph = tf.placeholder(
41
+ tf.float32, shape=(1, args.image_height, args.image_width*3, 3))
42
+ output = model.build_server_graph(input_image_ph)
43
+ output = (output + 1.) * 127.5
44
+ output = tf.reverse(output, [-1])
45
+ output = tf.saturate_cast(output, tf.uint8)
46
+ vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
47
+ assign_ops = []
48
+ for var in vars_list:
49
+ vname = var.name
50
+ from_name = vname
51
+ var_value = tf.contrib.framework.load_variable(
52
+ args.checkpoint_dir, from_name)
53
+ assign_ops.append(tf.assign(var, var_value))
54
+ sess.run(assign_ops)
55
+ print('Model loaded.')
56
+
57
+ with open(args.flist, 'r') as f:
58
+ lines = f.read().splitlines()
59
+ t = time.time()
60
+ for line in lines:
61
+ # for i in range(100):
62
+ image, mask, out = line.split()
63
+ base = os.path.basename(mask)
64
+
65
+ guidance = cv2.imread(image[:-4] + '_edge.jpg')
66
+ image = cv2.imread(image)
67
+ mask = cv2.imread(mask)
68
+ image = cv2.resize(image, (args.image_width, args.image_height))
69
+ guidance = cv2.resize(guidance, (args.image_width, args.image_height))
70
+ mask = cv2.resize(mask, (args.image_width, args.image_height))
71
+ # cv2.imwrite(out, image*(1-mask/255.) + mask)
72
+ # # continue
73
+ # image = np.zeros((128, 256, 3))
74
+ # mask = np.zeros((128, 256, 3))
75
+
76
+ assert image.shape == mask.shape
77
+
78
+ h, w, _ = image.shape
79
+ grid = 4
80
+ image = image[:h//grid*grid, :w//grid*grid, :]
81
+ mask = mask[:h//grid*grid, :w//grid*grid, :]
82
+ guidance = guidance[:h//grid*grid, :w//grid*grid, :]
83
+ print('Shape of image: {}'.format(image.shape))
84
+
85
+ image = np.expand_dims(image, 0)
86
+ guidance = np.expand_dims(guidance, 0)
87
+ mask = np.expand_dims(mask, 0)
88
+ input_image = np.concatenate([image, guidance, mask], axis=2)
89
+
90
+ # load pretrained model
91
+ result = sess.run(output, feed_dict={input_image_ph: input_image})
92
+ print('Processed: {}'.format(out))
93
+ cv2.imwrite(out, result[0][:, :, ::-1])
94
+
95
+ print('Time total: {}'.format(time.time() - t))
inpaint.yml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========================== Basic Settings ===========================
2
+ # machine info
3
+ num_gpus_per_job: 1 # number of gpus each job need
4
+ num_cpus_per_job: 4 # number of gpus each job need
5
+ num_hosts_per_job: 1
6
+ memory_per_job: 32 # number of gpus each job need
7
+ gpu_type: 'nvidia-tesla-p100'
8
+
9
+ # parameters
10
+ name: places2_gated_conv_v100 # any name
11
+ model_restore: '' # logs/places2_gated_conv
12
+ dataset: 'celebahq' # 'tmnist', 'dtd', 'places2', 'celeba', 'imagenet', 'cityscapes'
13
+ random_crop: False # Set to false when dataset is 'celebahq', meaning only resize the images to img_shapes, instead of crop img_shapes from a larger raw image. This is useful when you train on images with different resolutions like places2. In these cases, please set random_crop to true.
14
+ val: False # true if you want to view validation results in tensorboard
15
+ log_dir: logs/full_model_celeba_hq_256
16
+
17
+ gan: 'sngan'
18
+ gan_loss_alpha: 1
19
+ gan_with_mask: True
20
+ discounted_mask: True
21
+ random_seed: False
22
+ padding: 'SAME'
23
+
24
+ # training
25
+ train_spe: 4000
26
+ max_iters: 100000000
27
+ viz_max_out: 10
28
+ val_psteps: 2000
29
+
30
+ # data
31
+ data_flist:
32
+ # https://github.com/jiahuiyu/progressive_growing_of_gans_tf
33
+ celebahq: [
34
+ 'data/celeba_hq/train_shuffled.flist',
35
+ 'data/celeba_hq/validation_static_view.flist'
36
+ ]
37
+ # http://mmlab.ie.cuhk.edu.hk/projects/celeba.html, please to use random_crop: True
38
+ celeba: [
39
+ 'data/celeba/train_shuffled.flist',
40
+ 'data/celeba/validation_static_view.flist'
41
+ ]
42
+ # http://places2.csail.mit.edu/, please download the high-resolution dataset and use random_crop: True
43
+ places2: [
44
+ 'data/places2/train_shuffled.flist',
45
+ 'data/places2/validation_static_view.flist'
46
+ ]
47
+ # http://www.image-net.org/, please use random_crop: True
48
+ imagenet: [
49
+ 'data/imagenet/train_shuffled.flist',
50
+ 'data/imagenet/validation_static_view.flist',
51
+ ]
52
+
53
+ static_view_size: 30
54
+ img_shapes: [256, 256, 3]
55
+ height: 128
56
+ width: 128
57
+ max_delta_height: 32
58
+ max_delta_width: 32
59
+ batch_size: 16
60
+ vertical_margin: 0
61
+ horizontal_margin: 0
62
+
63
+ # loss
64
+ ae_loss: True
65
+ l1_loss: True
66
+ l1_loss_alpha: 1.
67
+
68
+ # to tune
69
+ guided: False
70
+ edge_threshold: 0.6
inpaint_model.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ common model for DCGAN """
2
+ import logging
3
+
4
+ import cv2
5
+ import neuralgym as ng
6
+ import tensorflow as tf
7
+ from tensorflow.contrib.framework.python.ops import arg_scope
8
+
9
+ from neuralgym.models import Model
10
+ from neuralgym.ops.summary_ops import scalar_summary, images_summary
11
+ from neuralgym.ops.summary_ops import gradients_summary
12
+ from neuralgym.ops.layers import flatten, resize
13
+ from neuralgym.ops.gan_ops import gan_hinge_loss
14
+ from neuralgym.ops.gan_ops import random_interpolates
15
+
16
+ from inpaint_ops import gen_conv, gen_deconv, dis_conv
17
+ from inpaint_ops import random_bbox, bbox2mask, local_patch, brush_stroke_mask
18
+ from inpaint_ops import resize_mask_like, contextual_attention
19
+
20
+
21
+ logger = logging.getLogger()
22
+
23
+
24
+ class InpaintCAModel(Model):
25
+ def __init__(self):
26
+ super().__init__('InpaintCAModel')
27
+
28
+ def build_inpaint_net(self, x, mask, reuse=False,
29
+ training=True, padding='SAME', name='inpaint_net'):
30
+ """Inpaint network.
31
+
32
+ Args:
33
+ x: incomplete image, [-1, 1]
34
+ mask: mask region {0, 1}
35
+ Returns:
36
+ [-1, 1] as predicted image
37
+ """
38
+ xin = x
39
+ offset_flow = None
40
+ ones_x = tf.ones_like(x)[:, :, :, 0:1]
41
+ x = tf.concat([x, ones_x, ones_x*mask], axis=3)
42
+
43
+ # two stage network
44
+ cnum = 48
45
+ with tf.variable_scope(name, reuse=reuse), \
46
+ arg_scope([gen_conv, gen_deconv],
47
+ training=training, padding=padding):
48
+ # stage1
49
+ x = gen_conv(x, cnum, 5, 1, name='conv1')
50
+ x = gen_conv(x, 2*cnum, 3, 2, name='conv2_downsample')
51
+ x = gen_conv(x, 2*cnum, 3, 1, name='conv3')
52
+ x = gen_conv(x, 4*cnum, 3, 2, name='conv4_downsample')
53
+ x = gen_conv(x, 4*cnum, 3, 1, name='conv5')
54
+ x = gen_conv(x, 4*cnum, 3, 1, name='conv6')
55
+ mask_s = resize_mask_like(mask, x)
56
+ x = gen_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous')
57
+ x = gen_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous')
58
+ x = gen_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous')
59
+ x = gen_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous')
60
+ x = gen_conv(x, 4*cnum, 3, 1, name='conv11')
61
+ x = gen_conv(x, 4*cnum, 3, 1, name='conv12')
62
+ x = gen_deconv(x, 2*cnum, name='conv13_upsample')
63
+ x = gen_conv(x, 2*cnum, 3, 1, name='conv14')
64
+ x = gen_deconv(x, cnum, name='conv15_upsample')
65
+ x = gen_conv(x, cnum//2, 3, 1, name='conv16')
66
+ x = gen_conv(x, 3, 3, 1, activation=None, name='conv17')
67
+ x = tf.nn.tanh(x)
68
+ x_stage1 = x
69
+
70
+ # stage2, paste result as input
71
+ x = x*mask + xin[:, :, :, 0:3]*(1.-mask)
72
+ x.set_shape(xin[:, :, :, 0:3].get_shape().as_list())
73
+ # conv branch
74
+ # xnow = tf.concat([x, ones_x, ones_x*mask], axis=3)
75
+ xnow = x
76
+ x = gen_conv(xnow, cnum, 5, 1, name='xconv1')
77
+ x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample')
78
+ x = gen_conv(x, 2*cnum, 3, 1, name='xconv3')
79
+ x = gen_conv(x, 2*cnum, 3, 2, name='xconv4_downsample')
80
+ x = gen_conv(x, 4*cnum, 3, 1, name='xconv5')
81
+ x = gen_conv(x, 4*cnum, 3, 1, name='xconv6')
82
+ x = gen_conv(x, 4*cnum, 3, rate=2, name='xconv7_atrous')
83
+ x = gen_conv(x, 4*cnum, 3, rate=4, name='xconv8_atrous')
84
+ x = gen_conv(x, 4*cnum, 3, rate=8, name='xconv9_atrous')
85
+ x = gen_conv(x, 4*cnum, 3, rate=16, name='xconv10_atrous')
86
+ x_hallu = x
87
+ # attention branch
88
+ x = gen_conv(xnow, cnum, 5, 1, name='pmconv1')
89
+ x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample')
90
+ x = gen_conv(x, 2*cnum, 3, 1, name='pmconv3')
91
+ x = gen_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample')
92
+ x = gen_conv(x, 4*cnum, 3, 1, name='pmconv5')
93
+ x = gen_conv(x, 4*cnum, 3, 1, name='pmconv6',
94
+ activation=tf.nn.relu)
95
+ x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
96
+ x = gen_conv(x, 4*cnum, 3, 1, name='pmconv9')
97
+ x = gen_conv(x, 4*cnum, 3, 1, name='pmconv10')
98
+ pm = x
99
+ x = tf.concat([x_hallu, pm], axis=3)
100
+
101
+ x = gen_conv(x, 4*cnum, 3, 1, name='allconv11')
102
+ x = gen_conv(x, 4*cnum, 3, 1, name='allconv12')
103
+ x = gen_deconv(x, 2*cnum, name='allconv13_upsample')
104
+ x = gen_conv(x, 2*cnum, 3, 1, name='allconv14')
105
+ x = gen_deconv(x, cnum, name='allconv15_upsample')
106
+ x = gen_conv(x, cnum//2, 3, 1, name='allconv16')
107
+ x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
108
+ x = tf.nn.tanh(x)
109
+ x_stage2 = x
110
+ return x_stage1, x_stage2, offset_flow
111
+
112
+ def build_sn_patch_gan_discriminator(self, x, reuse=False, training=True):
113
+ with tf.variable_scope('sn_patch_gan', reuse=reuse):
114
+ cnum = 64
115
+ x = dis_conv(x, cnum, name='conv1', training=training)
116
+ x = dis_conv(x, cnum*2, name='conv2', training=training)
117
+ x = dis_conv(x, cnum*4, name='conv3', training=training)
118
+ x = dis_conv(x, cnum*4, name='conv4', training=training)
119
+ x = dis_conv(x, cnum*4, name='conv5', training=training)
120
+ x = dis_conv(x, cnum*4, name='conv6', training=training)
121
+ x = flatten(x, name='flatten')
122
+ return x
123
+
124
+ def build_gan_discriminator(
125
+ self, batch, reuse=False, training=True):
126
+ with tf.variable_scope('discriminator', reuse=reuse):
127
+ d = self.build_sn_patch_gan_discriminator(
128
+ batch, reuse=reuse, training=training)
129
+ return d
130
+
131
+ def build_graph_with_losses(
132
+ self, FLAGS, batch_data, training=True, summary=False,
133
+ reuse=False):
134
+ if FLAGS.guided:
135
+ batch_data, edge = batch_data
136
+ edge = edge[:, :, :, 0:1] / 255.
137
+ edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
138
+ batch_pos = batch_data / 127.5 - 1.
139
+ # generate mask, 1 represents masked point
140
+ bbox = random_bbox(FLAGS)
141
+ regular_mask = bbox2mask(FLAGS, bbox, name='mask_c')
142
+ irregular_mask = brush_stroke_mask(FLAGS, name='mask_c')
143
+ mask = tf.cast(
144
+ tf.logical_or(
145
+ tf.cast(irregular_mask, tf.bool),
146
+ tf.cast(regular_mask, tf.bool),
147
+ ),
148
+ tf.float32
149
+ )
150
+
151
+ batch_incomplete = batch_pos*(1.-mask)
152
+ if FLAGS.guided:
153
+ edge = edge * mask
154
+ xin = tf.concat([batch_incomplete, edge], axis=3)
155
+ else:
156
+ xin = batch_incomplete
157
+ x1, x2, offset_flow = self.build_inpaint_net(
158
+ xin, mask, reuse=reuse, training=training,
159
+ padding=FLAGS.padding)
160
+ batch_predicted = x2
161
+ losses = {}
162
+ # apply mask and complete image
163
+ batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
164
+ # local patches
165
+ losses['ae_loss'] = FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x1))
166
+ losses['ae_loss'] += FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x2))
167
+ if summary:
168
+ scalar_summary('losses/ae_loss', losses['ae_loss'])
169
+ if FLAGS.guided:
170
+ viz_img = [
171
+ batch_pos,
172
+ batch_incomplete + edge,
173
+ batch_complete]
174
+ else:
175
+ viz_img = [batch_pos, batch_incomplete, batch_complete]
176
+ if offset_flow is not None:
177
+ viz_img.append(
178
+ resize(offset_flow, scale=4,
179
+ func=tf.image.resize_bilinear))
180
+ images_summary(
181
+ tf.concat(viz_img, axis=2),
182
+ 'raw_incomplete_predicted_complete', FLAGS.viz_max_out)
183
+
184
+ # gan
185
+ batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
186
+ if FLAGS.gan_with_mask:
187
+ batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [FLAGS.batch_size*2, 1, 1, 1])], axis=3)
188
+ if FLAGS.guided:
189
+ # conditional GANs
190
+ batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(edge, [2, 1, 1, 1])], axis=3)
191
+ # wgan with gradient penalty
192
+ if FLAGS.gan == 'sngan':
193
+ pos_neg = self.build_gan_discriminator(batch_pos_neg, training=training, reuse=reuse)
194
+ pos, neg = tf.split(pos_neg, 2)
195
+ g_loss, d_loss = gan_hinge_loss(pos, neg)
196
+ losses['g_loss'] = g_loss
197
+ losses['d_loss'] = d_loss
198
+ else:
199
+ raise NotImplementedError('{} not implemented.'.format(FLAGS.gan))
200
+ if summary:
201
+ # summary the magnitude of gradients from different losses w.r.t. predicted image
202
+ gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
203
+ gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
204
+ # gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
205
+ gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
206
+ losses['g_loss'] = FLAGS.gan_loss_alpha * losses['g_loss']
207
+ if FLAGS.ae_loss:
208
+ losses['g_loss'] += losses['ae_loss']
209
+ g_vars = tf.get_collection(
210
+ tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net')
211
+ d_vars = tf.get_collection(
212
+ tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
213
+ return g_vars, d_vars, losses
214
+
215
+ def build_infer_graph(self, FLAGS, batch_data, bbox=None, name='val'):
216
+ """
217
+ """
218
+ if FLAGS.guided:
219
+ batch_data, edge = batch_data
220
+ edge = edge[:, :, :, 0:1] / 255.
221
+ edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
222
+ regular_mask = bbox2mask(FLAGS, bbox, name='mask_c')
223
+ irregular_mask = brush_stroke_mask(FLAGS, name='mask_c')
224
+ mask = tf.cast(
225
+ tf.logical_or(
226
+ tf.cast(irregular_mask, tf.bool),
227
+ tf.cast(regular_mask, tf.bool),
228
+ ),
229
+ tf.float32
230
+ )
231
+
232
+ batch_pos = batch_data / 127.5 - 1.
233
+ batch_incomplete = batch_pos*(1.-mask)
234
+ if FLAGS.guided:
235
+ edge = edge * mask
236
+ xin = tf.concat([batch_incomplete, edge], axis=3)
237
+ else:
238
+ xin = batch_incomplete
239
+ # inpaint
240
+ x1, x2, offset_flow = self.build_inpaint_net(
241
+ xin, mask, reuse=True,
242
+ training=False, padding=FLAGS.padding)
243
+ batch_predicted = x2
244
+ # apply mask and reconstruct
245
+ batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
246
+ # global image visualization
247
+ if FLAGS.guided:
248
+ viz_img = [
249
+ batch_pos,
250
+ batch_incomplete + edge,
251
+ batch_complete]
252
+ else:
253
+ viz_img = [batch_pos, batch_incomplete, batch_complete]
254
+ if offset_flow is not None:
255
+ viz_img.append(
256
+ resize(offset_flow, scale=4,
257
+ func=tf.image.resize_bilinear))
258
+ images_summary(
259
+ tf.concat(viz_img, axis=2),
260
+ name+'_raw_incomplete_complete', FLAGS.viz_max_out)
261
+ return batch_complete
262
+
263
+ def build_static_infer_graph(self, FLAGS, batch_data, name):
264
+ """
265
+ """
266
+ # generate mask, 1 represents masked point
267
+ bbox = (tf.constant(FLAGS.height//2), tf.constant(FLAGS.width//2),
268
+ tf.constant(FLAGS.height), tf.constant(FLAGS.width))
269
+ return self.build_infer_graph(FLAGS, batch_data, bbox, name)
270
+
271
+
272
+ def build_server_graph(self, FLAGS, batch_data, reuse=False, is_training=False):
273
+ """
274
+ """
275
+ # generate mask, 1 represents masked point
276
+ if FLAGS.guided:
277
+ batch_raw, edge, masks_raw = tf.split(batch_data, 3, axis=2)
278
+ edge = edge[:, :, :, 0:1] / 255.
279
+ edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
280
+ else:
281
+ batch_raw, masks_raw = tf.split(batch_data, 2, axis=2)
282
+ masks = tf.cast(masks_raw[0:1, :, :, 0:1] > 127.5, tf.float32)
283
+
284
+ batch_pos = batch_raw / 127.5 - 1.
285
+ batch_incomplete = batch_pos * (1. - masks)
286
+ if FLAGS.guided:
287
+ edge = edge * masks[:, :, :, 0:1]
288
+ xin = tf.concat([batch_incomplete, edge], axis=3)
289
+ else:
290
+ xin = batch_incomplete
291
+ # inpaint
292
+ x1, x2, flow = self.build_inpaint_net(
293
+ xin, masks, reuse=reuse, training=is_training)
294
+ batch_predict = x2
295
+ # apply mask and reconstruct
296
+ batch_complete = batch_predict*masks + batch_incomplete*(1-masks)
297
+ return batch_complete
inpaint_ops.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import tensorflow as tf
7
+ from tensorflow.contrib.framework.python.ops import add_arg_scope
8
+ from PIL import Image, ImageDraw
9
+
10
+ from neuralgym.ops.layers import resize
11
+ from neuralgym.ops.layers import *
12
+ from neuralgym.ops.loss_ops import *
13
+ from neuralgym.ops.gan_ops import *
14
+ from neuralgym.ops.summary_ops import *
15
+
16
+
17
+ logger = logging.getLogger()
18
+ np.random.seed(2018)
19
+
20
+
21
+ @add_arg_scope
22
+ def gen_conv(x, cnum, ksize, stride=1, rate=1, name='conv',
23
+ padding='SAME', activation=tf.nn.elu, training=True):
24
+ """Define conv for generator.
25
+
26
+ Args:
27
+ x: Input.
28
+ cnum: Channel number.
29
+ ksize: Kernel size.
30
+ Stride: Convolution stride.
31
+ Rate: Rate for or dilated conv.
32
+ name: Name of layers.
33
+ padding: Default to SYMMETRIC.
34
+ activation: Activation function after convolution.
35
+ training: If current graph is for training or inference, used for bn.
36
+
37
+ Returns:
38
+ tf.Tensor: output
39
+
40
+ """
41
+ assert padding in ['SYMMETRIC', 'SAME', 'REFELECT']
42
+ if padding == 'SYMMETRIC' or padding == 'REFELECT':
43
+ p = int(rate*(ksize-1)/2)
44
+ x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding)
45
+ padding = 'VALID'
46
+ x = tf.layers.conv2d(
47
+ x, cnum, ksize, stride, dilation_rate=rate,
48
+ activation=None, padding=padding, name=name)
49
+ if cnum == 3 or activation is None:
50
+ # conv for output
51
+ return x
52
+ x, y = tf.split(x, 2, 3)
53
+ x = activation(x)
54
+ y = tf.nn.sigmoid(y)
55
+ x = x * y
56
+ return x
57
+
58
+
59
+ @add_arg_scope
60
+ def gen_deconv(x, cnum, name='upsample', padding='SAME', training=True):
61
+ """Define deconv for generator.
62
+ The deconv is defined to be a x2 resize_nearest_neighbor operation with
63
+ additional gen_conv operation.
64
+
65
+ Args:
66
+ x: Input.
67
+ cnum: Channel number.
68
+ name: Name of layers.
69
+ training: If current graph is for training or inference, used for bn.
70
+
71
+ Returns:
72
+ tf.Tensor: output
73
+
74
+ """
75
+ with tf.variable_scope(name):
76
+ x = resize(x, func=tf.image.resize_nearest_neighbor)
77
+ x = gen_conv(
78
+ x, cnum, 3, 1, name=name+'_conv', padding=padding,
79
+ training=training)
80
+ return x
81
+
82
+
83
+ @add_arg_scope
84
+ def dis_conv(x, cnum, ksize=5, stride=2, name='conv', training=True):
85
+ """Define conv for discriminator.
86
+ Activation is set to leaky_relu.
87
+
88
+ Args:
89
+ x: Input.
90
+ cnum: Channel number.
91
+ ksize: Kernel size.
92
+ Stride: Convolution stride.
93
+ name: Name of layers.
94
+ training: If current graph is for training or inference, used for bn.
95
+
96
+ Returns:
97
+ tf.Tensor: output
98
+
99
+ """
100
+ x = conv2d_spectral_norm(x, cnum, ksize, stride, 'SAME', name=name)
101
+ x = tf.nn.leaky_relu(x)
102
+ return x
103
+
104
+
105
+ def random_bbox(FLAGS):
106
+ """Generate a random tlhw.
107
+
108
+ Returns:
109
+ tuple: (top, left, height, width)
110
+
111
+ """
112
+ img_shape = FLAGS.img_shapes
113
+ img_height = img_shape[0]
114
+ img_width = img_shape[1]
115
+ maxt = img_height - FLAGS.vertical_margin - FLAGS.height
116
+ maxl = img_width - FLAGS.horizontal_margin - FLAGS.width
117
+ t = tf.random_uniform(
118
+ [], minval=FLAGS.vertical_margin, maxval=maxt, dtype=tf.int32)
119
+ l = tf.random_uniform(
120
+ [], minval=FLAGS.horizontal_margin, maxval=maxl, dtype=tf.int32)
121
+ h = tf.constant(FLAGS.height)
122
+ w = tf.constant(FLAGS.width)
123
+ return (t, l, h, w)
124
+
125
+
126
+ def bbox2mask(FLAGS, bbox, name='mask'):
127
+ """Generate mask tensor from bbox.
128
+
129
+ Args:
130
+ bbox: tuple, (top, left, height, width)
131
+
132
+ Returns:
133
+ tf.Tensor: output with shape [1, H, W, 1]
134
+
135
+ """
136
+ def npmask(bbox, height, width, delta_h, delta_w):
137
+ mask = np.zeros((1, height, width, 1), np.float32)
138
+ h = np.random.randint(delta_h//2+1)
139
+ w = np.random.randint(delta_w//2+1)
140
+ mask[:, bbox[0]+h:bbox[0]+bbox[2]-h,
141
+ bbox[1]+w:bbox[1]+bbox[3]-w, :] = 1.
142
+ return mask
143
+ with tf.variable_scope(name), tf.device('/cpu:0'):
144
+ img_shape = FLAGS.img_shapes
145
+ height = img_shape[0]
146
+ width = img_shape[1]
147
+ mask = tf.py_func(
148
+ npmask,
149
+ [bbox, height, width,
150
+ FLAGS.max_delta_height, FLAGS.max_delta_width],
151
+ tf.float32, stateful=False)
152
+ mask.set_shape([1] + [height, width] + [1])
153
+ return mask
154
+
155
+
156
+ def brush_stroke_mask(FLAGS, name='mask'):
157
+ """Generate mask tensor from bbox.
158
+
159
+ Returns:
160
+ tf.Tensor: output with shape [1, H, W, 1]
161
+
162
+ """
163
+ min_num_vertex = 4
164
+ max_num_vertex = 12
165
+ mean_angle = 2*math.pi / 5
166
+ angle_range = 2*math.pi / 15
167
+ min_width = 12
168
+ max_width = 40
169
+ def generate_mask(H, W):
170
+ average_radius = math.sqrt(H*H+W*W) / 8
171
+ mask = Image.new('L', (W, H), 0)
172
+
173
+ for _ in range(np.random.randint(1, 4)):
174
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
175
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
176
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
177
+ angles = []
178
+ vertex = []
179
+ for i in range(num_vertex):
180
+ if i % 2 == 0:
181
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
182
+ else:
183
+ angles.append(np.random.uniform(angle_min, angle_max))
184
+
185
+ h, w = mask.size
186
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
187
+ for i in range(num_vertex):
188
+ r = np.clip(
189
+ np.random.normal(loc=average_radius, scale=average_radius//2),
190
+ 0, 2*average_radius)
191
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
192
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
193
+ vertex.append((int(new_x), int(new_y)))
194
+
195
+ draw = ImageDraw.Draw(mask)
196
+ width = int(np.random.uniform(min_width, max_width))
197
+ draw.line(vertex, fill=1, width=width)
198
+ for v in vertex:
199
+ draw.ellipse((v[0] - width//2,
200
+ v[1] - width//2,
201
+ v[0] + width//2,
202
+ v[1] + width//2),
203
+ fill=1)
204
+
205
+ if np.random.normal() > 0:
206
+ mask.transpose(Image.FLIP_LEFT_RIGHT)
207
+ if np.random.normal() > 0:
208
+ mask.transpose(Image.FLIP_TOP_BOTTOM)
209
+ mask = np.asarray(mask, np.float32)
210
+ mask = np.reshape(mask, (1, H, W, 1))
211
+ return mask
212
+ with tf.variable_scope(name), tf.device('/cpu:0'):
213
+ img_shape = FLAGS.img_shapes
214
+ height = img_shape[0]
215
+ width = img_shape[1]
216
+ mask = tf.py_func(
217
+ generate_mask,
218
+ [height, width],
219
+ tf.float32, stateful=True)
220
+ mask.set_shape([1] + [height, width] + [1])
221
+ return mask
222
+
223
+
224
+ def local_patch(x, bbox):
225
+ """Crop local patch according to bbox.
226
+
227
+ Args:
228
+ x: input
229
+ bbox: (top, left, height, width)
230
+
231
+ Returns:
232
+ tf.Tensor: local patch
233
+
234
+ """
235
+ x = tf.image.crop_to_bounding_box(x, bbox[0], bbox[1], bbox[2], bbox[3])
236
+ return x
237
+
238
+
239
+ def resize_mask_like(mask, x):
240
+ """Resize mask like shape of x.
241
+
242
+ Args:
243
+ mask: Original mask.
244
+ x: To shape of x.
245
+
246
+ Returns:
247
+ tf.Tensor: resized mask
248
+
249
+ """
250
+ mask_resize = resize(
251
+ mask, to_shape=x.get_shape().as_list()[1:3],
252
+ func=tf.image.resize_nearest_neighbor)
253
+ return mask_resize
254
+
255
+
256
+ def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1,
257
+ fuse_k=3, softmax_scale=10., training=True, fuse=True):
258
+ """ Contextual attention layer implementation.
259
+
260
+ Contextual attention is first introduced in publication:
261
+ Generative Image Inpainting with Contextual Attention, Yu et al.
262
+
263
+ Args:
264
+ x: Input feature to match (foreground).
265
+ t: Input feature for match (background).
266
+ mask: Input mask for t, indicating patches not available.
267
+ ksize: Kernel size for contextual attention.
268
+ stride: Stride for extracting patches from t.
269
+ rate: Dilation for matching.
270
+ softmax_scale: Scaled softmax for attention.
271
+ training: Indicating if current graph is training or inference.
272
+
273
+ Returns:
274
+ tf.Tensor: output
275
+
276
+ """
277
+ # get shapes
278
+ raw_fs = tf.shape(f)
279
+ raw_int_fs = f.get_shape().as_list()
280
+ raw_int_bs = b.get_shape().as_list()
281
+ # extract patches from background with stride and rate
282
+ kernel = 2*rate
283
+ raw_w = tf.extract_image_patches(
284
+ b, [1,kernel,kernel,1], [1,rate*stride,rate*stride,1], [1,1,1,1], padding='SAME')
285
+ raw_w = tf.reshape(raw_w, [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]])
286
+ raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw
287
+ # downscaling foreground option: downscaling both foreground and
288
+ # background for matching and use original background for reconstruction.
289
+ f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor)
290
+ b = resize(b, to_shape=[int(raw_int_bs[1]/rate), int(raw_int_bs[2]/rate)], func=tf.image.resize_nearest_neighbor) # https://github.com/tensorflow/tensorflow/issues/11651
291
+ if mask is not None:
292
+ mask = resize(mask, scale=1./rate, func=tf.image.resize_nearest_neighbor)
293
+ fs = tf.shape(f)
294
+ int_fs = f.get_shape().as_list()
295
+ f_groups = tf.split(f, int_fs[0], axis=0)
296
+ # from t(H*W*C) to w(b*k*k*c*h*w)
297
+ bs = tf.shape(b)
298
+ int_bs = b.get_shape().as_list()
299
+ w = tf.extract_image_patches(
300
+ b, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
301
+ w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]])
302
+ w = tf.transpose(w, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw
303
+ # process mask
304
+ if mask is None:
305
+ mask = tf.zeros([1, bs[1], bs[2], 1])
306
+ m = tf.extract_image_patches(
307
+ mask, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
308
+ m = tf.reshape(m, [1, -1, ksize, ksize, 1])
309
+ m = tf.transpose(m, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw
310
+ m = m[0]
311
+ mm = tf.cast(tf.equal(tf.reduce_mean(m, axis=[0,1,2], keep_dims=True), 0.), tf.float32)
312
+ w_groups = tf.split(w, int_bs[0], axis=0)
313
+ raw_w_groups = tf.split(raw_w, int_bs[0], axis=0)
314
+ y = []
315
+ offsets = []
316
+ k = fuse_k
317
+ scale = softmax_scale
318
+ fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1])
319
+ for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
320
+ # conv for compare
321
+ wi = wi[0]
322
+ wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0,1,2])), 1e-4)
323
+ yi = tf.nn.conv2d(xi, wi_normed, strides=[1,1,1,1], padding="SAME")
324
+
325
+ # conv implementation for fuse scores to encourage large patches
326
+ if fuse:
327
+ yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
328
+ yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
329
+ yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]])
330
+ yi = tf.transpose(yi, [0, 2, 1, 4, 3])
331
+ yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
332
+ yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
333
+ yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]])
334
+ yi = tf.transpose(yi, [0, 2, 1, 4, 3])
335
+ yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1]*bs[2]])
336
+
337
+ # softmax to match
338
+ yi *= mm # mask
339
+ yi = tf.nn.softmax(yi*scale, 3)
340
+ yi *= mm # mask
341
+
342
+ offset = tf.argmax(yi, axis=3, output_type=tf.int32)
343
+ offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1)
344
+ # deconv for patch pasting
345
+ # 3.1 paste center
346
+ wi_center = raw_wi[0]
347
+ yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_fs[1:]], axis=0), strides=[1,rate,rate,1]) / 4.
348
+ y.append(yi)
349
+ offsets.append(offset)
350
+ y = tf.concat(y, axis=0)
351
+ y.set_shape(raw_int_fs)
352
+ offsets = tf.concat(offsets, axis=0)
353
+ offsets.set_shape(int_bs[:3] + [2])
354
+ # case1: visualize optical flow: minus current position
355
+ h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]), [bs[0], 1, bs[2], 1])
356
+ w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]), [bs[0], bs[1], 1, 1])
357
+ offsets = offsets - tf.concat([h_add, w_add], axis=3)
358
+ # to flow image
359
+ flow = flow_to_image_tf(offsets)
360
+ # # case2: visualize which pixels are attended
361
+ # flow = highlight_flow_tf(offsets * tf.cast(mask, tf.int32))
362
+ if rate != 1:
363
+ flow = resize(flow, scale=rate, func=tf.image.resize_bilinear)
364
+ return y, flow
365
+
366
+
367
+ def test_contextual_attention(args):
368
+ """Test contextual attention layer with 3-channel image input
369
+ (instead of n-channel feature).
370
+
371
+ """
372
+ import cv2
373
+ import os
374
+ # run on cpu
375
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
376
+
377
+ rate = 2
378
+ stride = 1
379
+ grid = rate*stride
380
+
381
+ b = cv2.imread(args.imageA)
382
+ b = cv2.resize(b, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC)
383
+ h, w, _ = b.shape
384
+ b = b[:h//grid*grid, :w//grid*grid, :]
385
+ b = np.expand_dims(b, 0)
386
+ logger.info('Size of imageA: {}'.format(b.shape))
387
+
388
+ f = cv2.imread(args.imageB)
389
+ h, w, _ = f.shape
390
+ f = f[:h//grid*grid, :w//grid*grid, :]
391
+ f = np.expand_dims(f, 0)
392
+ logger.info('Size of imageB: {}'.format(f.shape))
393
+
394
+ with tf.Session() as sess:
395
+ bt = tf.constant(b, dtype=tf.float32)
396
+ ft = tf.constant(f, dtype=tf.float32)
397
+
398
+ yt, flow = contextual_attention(
399
+ ft, bt, stride=stride, rate=rate,
400
+ training=False, fuse=False)
401
+ y = sess.run(yt)
402
+ cv2.imwrite(args.imageOut, y[0])
403
+
404
+
405
+ def make_color_wheel():
406
+ RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
407
+ ncols = RY + YG + GC + CB + BM + MR
408
+ colorwheel = np.zeros([ncols, 3])
409
+ col = 0
410
+ # RY
411
+ colorwheel[0:RY, 0] = 255
412
+ colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
413
+ col += RY
414
+ # YG
415
+ colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
416
+ colorwheel[col:col+YG, 1] = 255
417
+ col += YG
418
+ # GC
419
+ colorwheel[col:col+GC, 1] = 255
420
+ colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
421
+ col += GC
422
+ # CB
423
+ colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
424
+ colorwheel[col:col+CB, 2] = 255
425
+ col += CB
426
+ # BM
427
+ colorwheel[col:col+BM, 2] = 255
428
+ colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
429
+ col += + BM
430
+ # MR
431
+ colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
432
+ colorwheel[col:col+MR, 0] = 255
433
+ return colorwheel
434
+
435
+
436
+ COLORWHEEL = make_color_wheel()
437
+
438
+
439
+ def compute_color(u,v):
440
+ h, w = u.shape
441
+ img = np.zeros([h, w, 3])
442
+ nanIdx = np.isnan(u) | np.isnan(v)
443
+ u[nanIdx] = 0
444
+ v[nanIdx] = 0
445
+ # colorwheel = COLORWHEEL
446
+ colorwheel = make_color_wheel()
447
+ ncols = np.size(colorwheel, 0)
448
+ rad = np.sqrt(u**2+v**2)
449
+ a = np.arctan2(-v, -u) / np.pi
450
+ fk = (a+1) / 2 * (ncols - 1) + 1
451
+ k0 = np.floor(fk).astype(int)
452
+ k1 = k0 + 1
453
+ k1[k1 == ncols+1] = 1
454
+ f = fk - k0
455
+ for i in range(np.size(colorwheel,1)):
456
+ tmp = colorwheel[:, i]
457
+ col0 = tmp[k0-1] / 255
458
+ col1 = tmp[k1-1] / 255
459
+ col = (1-f) * col0 + f * col1
460
+ idx = rad <= 1
461
+ col[idx] = 1-rad[idx]*(1-col[idx])
462
+ notidx = np.logical_not(idx)
463
+ col[notidx] *= 0.75
464
+ img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
465
+ return img
466
+
467
+
468
+
469
+ def flow_to_image(flow):
470
+ """Transfer flow map to image.
471
+ Part of code forked from flownet.
472
+ """
473
+ out = []
474
+ maxu = -999.
475
+ maxv = -999.
476
+ minu = 999.
477
+ minv = 999.
478
+ maxrad = -1
479
+ for i in range(flow.shape[0]):
480
+ u = flow[i, :, :, 0]
481
+ v = flow[i, :, :, 1]
482
+ idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7)
483
+ u[idxunknow] = 0
484
+ v[idxunknow] = 0
485
+ maxu = max(maxu, np.max(u))
486
+ minu = min(minu, np.min(u))
487
+ maxv = max(maxv, np.max(v))
488
+ minv = min(minv, np.min(v))
489
+ rad = np.sqrt(u ** 2 + v ** 2)
490
+ maxrad = max(maxrad, np.max(rad))
491
+ u = u/(maxrad + np.finfo(float).eps)
492
+ v = v/(maxrad + np.finfo(float).eps)
493
+ img = compute_color(u, v)
494
+ out.append(img)
495
+ return np.float32(np.uint8(out))
496
+
497
+
498
+ def flow_to_image_tf(flow, name='flow_to_image'):
499
+ """Tensorflow ops for computing flow to image.
500
+ """
501
+ with tf.variable_scope(name), tf.device('/cpu:0'):
502
+ img = tf.py_func(flow_to_image, [flow], tf.float32, stateful=False)
503
+ img.set_shape(flow.get_shape().as_list()[0:-1]+[3])
504
+ img = img / 127.5 - 1.
505
+ return img
506
+
507
+
508
+ def highlight_flow(flow):
509
+ """Convert flow into middlebury color code image.
510
+ """
511
+ out = []
512
+ s = flow.shape
513
+ for i in range(flow.shape[0]):
514
+ img = np.ones((s[1], s[2], 3)) * 144.
515
+ u = flow[i, :, :, 0]
516
+ v = flow[i, :, :, 1]
517
+ for h in range(s[1]):
518
+ for w in range(s[1]):
519
+ ui = u[h,w]
520
+ vi = v[h,w]
521
+ img[ui, vi, :] = 255.
522
+ out.append(img)
523
+ return np.float32(np.uint8(out))
524
+
525
+
526
+ def highlight_flow_tf(flow, name='flow_to_image'):
527
+ """Tensorflow ops for highlight flow.
528
+ """
529
+ with tf.variable_scope(name), tf.device('/cpu:0'):
530
+ img = tf.py_func(highlight_flow, [flow], tf.float32, stateful=False)
531
+ img.set_shape(flow.get_shape().as_list()[0:-1]+[3])
532
+ img = img / 127.5 - 1.
533
+ return img
534
+
535
+
536
+ def image2edge(image):
537
+ """Convert image to edges.
538
+ """
539
+ out = []
540
+ for i in range(image.shape[0]):
541
+ img = cv2.Laplacian(image[i, :, :, :], cv2.CV_64F, ksize=3, scale=2)
542
+ out.append(img)
543
+ return np.float32(np.uint8(out))
544
+
545
+
546
+ if __name__ == "__main__":
547
+ import argparse
548
+ parser = argparse.ArgumentParser()
549
+ parser.add_argument('--imageA', default='', type=str, help='Image A as background patches to reconstruct image B.')
550
+ parser.add_argument('--imageB', default='', type=str, help='Image B is reconstructed with image A.')
551
+ parser.add_argument('--imageOut', default='result.png', type=str, help='Image B is reconstructed with image A.')
552
+ args = parser.parse_args()
553
+ test_contextual_attention(args)
main.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from PIL import Image
4
+ import cv2
5
+ import numpy as np
6
+ from preprocess_image import preprocess_image
7
+ import tensorflow as tf
8
+ import neuralgym as ng
9
+
10
+ from inpaint_model import InpaintCAModel
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--image', default='', type=str,
14
+ help='The filename of image to be completed.')
15
+ parser.add_argument('--output', default='output.png', type=str,
16
+ help='Where to write output.')
17
+ parser.add_argument('--watermark_type', default='istock', type=str,
18
+ help='The watermark type')
19
+ parser.add_argument('--checkpoint_dir', default='model/', type=str,
20
+ help='The directory of tensorflow checkpoint.')
21
+
22
+ #checkpoint_dir = 'model/'
23
+
24
+
25
+ if __name__ == "__main__":
26
+ FLAGS = ng.Config('inpaint.yml')
27
+ # ng.get_gpus(1)
28
+ args, unknown = parser.parse_known_args()
29
+
30
+ model = InpaintCAModel()
31
+ image = Image.open(args.image)
32
+ input_image = preprocess_image(image, args.watermark_type)
33
+ tf.reset_default_graph()
34
+
35
+ sess_config = tf.ConfigProto()
36
+ sess_config.gpu_options.allow_growth = True
37
+ if (input_image.shape != (0,)):
38
+ with tf.Session(config=sess_config) as sess:
39
+ input_image = tf.constant(input_image, dtype=tf.float32)
40
+ output = model.build_server_graph(FLAGS, input_image)
41
+ output = (output + 1.) * 127.5
42
+ output = tf.reverse(output, [-1])
43
+ output = tf.saturate_cast(output, tf.uint8)
44
+ # load pretrained model
45
+ vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
46
+ assign_ops = []
47
+ for var in vars_list:
48
+ vname = var.name
49
+ from_name = vname
50
+ var_value = tf.contrib.framework.load_variable(
51
+ args.checkpoint_dir, from_name)
52
+ assign_ops.append(tf.assign(var, var_value))
53
+ sess.run(assign_ops)
54
+ print('Model loaded.')
55
+ result = sess.run(output)
56
+ cv2.imwrite(args.output, cv2.cvtColor(
57
+ result[0][:, :, ::-1], cv2.COLOR_BGR2RGB))
58
+ print('image saved to {}'.format(args.output))
preprocess_image.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import cv2
4
+
5
+
6
+ def preprocess_image(image, watermark_type):
7
+ image_type: str = ''
8
+ preprocessed_mask_image = np.array([])
9
+ if image.mode != "RGB":
10
+ image = image.convert("RGB")
11
+ image = np.array(image)
12
+ image_h = image.shape[0]
13
+ image_w = image.shape[1]
14
+ aspectRatioImage = image_w / image_h
15
+ print("image size: {}".format(image.shape))
16
+
17
+ if image_w > image_h:
18
+ image_type = "landscape"
19
+ elif image_w == image_h:
20
+ image_type = "landscape"
21
+ else:
22
+ image_type = "potrait"
23
+
24
+ mask_image = Image.open(
25
+ "utils/{}/{}/mask.png".format(watermark_type, image_type))
26
+ if mask_image.mode != "RGB":
27
+ mask_image = mask_image.convert("RGB")
28
+ mask_image = np.array(mask_image)
29
+ print("mask image size: {}".format(mask_image.shape))
30
+
31
+ aspectRatioMaskImage = mask_image.shape[1] / mask_image.shape[0]
32
+ upperBoundAspectRatio = 1.05 * aspectRatioMaskImage
33
+ lowerBoundAspectRatio = 0.95 * aspectRatioMaskImage
34
+
35
+ if aspectRatioImage >= lowerBoundAspectRatio and aspectRatioImage <= upperBoundAspectRatio:
36
+ preprocessed_mask_image = cv2.resize(mask_image, (image_w, image_h))
37
+ print(preprocessed_mask_image.shape)
38
+ else:
39
+ print("Image size not supported!!!")
40
+
41
+ if (preprocessed_mask_image.shape != (0,)):
42
+ assert image.shape == preprocessed_mask_image.shape
43
+ grid = 8
44
+ image = image[:image_h//grid*grid, :image_w//grid*grid, :]
45
+ preprocessed_mask_image = preprocessed_mask_image[:image_h //
46
+ grid*grid, :image_w//grid*grid, :]
47
+ image = np.expand_dims(image, 0)
48
+ preprocessed_mask_image = np.expand_dims(preprocessed_mask_image, 0)
49
+ input_image = np.concatenate([image, preprocessed_mask_image], axis=2)
50
+ return input_image
51
+
52
+ else:
53
+ return preprocessed_mask_image
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tensorflow==1.15.5
2
+ opencv-python==4.9.0.80
utils/istock/landscape/mask.png ADDED