import os import cv2 import numpy as np import tensorflow as tf import wbc.network as network import wbc.guided_filter as guided_filter from tqdm import tqdm def resize_crop(image): h, w, c = np.shape(image) if min(h, w) > 720: if h > w: h, w = int(720 * h / w), 720 else: h, w = 720, int(720 * w / h) image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA) h, w = (h // 8) * 8, (w // 8) * 8 image = image[:h, :w, :] return image def cartoonize(load_folder, save_folder, model_path): print(model_path) input_photo = tf.placeholder(tf.float32, [1, None, None, 3]) network_out = network.unet_generator(input_photo) final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3) all_vars = tf.trainable_variables() gene_vars = [var for var in all_vars if 'generator' in var.name] saver = tf.train.Saver(var_list=gene_vars) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.run(tf.global_variables_initializer()) saver.restore(sess, tf.train.latest_checkpoint(model_path)) name_list = os.listdir(load_folder) for name in tqdm(name_list): try: load_path = os.path.join(load_folder, name) save_path = os.path.join(save_folder, name) image = cv2.imread(load_path) image = resize_crop(image) batch_image = image.astype(np.float32) / 127.5 - 1 batch_image = np.expand_dims(batch_image, axis=0) output = sess.run(final_out, feed_dict={input_photo: batch_image}) output = (np.squeeze(output) + 1) * 127.5 output = np.clip(output, 0, 255).astype(np.uint8) cv2.imwrite(save_path, output) except: print('cartoonize {} failed'.format(load_path)) class Cartoonize: def __init__(self, model_path): print(model_path) self.input_photo = tf.placeholder(tf.float32, [1, None, None, 3]) network_out = network.unet_generator(self.input_photo) self.final_out = guided_filter.guided_filter(self.input_photo, network_out, r=1, eps=5e-3) all_vars = tf.trainable_variables() gene_vars = [var for var in all_vars if 'generator' in var.name] saver = tf.train.Saver(var_list=gene_vars) config = tf.ConfigProto() config.gpu_options.allow_growth = True self.sess = tf.Session(config=config) self.sess.run(tf.global_variables_initializer()) saver.restore(self.sess, tf.train.latest_checkpoint(model_path)) def run(self, load_folder, save_folder): name_list = os.listdir(load_folder) for name in tqdm(name_list): try: load_path = os.path.join(load_folder, name) save_path = os.path.join(save_folder, name) image = cv2.imread(load_path) image = resize_crop(image) batch_image = image.astype(np.float32) / 127.5 - 1 batch_image = np.expand_dims(batch_image, axis=0) output = self.sess.run(self.final_out, feed_dict={self.input_photo: batch_image}) output = (np.squeeze(output) + 1) * 127.5 output = np.clip(output, 0, 255).astype(np.uint8) cv2.imwrite(save_path, output) except: print('cartoonize {} failed'.format(load_path)) def run_sigle(self, load_path, save_path): try: image = cv2.imread(load_path) image = resize_crop(image) batch_image = image.astype(np.float32) / 127.5 - 1 batch_image = np.expand_dims(batch_image, axis=0) output = self.sess.run(self.final_out, feed_dict={self.input_photo: batch_image}) output = (np.squeeze(output) + 1) * 127.5 output = np.clip(output, 0, 255).astype(np.uint8) cv2.imwrite(save_path, output) except: print('cartoonize {} failed'.format(load_path)) if __name__ == '__main__': model_path = 'saved_models' load_folder = 'test_images' save_folder = 'cartoonized_images' if not os.path.exists(save_folder): os.mkdir(save_folder) cartoonize(load_folder, save_folder, model_path)