|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
import snowy |
|
import os |
|
|
|
|
|
def get_resized_image(img, size): |
|
if len(img.shape) == 2: |
|
img = np.repeat(np.expand_dims(img, 2), 3, 2) |
|
|
|
if (img.shape[0] < img.shape[1]): |
|
height = img.shape[0] |
|
ratio = height / size |
|
width = int(np.ceil(img.shape[1] / ratio)) |
|
img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA) |
|
else: |
|
width = img.shape[1] |
|
ratio = width / size |
|
height = int(np.ceil(img.shape[0] / ratio)) |
|
img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA) |
|
|
|
if (img.dtype == 'float32'): |
|
np.clip(img, 0, 1, out = img) |
|
|
|
return img |
|
|
|
|
|
def get_sketch_image(img, sketcher, mult_val): |
|
|
|
if mult_val: |
|
sketch_image = sketcher.get_sketch_with_resize(img, mult = mult_val) |
|
else: |
|
sketch_image = sketcher.get_sketch_with_resize(img) |
|
|
|
return sketch_image |
|
|
|
|
|
def get_dfm_image(sketch): |
|
dfm_image = snowy.unitize(snowy.generate_sdf(np.expand_dims(1 - sketch, 2) != 0)).squeeze() |
|
return dfm_image |
|
|
|
def get_sketch(image, sketcher, dfm, mult = None): |
|
sketch_image = get_sketch_image(image, sketcher, mult) |
|
|
|
dfm_image = None |
|
|
|
if dfm: |
|
dfm_image = get_dfm_image(sketch_image) |
|
|
|
sketch_image = (sketch_image * 255).astype('uint8') |
|
|
|
if dfm: |
|
dfm_image = (dfm_image * 255).astype('uint8') |
|
|
|
return sketch_image, dfm_image |
|
|
|
def get_sketches(image, sketcher, mult_list, dfm): |
|
for mult in mult_list: |
|
yield get_sketch(image, sketcher, dfm, mult) |
|
|
|
|
|
def create_resized_dataset(source_path, target_path, side_size): |
|
images = os.listdir(source_path) |
|
|
|
for image_name in images: |
|
|
|
new_image_name = image_name[:image_name.rfind('.')] + '.png' |
|
new_path = os.path.join(target_path, new_image_name) |
|
|
|
if not os.path.exists(new_path): |
|
try: |
|
image = cv2.imread(os.path.join(source_path, image_name)) |
|
|
|
if image is None: |
|
raise Exception() |
|
|
|
image = get_resized_image(image, side_size) |
|
|
|
cv2.imwrite(new_path, image) |
|
except: |
|
print('Failed to process {}'.format(image_name)) |
|
|
|
|
|
def create_sketches_dataset(source_path, target_path, sketcher, mult_list, dfm = False): |
|
|
|
images = os.listdir(source_path) |
|
for image_name in images: |
|
try: |
|
image = cv2.imread(os.path.join(source_path, image_name)) |
|
|
|
if image is None: |
|
raise Exception() |
|
|
|
for number, (sketch_image, dfm_image) in enumerate(get_sketches(image, sketcher, mult_list, dfm)): |
|
new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png' |
|
cv2.imwrite(os.path.join(target_path, new_sketch_name), sketch_image) |
|
|
|
if dfm: |
|
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png' |
|
cv2.imwrite(os.path.join(target_path, dfm_name), dfm_image) |
|
|
|
except: |
|
print('Failed to process {}'.format(image_name)) |
|
|
|
|
|
def create_dataset(source_path, target_path, sketcher, mult_list, side_size, dfm = False): |
|
images = os.listdir(source_path) |
|
|
|
color_path = os.path.join(target_path, 'color') |
|
sketch_path = os.path.join(target_path, 'bw') |
|
|
|
if not os.path.exists(color_path): |
|
os.makedirs(color_path) |
|
|
|
if not os.path.exists(sketch_path): |
|
os.makedirs(sketch_path) |
|
|
|
for image_name in images: |
|
new_image_name = image_name[:image_name.rfind('.')] + '.png' |
|
|
|
try: |
|
image = cv2.imread(os.path.join(source_path, image_name)) |
|
|
|
if image is None: |
|
raise Exception() |
|
|
|
resized_image = get_resized_image(image, side_size) |
|
cv2.imwrite(os.path.join(color_path, new_image_name), resized_image) |
|
|
|
for number, (sketch_image, dfm_image) in enumerate(get_sketches(resized_image, sketcher, mult_list, dfm)): |
|
new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png' |
|
cv2.imwrite(os.path.join(sketch_path, new_sketch_name), sketch_image) |
|
|
|
if dfm: |
|
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png' |
|
cv2.imwrite(os.path.join(sketch_path, dfm_name), dfm_image) |
|
|
|
except: |
|
print('Failed to process {}'.format(image_name)) |
|
|