myTest01 / models /flowplusplus /transformer_nn.py
meng2003's picture
Upload 85 files
bc32eea
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from models.util import concat_elu #,WNConv2d
from models.transformer import BasicTransformerModel
class TransformerNN(nn.Module):
"""Neural network used to parametrize the transformations of an MLCoupling.
An `NN` is a stack of blocks, where each block consists of the following
two layers connected in a residual fashion:
1. Conv: input -> nonlinearit -> conv3x3 -> nonlinearity -> gate
2. Attn: input -> conv1x1 -> multihead self-attention -> gate,
where gate refers to a 1×1 convolution that doubles the number of channels,
followed by a gated linear unit (Dauphin et al., 2016).
The convolutional layer is identical to the one used by PixelCNN++
(Salimans et al., 2017), and the multi-head self attention mechanism we
use is identical to the one in the Transformer (Vaswani et al., 2017).
Args:
in_channels (int): Number of channels in the input.
out_channels (int): Number of channels in the output.
num_channels (int): Number of channels in each block of the network.
num_blocks (int): Number of blocks in the network.
num_components (int): Number of components in the mixture.
drop_prob (float): Dropout probability.
use_attn (bool): Use attention in each block.
aux_channels (int): Number of channels in optional auxiliary input.
"""
def __init__(self, in_channels, out_channels, num_channels, num_layers, num_heads, num_components, drop_prob, use_pos_emb, use_rel_pos_emb, input_length, concat_dims, output_length):
#import pdb;pdb.set_trace()
super(TransformerNN, self).__init__()
self.k = num_components # k = number of mixture components
# import pdb;pdb.set_trace()
self.transformer = BasicTransformerModel(out_channels * (2 + 3 * self.k), in_channels, num_heads, num_channels, num_layers, drop_prob, use_pos_emb=use_pos_emb, use_rel_pos_emb=use_rel_pos_emb, input_length=input_length)
self.rescale = weight_norm(Rescale(out_channels))
self.out_channels = out_channels
self.concat_dims = concat_dims
self.output_length = output_length
def forward(self, x, aux=None):
b, c, h, w = x.size()
# import pdb;pdb.set_trace()
x = x.squeeze(-1) # only squeeze the w dimension (important coz otherwise it would squeeze batch dim if theres only one element in minibatch..
# import pdb;pdb.set_trace()
x = x.permute(2,0,1)
# import pdb;pdb.set_trace()
if self.concat_dims:
x = self.transformer(x)
# x = torch.mean(self.transformer(x), dim=0, keepdim=True)
# x = 0.5*x + 0.5*torch.mean(x, dim=0, keepdim=True)
# x = self.transformer(x)[:1]
else:
x = self.transformer(x)[:self.output_length]
# import pdb;pdb.set_trace()
x = x.permute(1,2,0)
# Split into components and post-process
if self.concat_dims:
x = x.view(b, -1, self.out_channels, h, w)
# x = x.view(b, -1, self.out_channels, 1, w)
else:
x = x.view(b, -1, self.out_channels, self.output_length, w)
s, t, pi, mu, scales = x.split((1, 1, self.k, self.k, self.k), dim=1)
s = self.rescale(torch.tanh(s.squeeze(1)))
t = t.squeeze(1)
scales = scales.clamp(min=-7) # From the code in original Flow++ paper
return s, t, pi, mu, scales
class Rescale(nn.Module):
"""Per-channel rescaling. Need a proper `nn.Module` so we can wrap it
with `torch.nn.utils.weight_norm`.
Args:
num_channels (int): Number of channels in the input.
"""
def __init__(self, num_channels):
super(Rescale, self).__init__()
self.weight = nn.Parameter(torch.ones(num_channels, 1, 1))
def forward(self, x):
x = self.weight * x
return x