File size: 2,695 Bytes
e17e8cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""This code is taken from <https://github.com/alexandre01/deepsvg>
by Alexandre Carlier, Martin Danelljan, Alexandre Alahi and Radu Timofte
from the paper >https://arxiv.org/pdf/2007.11301.pdf>
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from src.preprocessing.deepsvg.deepsvg_difflib.tensor import SVGTensor
from .model_utils import _get_padding_mask, _get_visibility_mask
from .model_config import _DefaultConfig


class SVGLoss(nn.Module):
    def __init__(self, cfg: _DefaultConfig):
        super().__init__()

        self.cfg = cfg

        self.args_dim = 2 * cfg.args_dim if cfg.rel_targets else cfg.args_dim + 1

        self.register_buffer("cmd_args_mask", SVGTensor.CMD_ARGS_MASK)

    def forward(self, output, labels, weights):
        loss = 0.
        res = {}

        # VAE
        if self.cfg.use_vae:
            mu, logsigma = output["mu"], output["logsigma"]
            loss_kl = -0.5 * torch.mean(1 + logsigma - mu.pow(2) - torch.exp(logsigma))
            loss_kl = loss_kl.clamp(min=weights["kl_tolerance"])

            loss += weights["loss_kl_weight"] * loss_kl
            res["loss_kl"] = loss_kl

        # Target & predictions
        tgt_commands, tgt_args = output["tgt_commands"], output["tgt_args"]

        visibility_mask = _get_visibility_mask(tgt_commands, seq_dim=-1)
        padding_mask = _get_padding_mask(tgt_commands, seq_dim=-1, extended=True) * visibility_mask.unsqueeze(-1)

        command_logits, args_logits = output["command_logits"], output["args_logits"]

        # 2-stage visibility
        if self.cfg.decode_stages == 2:
            visibility_logits = output["visibility_logits"]
            loss_visibility = F.cross_entropy(visibility_logits.reshape(-1, 2), visibility_mask.reshape(-1).long())

            loss += weights["loss_visibility_weight"] * loss_visibility
            res["loss_visibility"] = loss_visibility

        # Commands & args
        tgt_commands, tgt_args, padding_mask = tgt_commands[..., 1:], tgt_args[..., 1:, :], padding_mask[..., 1:]

        mask = self.cmd_args_mask[tgt_commands.long()]

        loss_cmd = F.cross_entropy(command_logits[padding_mask.bool()].reshape(-1, self.cfg.n_commands), tgt_commands[padding_mask.bool()].reshape(-1).long())
        loss_args = F.cross_entropy(args_logits[mask.bool()].reshape(-1, self.args_dim), tgt_args[mask.bool()].reshape(-1).long() + 1)  # shift due to -1 PAD_VAL

        loss += weights["loss_cmd_weight"] * loss_cmd \
                + weights["loss_args_weight"] * loss_args

        res.update({
            "loss": loss,
            "loss_cmd": loss_cmd,
            "loss_args": loss_args
        })

        return res