File size: 13,340 Bytes
bc32eea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from models.flowplusplus.act_norm import ActNorm, BatchNorm
from models.flowplusplus.inv_conv import InvConv, InvertibleConv1x1
from models.flowplusplus.nn import GatedConv
from models.flowplusplus.coupling import Coupling
from models.util import channelwise, checkerboard, Flip, safe_log, squeeze, unsqueeze

from models.moglow.modules import GaussianDiag, StudentT

class FlowPlusPlus(nn.Module):
    """Flow++ Model

    Based on the paper:
    "Flow++: Improving Flow-Based Generative Models
        with Variational Dequantization and Architecture Design"
    by Jonathan Ho, Xi Chen, Aravind Srinivas, Yan Duan, Pieter Abbeel
    (https://openreview.net/forum?id=Hyg74h05tX).

    Args:
        scales (tuple or list): Number of each type of coupling layer in each
            scale. Each scale is a 2-tuple of the form
            (num_channelwise, num_checkerboard).
        in_channels (int): Number of channels in the input.
        mid_channels (int): Number of channels in the intermediate layers.
        num_blocks (int): Number of residual blocks in the s and t network of
            `Coupling` layers.
        num_dequant_blocks (int): Number of blocks in the dequantization flows.
    """
    def __init__(self,
                 scales=((0, 4), (2, 3)),
                 in_shape=(3, 32, 32),
                 cond_dim=0,
                 mid_channels=96,
                 num_blocks=10,
                 num_components=32,
                 use_attn=True,
                 use_logmix=True,
                 use_transformer_nn=False,
                 use_pos_emb=False,
                 use_rel_pos_emb=False,
                 num_heads=10,
                 drop_prob=0.2,
                 norm_layer=None,
                 cond_concat_dims=True,
                 cond_seq_len=1,
                 flow_dist="normal",
                 flow_dist_param=50,
                 bn_momentum=0.1):
        super(FlowPlusPlus, self).__init__()
        # Register bounds to pre-process images, not learnable
        self.register_buffer('bounds', torch.tensor([0.9], dtype=torch.float32))
        self.flows = _FlowStep(scales=scales,
                               in_shape=in_shape,
                               cond_dim=cond_dim,
                               mid_channels=mid_channels,
                               num_blocks=num_blocks,
                               num_components=num_components,
                               use_attn=use_attn,
                               use_logmix=use_logmix,
                               use_transformer_nn=use_transformer_nn,
                               use_pos_emb=use_pos_emb,
                               use_rel_pos_emb=use_rel_pos_emb,
                               num_heads=num_heads,
                               drop_prob=drop_prob,
                               norm_layer=norm_layer,
                               cond_concat_dims=cond_concat_dims,
                               cond_seq_len=cond_seq_len,
                               bn_momentum=bn_momentum)
        if flow_dist == "normal":
            self.distribution = GaussianDiag()
        elif flow_dist == "studentT":
            in_channels, in_height, in_width = in_shape
            self.distribution = StudentT(flow_dist_param, in_channels)

    def forward(self, x, cond, reverse=False):
        if cond is not None:
            cond = cond.permute(0,2,1).unsqueeze(3)
            
        if not reverse:        
            if x is not None:
                x = x.permute(0,2,1).unsqueeze(3)
        else:
            c, h, w = self.flows.z_dim()
            # x = 1.0*torch.randn((cond.size(0), c, h, w), dtype=torch.float32).type_as(cond)
            eps_std=1.0
            # x = self.distribution.sample((cond.size(0), c, h, w), eps_std, device=cond.device).type_as(cond)
            assert w==1
            x = self.distribution.sample((cond.size(0), c, h), eps_std, device=cond.device).type_as(cond)
            x = x.unsqueeze(-1)
            # import pdb;pdb.set_trace()

        sldj = torch.zeros(x.size(0), device=x.device)
        x, sldj = self.flows(x, cond, sldj, reverse)
        
        if reverse:
            if x is not None:
                x = x.squeeze(3).permute(0,2,1)

        return x, sldj

    def loss_generative(self, z, sldj):
        """Negative log-likelihood loss assuming isotropic gaussian with unit norm.

        Args:
            k (int or float): Number of discrete values in each input dimension.
                E.g., `k` is 256 for natural images.

        See Also:
            Equation (3) in the RealNVP paper: https://arxiv.org/abs/1605.08803
        """
        # print(z)
        # prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi))
        # prior_ll = prior_ll.flatten(1).sum(-1)# \
        prior_ll = self.distribution.logp(z)
        prior_ll = prior_ll.flatten(1).sum(-1)# \
        # import pdb;pdb.set_trace()
