File size: 3,299 Bytes
3c149ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import os
import re

import imageio
import matplotlib.pyplot as plt
import moviepy.editor as mvp
import numpy as np
import pydiffvg
import torch
from IPython.display import Image as Image_colab
from IPython.display import display, SVG
from PIL import Image

parser = argparse.ArgumentParser()
parser.add_argument("--target_file", type=str,
                    help="target image file, located in <target_images>")
parser.add_argument("--num_strokes", type=int)
args = parser.parse_args()


def read_svg(path_svg, multiply=False):
    device = torch.device("cuda" if (
        torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu")
    canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(
        path_svg)
    if multiply:
        canvas_width *= 2
        canvas_height *= 2
        for path in shapes:
            path.points *= 2
            path.stroke_width *= 2
    _render = pydiffvg.RenderFunction.apply
    scene_args = pydiffvg.RenderFunction.serialize_scene(
        canvas_width, canvas_height, shapes, shape_groups)
    img = _render(canvas_width,  # width
                  canvas_height,  # height
                  2,   # num_samples_x
                  2,   # num_samples_y
                  0,   # seed
                  None,
                  *scene_args)
    img = img[:, :, 3:4] * img[:, :, :3] + \
        torch.ones(img.shape[0], img.shape[1], 3,
                   device=device) * (1 - img[:, :, 3:4])
    img = img[:, :, :3]
    return img


abs_path = os.path.abspath(os.getcwd())

result_path = f"{abs_path}/output_sketches/{os.path.splitext(args.target_file)[0]}"
svg_files = os.listdir(result_path)
svg_files = [f for f in svg_files if "best.svg" in f and f"{args.num_strokes}strokes" in f]
svg_output_path = f"{result_path}/{svg_files[0]}"

target_path = f"{svg_output_path[:-9]}/input.png"

sketch_res = read_svg(svg_output_path, multiply=True).cpu().numpy()
sketch_res = Image.fromarray((sketch_res * 255).astype('uint8'), 'RGB')

input_im = Image.open(target_path).resize((224,224))
display(input_im)
display(SVG(svg_output_path))

p = re.compile("_best")
best_sketch_dir = ""
for m in p.finditer(svg_files[0]):
    best_sketch_dir += svg_files[0][0: m.start()]


sketches = []
cur_path = f"{result_path}/{best_sketch_dir}"
sketch_res.save(f"{cur_path}/final_sketch.png")
print(f"You can download the result sketch from {cur_path}/final_sketch.png")

if not os.path.exists(f"{cur_path}/svg_to_png"):
    os.mkdir(f"{cur_path}/svg_to_png")
if os.path.exists(f"{cur_path}/config.npy"):
    config = np.load(f"{cur_path}/config.npy", allow_pickle=True)[()]
    inter = config["save_interval"]
    loss_eval = np.array(config['loss_eval'])
    inds = np.argsort(loss_eval)
    intervals = list(range(0, (inds[0] + 1) * inter, inter))
    for i_ in intervals:
        path_svg = f"{cur_path}/svg_logs/svg_iter{i_}.svg"
        sketch = read_svg(path_svg, multiply=True).cpu().numpy()
        sketch = Image.fromarray((sketch * 255).astype('uint8'), 'RGB')
        # print("{0}/iter_{1:04}.png".format(cur_path, int(i_)))
        sketch.save("{0}/{1}/iter_{2:04}.png".format(cur_path, "svg_to_png", int(i_)))
        sketches.append(sketch)
    imageio.mimsave(f"{cur_path}/sketch.gif", sketches)

print(cur_path)