mangaaa / utils /dataset_utils.py
Keiser41's picture
Upload 246 files
212d7be
raw
history blame
No virus
4.79 kB
import numpy as np
import matplotlib.pyplot as plt
import cv2
import snowy
import os
def get_resized_image(img, size):
if len(img.shape) == 2:
img = np.repeat(np.expand_dims(img, 2), 3, 2)
if (img.shape[0] < img.shape[1]):
height = img.shape[0]
ratio = height / size
width = int(np.ceil(img.shape[1] / ratio))
img = cv2.resize(img, (width, size), interpolation = cv2.INTER_AREA)
else:
width = img.shape[1]
ratio = width / size
height = int(np.ceil(img.shape[0] / ratio))
img = cv2.resize(img, (size, height), interpolation = cv2.INTER_AREA)
if (img.dtype == 'float32'):
np.clip(img, 0, 1, out = img)
return img
def get_sketch_image(img, sketcher, mult_val):
if mult_val:
sketch_image = sketcher.get_sketch_with_resize(img, mult = mult_val)
else:
sketch_image = sketcher.get_sketch_with_resize(img)
return sketch_image
def get_dfm_image(sketch):
dfm_image = snowy.unitize(snowy.generate_sdf(np.expand_dims(1 - sketch, 2) != 0)).squeeze()
return dfm_image
def get_sketch(image, sketcher, dfm, mult = None):
sketch_image = get_sketch_image(image, sketcher, mult)
dfm_image = None
if dfm:
dfm_image = get_dfm_image(sketch_image)
sketch_image = (sketch_image * 255).astype('uint8')
if dfm:
dfm_image = (dfm_image * 255).astype('uint8')
return sketch_image, dfm_image
def get_sketches(image, sketcher, mult_list, dfm):
for mult in mult_list:
yield get_sketch(image, sketcher, dfm, mult)
def create_resized_dataset(source_path, target_path, side_size):
images = os.listdir(source_path)
for image_name in images:
new_image_name = image_name[:image_name.rfind('.')] + '.png'
new_path = os.path.join(target_path, new_image_name)
if not os.path.exists(new_path):
try:
image = cv2.imread(os.path.join(source_path, image_name))
if image is None:
raise Exception()
image = get_resized_image(image, side_size)
cv2.imwrite(new_path, image)
except:
print('Failed to process {}'.format(image_name))
def create_sketches_dataset(source_path, target_path, sketcher, mult_list, dfm = False):
images = os.listdir(source_path)
for image_name in images:
try:
image = cv2.imread(os.path.join(source_path, image_name))
if image is None:
raise Exception()
for number, (sketch_image, dfm_image) in enumerate(get_sketches(image, sketcher, mult_list, dfm)):
new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
cv2.imwrite(os.path.join(target_path, new_sketch_name), sketch_image)
if dfm:
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
cv2.imwrite(os.path.join(target_path, dfm_name), dfm_image)
except:
print('Failed to process {}'.format(image_name))
def create_dataset(source_path, target_path, sketcher, mult_list, side_size, dfm = False):
images = os.listdir(source_path)
color_path = os.path.join(target_path, 'color')
sketch_path = os.path.join(target_path, 'bw')
if not os.path.exists(color_path):
os.makedirs(color_path)
if not os.path.exists(sketch_path):
os.makedirs(sketch_path)
for image_name in images:
new_image_name = image_name[:image_name.rfind('.')] + '.png'
try:
image = cv2.imread(os.path.join(source_path, image_name))
if image is None:
raise Exception()
resized_image = get_resized_image(image, side_size)
cv2.imwrite(os.path.join(color_path, new_image_name), resized_image)
for number, (sketch_image, dfm_image) in enumerate(get_sketches(resized_image, sketcher, mult_list, dfm)):
new_sketch_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '.png'
cv2.imwrite(os.path.join(sketch_path, new_sketch_name), sketch_image)
if dfm:
dfm_name = image_name[:image_name.rfind('.')] + '_' + str(number) + '_dfm.png'
cv2.imwrite(os.path.join(sketch_path, dfm_name), dfm_image)
except:
print('Failed to process {}'.format(image_name))