File size: 4,786 Bytes
212d7be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import numpy as np
import matplotlib.pyplot as plt
import cv2
import snowy
import os
def get_resized_image(img, size):
if len(img.shape) == 2:
img = np.repeat(np.expand_dims(img, 2), 3, 2)
if (img.shape[0] < img.shape[1]):
height = img.shape[0]
ratio = height / size
width = int(np.ceil(img.shape[1] / ratio))
img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA)
else:
width = img.shape[1]
ratio = width / size
height = int(np.ceil(img.shape[0] / ratio))
img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
if (img.dtype == 'float32'):
np.clip(img, 0, 1, out = img)
return img
def get_sketch_image(img, sketcher, mult_val):
if mult_val:
sketch_image = sketcher.get_sketch_with_resize(img, mult = mult_val)
else:
sketch_image = sketcher.get_sketch_with_resize(img)
return sketch_image
def get_dfm_image(sketch):
dfm_image = snowy.unitize(snowy.generate_sdf(np.expand_dims(1 - sketch, 2) != 0)).squeeze()
return dfm_image
def get_sketch(image, sketcher, dfm, mult = None):
sketch_image = get_sketch_image(image, sketcher, mult)
dfm_image = None
if dfm:
dfm_image = get_dfm_image(sketch_image)
sketch_image = (sketch_image * 255).astype('uint8')
if dfm:
dfm_image = (dfm_image * 255).astype('uint8')
return sketch_image, dfm_image
def get_sketches(image, sketcher, mult_list, dfm):
for mult in mult_list:
yield get_sketch(image, sketcher, dfm, mult)
def create_resized_dataset(source_path, target_path, side_size):
images = os.listdir(source_path)
for image_name in images:
new_image_name = image_name[:image_name.rfind('.')] + '.png'
new_path = os.path.join(target_path, new_image_name)
if not os.path.exists(new_path):
try:
image = cv2.imread(os.path.join(source_path, image_name))
if image is None:
raise Exception()
image = get_resized_image(image, side_size)
cv2.imwrite(new_path, image)
except:
print('Failed to process {}'.format(image_name))
def create_sketches_dataset(source_path, target_path, sketcher, mult_list, dfm = False):
images = os.listdir(source_path)
for image_name in images:
try:
image = cv2.imread(os.path.join(source_path, image_name))
if image is None:
raise Exception()
for number, (sketch_image, dfm_image) in enumerate(get_sketches(image, sketcher, mult_list, dfm)):
new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
cv2.imwrite(os.path.join(target_path, new_sketch_name), sketch_image)
if dfm:
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
cv2.imwrite(os.path.join(target_path, dfm_name), dfm_image)
except:
print('Failed to process {}'.format(image_name))
def create_dataset(source_path, target_path, sketcher, mult_list, side_size, dfm = False):
images = os.listdir(source_path)
color_path = os.path.join(target_path, 'color')
sketch_path = os.path.join(target_path, 'bw')
if not os.path.exists(color_path):
os.makedirs(color_path)
if not os.path.exists(sketch_path):
os.makedirs(sketch_path)
for image_name in images:
new_image_name = image_name[:image_name.rfind('.')] + '.png'
try:
image = cv2.imread(os.path.join(source_path, image_name))
if image is None:
raise Exception()
resized_image = get_resized_image(image, side_size)
cv2.imwrite(os.path.join(color_path, new_image_name), resized_image)
for number, (sketch_image, dfm_image) in enumerate(get_sketches(resized_image, sketcher, mult_list, dfm)):
new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
cv2.imwrite(os.path.join(sketch_path, new_sketch_name), sketch_image)
if dfm:
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
cv2.imwrite(os.path.join(sketch_path, dfm_name), dfm_image)
except:
print('Failed to process {}'.format(image_name))
|