Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .gitignore +4 -0
- README.md +85 -12
- batch_test.py +92 -0
- guided_batch_test.py +95 -0
- inpaint.yml +70 -0
- inpaint_model.py +297 -0
- inpaint_ops.py +553 -0
- main.py +58 -0
- preprocess_image.py +53 -0
- requirements.txt +2 -0
- utils/istock/landscape/mask.png +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model
|
| 2 |
+
logs
|
| 3 |
+
data
|
| 4 |
+
load_model.py
|
README.md
CHANGED
|
@@ -1,12 +1,85 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Chimzuruoke Okafor
|
| 2 |
+
|
| 3 |
+
# Watermark-Removal
|
| 4 |
+
|
| 5 |
+
<p>
|
| 6 |
+
<a href="https://pepy.tech/project/prompttools" target="_blank"><img src="https://static.pepy.tech/badge/prompttools" alt="Total Downloads"/></a>
|
| 7 |
+
<a href="https://github.com/hegelai/watermark-removal"><img src="https://img.shields.io/github/stars/zuruoke/watermark-removal" /></a>
|
| 8 |
+
<a href="https://twitter.com/zuruoke_okafor"><img src="https://img.shields.io/twitter/follow/Zuruoke_Okafor?style=social"></a>
|
| 9 |
+
</p>
|
| 10 |
+
|
| 11 |
+

|
| 12 |
+

|
| 13 |
+

