File size: 6,128 Bytes
966ae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# -*- coding: utf-8 -*-
# Author: ximing
# Description: LIVE pipeline
# Copyright (c) 2023, XiMing Xing.
# License: MIT License

import shutil
from pathlib import Path
from functools import partial
from typing import AnyStr
from PIL import Image

from tqdm.auto import tqdm
import torch
from torchvision import transforms

from pytorch_svgrender.libs.engine import ModelState
from pytorch_svgrender.painter.diffvg import Painter, PainterOptimizer
from pytorch_svgrender.plt import plot_img, plot_couple
from pytorch_svgrender.libs.metric.lpips_origin import LPIPS


class DiffVGPipeline(ModelState):

    def __init__(self, args):
        logdir_ = f"sd{args.seed}" \
                  f"-{args.x.path_type}" \
                  f"-P{args.x.num_paths}"
        super().__init__(args, log_path_suffix=logdir_)

        assert self.x_cfg.path_type in ['unclosed', 'closed']

        # create log dir
        self.png_logs_dir = self.result_path / "png_logs"
        self.svg_logs_dir = self.result_path / "svg_logs"
        if self.accelerator.is_main_process:
            self.png_logs_dir.mkdir(parents=True, exist_ok=True)
            self.svg_logs_dir.mkdir(parents=True, exist_ok=True)

        # make video log
        self.make_video = self.args.mv
        if self.make_video:
            self.frame_idx = 0
            self.frame_log_dir = self.result_path / "frame_logs"
            self.frame_log_dir.mkdir(parents=True, exist_ok=True)

    def target_file_preprocess(self, tar_path):
        process_comp = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda t: t.unsqueeze(0)),
        ])

        tar_pil = Image.open(tar_path).convert("RGB")  # open file
        target_img = process_comp(tar_pil)  # preprocess
        target_img = target_img.to(self.device)
        return target_img

    def painterly_rendering(self, img_path: AnyStr):
        # load target file
        target_file = Path(img_path)
        assert target_file.exists(), f"{target_file} is not exist!"
        shutil.copy(target_file, self.result_path)  # copy target file
        target_img = self.target_file_preprocess(target_file.as_posix())
        self.print(f"load image from: '{target_file.as_posix()}'")

        # init Painter
        renderer = Painter(target_img,
                           self.args.diffvg,
                           canvas_size=[target_img.shape[3], target_img.shape[2]],
                           path_type=self.x_cfg.path_type,
                           max_width=self.x_cfg.max_width,
                           device=self.device)
        init_img = renderer.init_image(num_paths=self.x_cfg.num_paths)
        self.print("init_image shape: ", init_img.shape)
        plot_img(init_img, self.result_path, fname="init_img")

        # init Painter Optimizer
        num_iter = self.x_cfg.num_iter
        optimizer = PainterOptimizer(renderer,
                                     num_iter,
                                     self.x_cfg.lr_base,
                                     trainable_stroke=self.x_cfg.path_type == 'unclosed')
        optimizer.init_optimizer()

        # Set Loss
        if self.x_cfg.loss_type in ['lpips', 'l2+lpips']:
            lpips_loss_fn = LPIPS(net=self.x_cfg.perceptual.lpips_net).to(self.device)
            perceptual_loss_fn = partial(lpips_loss_fn.forward, return_per_layer=False, normalize=False)

        with tqdm(initial=self.step, total=num_iter, disable=not self.accelerator.is_main_process) as pbar:
            while self.step < num_iter:
                raster_img = renderer.get_image(self.step).to(self.device)

                if self.make_video and (self.step % self.args.framefreq == 0 or self.step == num_iter - 1):
                    plot_img(raster_img, self.frame_log_dir, fname=f"iter{self.frame_idx}")
                    self.frame_idx += 1

                # Reconstruction Loss
                if self.x_cfg.loss_type == 'l1':
                    loss_recon = torch.nn.functional.l1_loss(raster_img, target_img)
                elif self.x_cfg.loss_type == 'lpips':
                    loss_recon = perceptual_loss_fn(raster_img, target_img).mean()
                elif self.x_cfg.loss_type == 'l2':  # default: MSE loss
                    loss_recon = torch.nn.functional.mse_loss(raster_img, target_img)
                elif self.x_cfg.loss_type == 'l2+lpips':  # default: MSE loss
                    lpips = perceptual_loss_fn(raster_img, target_img).mean()
                    loss_mse = torch.nn.functional.mse_loss(raster_img, target_img)
                    loss_recon = loss_mse + lpips

                # total loss
                loss = loss_recon

                pbar.set_description(
                    f"lr: {optimizer.get_lr():.4f}, "
                    f"L_recon: {loss_recon.item():.4f}"
                )

                # optimization
                optimizer.zero_grad_()
                loss.backward()
                optimizer.step_()

                renderer.clip_curve_shape()

                if self.x_cfg.lr_schedule:
                    optimizer.update_lr()

                if self.step % self.args.save_step == 0 and self.accelerator.is_main_process:
                    plot_couple(target_img,
                                raster_img,
                                self.step,
                                output_dir=self.png_logs_dir.as_posix(),
                                fname=f"iter{self.step}")
                    renderer.save_svg(self.svg_logs_dir / f"svg_iter{self.step}.svg")

                self.step += 1
                pbar.update(1)

        # end rendering
        renderer.save_svg(self.result_path / "final_svg.svg")

        if self.make_video:
            from subprocess import call
            call([
                "ffmpeg",
                "-framerate", f"{self.args.framerate}",
                "-i", (self.frame_log_dir / "iter%d.png").as_posix(),
                "-vb", "20M",
                (self.result_path / "live_rendering.mp4").as_posix()
            ])

        self.close(msg="painterly rendering complete.")