File size: 2,420 Bytes
bc32eea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""Logistic distribution functions."""
import torch
import torch.nn.functional as F

from models.util import safe_log


def _log_pdf(x, mean, log_scale):
    """Element-wise log density of the logistic distribution."""
    z = (x - mean) * torch.exp(-log_scale)
    log_p = z - log_scale - 2 * F.softplus(z)

    return log_p


def _log_cdf(x, mean, log_scale):
    """Element-wise log CDF of the logistic distribution."""
    z = (x - mean) * torch.exp(-log_scale)
    log_p = F.logsigmoid(z)

    return log_p


def mixture_log_pdf(x, prior_logits, means, log_scales):
    """Log PDF of a mixture of logistic distributions."""
    log_ps = F.log_softmax(prior_logits, dim=1) \
        + _log_pdf(x.unsqueeze(1), means, log_scales)
    log_p = torch.logsumexp(log_ps, dim=1)

    return log_p


def mixture_log_cdf(x, prior_logits, means, log_scales):
    """Log CDF of a mixture of logistic distributions."""
    log_ps = F.log_softmax(prior_logits, dim=1) \
        + _log_cdf(x.unsqueeze(1), means, log_scales)
    log_p = torch.logsumexp(log_ps, dim=1)

    return log_p


def mixture_inv_cdf(y, prior_logits, means, log_scales,
                    eps=1e-10, max_iters=100):
    """Inverse CDF of a mixture of logisitics. Iterative algorithm."""
    if y.min() <= 0 or y.max() >= 1:
        raise RuntimeError('Inverse logisitic CDF got y outside (0, 1)')

    def body(x_, lb_, ub_):
        cur_y = torch.exp(mixture_log_cdf(x_, prior_logits, means,
                                          log_scales))
        gt = (cur_y > y).type(y.dtype)
        lt = 1 - gt
        new_x_ = gt * (x_ + lb_) / 2. + lt * (x_ + ub_) / 2.
        new_lb = gt * lb_ + lt * x_
        new_ub = gt * x_ + lt * ub_
        return new_x_, new_lb, new_ub

    x = torch.zeros_like(y)
    max_scales = torch.sum(torch.exp(log_scales), dim=1, keepdim=True)
    lb, _ = (means - 20 * max_scales).min(dim=1)
    ub, _ = (means + 20 * max_scales).max(dim=1)
    diff = float('inf')

    i = 0
    while diff > eps and i < max_iters:
        new_x, lb, ub = body(x, lb, ub)
        diff = (new_x - x).abs().max()
        x = new_x
        i += 1

    return x


def inverse(x, reverse=False):
    """Inverse logistic function."""
    if reverse:
        z = torch.sigmoid(x)
        ldj = F.softplus(x) + F.softplus(-x)
    else:
        z = -safe_log(x.reciprocal() - 1.)
        ldj = -safe_log(x) - safe_log(1. - x)

    return z, ldj