#            - np.log(k) * np.prod(z.size()[1:])
        ll = prior_ll + sldj
        # print(sldj.mean())
        # import pdb;pdb.set_trace()
        nll = -ll.mean()/float(np.log(2.) * z.size(2) * z.size(3))
        # nll = -ll.mean()/float(np.log(2.))

        return nll
        
class _FlowStep(nn.Module):
    """Recursive builder for a Flow++ model.

    Each `_FlowStep` corresponds to a single scale in Flow++.
    The constructor is recursively called to build a full model.

    Args:
        scales (tuple): Number of each type of coupling layer in each scale.
            Each scale is a 2-tuple of the form (num_channelwise, num_checkerboard).
        in_channels (int): Number of channels in the input.
        mid_channels (int): Number of channels in the intermediate layers.
        num_blocks (int): Number of residual blocks in the s and t network of
            `Coupling` layers.
        num_components (int): Number of components in the mixture.
        use_attn (bool): Use attention in the coupling layers.
        drop_prob (float): Dropout probability.
    """
    def __init__(self, scales, in_shape, cond_dim, mid_channels, num_blocks, num_components, use_attn, use_logmix, use_transformer_nn, use_pos_emb, use_rel_pos_emb, num_heads, drop_prob, norm_layer, bn_momentum, cond_concat_dims, cond_seq_len):
        super(_FlowStep, self).__init__()
        in_channels, in_height, in_width = in_shape
        num_channelwise, num_checkerboard = scales[0]
        #import pdb;pdb.set_trace()
        channels = []
        for i in range(num_channelwise):
            new_channels = in_channels// 2
            out_channels = in_channels-new_channels
            # print(norm_layer)
            if norm_layer == "batchnorm":
                channels += [BatchNorm(in_channels, bn_momentum)]
            elif norm_layer == "actnorm":
                channels += [ActNorm(in_channels)]
            if cond_concat_dims:
                c_in_channels = new_channels + cond_dim
                seq_length = in_height
            else:
                c_in_channels = new_channels
                seq_length = in_height + cond_seq_len
            channels += [InvertibleConv1x1(in_channels)]
            channels += [Coupling(in_channels=c_in_channels,
                                  cond_dim=cond_dim,
                                  out_channels=out_channels,
                                  mid_channels=mid_channels,
                                  num_blocks=num_blocks,
                                  num_components=num_components,
                                  use_attn=use_attn,
                                  use_logmix=use_logmix,
                                  use_transformer_nn=use_transformer_nn,
                                  use_pos_emb=use_pos_emb,
                                  use_rel_pos_emb=use_rel_pos_emb,
                                  num_heads=num_heads,
                                  seq_length=seq_length,
                                  output_length=in_height,
                                  concat_dims=cond_concat_dims,
                                  drop_prob=drop_prob)]#,
                         #Flip()] Flip currently does not work with odd number of channels. But is it needed when we have channel mixing with 1x1convs? 

        checkers = []
        if cond_concat_dims:
            c_in_channels = new_channels + cond_dim
            seq_length = in_height
        else:
            c_in_channels = new_channels
            seq_length = in_height + cond_seq_len
        for i in range(num_checkerboard):
            if norm_layer == "batchnorm":
                checkers += [BatchNorm(in_channels, bn_momentum)]
            elif norm_layer == "actnorm":
                checkers += [ActNorm(in_channels)]
            checkers += [InvertibleConv1x1(in_channels)]
            checkers += [Coupling(in_channels=c_in_channels,
                                  out_channels=in_channels,
                                  mid_channels=mid_channels,
                                  num_blocks=num_blocks,
                                  num_components=num_components,
                                  use_attn=use_attn,
                                  use_logmix=use_logmix,
                                  use_transformer_nn=use_transformer_nn,
                                  use_pos_emb=use_pos_emb,
                                  use_rel_pos_emb=use_rel_pos_emb,
                                  num_heads=num_heads,
                                  seq_length=seq_length,
                                  output_length=in_height,
                                  concat_dims=cond_concat_dims,
                                  drop_prob=drop_prob)]#,
                         #Flip()]
        self.channels = nn.ModuleList(channels) if channels else None
        self.checkers = nn.ModuleList(checkers) if checkers else None

        if len(scales) <= 1:
            self.next = None
        else:
            next_shape = (in_channels, in_height // 2, in_width)
            self.next = _FlowStep(scales=scales[1:],
                                  in_shape=next_shape,
                                  cond_dim=2*cond_dim,
                                  mid_channels=mid_channels,
                                  num_blocks=num_blocks,
                                  num_components=num_components,
                                  use_attn=use_attn,
                                  use_logmix=use_logmix,
                                  use_transformer_nn=use_transformer_nn,
                                  use_pos_emb=use_pos_emb,
                                  use_rel_pos_emb=use_rel_pos_emb,
                                  num_heads=num_heads,
                                  norm_layer = norm_layer,
                                  bn_momentum = bn_momentum,
                                  cond_concat_dims = cond_concat_dims,
                                  cond_seq_len = cond_seq_len,
                                  drop_prob=drop_prob)
                                  
        self.z_shape = (in_channels, in_height, in_width)
        
    def z_dim(self):
        return self.z_shape

    def forward(self, x, cond, sldj, reverse=False):
            
        if reverse:
            #import pdb;pdb.set_trace()
            if self.next is not None:
                x = squeeze(x)
                cond = squeeze(cond)
                x, x_split = x.chunk(2, dim=1)
                x, sldj = self.next(x, cond, sldj, reverse)
                x = torch.cat((x, x_split), dim=1)
                x = unsqueeze(x)
                cond = unsqueeze(cond)

            if self.checkers:
                x = checkerboard(x)
                for flow in reversed(self.checkers):
                    x, sldj = flow(x, cond, sldj, reverse)
                x = checkerboard(x, reverse=True)

            if self.channels:
                x = channelwise(x)
                for flow in reversed(self.channels):
                    x, sldj = flow(x, cond, sldj, reverse)
                x = channelwise(x, reverse=True)
        else:
            # import pdb;pdb.set_trace()
            if self.channels:
                x = channelwise(x)
                for flow in self.channels:
                    # import pdb;pdb.set_trace()
                    x, sldj = flow(x, cond, sldj, reverse)
                    # print(type(flow).__name__)
                    # print(x[0].std())
                x = channelwise(x, reverse=True)

            if self.checkers:
                x = checkerboard(x)
                for flow in self.checkers:
                    x, sldj = flow(x, cond, sldj, reverse)
                x = checkerboard(x, reverse=True)

            if self.next is not None:
                # import pdb;pdb.set_trace()
                # here we apply the flow steps but only to dimensions sampled at a lower scale. Hmm feels a bit weird
                x = squeeze(x)
                cond = squeeze(cond)
                x, x_split = x.chunk(2, dim=1)
                x, sldj = self.next(x, cond, sldj, reverse)
                x = torch.cat((x, x_split), dim=1)
                x = unsqueeze(x)

        # print(x.std())
        return x, sldj