import argparse import os import re import imageio import matplotlib.pyplot as plt import moviepy.editor as mvp import numpy as np import pydiffvg import torch from IPython.display import Image as Image_colab from IPython.display import display, SVG from PIL import Image parser = argparse.ArgumentParser() parser.add_argument("--target_file", type=str, help="target image file, located in ") parser.add_argument("--num_strokes", type=int) args = parser.parse_args() def read_svg(path_svg, multiply=False): device = torch.device("cuda" if ( torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu") canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene( path_svg) if multiply: canvas_width *= 2 canvas_height *= 2 for path in shapes: path.points *= 2 path.stroke_width *= 2 _render = pydiffvg.RenderFunction.apply scene_args = pydiffvg.RenderFunction.serialize_scene( canvas_width, canvas_height, shapes, shape_groups) img = _render(canvas_width, # width canvas_height, # height 2, # num_samples_x 2, # num_samples_y 0, # seed None, *scene_args) img = img[:, :, 3:4] * img[:, :, :3] + \ torch.ones(img.shape[0], img.shape[1], 3, device=device) * (1 - img[:, :, 3:4]) img = img[:, :, :3] return img abs_path = os.path.abspath(os.getcwd()) result_path = f"{abs_path}/output_sketches/{os.path.splitext(args.target_file)[0]}" svg_files = os.listdir(result_path) svg_files = [f for f in svg_files if "best.svg" in f and f"{args.num_strokes}strokes" in f] svg_output_path = f"{result_path}/{svg_files[0]}" target_path = f"{svg_output_path[:-9]}/input.png" sketch_res = read_svg(svg_output_path, multiply=True).cpu().numpy() sketch_res = Image.fromarray((sketch_res * 255).astype('uint8'), 'RGB') input_im = Image.open(target_path).resize((224,224)) display(input_im) display(SVG(svg_output_path)) p = re.compile("_best") best_sketch_dir = "" for m in p.finditer(svg_files[0]): best_sketch_dir += svg_files[0][0: m.start()] sketches = [] cur_path = f"{result_path}/{best_sketch_dir}" sketch_res.save(f"{cur_path}/final_sketch.png") print(f"You can download the result sketch from {cur_path}/final_sketch.png") if not os.path.exists(f"{cur_path}/svg_to_png"): os.mkdir(f"{cur_path}/svg_to_png") if os.path.exists(f"{cur_path}/config.npy"): config = np.load(f"{cur_path}/config.npy", allow_pickle=True)[()] inter = config["save_interval"] loss_eval = np.array(config['loss_eval']) inds = np.argsort(loss_eval) intervals = list(range(0, (inds[0] + 1) * inter, inter)) for i_ in intervals: path_svg = f"{cur_path}/svg_logs/svg_iter{i_}.svg" sketch = read_svg(path_svg, multiply=True).cpu().numpy() sketch = Image.fromarray((sketch * 255).astype('uint8'), 'RGB') # print("{0}/iter_{1:04}.png".format(cur_path, int(i_))) sketch.save("{0}/{1}/iter_{2:04}.png".format(cur_path, "svg_to_png", int(i_))) sketches.append(sketch) imageio.mimsave(f"{cur_path}/sketch.gif", sketches) print(cur_path)