File size: 3,405 Bytes
31726e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Losses for the generative models and baselines."""
import torch as th
import numpy as np

import ttools.modules.image_operators as imops


class KLDivergence(th.nn.Module):
    """
    Args:
        min_value(float): the loss is clipped so that value below this
            number don't affect the optimization.
    """
    def __init__(self, min_value=0.2):
        super(KLDivergence, self).__init__()
        self.min_value = min_value

    def forward(self, mu, log_sigma):
        loss = -0.5 * (1.0 + log_sigma - mu.pow(2) - log_sigma.exp())
        loss = loss.mean()
        loss = th.max(loss, self.min_value*th.ones_like(loss))
        return loss


class MultiscaleMSELoss(th.nn.Module):
    def __init__(self, channels=3):
        super(MultiscaleMSELoss, self).__init__()
        self.blur = imops.GaussianBlur(1, channels=channels)

    def forward(self, im, target):
        bs, c, h, w = im.shape
        num_levels = max(int(np.ceil(np.log2(h))) - 2, 1)

        losses = []
        for lvl in range(num_levels):
            loss = th.nn.functional.mse_loss(im, target)
            losses.append(loss)
            im = th.nn.functional.interpolate(self.blur(im), 
                                              scale_factor=0.5,
                                              mode="nearest")
            target = th.nn.functional.interpolate(self.blur(target),
                                                  scale_factor=0.5,
                                                  mode="nearest")

        losses = th.stack(losses)
        return losses.sum()


def gaussian_pdfs(dx, dy, params):
    """Returns the pdf at (dx, dy) for each Gaussian in the mixture.
    """
    dx = dx.unsqueeze(-1)  # replicate dx, dy to evaluate all pdfs at once
    dy = dy.unsqueeze(-1)

    mu_x = params[..., 0]
    mu_y = params[..., 1]
    sigma_x = params[..., 2].exp()
    sigma_y = params[..., 3].exp()
    rho_xy = th.tanh(params[..., 4])

    x = ((dx-mu_x) / sigma_x).pow(2)
    y = ((dy-mu_y) / sigma_y).pow(2)

    xy = (dx-mu_x)*(dy-mu_y) / (sigma_x * sigma_y)
    arg = x + y - 2.0*rho_xy*xy
    pdf = th.exp(-arg / (2*(1.0 - rho_xy.pow(2))))
    norm = 2.0 * np.pi * sigma_x * sigma_y * (1.0 - rho_xy.pow(2)).sqrt()

    return pdf / norm


class GaussianMixtureReconstructionLoss(th.nn.Module):
    """
    Args:
    """
    def __init__(self, eps=1e-5):
        super(GaussianMixtureReconstructionLoss, self).__init__()
        self.eps = eps

    def forward(self, pen_logits, mixture_logits, gaussian_params, targets):
        dx = targets[..., 0]
        dy = targets[..., 1]
        pen_state = targets[..., 2:].argmax(-1)  # target index

        # Likelihood loss on the stroke position
        # No need to predict accurate pen position for end-of-sequence tokens
        valid_stroke = (targets[..., -1] != 1.0).float()
        mixture_weights = th.nn.functional.softmax(mixture_logits, -1)
        pdfs = gaussian_pdfs(dx, dy, gaussian_params)
        position_loss = - th.log(self.eps + (pdfs * mixture_weights).sum(-1))

        # by actual non-empty count
        position_loss = (position_loss*valid_stroke).sum() / valid_stroke.sum()

        # Classification loss for the stroke mode
        pen_loss = th.nn.functional.cross_entropy(pen_logits.view(-1, 3),
                                                  pen_state.view(-1))

        return position_loss + pen_loss