File size: 3,987 Bytes
bc32eea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.utils import weight_norm
from models.util import concat_elu #,WNConv2d
from models.transformer import BasicTransformerModel

class TransformerNN(nn.Module):
    """Neural network used to parametrize the transformations of an MLCoupling.

    An `NN` is a stack of blocks, where each block consists of the following
    two layers connected in a residual fashion:
      1. Conv: input -> nonlinearit -> conv3x3 -> nonlinearity -> gate
      2. Attn: input -> conv1x1 -> multihead self-attention -> gate,
    where gate refers to a 1×1 convolution that doubles the number of channels,
    followed by a gated linear unit (Dauphin et al., 2016).
    The convolutional layer is identical to the one used by PixelCNN++
    (Salimans et al., 2017), and the multi-head self attention mechanism we
    use is identical to the one in the Transformer (Vaswani et al., 2017).

    Args:
        in_channels (int): Number of channels in the input.
        out_channels (int): Number of channels in the output.
        num_channels (int): Number of channels in each block of the network.
        num_blocks (int): Number of blocks in the network.
        num_components (int): Number of components in the mixture.
        drop_prob (float): Dropout probability.
        use_attn (bool): Use attention in each block.
        aux_channels (int): Number of channels in optional auxiliary input.
    """
    def __init__(self, in_channels, out_channels, num_channels, num_layers, num_heads, num_components, drop_prob, use_pos_emb, use_rel_pos_emb, input_length, concat_dims, output_length):
        #import pdb;pdb.set_trace()
        super(TransformerNN, self).__init__()
        self.k = num_components  # k = number of mixture components
        # import pdb;pdb.set_trace()
        self.transformer = BasicTransformerModel(out_channels * (2 + 3 * self.k), in_channels, num_heads, num_channels, num_layers, drop_prob, use_pos_emb=use_pos_emb, use_rel_pos_emb=use_rel_pos_emb, input_length=input_length)
        self.rescale = weight_norm(Rescale(out_channels))
        self.out_channels = out_channels
        self.concat_dims = concat_dims
        self.output_length = output_length

    def forward(self, x, aux=None):
        b, c, h, w = x.size()
        # import pdb;pdb.set_trace()
        x = x.squeeze(-1) # only squeeze the w dimension (important coz otherwise it would squeeze batch dim if theres only one element in minibatch..
        # import pdb;pdb.set_trace()
        x = x.permute(2,0,1)
        # import pdb;pdb.set_trace()
        if self.concat_dims:
            x = self.transformer(x)
            # x = torch.mean(self.transformer(x), dim=0, keepdim=True)
            # x = 0.5*x + 0.5*torch.mean(x, dim=0, keepdim=True)
            # x = self.transformer(x)[:1]
        else:
            x = self.transformer(x)[:self.output_length]
        # import pdb;pdb.set_trace()
        x = x.permute(1,2,0)
        # Split into components and post-process
        if self.concat_dims:
            x = x.view(b, -1, self.out_channels, h, w)
            # x = x.view(b, -1, self.out_channels, 1, w)
        else:
            x = x.view(b, -1, self.out_channels, self.output_length, w)
        s, t, pi, mu, scales = x.split((1, 1, self.k, self.k, self.k), dim=1)
        s = self.rescale(torch.tanh(s.squeeze(1)))
        t = t.squeeze(1)
        scales = scales.clamp(min=-7)  # From the code in original Flow++ paper

        return s, t, pi, mu, scales

class Rescale(nn.Module):
    """Per-channel rescaling. Need a proper `nn.Module` so we can wrap it
    with `torch.nn.utils.weight_norm`.
    Args:
        num_channels (int): Number of channels in the input.
    """

    def __init__(self, num_channels):
        super(Rescale, self).__init__()
        self.weight = nn.Parameter(torch.ones(num_channels, 1, 1))

    def forward(self, x):
        x = self.weight * x
        return x