|
| 14 |
+
|
| 15 |
+
<a href="https://coff.ee/zuruokeokafor" target="_blank">
|
| 16 |
+
<img src="https://www.buymeacoffee.com/assets/img/custom_images/orange_img.png"
|
| 17 |
+
alt="Buy Me A Coffee"
|
| 18 |
+
style="height: 41px;width:174px;box-shadow:0px 3px 2px rgba(190,190,190,0.5);" />
|
| 19 |
+
</a>
|
| 20 |
+
|
| 21 |
+
An open source project that uses a machine learning based image inpainting methodology to remove watermark from images which is totally indistinguishable from the ground truth version of the image.
|
| 22 |
+
|
| 23 |
+
This project was inspired by the [Contextual Attention](https://arxiv.org/abs/1801.07892) (CVPR 2018) and [Gated Convolution](https://arxiv.org/abs/1806.03589) (ICCV 2019 Oral).
|
| 24 |
+
|
| 25 |
+
And also a shoutout to [Chu-Tak Li](https://chutakcode.wixsite.com/website) for his [Medium article series](https://towardsdatascience.com/10-papers-you-must-read-for-deep-image-inpainting-2e41c589ced0) that really gave me a deep insight into the image inpainting papers stated above
|
| 26 |
+
|
| 27 |
+
<img src="https://user-images.githubusercontent.com/51057490/140277713-c7d6e2b9-db62-4793-823a-25ed0c4e2771.png" width="45%"/> <img src="https://user-images.githubusercontent.com/51057490/140277781-5b5218bb-9044-4ec9-a349-eea93bc56d4a.png" width="45%"/> <img src="https://user-images.githubusercontent.com/51057490/140277929-3f187647-0e63-4bcb-b9f1-472f7558aae5.jpeg" width="45%"/> <img src="https://user-images.githubusercontent.com/51057490/140277957-6ddb7dec-25c8-42f1-8e39-be491d4f2248.png" width="45%"/> <img src="https://user-images.githubusercontent.com/51057490/140277983-265a1c9e-6093-4154-8252-838baca21c41.jpeg" width="45%" /> <img src="https://user-images.githubusercontent.com/51057490/140278002-56c4ae3d-6bfb-4ba3-aa02-7bd28474bfdf.png" width="45%" /> <img src="https://user-images.githubusercontent.com/51057490/140278030-d2a962ce-3722-43f1-b1bd-0ffde2aa7026.jpeg" width="45%" /> <img src="https://user-images.githubusercontent.com/51057490/140278040-10e401d7-4b7d-4d81-91fe-e9f01ef4ce7f.png" width="45%" /> <img src="https://user-images.githubusercontent.com/51057490/140278017-34862de0-86eb-40f0-b04b-7dc02fe38a77.jpeg" width="45%" /> <img src="https://user-images.githubusercontent.com/51057490/140278011-e0ae9ed0-e4ed-44ed-a9ac-28eb8456797a.png" width="45%" />
|
| 28 |
+
|
| 29 |
+
## Run
|
| 30 |
+
|
| 31 |
+
- use [Google colab](https://research.google.com/colaboratory/)
|
| 32 |
+
|
| 33 |
+
- First of all, clone this repo
|
| 34 |
+
|
| 35 |
+
!git clone https://github.com/zuruoke/watermark-removal
|
| 36 |
+
|
| 37 |
+
- Change Directory to the repo
|
| 38 |
+
|
| 39 |
+
!cd watermark-removal
|
| 40 |
+
|
| 41 |
+
- Since Google Colab uses the latest Tensorflow 2x version and this project uses 1.15.0, downgrade to Tensorflow 1.15.0 version and restart the runtime, (`although the new version of Google Colab does not need you to restart the runtime`).
|
| 42 |
+
|
| 43 |
+
!pip install tensorflow==1.15.0
|
| 44 |
+
|
| 45 |
+
- Install tensorflow toolkit [neuralgym](https://github.com/JiahuiYu/neuralgym).
|
| 46 |
+
|
| 47 |
+
!pip install git+https://github.com/JiahuiYu/neuralgym
|
| 48 |
+
|
| 49 |
+
- Download the model dirs using this [link](https://drive.google.com/drive/folders/1xRV4EdjJuAfsX9pQme6XeoFznKXG0ptJ?usp=sharing) and put it under `model/` (rename `checkpoint.txt` to `checkpoint` because sometimes google drive automatically adds .txt after download)
|
| 50 |
+
|
| 51 |
+
And you're all Set!!
|
| 52 |
+
|
| 53 |
+
- Now remove the watermark on the image by runing the `main.py` file
|
| 54 |
+
|
| 55 |
+
!python main.py --image path-to-input-image --output path-to-output-image --checkpoint_dir model/ --watermark_type istock
|
| 56 |
+
|
| 57 |
+
## Citing
|
| 58 |
+
|
| 59 |
+
```
|
| 60 |
+
@article{yu2018generative,
|
| 61 |
+
title={Generative Image Inpainting with Contextual Attention},
|
| 62 |
+
author={Yu, Jiahui and Lin, Zhe and Yang, Jimei and Shen, Xiaohui and Lu, Xin and Huang, Thomas S},
|
| 63 |
+
journal={arXiv preprint arXiv:1801.07892},
|
| 64 |
+
year={2018}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
@article{yu2018free,
|
| 68 |
+
title={Free-Form Image Inpainting with Gated Convolution},
|
| 69 |
+
author={Yu, Jiahui and Lin, Zhe and Yang, Jimei and Shen, Xiaohui and Lu, Xin and Huang, Thomas S},
|
| 70 |
+
journal={arXiv preprint arXiv:1806.03589},
|
| 71 |
+
year={2018}
|
| 72 |
+
}
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
<p align="center">
|
| 76 |
+
<a href="https://star-history.com/#zuruoke/watermark-removal">
|
| 77 |
+
<picture>
|
| 78 |
+
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=zuruoke/watermark-removal&type=Date&theme=dark" />
|
| 79 |
+
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=zuruoke/watermark-removal&type=Date" />
|
| 80 |
+
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=zuruoke/watermark-removal&type=Date" />
|
| 81 |
+
</picture>
|
| 82 |
+
</a>
|
| 83 |
+
</p>
|
| 84 |
+
|
| 85 |
+
## © Chimzuruoke Okafor
|
batch_test.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tensorflow as tf
|
| 8 |
+
import neuralgym as ng
|
| 9 |
+
|
| 10 |
+
from inpaint_model import InpaintCAModel
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument(
|
| 15 |
+
'--flist', default='', type=str,
|
| 16 |
+
help='The filenames of image to be processed: input, mask, output.')
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
'--image_height', default=-1, type=int,
|
| 19 |
+
help='The height of images should be defined, otherwise batch mode is not'
|
| 20 |
+
' supported.')
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
'--image_width', default=-1, type=int,
|
| 23 |
+
help='The width of images should be defined, otherwise batch mode is not'
|
| 24 |
+
' supported.')
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
'--checkpoint_dir', default='', type=str,
|
| 27 |
+
help='The directory of tensorflow checkpoint.')
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
FLAGS = ng.Config('inpaint.yml')
|
| 32 |
+
ng.get_gpus(1)
|
| 33 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] =''
|
| 34 |
+
args = parser.parse_args()
|
| 35 |
+
|
| 36 |
+
sess_config = tf.ConfigProto()
|
| 37 |
+
sess_config.gpu_options.allow_growth = True
|
| 38 |
+
sess = tf.Session(config=sess_config)
|
| 39 |
+
|
| 40 |
+
model = InpaintCAModel()
|
| 41 |
+
input_image_ph = tf.placeholder(
|
| 42 |
+
tf.float32, shape=(1, args.image_height, args.image_width*2, 3))
|
| 43 |
+
output = model.build_server_graph(FLAGS, input_image_ph)
|
| 44 |
+
output = (output + 1.) * 127.5
|
| 45 |
+
output = tf.reverse(output, [-1])
|
| 46 |
+
output = tf.saturate_cast(output, tf.uint8)
|
| 47 |
+
vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
| 48 |
+
assign_ops = []
|
| 49 |
+
for var in vars_list:
|
| 50 |
+
vname = var.name
|
| 51 |
+
from_name = vname
|
| 52 |
+
var_value = tf.contrib.framework.load_variable(
|
| 53 |
+
args.checkpoint_dir, from_name)
|
| 54 |
+
assign_ops.append(tf.assign(var, var_value))
|
| 55 |
+
sess.run(assign_ops)
|
| 56 |
+
print('Model loaded.')
|
| 57 |
+
|
| 58 |
+
with open(args.flist, 'r') as f:
|
| 59 |
+
lines = f.read().splitlines()
|
| 60 |
+
t = time.time()
|
| 61 |
+
for line in lines:
|
| 62 |
+
# for i in range(100):
|
| 63 |
+
image, mask, out = line.split()
|
| 64 |
+
base = os.path.basename(mask)
|
| 65 |
+
|
| 66 |
+
image = cv2.imread(image)
|
| 67 |
+
mask = cv2.imread(mask)
|
| 68 |
+
image = cv2.resize(image, (args.image_width, args.image_height))
|
| 69 |
+
mask = cv2.resize(mask, (args.image_width, args.image_height))
|
| 70 |
+
# cv2.imwrite(out, image*(1-mask/255.) + mask)
|
| 71 |
+
# # continue
|
| 72 |
+
# image = np.zeros((128, 256, 3))
|
| 73 |
+
# mask = np.zeros((128, 256, 3))
|
| 74 |
+
|
| 75 |
+
assert image.shape == mask.shape
|
| 76 |
+
|
| 77 |
+
h, w, _ = image.shape
|
| 78 |
+
grid = 4
|
| 79 |
+
image = image[:h//grid*grid, :w//grid*grid, :]
|
| 80 |
+
mask = mask[:h//grid*grid, :w//grid*grid, :]
|
| 81 |
+
print('Shape of image: {}'.format(image.shape))
|
| 82 |
+
|
| 83 |
+
image = np.expand_dims(image, 0)
|
| 84 |
+
mask = np.expand_dims(mask, 0)
|
| 85 |
+
input_image = np.concatenate([image, mask], axis=2)
|
| 86 |
+
|
| 87 |
+
# load pretrained model
|
| 88 |
+
result = sess.run(output, feed_dict={input_image_ph: input_image})
|
| 89 |
+
print('Processed: {}'.format(out))
|
| 90 |
+
cv2.imwrite(out, result[0][:, :, ::-1])
|
| 91 |
+
|
| 92 |
+
print('Time total: {}'.format(time.time() - t))
|
guided_batch_test.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tensorflow as tf
|
| 8 |
+
import neuralgym as ng
|
| 9 |
+
|
| 10 |
+
from inpaint_model import InpaintCAModel
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument(
|
| 15 |
+
'--flist', default='', type=str,
|
| 16 |
+
help='The filenames of image to be processed: input, mask, output.')
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
'--image_height', default=-1, type=int,
|
| 19 |
+
help='The height of images should be defined, otherwise batch mode is not'
|
| 20 |
+
' supported.')
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
'--image_width', default=-1, type=int,
|
| 23 |
+
help='The width of images should be defined, otherwise batch mode is not'
|
| 24 |
+
' supported.')
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
'--checkpoint_dir', default='', type=str,
|
| 27 |
+
help='The directory of tensorflow checkpoint.')
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
ng.get_gpus(1)
|
| 32 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] =''
|
| 33 |
+
args = parser.parse_args()
|
| 34 |
+
|
| 35 |
+
sess_config = tf.ConfigProto()
|
| 36 |
+
sess_config.gpu_options.allow_growth = True
|
| 37 |
+
sess = tf.Session(config=sess_config)
|
| 38 |
+
|
| 39 |
+
model = InpaintCAModel()
|
| 40 |
+
input_image_ph = tf.placeholder(
|
| 41 |
+
tf.float32, shape=(1, args.image_height, args.image_width*3, 3))
|
| 42 |
+
output = model.build_server_graph(input_image_ph)
|
| 43 |
+
output = (output + 1.) * 127.5
|
| 44 |
+
output = tf.reverse(output, [-1])
|
| 45 |
+
output = tf.saturate_cast(output, tf.uint8)
|
| 46 |
+
vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
| 47 |
+
assign_ops = []
|
| 48 |
+
for var in vars_list:
|
| 49 |
+
vname = var.name
|
| 50 |
+
from_name = vname
|
| 51 |
+
var_value = tf.contrib.framework.load_variable(
|
| 52 |
+
args.checkpoint_dir, from_name)
|
| 53 |
+
assign_ops.append(tf.assign(var, var_value))
|
| 54 |
+
sess.run(assign_ops)
|
| 55 |
+
print('Model loaded.')
|
| 56 |
+
|
| 57 |
+
with open(args.flist, 'r') as f:
|
| 58 |
+
lines = f.read().splitlines()
|
| 59 |
+
t = time.time()
|
| 60 |
+
for line in lines:
|
| 61 |
+
# for i in range(100):
|
| 62 |
+
image, mask, out = line.split()
|
| 63 |
+
base = os.path.basename(mask)
|
| 64 |
+
|
| 65 |
+
guidance = cv2.imread(image[:-4] + '_edge.jpg')
|
| 66 |
+
image = cv2.imread(image)
|
| 67 |
+
mask = cv2.imread(mask)
|
| 68 |
+
image = cv2.resize(image, (args.image_width, args.image_height))
|
| 69 |
+
guidance = cv2.resize(guidance, (args.image_width, args.image_height))
|
| 70 |
+
mask = cv2.resize(mask, (args.image_width, args.image_height))
|
| 71 |
+
# cv2.imwrite(out, image*(1-mask/255.) + mask)
|
| 72 |
+
# # continue
|
| 73 |
+
# image = np.zeros((128, 256, 3))
|
| 74 |
+
# mask = np.zeros((128, 256, 3))
|
| 75 |
+
|
| 76 |
+
assert image.shape == mask.shape
|
| 77 |
+
|
| 78 |
+
h, w, _ = image.shape
|
| 79 |
+
grid = 4
|
| 80 |
+
image = image[:h//grid*grid, :w//grid*grid, :]
|
| 81 |
+
mask = mask[:h//grid*grid, :w//grid*grid, :]
|
| 82 |
+
guidance = guidance[:h//grid*grid, :w//grid*grid, :]
|
| 83 |
+
print('Shape of image: {}'.format(image.shape))
|
| 84 |
+
|
| 85 |
+
image = np.expand_dims(image, 0)
|
| 86 |
+
guidance = np.expand_dims(guidance, 0)
|
| 87 |
+
mask = np.expand_dims(mask, 0)
|
| 88 |
+
input_image = np.concatenate([image, guidance, mask], axis=2)
|
| 89 |
+
|
| 90 |
+
# load pretrained model
|
| 91 |
+
result = sess.run(output, feed_dict={input_image_ph: input_image})
|
| 92 |
+
print('Processed: {}'.format(out))
|
| 93 |
+
cv2.imwrite(out, result[0][:, :, ::-1])
|
| 94 |
+
|
| 95 |
+
print('Time total: {}'.format(time.time() - t))
|
inpaint.yml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =========================== Basic Settings ===========================
|
| 2 |
+
# machine info
|
| 3 |
+
num_gpus_per_job: 1 # number of gpus each job need
|
| 4 |
+
num_cpus_per_job: 4 # number of gpus each job need
|
| 5 |
+
num_hosts_per_job: 1
|
| 6 |
+
memory_per_job: 32 # number of gpus each job need
|
| 7 |
+
gpu_type: 'nvidia-tesla-p100'
|
| 8 |
+
|
| 9 |
+
# parameters
|
| 10 |
+
name: places2_gated_conv_v100 # any name
|
| 11 |
+
model_restore: '' # logs/places2_gated_conv
|
| 12 |
+
dataset: 'celebahq' # 'tmnist', 'dtd', 'places2', 'celeba', 'imagenet', 'cityscapes'
|
| 13 |
+
random_crop: False # Set to false when dataset is 'celebahq', meaning only resize the images to img_shapes, instead of crop img_shapes from a larger raw image. This is useful when you train on images with different resolutions like places2. In these cases, please set random_crop to true.
|
| 14 |
+
val: False # true if you want to view validation results in tensorboard
|
| 15 |
+
log_dir: logs/full_model_celeba_hq_256
|
| 16 |
+
|
| 17 |
+
gan: 'sngan'
|
| 18 |
+
gan_loss_alpha: 1
|
| 19 |
+
gan_with_mask: True
|
| 20 |
+
discounted_mask: True
|
| 21 |
+
random_seed: False
|
| 22 |
+
padding: 'SAME'
|
| 23 |
+
|
| 24 |
+
# training
|
| 25 |
+
train_spe: 4000
|
| 26 |
+
max_iters: 100000000
|
| 27 |
+
viz_max_out: 10
|
| 28 |
+
val_psteps: 2000
|
| 29 |
+
|
| 30 |
+
# data
|
| 31 |
+
data_flist:
|
| 32 |
+
# https://github.com/jiahuiyu/progressive_growing_of_gans_tf
|
| 33 |
+
celebahq: [
|
| 34 |
+
'data/celeba_hq/train_shuffled.flist',
|
| 35 |
+
'data/celeba_hq/validation_static_view.flist'
|
| 36 |
+
]
|
| 37 |
+
# http://mmlab.ie.cuhk.edu.hk/projects/celeba.html, please to use random_crop: True
|
| 38 |
+
celeba: [
|
| 39 |
+
'data/celeba/train_shuffled.flist',
|
| 40 |
+
'data/celeba/validation_static_view.flist'
|
| 41 |
+
]
|
| 42 |
+
# http://places2.csail.mit.edu/, please download the high-resolution dataset and use random_crop: True
|
| 43 |
+
places2: [
|
| 44 |
+
'data/places2/train_shuffled.flist',
|
| 45 |
+
'data/places2/validation_static_view.flist'
|
| 46 |
+
]
|
| 47 |
+
# http://www.image-net.org/, please use random_crop: True
|
| 48 |
+
imagenet: [
|
| 49 |
+
'data/imagenet/train_shuffled.flist',
|
| 50 |
+
'data/imagenet/validation_static_view.flist',
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
static_view_size: 30
|
| 54 |
+
img_shapes: [256, 256, 3]
|
| 55 |
+
height: 128
|
| 56 |
+
width: 128
|
| 57 |
+
max_delta_height: 32
|
| 58 |
+
max_delta_width: 32
|
| 59 |
+
batch_size: 16
|
| 60 |
+
vertical_margin: 0
|
| 61 |
+
horizontal_margin: 0
|
| 62 |
+
|
| 63 |
+
# loss
|
| 64 |
+
ae_loss: True
|
| 65 |
+
l1_loss: True
|
| 66 |
+
l1_loss_alpha: 1.
|
| 67 |
+
|
| 68 |
+
# to tune
|
| 69 |
+
guided: False
|
| 70 |
+
edge_threshold: 0.6
|
inpaint_model.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" common model for DCGAN """
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import neuralgym as ng
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
from tensorflow.contrib.framework.python.ops import arg_scope
|
| 8 |
+
|
| 9 |
+
from neuralgym.models import Model
|
| 10 |
+
from neuralgym.ops.summary_ops import scalar_summary, images_summary
|
| 11 |
+
from neuralgym.ops.summary_ops import gradients_summary
|
| 12 |
+
from neuralgym.ops.layers import flatten, resize
|
| 13 |
+
from neuralgym.ops.gan_ops import gan_hinge_loss
|
| 14 |
+
from neuralgym.ops.gan_ops import random_interpolates
|
| 15 |
+
|
| 16 |
+
from inpaint_ops import gen_conv, gen_deconv, dis_conv
|
| 17 |
+
from inpaint_ops import random_bbox, bbox2mask, local_patch, brush_stroke_mask
|
| 18 |
+
from inpaint_ops import resize_mask_like, contextual_attention
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class InpaintCAModel(Model):
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super().__init__('InpaintCAModel')
|
| 27 |
+
|
| 28 |
+
def build_inpaint_net(self, x, mask, reuse=False,
|
| 29 |
+
training=True, padding='SAME', name='inpaint_net'):
|
| 30 |
+
"""Inpaint network.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
x: incomplete image, [-1, 1]
|
| 34 |
+
mask: mask region {0, 1}
|
| 35 |
+
Returns:
|
| 36 |
+
[-1, 1] as predicted image
|
| 37 |
+
"""
|
| 38 |
+
xin = x
|
| 39 |
+
offset_flow = None
|
| 40 |
+
ones_x = tf.ones_like(x)[:, :, :, 0:1]
|
| 41 |
+
x = tf.concat([x, ones_x, ones_x*mask], axis=3)
|
| 42 |
+
|
| 43 |
+
# two stage network
|
| 44 |
+
cnum = 48
|
| 45 |
+
with tf.variable_scope(name, reuse=reuse), \
|
| 46 |
+
arg_scope([gen_conv, gen_deconv],
|
| 47 |
+
training=training, padding=padding):
|
| 48 |
+
# stage1
|
| 49 |
+
x = gen_conv(x, cnum, 5, 1, name='conv1')
|
| 50 |
+
x = gen_conv(x, 2*cnum, 3, 2, name='conv2_downsample')
|
| 51 |
+
x = gen_conv(x, 2*cnum, 3, 1, name='conv3')
|
| 52 |
+
x = gen_conv(x, 4*cnum, 3, 2, name='conv4_downsample')
|
| 53 |
+
x = gen_conv(x, 4*cnum, 3, 1, name='conv5')
|
| 54 |
+
x = gen_conv(x, 4*cnum, 3, 1, name='conv6')
|
| 55 |
+
mask_s = resize_mask_like(mask, x)
|
| 56 |
+
x = gen_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous')
|
| 57 |
+
x = gen_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous')
|
| 58 |
+
x = gen_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous')
|
| 59 |
+
x = gen_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous')
|
| 60 |
+
x = gen_conv(x, 4*cnum, 3, 1, name='conv11')
|
| 61 |
+
x = gen_conv(x, 4*cnum, 3, 1, name='conv12')
|
| 62 |
+
x = gen_deconv(x, 2*cnum, name='conv13_upsample')
|
| 63 |
+
x = gen_conv(x, 2*cnum, 3, 1, name='conv14')
|
| 64 |
+
x = gen_deconv(x, cnum, name='conv15_upsample')
|
| 65 |
+
x = gen_conv(x, cnum//2, 3, 1, name='conv16')
|
| 66 |
+
x = gen_conv(x, 3, 3, 1, activation=None, name='conv17')
|
| 67 |
+
x = tf.nn.tanh(x)
|
| 68 |
+
x_stage1 = x
|
| 69 |
+
|
| 70 |
+
# stage2, paste result as input
|
| 71 |
+
x = x*mask + xin[:, :, :, 0:3]*(1.-mask)
|
| 72 |
+
x.set_shape(xin[:, :, :, 0:3].get_shape().as_list())
|
| 73 |
+
# conv branch
|
| 74 |
+
# xnow = tf.concat([x, ones_x, ones_x*mask], axis=3)
|
| 75 |
+
xnow = x
|
| 76 |
+
x = gen_conv(xnow, cnum, 5, 1, name='xconv1')
|
| 77 |
+
x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample')
|
| 78 |
+
x = gen_conv(x, 2*cnum, 3, 1, name='xconv3')
|
| 79 |
+
x = gen_conv(x, 2*cnum, 3, 2, name='xconv4_downsample')
|
| 80 |
+
x = gen_conv(x, 4*cnum, 3, 1, name='xconv5')
|
| 81 |
+
x = gen_conv(x, 4*cnum, 3, 1, name='xconv6')
|
| 82 |
+
x = gen_conv(x, 4*cnum, 3, rate=2, name='xconv7_atrous')
|
| 83 |
+
x = gen_conv(x, 4*cnum, 3, rate=4, name='xconv8_atrous')
|
| 84 |
+
x = gen_conv(x, 4*cnum, 3, rate=8, name='xconv9_atrous')
|
| 85 |
+
x = gen_conv(x, 4*cnum, 3, rate=16, name='xconv10_atrous')
|
| 86 |
+
x_hallu = x
|
| 87 |
+
# attention branch
|
| 88 |
+
x = gen_conv(xnow, cnum, 5, 1, name='pmconv1')
|
| 89 |
+
x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample')
|
| 90 |
+
x = gen_conv(x, 2*cnum, 3, 1, name='pmconv3')
|
| 91 |
+
x = gen_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample')
|
| 92 |
+
x = gen_conv(x, 4*cnum, 3, 1, name='pmconv5')
|
| 93 |
+
x = gen_conv(x, 4*cnum, 3, 1, name='pmconv6',
|
| 94 |
+
activation=tf.nn.relu)
|
| 95 |
+
x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
|
| 96 |
+
x = gen_conv(x, 4*cnum, 3, 1, name='pmconv9')
|
| 97 |
+
x = gen_conv(x, 4*cnum, 3, 1, name='pmconv10')
|
| 98 |
+
pm = x
|
| 99 |
+
x = tf.concat([x_hallu, pm], axis=3)
|
| 100 |
+
|
| 101 |
+
x = gen_conv(x, 4*cnum, 3, 1, name='allconv11')
|
| 102 |
+
x = gen_conv(x, 4*cnum, 3, 1, name='allconv12')
|
| 103 |
+
x = gen_deconv(x, 2*cnum, name='allconv13_upsample')
|
| 104 |
+
x = gen_conv(x, 2*cnum, 3, 1, name='allconv14')
|
| 105 |
+
x = gen_deconv(x, cnum, name='allconv15_upsample')
|
| 106 |
+
x = gen_conv(x, cnum//2, 3, 1, name='allconv16')
|
| 107 |
+
x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
|
| 108 |
+
x = tf.nn.tanh(x)
|
| 109 |
+
x_stage2 = x
|
| 110 |
+
return x_stage1, x_stage2, offset_flow
|
| 111 |
+
|
| 112 |
+
def build_sn_patch_gan_discriminator(self, x, reuse=False, training=True):
|
| 113 |
+
with tf.variable_scope('sn_patch_gan', reuse=reuse):
|
| 114 |
+
cnum = 64
|
| 115 |
+
x = dis_conv(x, cnum, name='conv1', training=training)
|
| 116 |
+
x = dis_conv(x, cnum*2, name='conv2', training=training)
|
| 117 |
+
x = dis_conv(x, cnum*4, name='conv3', training=training)
|
| 118 |
+
x = dis_conv(x, cnum*4, name='conv4', training=training)
|
| 119 |
+
x = dis_conv(x, cnum*4, name='conv5', training=training)
|
| 120 |
+
x = dis_conv(x, cnum*4, name='conv6', training=training)
|
| 121 |
+
x = flatten(x, name='flatten')
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
def build_gan_discriminator(
|
| 125 |
+
self, batch, reuse=False, training=True):
|
| 126 |
+
with tf.variable_scope('discriminator', reuse=reuse):
|
| 127 |
+
d = self.build_sn_patch_gan_discriminator(
|
| 128 |
+
batch, reuse=reuse, training=training)
|
| 129 |
+
return d
|
| 130 |
+
|
| 131 |
+
def build_graph_with_losses(
|
| 132 |
+
self, FLAGS, batch_data, training=True, summary=False,
|
| 133 |
+
reuse=False):
|
| 134 |
+
if FLAGS.guided:
|
| 135 |
+
batch_data, edge = batch_data
|
| 136 |
+
edge = edge[:, :, :, 0:1] / 255.
|
| 137 |
+
edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
|
| 138 |
+
batch_pos = batch_data / 127.5 - 1.
|
| 139 |
+
# generate mask, 1 represents masked point
|
| 140 |
+
bbox = random_bbox(FLAGS)
|
| 141 |
+
regular_mask = bbox2mask(FLAGS, bbox, name='mask_c')
|
| 142 |
+
irregular_mask = brush_stroke_mask(FLAGS, name='mask_c')
|
| 143 |
+
mask = tf.cast(
|
| 144 |
+
tf.logical_or(
|
| 145 |
+
tf.cast(irregular_mask, tf.bool),
|
| 146 |
+
tf.cast(regular_mask, tf.bool),
|
| 147 |
+
),
|
| 148 |
+
tf.float32
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
batch_incomplete = batch_pos*(1.-mask)
|
| 152 |
+
if FLAGS.guided:
|
| 153 |
+
edge = edge * mask
|
| 154 |
+
xin = tf.concat([batch_incomplete, edge], axis=3)
|
| 155 |
+
else:
|
| 156 |
+
xin = batch_incomplete
|
| 157 |
+
x1, x2, offset_flow = self.build_inpaint_net(
|
| 158 |
+
xin, mask, reuse=reuse, training=training,
|
| 159 |
+
padding=FLAGS.padding)
|
| 160 |
+
batch_predicted = x2
|
| 161 |
+
losses = {}
|
| 162 |
+
# apply mask and complete image
|
| 163 |
+
batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
|
| 164 |
+
# local patches
|
| 165 |
+
losses['ae_loss'] = FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x1))
|
| 166 |
+
losses['ae_loss'] += FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x2))
|
| 167 |
+
if summary:
|
| 168 |
+
scalar_summary('losses/ae_loss', losses['ae_loss'])
|
| 169 |
+
if FLAGS.guided:
|
| 170 |
+
viz_img = [
|
| 171 |
+
batch_pos,
|
| 172 |
+
batch_incomplete + edge,
|
| 173 |
+
batch_complete]
|
| 174 |
+
else:
|
| 175 |
+
viz_img = [batch_pos, batch_incomplete, batch_complete]
|
| 176 |
+
if offset_flow is not None:
|
| 177 |
+
viz_img.append(
|
| 178 |
+
resize(offset_flow, scale=4,
|
| 179 |
+
func=tf.image.resize_bilinear))
|
| 180 |
+
images_summary(
|
| 181 |
+
tf.concat(viz_img, axis=2),
|
| 182 |
+
'raw_incomplete_predicted_complete', FLAGS.viz_max_out)
|
| 183 |
+
|
| 184 |
+
# gan
|
| 185 |
+
batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
|
| 186 |
+
if FLAGS.gan_with_mask:
|
| 187 |
+
batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [FLAGS.batch_size*2, 1, 1, 1])], axis=3)
|
| 188 |
+
if FLAGS.guided:
|
| 189 |
+
# conditional GANs
|
| 190 |
+
batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(edge, [2, 1, 1, 1])], axis=3)
|
| 191 |
+
# wgan with gradient penalty
|
| 192 |
+
if FLAGS.gan == 'sngan':
|
| 193 |
+
pos_neg = self.build_gan_discriminator(batch_pos_neg, training=training, reuse=reuse)
|
| 194 |
+
pos, neg = tf.split(pos_neg, 2)
|
| 195 |
+
g_loss, d_loss = gan_hinge_loss(pos, neg)
|
| 196 |
+
losses['g_loss'] = g_loss
|
| 197 |
+
losses['d_loss'] = d_loss
|
| 198 |
+
else:
|
| 199 |
+
raise NotImplementedError('{} not implemented.'.format(FLAGS.gan))
|
| 200 |
+
if summary:
|
| 201 |
+
# summary the magnitude of gradients from different losses w.r.t. predicted image
|
| 202 |
+
gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
|
| 203 |
+
gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
|
| 204 |
+
# gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
|
| 205 |
+
gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
|
| 206 |
+
losses['g_loss'] = FLAGS.gan_loss_alpha * losses['g_loss']
|
| 207 |
+
if FLAGS.ae_loss:
|
| 208 |
+
losses['g_loss'] += losses['ae_loss']
|
| 209 |
+
g_vars = tf.get_collection(
|
| 210 |
+
tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net')
|
| 211 |
+
d_vars = tf.get_collection(
|
| 212 |
+
tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
|
| 213 |
+
return g_vars, d_vars, losses
|
| 214 |
+
|
| 215 |
+
def build_infer_graph(self, FLAGS, batch_data, bbox=None, name='val'):
|
| 216 |
+
"""
|
| 217 |
+
"""
|
| 218 |
+
if FLAGS.guided:
|
| 219 |
+
batch_data, edge = batch_data
|
| 220 |
+
edge = edge[:, :, :, 0:1] / 255.
|
| 221 |
+
edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
|
| 222 |
+
regular_mask = bbox2mask(FLAGS, bbox, name='mask_c')
|
| 223 |
+
irregular_mask = brush_stroke_mask(FLAGS, name='mask_c')
|
| 224 |
+
mask = tf.cast(
|
| 225 |
+
tf.logical_or(
|
| 226 |
+
tf.cast(irregular_mask, tf.bool),
|
| 227 |
+
tf.cast(regular_mask, tf.bool),
|
| 228 |
+
),
|
| 229 |
+
tf.float32
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
batch_pos = batch_data / 127.5 - 1.
|
| 233 |
+
batch_incomplete = batch_pos*(1.-mask)
|
| 234 |
+
if FLAGS.guided:
|
| 235 |
+
edge = edge * mask
|
| 236 |
+
xin = tf.concat([batch_incomplete, edge], axis=3)
|
| 237 |
+
else:
|
| 238 |
+
xin = batch_incomplete
|
| 239 |
+
# inpaint
|
| 240 |
+
x1, x2, offset_flow = self.build_inpaint_net(
|
| 241 |
+
xin, mask, reuse=True,
|
| 242 |
+
training=False, padding=FLAGS.padding)
|
| 243 |
+
batch_predicted = x2
|
| 244 |
+
# apply mask and reconstruct
|
| 245 |
+
batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
|
| 246 |
+
# global image visualization
|
| 247 |
+
if FLAGS.guided:
|
| 248 |
+
viz_img = [
|
| 249 |
+
batch_pos,
|
| 250 |
+
batch_incomplete + edge,
|
| 251 |
+
batch_complete]
|
| 252 |
+
else:
|
| 253 |
+
viz_img = [batch_pos, batch_incomplete, batch_complete]
|
| 254 |
+
if offset_flow is not None:
|
| 255 |
+
viz_img.append(
|
| 256 |
+
resize(offset_flow, scale=4,
|
| 257 |
+
func=tf.image.resize_bilinear))
|
| 258 |
+
images_summary(
|
| 259 |
+
tf.concat(viz_img, axis=2),
|
| 260 |
+
name+'_raw_incomplete_complete', FLAGS.viz_max_out)
|
| 261 |
+
return batch_complete
|
| 262 |
+
|
| 263 |
+
def build_static_infer_graph(self, FLAGS, batch_data, name):
|
| 264 |
+
"""
|
| 265 |
+
"""
|
| 266 |
+
# generate mask, 1 represents masked point
|
| 267 |
+
bbox = (tf.constant(FLAGS.height//2), tf.constant(FLAGS.width//2),
|
| 268 |
+
tf.constant(FLAGS.height), tf.constant(FLAGS.width))
|
| 269 |
+
return self.build_infer_graph(FLAGS, batch_data, bbox, name)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def build_server_graph(self, FLAGS, batch_data, reuse=False, is_training=False):
|
| 273 |
+
"""
|
| 274 |
+
"""
|
| 275 |
+
# generate mask, 1 represents masked point
|
| 276 |
+
if FLAGS.guided:
|
| 277 |
+
batch_raw, edge, masks_raw = tf.split(batch_data, 3, axis=2)
|
| 278 |
+
edge = edge[:, :, :, 0:1] / 255.
|
| 279 |
+
edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
|
| 280 |
+
else:
|
| 281 |
+
batch_raw, masks_raw = tf.split(batch_data, 2, axis=2)
|
| 282 |
+
masks = tf.cast(masks_raw[0:1, :, :, 0:1] > 127.5, tf.float32)
|
| 283 |
+
|
| 284 |
+
batch_pos = batch_raw / 127.5 - 1.
|
| 285 |
+
batch_incomplete = batch_pos * (1. - masks)
|
| 286 |
+
if FLAGS.guided:
|
| 287 |
+
edge = edge * masks[:, :, :, 0:1]
|
| 288 |
+
xin = tf.concat([batch_incomplete, edge], axis=3)
|
| 289 |
+
else:
|
| 290 |
+
xin = batch_incomplete
|
| 291 |
+
# inpaint
|
| 292 |
+
x1, x2, flow = self.build_inpaint_net(
|
| 293 |
+
xin, masks, reuse=reuse, training=is_training)
|
| 294 |
+
batch_predict = x2
|
| 295 |
+
# apply mask and reconstruct
|
| 296 |
+
batch_complete = batch_predict*masks + batch_incomplete*(1-masks)
|
| 297 |
+
return batch_complete
|
inpaint_ops.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
from tensorflow.contrib.framework.python.ops import add_arg_scope
|
| 8 |
+
from PIL import Image, ImageDraw
|
| 9 |
+
|
| 10 |
+
from neuralgym.ops.layers import resize
|
| 11 |
+
from neuralgym.ops.layers import *
|
| 12 |
+
from neuralgym.ops.loss_ops import *
|
| 13 |
+
from neuralgym.ops.gan_ops import *
|
| 14 |
+
from neuralgym.ops.summary_ops import *
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger()
|
| 18 |
+
np.random.seed(2018)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@add_arg_scope
|
| 22 |
+
def gen_conv(x, cnum, ksize, stride=1, rate=1, name='conv',
|
| 23 |
+
padding='SAME', activation=tf.nn.elu, training=True):
|
| 24 |
+
"""Define conv for generator.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
x: Input.
|
| 28 |
+
cnum: Channel number.
|
| 29 |
+
ksize: Kernel size.
|
| 30 |
+
Stride: Convolution stride.
|
| 31 |
+
Rate: Rate for or dilated conv.
|
| 32 |
+
name: Name of layers.
|
| 33 |
+
padding: Default to SYMMETRIC.
|
| 34 |
+
activation: Activation function after convolution.
|
| 35 |
+
training: If current graph is for training or inference, used for bn.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
tf.Tensor: output
|
| 39 |
+
|
| 40 |
+
"""
|
| 41 |
+
assert padding in ['SYMMETRIC', 'SAME', 'REFELECT']
|
| 42 |
+
if padding == 'SYMMETRIC' or padding == 'REFELECT':
|
| 43 |
+
p = int(rate*(ksize-1)/2)
|
| 44 |
+
x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding)
|
| 45 |
+
padding = 'VALID'
|
| 46 |
+
x = tf.layers.conv2d(
|
| 47 |
+
x, cnum, ksize, stride, dilation_rate=rate,
|
| 48 |
+
activation=None, padding=padding, name=name)
|
| 49 |
+
if cnum == 3 or activation is None:
|
| 50 |
+
# conv for output
|
| 51 |
+
return x
|
| 52 |
+
x, y = tf.split(x, 2, 3)
|
| 53 |
+
x = activation(x)
|
| 54 |
+
y = tf.nn.sigmoid(y)
|
| 55 |
+
x = x * y
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@add_arg_scope
|
| 60 |
+
def gen_deconv(x, cnum, name='upsample', padding='SAME', training=True):
|
| 61 |
+
"""Define deconv for generator.
|
| 62 |
+
The deconv is defined to be a x2 resize_nearest_neighbor operation with
|
| 63 |
+
additional gen_conv operation.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
x: Input.
|
| 67 |
+
cnum: Channel number.
|
| 68 |
+
name: Name of layers.
|
| 69 |
+
training: If current graph is for training or inference, used for bn.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
tf.Tensor: output
|
| 73 |
+
|
| 74 |
+
"""
|
| 75 |
+
with tf.variable_scope(name):
|
| 76 |
+
x = resize(x, func=tf.image.resize_nearest_neighbor)
|
| 77 |
+
x = gen_conv(
|
| 78 |
+
x, cnum, 3, 1, name=name+'_conv', padding=padding,
|
| 79 |
+
training=training)
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@add_arg_scope
|
| 84 |
+
def dis_conv(x, cnum, ksize=5, stride=2, name='conv', training=True):
|
| 85 |
+
"""Define conv for discriminator.
|
| 86 |
+
Activation is set to leaky_relu.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
x: Input.
|
| 90 |
+
cnum: Channel number.
|
| 91 |
+
ksize: Kernel size.
|
| 92 |
+
Stride: Convolution stride.
|
| 93 |
+
name: Name of layers.
|
| 94 |
+
training: If current graph is for training or inference, used for bn.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
tf.Tensor: output
|
| 98 |
+
|
| 99 |
+
"""
|
| 100 |
+
x = conv2d_spectral_norm(x, cnum, ksize, stride, 'SAME', name=name)
|
| 101 |
+
x = tf.nn.leaky_relu(x)
|
| 102 |
+
return x
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def random_bbox(FLAGS):
|
| 106 |
+
"""Generate a random tlhw.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
tuple: (top, left, height, width)
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
img_shape = FLAGS.img_shapes
|
| 113 |
+
img_height = img_shape[0]
|
| 114 |
+
img_width = img_shape[1]
|
| 115 |
+
maxt = img_height - FLAGS.vertical_margin - FLAGS.height
|
| 116 |
+
maxl = img_width - FLAGS.horizontal_margin - FLAGS.width
|
| 117 |
+
t = tf.random_uniform(
|
| 118 |
+
[], minval=FLAGS.vertical_margin, maxval=maxt, dtype=tf.int32)
|
| 119 |
+
l = tf.random_uniform(
|
| 120 |
+
[], minval=FLAGS.horizontal_margin, maxval=maxl, dtype=tf.int32)
|
| 121 |
+
h = tf.constant(FLAGS.height)
|
| 122 |
+
w = tf.constant(FLAGS.width)
|
| 123 |
+
return (t, l, h, w)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def bbox2mask(FLAGS, bbox, name='mask'):
|
| 127 |
+
"""Generate mask tensor from bbox.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
bbox: tuple, (top, left, height, width)
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
tf.Tensor: output with shape [1, H, W, 1]
|
| 134 |
+
|
| 135 |
+
"""
|
| 136 |
+
def npmask(bbox, height, width, delta_h, delta_w):
|
| 137 |
+
mask = np.zeros((1, height, width, 1), np.float32)
|
| 138 |
+
h = np.random.randint(delta_h//2+1)
|
| 139 |
+
w = np.random.randint(delta_w//2+1)
|
| 140 |
+
mask[:, bbox[0]+h:bbox[0]+bbox[2]-h,
|
| 141 |
+
bbox[1]+w:bbox[1]+bbox[3]-w, :] = 1.
|
| 142 |
+
return mask
|
| 143 |
+
with tf.variable_scope(name), tf.device('/cpu:0'):
|
| 144 |
+
img_shape = FLAGS.img_shapes
|
| 145 |
+
height = img_shape[0]
|
| 146 |
+
width = img_shape[1]
|
| 147 |
+
mask = tf.py_func(
|
| 148 |
+
npmask,
|
| 149 |
+
[bbox, height, width,
|
| 150 |
+
FLAGS.max_delta_height, FLAGS.max_delta_width],
|
| 151 |
+
tf.float32, stateful=False)
|
| 152 |
+
mask.set_shape([1] + [height, width] + [1])
|
| 153 |
+
return mask
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def brush_stroke_mask(FLAGS, name='mask'):
|
| 157 |
+
"""Generate mask tensor from bbox.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
tf.Tensor: output with shape [1, H, W, 1]
|
| 161 |
+
|
| 162 |
+
"""
|
| 163 |
+
min_num_vertex = 4
|
| 164 |
+
max_num_vertex = 12
|
| 165 |
+
mean_angle = 2*math.pi / 5
|
| 166 |
+
angle_range = 2*math.pi / 15
|
| 167 |
+
min_width = 12
|
| 168 |
+
max_width = 40
|
| 169 |
+
def generate_mask(H, W):
|
| 170 |
+
average_radius = math.sqrt(H*H+W*W) / 8
|
| 171 |
+
mask = Image.new('L', (W, H), 0)
|
| 172 |
+
|
| 173 |
+
for _ in range(np.random.randint(1, 4)):
|
| 174 |
+
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
|
| 175 |
+
angle_min = mean_angle - np.random.uniform(0, angle_range)
|
| 176 |
+
angle_max = mean_angle + np.random.uniform(0, angle_range)
|
| 177 |
+
angles = []
|
| 178 |
+
vertex = []
|
| 179 |
+
for i in range(num_vertex):
|
| 180 |
+
if i % 2 == 0:
|
| 181 |
+
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
|
| 182 |
+
else:
|
| 183 |
+
angles.append(np.random.uniform(angle_min, angle_max))
|
| 184 |
+
|
| 185 |
+
h, w = mask.size
|
| 186 |
+
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
|
| 187 |
+
for i in range(num_vertex):
|
| 188 |
+
r = np.clip(
|
| 189 |
+
np.random.normal(loc=average_radius, scale=average_radius//2),
|
| 190 |
+
0, 2*average_radius)
|
| 191 |
+
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
|
| 192 |
+
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
|
| 193 |
+
vertex.append((int(new_x), int(new_y)))
|
| 194 |
+
|
| 195 |
+
draw = ImageDraw.Draw(mask)
|
| 196 |
+
width = int(np.random.uniform(min_width, max_width))
|
| 197 |
+
draw.line(vertex, fill=1, width=width)
|
| 198 |
+
for v in vertex:
|
| 199 |
+
draw.ellipse((v[0] - width//2,
|
| 200 |
+
v[1] - width//2,
|
| 201 |
+
v[0] + width//2,
|
| 202 |
+
v[1] + width//2),
|
| 203 |
+
fill=1)
|
| 204 |
+
|
| 205 |
+
if np.random.normal() > 0:
|
| 206 |
+
mask.transpose(Image.FLIP_LEFT_RIGHT)
|
| 207 |
+
if np.random.normal() > 0:
|
| 208 |
+
mask.transpose(Image.FLIP_TOP_BOTTOM)
|
| 209 |
+
mask = np.asarray(mask, np.float32)
|
| 210 |
+
mask = np.reshape(mask, (1, H, W, 1))
|
| 211 |
+
return mask
|
| 212 |
+
with tf.variable_scope(name), tf.device('/cpu:0'):
|
| 213 |
+
img_shape = FLAGS.img_shapes
|
| 214 |
+
height = img_shape[0]
|
| 215 |
+
width = img_shape[1]
|
| 216 |
+
mask = tf.py_func(
|
| 217 |
+
generate_mask,
|
| 218 |
+
[height, width],
|
| 219 |
+
tf.float32, stateful=True)
|
| 220 |
+
mask.set_shape([1] + [height, width] + [1])
|
| 221 |
+
return mask
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def local_patch(x, bbox):
|
| 225 |
+
"""Crop local patch according to bbox.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
x: input
|
| 229 |
+
bbox: (top, left, height, width)
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
tf.Tensor: local patch
|
| 233 |
+
|
| 234 |
+
"""
|
| 235 |
+
x = tf.image.crop_to_bounding_box(x, bbox[0], bbox[1], bbox[2], bbox[3])
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def resize_mask_like(mask, x):
|
| 240 |
+
"""Resize mask like shape of x.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
mask: Original mask.
|
| 244 |
+
x: To shape of x.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
tf.Tensor: resized mask
|
| 248 |
+
|
| 249 |
+
"""
|
| 250 |
+
mask_resize = resize(
|
| 251 |
+
mask, to_shape=x.get_shape().as_list()[1:3],
|
| 252 |
+
func=tf.image.resize_nearest_neighbor)
|
| 253 |
+
return mask_resize
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1,
|
| 257 |
+
fuse_k=3, softmax_scale=10., training=True, fuse=True):
|
| 258 |
+
""" Contextual attention layer implementation.
|
| 259 |
+
|
| 260 |
+
Contextual attention is first introduced in publication:
|
| 261 |
+
Generative Image Inpainting with Contextual Attention, Yu et al.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
x: Input feature to match (foreground).
|
| 265 |
+
t: Input feature for match (background).
|
| 266 |
+
mask: Input mask for t, indicating patches not available.
|
| 267 |
+
ksize: Kernel size for contextual attention.
|
| 268 |
+
stride: Stride for extracting patches from t.
|
| 269 |
+
rate: Dilation for matching.
|
| 270 |
+
softmax_scale: Scaled softmax for attention.
|
| 271 |
+
training: Indicating if current graph is training or inference.
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
tf.Tensor: output
|
| 275 |
+
|
| 276 |
+
"""
|
| 277 |
+
# get shapes
|
| 278 |
+
raw_fs = tf.shape(f)
|
| 279 |
+
raw_int_fs = f.get_shape().as_list()
|
| 280 |
+
raw_int_bs = b.get_shape().as_list()
|
| 281 |
+
# extract patches from background with stride and rate
|
| 282 |
+
kernel = 2*rate
|
| 283 |
+
raw_w = tf.extract_image_patches(
|
| 284 |
+
b, [1,kernel,kernel,1], [1,rate*stride,rate*stride,1], [1,1,1,1], padding='SAME')
|
| 285 |
+
raw_w = tf.reshape(raw_w, [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]])
|
| 286 |
+
raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw
|
| 287 |
+
# downscaling foreground option: downscaling both foreground and
|
| 288 |
+
# background for matching and use original background for reconstruction.
|
| 289 |
+
f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor)
|
| 290 |
+
b = resize(b, to_shape=[int(raw_int_bs[1]/rate), int(raw_int_bs[2]/rate)], func=tf.image.resize_nearest_neighbor) # https://github.com/tensorflow/tensorflow/issues/11651
|
| 291 |
+
if mask is not None:
|
| 292 |
+
mask = resize(mask, scale=1./rate, func=tf.image.resize_nearest_neighbor)
|
| 293 |
+
fs = tf.shape(f)
|
| 294 |
+
int_fs = f.get_shape().as_list()
|
| 295 |
+
f_groups = tf.split(f, int_fs[0], axis=0)
|
| 296 |
+
# from t(H*W*C) to w(b*k*k*c*h*w)
|
| 297 |
+
bs = tf.shape(b)
|
| 298 |
+
int_bs = b.get_shape().as_list()
|
| 299 |
+
w = tf.extract_image_patches(
|
| 300 |
+
b, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
|
| 301 |
+
w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]])
|
| 302 |
+
w = tf.transpose(w, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw
|
| 303 |
+
# process mask
|
| 304 |
+
if mask is None:
|
| 305 |
+
mask = tf.zeros([1, bs[1], bs[2], 1])
|
| 306 |
+
m = tf.extract_image_patches(
|
| 307 |
+
mask, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
|
| 308 |
+
m = tf.reshape(m, [1, -1, ksize, ksize, 1])
|
| 309 |
+
m = tf.transpose(m, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw
|
| 310 |
+
m = m[0]
|
| 311 |
+
mm = tf.cast(tf.equal(tf.reduce_mean(m, axis=[0,1,2], keep_dims=True), 0.), tf.float32)
|
| 312 |
+
w_groups = tf.split(w, int_bs[0], axis=0)
|
| 313 |
+
raw_w_groups = tf.split(raw_w, int_bs[0], axis=0)
|
| 314 |
+
y = []
|
| 315 |
+
offsets = []
|
| 316 |
+
k = fuse_k
|
| 317 |
+
scale = softmax_scale
|
| 318 |
+
fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1])
|
| 319 |
+
for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
|
| 320 |
+
# conv for compare
|
| 321 |
+
wi = wi[0]
|
| 322 |
+
wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0,1,2])), 1e-4)
|
| 323 |
+
yi = tf.nn.conv2d(xi, wi_normed, strides=[1,1,1,1], padding="SAME")
|
| 324 |
+
|
| 325 |
+
# conv implementation for fuse scores to encourage large patches
|
| 326 |
+
if fuse:
|
| 327 |
+
yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
|
| 328 |
+
yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
|
| 329 |
+
yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]])
|
| 330 |
+
yi = tf.transpose(yi, [0, 2, 1, 4, 3])
|
| 331 |
+
yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
|
| 332 |
+
yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
|
| 333 |
+
yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]])
|
| 334 |
+
yi = tf.transpose(yi, [0, 2, 1, 4, 3])
|
| 335 |
+
yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1]*bs[2]])
|
| 336 |
+
|
| 337 |
+
# softmax to match
|
| 338 |
+
yi *= mm # mask
|
| 339 |
+
yi = tf.nn.softmax(yi*scale, 3)
|
| 340 |
+
yi *= mm # mask
|
| 341 |
+
|
| 342 |
+
offset = tf.argmax(yi, axis=3, output_type=tf.int32)
|
| 343 |
+
offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1)
|
| 344 |
+
# deconv for patch pasting
|
| 345 |
+
# 3.1 paste center
|
| 346 |
+
wi_center = raw_wi[0]
|
| 347 |
+
yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_fs[1:]], axis=0), strides=[1,rate,rate,1]) / 4.
|
| 348 |
+
y.append(yi)
|
| 349 |
+
offsets.append(offset)
|
| 350 |
+
y = tf.concat(y, axis=0)
|
| 351 |
+
y.set_shape(raw_int_fs)
|
| 352 |
+
offsets = tf.concat(offsets, axis=0)
|
| 353 |
+
offsets.set_shape(int_bs[:3] + [2])
|
| 354 |
+
# case1: visualize optical flow: minus current position
|
| 355 |
+
h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]), [bs[0], 1, bs[2], 1])
|
| 356 |
+
w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]), [bs[0], bs[1], 1, 1])
|
| 357 |
+
offsets = offsets - tf.concat([h_add, w_add], axis=3)
|
| 358 |
+
# to flow image
|
| 359 |
+
flow = flow_to_image_tf(offsets)
|
| 360 |
+
# # case2: visualize which pixels are attended
|
| 361 |
+
# flow = highlight_flow_tf(offsets * tf.cast(mask, tf.int32))
|
| 362 |
+
if rate != 1:
|
| 363 |
+
flow = resize(flow, scale=rate, func=tf.image.resize_bilinear)
|
| 364 |
+
return y, flow
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def test_contextual_attention(args):
|
| 368 |
+
"""Test contextual attention layer with 3-channel image input
|
| 369 |
+
(instead of n-channel feature).
|
| 370 |
+
|
| 371 |
+
"""
|
| 372 |
+
import cv2
|
| 373 |
+
import os
|
| 374 |
+
# run on cpu
|
| 375 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
| 376 |
+
|
| 377 |
+
rate = 2
|
| 378 |
+
stride = 1
|
| 379 |
+
grid = rate*stride
|
| 380 |
+
|
| 381 |
+
b = cv2.imread(args.imageA)
|
| 382 |
+
b = cv2.resize(b, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC)
|
| 383 |
+
h, w, _ = b.shape
|
| 384 |
+
b = b[:h//grid*grid, :w//grid*grid, :]
|
| 385 |
+
b = np.expand_dims(b, 0)
|
| 386 |
+
logger.info('Size of imageA: {}'.format(b.shape))
|
| 387 |
+
|
| 388 |
+
f = cv2.imread(args.imageB)
|
| 389 |
+
h, w, _ = f.shape
|
| 390 |
+
f = f[:h//grid*grid, :w//grid*grid, :]
|
| 391 |
+
f = np.expand_dims(f, 0)
|
| 392 |
+
logger.info('Size of imageB: {}'.format(f.shape))
|
| 393 |
+
|
| 394 |
+
with tf.Session() as sess:
|
| 395 |
+
bt = tf.constant(b, dtype=tf.float32)
|
| 396 |
+
ft = tf.constant(f, dtype=tf.float32)
|
| 397 |
+
|
| 398 |
+
yt, flow = contextual_attention(
|
| 399 |
+
ft, bt, stride=stride, rate=rate,
|
| 400 |
+
training=False, fuse=False)
|
| 401 |
+
y = sess.run(yt)
|
| 402 |
+
cv2.imwrite(args.imageOut, y[0])
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def make_color_wheel():
|
| 406 |
+
RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
|
| 407 |
+
ncols = RY + YG + GC + CB + BM + MR
|
| 408 |
+
colorwheel = np.zeros([ncols, 3])
|
| 409 |
+
col = 0
|
| 410 |
+
# RY
|
| 411 |
+
colorwheel[0:RY, 0] = 255
|
| 412 |
+
colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
|
| 413 |
+
col += RY
|
| 414 |
+
# YG
|
| 415 |
+
colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
|
| 416 |
+
colorwheel[col:col+YG, 1] = 255
|
| 417 |
+
col += YG
|
| 418 |
+
# GC
|
| 419 |
+
colorwheel[col:col+GC, 1] = 255
|
| 420 |
+
colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
|
| 421 |
+
col += GC
|
| 422 |
+
# CB
|
| 423 |
+
colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
|
| 424 |
+
colorwheel[col:col+CB, 2] = 255
|
| 425 |
+
col += CB
|
| 426 |
+
# BM
|
| 427 |
+
colorwheel[col:col+BM, 2] = 255
|
| 428 |
+
colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
|
| 429 |
+
col += + BM
|
| 430 |
+
# MR
|
| 431 |
+
colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
|
| 432 |
+
colorwheel[col:col+MR, 0] = 255
|
| 433 |
+
return colorwheel
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
COLORWHEEL = make_color_wheel()
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def compute_color(u,v):
|
| 440 |
+
h, w = u.shape
|
| 441 |
+
img = np.zeros([h, w, 3])
|
| 442 |
+
nanIdx = np.isnan(u) | np.isnan(v)
|
| 443 |
+
u[nanIdx] = 0
|
| 444 |
+
v[nanIdx] = 0
|
| 445 |
+
# colorwheel = COLORWHEEL
|
| 446 |
+
colorwheel = make_color_wheel()
|
| 447 |
+
ncols = np.size(colorwheel, 0)
|
| 448 |
+
rad = np.sqrt(u**2+v**2)
|
| 449 |
+
a = np.arctan2(-v, -u) / np.pi
|
| 450 |
+
fk = (a+1) / 2 * (ncols - 1) + 1
|
| 451 |
+
k0 = np.floor(fk).astype(int)
|
| 452 |
+
k1 = k0 + 1
|
| 453 |
+
k1[k1 == ncols+1] = 1
|
| 454 |
+
f = fk - k0
|
| 455 |
+
for i in range(np.size(colorwheel,1)):
|
| 456 |
+
tmp = colorwheel[:, i]
|
| 457 |
+
col0 = tmp[k0-1] / 255
|
| 458 |
+
col1 = tmp[k1-1] / 255
|
| 459 |
+
col = (1-f) * col0 + f * col1
|
| 460 |
+
idx = rad <= 1
|
| 461 |
+
col[idx] = 1-rad[idx]*(1-col[idx])
|
| 462 |
+
notidx = np.logical_not(idx)
|
| 463 |
+
col[notidx] *= 0.75
|
| 464 |
+
img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
|
| 465 |
+
return img
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def flow_to_image(flow):
|
| 470 |
+
"""Transfer flow map to image.
|
| 471 |
+
Part of code forked from flownet.
|
| 472 |
+
"""
|
| 473 |
+
out = []
|
| 474 |
+
maxu = -999.
|
| 475 |
+
maxv = -999.
|
| 476 |
+
minu = 999.
|
| 477 |
+
minv = 999.
|
| 478 |
+
maxrad = -1
|
| 479 |
+
for i in range(flow.shape[0]):
|
| 480 |
+
u = flow[i, :, :, 0]
|
| 481 |
+
v = flow[i, :, :, 1]
|
| 482 |
+
idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7)
|
| 483 |
+
u[idxunknow] = 0
|
| 484 |
+
v[idxunknow] = 0
|
| 485 |
+
maxu = max(maxu, np.max(u))
|
| 486 |
+
minu = min(minu, np.min(u))
|
| 487 |
+
maxv = max(maxv, np.max(v))
|
| 488 |
+
minv = min(minv, np.min(v))
|
| 489 |
+
rad = np.sqrt(u ** 2 + v ** 2)
|
| 490 |
+
maxrad = max(maxrad, np.max(rad))
|
| 491 |
+
u = u/(maxrad + np.finfo(float).eps)
|
| 492 |
+
v = v/(maxrad + np.finfo(float).eps)
|
| 493 |
+
img = compute_color(u, v)
|
| 494 |
+
out.append(img)
|
| 495 |
+
return np.float32(np.uint8(out))
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def flow_to_image_tf(flow, name='flow_to_image'):
|
| 499 |
+
"""Tensorflow ops for computing flow to image.
|
| 500 |
+
"""
|
| 501 |
+
with tf.variable_scope(name), tf.device('/cpu:0'):
|
| 502 |
+
img = tf.py_func(flow_to_image, [flow], tf.float32, stateful=False)
|
| 503 |
+
img.set_shape(flow.get_shape().as_list()[0:-1]+[3])
|
| 504 |
+
img = img / 127.5 - 1.
|
| 505 |
+
return img
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def highlight_flow(flow):
|
| 509 |
+
"""Convert flow into middlebury color code image.
|
| 510 |
+
"""
|
| 511 |
+
out = []
|
| 512 |
+
s = flow.shape
|
| 513 |
+
for i in range(flow.shape[0]):
|
| 514 |
+
img = np.ones((s[1], s[2], 3)) * 144.
|
| 515 |
+
u = flow[i, :, :, 0]
|
| 516 |
+
v = flow[i, :, :, 1]
|
| 517 |
+
for h in range(s[1]):
|
| 518 |
+
for w in range(s[1]):
|
| 519 |
+
ui = u[h,w]
|
| 520 |
+
vi = v[h,w]
|
| 521 |
+
img[ui, vi, :] = 255.
|
| 522 |
+
out.append(img)
|
| 523 |
+
return np.float32(np.uint8(out))
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def highlight_flow_tf(flow, name='flow_to_image'):
|
| 527 |
+
"""Tensorflow ops for highlight flow.
|
| 528 |
+
"""
|
| 529 |
+
with tf.variable_scope(name), tf.device('/cpu:0'):
|
| 530 |
+
img = tf.py_func(highlight_flow, [flow], tf.float32, stateful=False)
|
| 531 |
+
img.set_shape(flow.get_shape().as_list()[0:-1]+[3])
|
| 532 |
+
img = img / 127.5 - 1.
|
| 533 |
+
return img
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def image2edge(image):
|
| 537 |
+
"""Convert image to edges.
|
| 538 |
+
"""
|
| 539 |
+
out = []
|
| 540 |
+
for i in range(image.shape[0]):
|
| 541 |
+
img = cv2.Laplacian(image[i, :, :, :], cv2.CV_64F, ksize=3, scale=2)
|
| 542 |
+
out.append(img)
|
| 543 |
+
return np.float32(np.uint8(out))
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
if __name__ == "__main__":
|
| 547 |
+
import argparse
|
| 548 |
+
parser = argparse.ArgumentParser()
|
| 549 |
+
parser.add_argument('--imageA', default='', type=str, help='Image A as background patches to reconstruct image B.')
|
| 550 |
+
parser.add_argument('--imageB', default='', type=str, help='Image B is reconstructed with image A.')
|
| 551 |
+
parser.add_argument('--imageOut', default='result.png', type=str, help='Image B is reconstructed with image A.')
|
| 552 |
+
args = parser.parse_args()
|
| 553 |
+
test_contextual_attention(args)
|
main.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from preprocess_image import preprocess_image
|
| 7 |
+
import tensorflow as tf
|
| 8 |
+
import neuralgym as ng
|
| 9 |
+
|
| 10 |
+
from inpaint_model import InpaintCAModel
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument('--image', default='', type=str,
|
| 14 |
+
help='The filename of image to be completed.')
|
| 15 |
+
parser.add_argument('--output', default='output.png', type=str,
|
| 16 |
+
help='Where to write output.')
|
| 17 |
+
parser.add_argument('--watermark_type', default='istock', type=str,
|
| 18 |
+
help='The watermark type')
|
| 19 |
+
parser.add_argument('--checkpoint_dir', default='model/', type=str,
|
| 20 |
+
help='The directory of tensorflow checkpoint.')
|
| 21 |
+
|
| 22 |
+
#checkpoint_dir = 'model/'
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if __name__ == "__main__":
|
| 26 |
+
FLAGS = ng.Config('inpaint.yml')
|
| 27 |
+
# ng.get_gpus(1)
|
| 28 |
+
args, unknown = parser.parse_known_args()
|
| 29 |
+
|
| 30 |
+
model = InpaintCAModel()
|
| 31 |
+
image = Image.open(args.image)
|
| 32 |
+
input_image = preprocess_image(image, args.watermark_type)
|
| 33 |
+
tf.reset_default_graph()
|
| 34 |
+
|
| 35 |
+
sess_config = tf.ConfigProto()
|
| 36 |
+
sess_config.gpu_options.allow_growth = True
|
| 37 |
+
if (input_image.shape != (0,)):
|
| 38 |
+
with tf.Session(config=sess_config) as sess:
|
| 39 |
+
input_image = tf.constant(input_image, dtype=tf.float32)
|
| 40 |
+
output = model.build_server_graph(FLAGS, input_image)
|
| 41 |
+
output = (output + 1.) * 127.5
|
| 42 |
+
output = tf.reverse(output, [-1])
|
| 43 |
+
output = tf.saturate_cast(output, tf.uint8)
|
| 44 |
+
# load pretrained model
|
| 45 |
+
vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
| 46 |
+
assign_ops = []
|
| 47 |
+
for var in vars_list:
|
| 48 |
+
vname = var.name
|
| 49 |
+
from_name = vname
|
| 50 |
+
var_value = tf.contrib.framework.load_variable(
|
| 51 |
+
args.checkpoint_dir, from_name)
|
| 52 |
+
assign_ops.append(tf.assign(var, var_value))
|
| 53 |
+
sess.run(assign_ops)
|
| 54 |
+
print('Model loaded.')
|
| 55 |
+
result = sess.run(output)
|
| 56 |
+
cv2.imwrite(args.output, cv2.cvtColor(
|
| 57 |
+
result[0][:, :, ::-1], cv2.COLOR_BGR2RGB))
|
| 58 |
+
print('image saved to {}'.format(args.output))
|
preprocess_image.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def preprocess_image(image, watermark_type):
|
| 7 |
+
image_type: str = ''
|
| 8 |
+
preprocessed_mask_image = np.array([])
|
| 9 |
+
if image.mode != "RGB":
|
| 10 |
+
image = image.convert("RGB")
|
| 11 |
+
image = np.array(image)
|
| 12 |
+
image_h = image.shape[0]
|
| 13 |
+
image_w = image.shape[1]
|
| 14 |
+
aspectRatioImage = image_w / image_h
|
| 15 |
+
print("image size: {}".format(image.shape))
|
| 16 |
+
|
| 17 |
+
if image_w > image_h:
|
| 18 |
+
image_type = "landscape"
|
| 19 |
+
elif image_w == image_h:
|
| 20 |
+
image_type = "landscape"
|
| 21 |
+
else:
|
| 22 |
+
image_type = "potrait"
|
| 23 |
+
|
| 24 |
+
mask_image = Image.open(
|
| 25 |
+
"utils/{}/{}/mask.png".format(watermark_type, image_type))
|
| 26 |
+
if mask_image.mode != "RGB":
|
| 27 |
+
mask_image = mask_image.convert("RGB")
|
| 28 |
+
mask_image = np.array(mask_image)
|
| 29 |
+
print("mask image size: {}".format(mask_image.shape))
|
| 30 |
+
|
| 31 |
+
aspectRatioMaskImage = mask_image.shape[1] / mask_image.shape[0]
|
| 32 |
+
upperBoundAspectRatio = 1.05 * aspectRatioMaskImage
|
| 33 |
+
lowerBoundAspectRatio = 0.95 * aspectRatioMaskImage
|
| 34 |
+
|
| 35 |
+
if aspectRatioImage >= lowerBoundAspectRatio and aspectRatioImage <= upperBoundAspectRatio:
|
| 36 |
+
preprocessed_mask_image = cv2.resize(mask_image, (image_w, image_h))
|
| 37 |
+
print(preprocessed_mask_image.shape)
|
| 38 |
+
else:
|
| 39 |
+
print("Image size not supported!!!")
|
| 40 |
+
|
| 41 |
+
if (preprocessed_mask_image.shape != (0,)):
|
| 42 |
+
assert image.shape == preprocessed_mask_image.shape
|
| 43 |
+
grid = 8
|
| 44 |
+
image = image[:image_h//grid*grid, :image_w//grid*grid, :]
|
| 45 |
+
preprocessed_mask_image = preprocessed_mask_image[:image_h //
|
| 46 |
+
grid*grid, :image_w//grid*grid, :]
|
| 47 |
+
image = np.expand_dims(image, 0)
|
| 48 |
+
preprocessed_mask_image = np.expand_dims(preprocessed_mask_image, 0)
|
| 49 |
+
input_image = np.concatenate([image, preprocessed_mask_image], axis=2)
|
| 50 |
+
return input_image
|
| 51 |
+
|
| 52 |
+
else:
|
| 53 |
+
return preprocessed_mask_image
|
requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tensorflow==1.15.5
|
| 2 |
+
opencv-python==4.9.0.80
|
utils/istock/landscape/mask.png
ADDED
|