import numpy as np import matplotlib.pyplot as plt import cv2 import snowy import os def get_resized_image(img, size): if len(img.shape) == 2: img = np.repeat(np.expand_dims(img, 2), 3, 2) if (img.shape[0] < img.shape[1]): height = img.shape[0] ratio = height / size width = int(np.ceil(img.shape[1] / ratio)) img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA) else: width = img.shape[1] ratio = width / size height = int(np.ceil(img.shape[0] / ratio)) img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA) if (img.dtype == 'float32'): np.clip(img, 0, 1, out = img) return img def get_sketch_image(img, sketcher, mult_val): if mult_val: sketch_image = sketcher.get_sketch_with_resize(img, mult = mult_val) else: sketch_image = sketcher.get_sketch_with_resize(img) return sketch_image def get_dfm_image(sketch): dfm_image = snowy.unitize(snowy.generate_sdf(np.expand_dims(1 - sketch, 2) != 0)).squeeze() return dfm_image def get_sketch(image, sketcher, dfm, mult = None): sketch_image = get_sketch_image(image, sketcher, mult) dfm_image = None if dfm: dfm_image = get_dfm_image(sketch_image) sketch_image = (sketch_image * 255).astype('uint8') if dfm: dfm_image = (dfm_image * 255).astype('uint8') return sketch_image, dfm_image def get_sketches(image, sketcher, mult_list, dfm): for mult in mult_list: yield get_sketch(image, sketcher, dfm, mult) def create_resized_dataset(source_path, target_path, side_size): images = os.listdir(source_path) for image_name in images: new_image_name = image_name[:image_name.rfind('.')] + '.png' new_path = os.path.join(target_path, new_image_name) if not os.path.exists(new_path): try: image = cv2.imread(os.path.join(source_path, image_name)) if image is None: raise Exception() image = get_resized_image(image, side_size) cv2.imwrite(new_path, image) except: print('Failed to process {}'.format(image_name)) def create_sketches_dataset(source_path, target_path, sketcher, mult_list, dfm = False): images = os.listdir(source_path) for image_name in images: try: image = cv2.imread(os.path.join(source_path, image_name)) if image is None: raise Exception() for number, (sketch_image, dfm_image) in enumerate(get_sketches(image, sketcher, mult_list, dfm)): new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png' cv2.imwrite(os.path.join(target_path, new_sketch_name), sketch_image) if dfm: dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png' cv2.imwrite(os.path.join(target_path, dfm_name), dfm_image) except: print('Failed to process {}'.format(image_name)) def create_dataset(source_path, target_path, sketcher, mult_list, side_size, dfm = False): images = os.listdir(source_path) color_path = os.path.join(target_path, 'color') sketch_path = os.path.join(target_path, 'bw') if not os.path.exists(color_path): os.makedirs(color_path) if not os.path.exists(sketch_path): os.makedirs(sketch_path) for image_name in images: new_image_name = image_name[:image_name.rfind('.')] + '.png' try: image = cv2.imread(os.path.join(source_path, image_name)) if image is None: raise Exception() resized_image = get_resized_image(image, side_size) cv2.imwrite(os.path.join(color_path, new_image_name), resized_image) for number, (sketch_image, dfm_image) in enumerate(get_sketches(resized_image, sketcher, mult_list, dfm)): new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png' cv2.imwrite(os.path.join(sketch_path, new_sketch_name), sketch_image) if dfm: dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png' cv2.imwrite(os.path.join(sketch_path, dfm_name), dfm_image) except: print('Failed to process {}'.format(image_name))