File size: 3,812 Bytes
e17e8cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""This code is taken from <https://github.com/alexandre01/deepsvg>
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
"""

from src.preprocessing.deepsvg.deepsvg_config.config import _Config
from src.preprocessing.deepsvg.deepsvg_models.model import SVGTransformer
from src.preprocessing.deepsvg.deepsvg_models.loss import SVGLoss
from src.preprocessing.deepsvg.deepsvg_models.model_config import *
from src.preprocessing.deepsvg.deepsvg_svglib.svg import SVG
from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
from src.preprocessing.deepsvg.deepsvg_svglib.svglib_utils import make_grid
from src.preprocessing.deepsvg.deepsvg_svglib.geom import Bbox
from src.preprocessing.deepsvg.deepsvg_utils.utils import batchify, linear

import torchvision.transforms.functional as TF
import torch.optim.lr_scheduler as lr_scheduler
import random


class ModelConfig(Hierarchical):
    """
    Overriding default model config.
    """
    def __init__(self):
        super().__init__()


class Config(_Config):
    """
    Overriding default training config.
    """
    def __init__(self, num_gpus=1):
        super().__init__(num_gpus=num_gpus)

        # Model
        self.model_cfg = ModelConfig()
        self.model_args = self.model_cfg.get_model_args()

        # Dataset
        self.filter_category = None

        self.train_ratio = 1.0

        self.max_num_groups = 8
        self.max_total_len = 50

        # Dataloader
        self.loader_num_workers = 4 * num_gpus

        # Training
        self.num_epochs = 50
        self.val_every = 1000

        # Optimization
        self.learning_rate = 1e-3 * num_gpus
        self.batch_size = 60 * num_gpus
        self.grad_clip = 1.0

    def make_schedulers(self, optimizers, epoch_size):
        optimizer, = optimizers
        return [lr_scheduler.StepLR(optimizer, step_size=2.5 * epoch_size, gamma=0.9)]

    def make_model(self):
        return SVGTransformer(self.model_cfg)

    def make_losses(self):
        return [SVGLoss(self.model_cfg)]

    def get_weights(self, step, epoch):
        return {
            "kl_tolerance": 0.1,
            "loss_kl_weight": linear(0, 10, step, 0, 10000),
            "loss_hierarch_weight": 1.0,
            "loss_cmd_weight": 1.0,
            "loss_args_weight": 2.0,
            "loss_visibility_weight": 1.0
        }

    def set_train_vars(self, train_vars, dataloader):
        train_vars.x_inputs_train = [dataloader.dataset.get(idx, [*self.model_args, "tensor_grouped"])
                                     for idx in random.sample(range(len(dataloader.dataset)), k=10)]

    def visualize(self, model, output, train_vars, step, epoch, summary_writer, visualization_dir):
        device = next(model.parameters()).device
        
        # Reconstruction
        for i, data in enumerate(train_vars.x_inputs_train):
            model_args = batchify((data[key] for key in self.model_args), device)
            commands_y, args_y = model.module.greedy_sample(*model_args)
            tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu())

            try:
                svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256), allow_empty=True).normalize().split_paths().set_color("random")
            except:
                continue

            tensor_target = data["tensor_grouped"][0].copy().drop_sos().unpad()
            svg_path_gt = SVG.from_tensor(tensor_target.data, viewbox=Bbox(256)).normalize().split_paths().set_color("random")

            img = make_grid([svg_path_sample, svg_path_gt]).draw(do_display=False, return_png=True, fill=False, with_points=False)
            summary_writer.add_image(f"reconstructions_train/{i}", TF.to_tensor(img), step)