File size: 3,894 Bytes
bc32eea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import math
import torch
import torch.nn as nn

from models.flowplusplus import log_dist as logistic
from models.flowplusplus.nn import NN
from models.flowplusplus.transformer_nn import TransformerNN

class Coupling(nn.Module):
    """Mixture-of-Logistics Coupling layer in Flow++

    Args:
        in_channels (int): Number of channels in the input.
        mid_channels (int): Number of channels in the transformation network.
        num_blocks (int): Number of residual blocks in the transformation network.
        num_components (int): Number of components in the mixture.
        drop_prob (float): Dropout probability.
        use_attn (bool): Use attention in the NN blocks.
        aux_channels (int): Number of channels in optional auxiliary input.
    """
    def __init__(self, in_channels, cond_dim, out_channels, mid_channels, num_blocks, num_components, drop_prob, seq_length, output_length,
                 use_attn=True, use_logmix=True, use_transformer_nn=False, use_pos_emb=False, use_rel_pos_emb=False, num_heads=10, aux_channels=None, concat_dims=True):
        super(Coupling, self).__init__()

        if use_transformer_nn:
            if concat_dims:
                self.nn = TransformerNN(in_channels, out_channels, mid_channels, num_blocks, num_heads, num_components, drop_prob=drop_prob, use_pos_emb=use_pos_emb, use_rel_pos_emb=use_rel_pos_emb, input_length=seq_length, concat_dims=concat_dims, output_length=output_length)
            else:
                self.nn = TransformerNN(cond_dim, out_channels, mid_channels, num_blocks, num_heads, num_components, drop_prob=drop_prob, use_pos_emb=use_pos_emb, use_rel_pos_emb=use_rel_pos_emb, input_length=seq_length, concat_dims=concat_dims, output_length=output_length)
        else:
            self.nn = NN(in_channels, out_channels, mid_channels, num_blocks, num_components, drop_prob, use_attn, aux_channels)

        if not concat_dims:
            self.input_encoder = nn.Linear(in_channels,cond_dim)
        self.use_logmix = use_logmix
        self.offset = 2.0
        self.sigmoid_offset = 1 - 1 / (1 + math.exp(-self.offset))
        self.cond_dim = cond_dim
        self.concat_dims = concat_dims

    def forward(self, x, cond, sldj=None, reverse=False, aux=None):
        x_change, x_id = x

        if self.concat_dims:
            x_id_cond = torch.cat((x_id, cond), dim=1)
        else:
            # import pdb;pdb.set_trace()
            x_id_enc = self.input_encoder(x_id.permute(0,2,3,1)).permute(0,3,1,2)
            #import pdb;pdb.set_trace()
            x_id_cond = torch.cat((x_id_enc, cond), dim=2)
        #import pdb;pdb.set_trace()
        a, b, pi, mu, s = self.nn(x_id_cond, aux)
        # import pdb;pdb.set_trace()
        scale = (torch.sigmoid(a+self.offset)+self.sigmoid_offset)

        if reverse:
            out = x_change / scale - b
            if self.use_logmix:
                out, scale_ldj = logistic.inverse(out, reverse=True)
                #out = out.clamp(1e-5, 1. - 1e-5)
                out = logistic.mixture_inv_cdf(out, pi, mu, s)
                logistic_ldj = logistic.mixture_log_pdf(out, pi, mu, s)
                sldj = sldj - (torch.log(scale) + scale_ldj + logistic_ldj).flatten(1).sum(-1)
            else:
                sldj = sldj - torch.log(scale).flatten(1).sum(-1)
        else:
            if self.use_logmix:
                out = logistic.mixture_log_cdf(x_change, pi, mu, s).exp()
                out, scale_ldj = logistic.inverse(out)
                logistic_ldj = logistic.mixture_log_pdf(x_change, pi, mu, s)
                sldj = sldj + (logistic_ldj + scale_ldj + torch.log(scale)).flatten(1).sum(-1)
            else:
                out = x_change
                sldj = sldj + torch.log(scale).flatten(1).sum(-1)
            
            out = (out + b) * scale
            

        x = (out, x_id)

        return x, sldj