myTest01 / models /flowplusplus /flowplusplus.py
meng2003's picture
Upload 85 files
bc32eea
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.flowplusplus.act_norm import ActNorm, BatchNorm
from models.flowplusplus.inv_conv import InvConv, InvertibleConv1x1
from models.flowplusplus.nn import GatedConv
from models.flowplusplus.coupling import Coupling
from models.util import channelwise, checkerboard, Flip, safe_log, squeeze, unsqueeze
from models.moglow.modules import GaussianDiag, StudentT
class FlowPlusPlus(nn.Module):
"""Flow++ Model
Based on the paper:
"Flow++: Improving Flow-Based Generative Models
with Variational Dequantization and Architecture Design"
by Jonathan Ho, Xi Chen, Aravind Srinivas, Yan Duan, Pieter Abbeel
(https://openreview.net/forum?id=Hyg74h05tX).
Args:
scales (tuple or list): Number of each type of coupling layer in each
scale. Each scale is a 2-tuple of the form
(num_channelwise, num_checkerboard).
in_channels (int): Number of channels in the input.
mid_channels (int): Number of channels in the intermediate layers.
num_blocks (int): Number of residual blocks in the s and t network of
`Coupling` layers.
num_dequant_blocks (int): Number of blocks in the dequantization flows.
"""
def __init__(self,
scales=((0, 4), (2, 3)),
in_shape=(3, 32, 32),
cond_dim=0,
mid_channels=96,
num_blocks=10,
num_components=32,
use_attn=True,
use_logmix=True,
use_transformer_nn=False,
use_pos_emb=False,
use_rel_pos_emb=False,
num_heads=10,
drop_prob=0.2,
norm_layer=None,
cond_concat_dims=True,
cond_seq_len=1,
flow_dist="normal",
flow_dist_param=50,
bn_momentum=0.1):
super(FlowPlusPlus, self).__init__()
# Register bounds to pre-process images, not learnable
self.register_buffer('bounds', torch.tensor([0.9], dtype=torch.float32))
self.flows = _FlowStep(scales=scales,
in_shape=in_shape,
cond_dim=cond_dim,
mid_channels=mid_channels,
num_blocks=num_blocks,
num_components=num_components,
use_attn=use_attn,
use_logmix=use_logmix,
use_transformer_nn=use_transformer_nn,
use_pos_emb=use_pos_emb,
use_rel_pos_emb=use_rel_pos_emb,
num_heads=num_heads,
drop_prob=drop_prob,
norm_layer=norm_layer,
cond_concat_dims=cond_concat_dims,
cond_seq_len=cond_seq_len,
bn_momentum=bn_momentum)
if flow_dist == "normal":
self.distribution = GaussianDiag()
elif flow_dist == "studentT":
in_channels, in_height, in_width = in_shape
self.distribution = StudentT(flow_dist_param, in_channels)
def forward(self, x, cond, reverse=False):
if cond is not None:
cond = cond.permute(0,2,1).unsqueeze(3)
if not reverse:
if x is not None:
x = x.permute(0,2,1).unsqueeze(3)
else:
c, h, w = self.flows.z_dim()
# x = 1.0*torch.randn((cond.size(0), c, h, w), dtype=torch.float32).type_as(cond)
eps_std=1.0
# x = self.distribution.sample((cond.size(0), c, h, w), eps_std, device=cond.device).type_as(cond)
assert w==1
x = self.distribution.sample((cond.size(0), c, h), eps_std, device=cond.device).type_as(cond)
x = x.unsqueeze(-1)
# import pdb;pdb.set_trace()
sldj = torch.zeros(x.size(0), device=x.device)
x, sldj = self.flows(x, cond, sldj, reverse)
if reverse:
if x is not None:
x = x.squeeze(3).permute(0,2,1)
return x, sldj
def loss_generative(self, z, sldj):
"""Negative log-likelihood loss assuming isotropic gaussian with unit norm.
Args:
k (int or float): Number of discrete values in each input dimension.
E.g., `k` is 256 for natural images.
See Also:
Equation (3) in the RealNVP paper: https://arxiv.org/abs/1605.08803
"""
# print(z)
# prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi))
# prior_ll = prior_ll.flatten(1).sum(-1)# \
prior_ll = self.distribution.logp(z)
prior_ll = prior_ll.flatten(1).sum(-1)# \
# import pdb;pdb.set_trace()
# - np.log(k) * np.prod(z.size()[1:])
ll = prior_ll + sldj
# print(sldj.mean())
# import pdb;pdb.set_trace()
nll = -ll.mean()/float(np.log(2.) * z.size(2) * z.size(3))
# nll = -ll.mean()/float(np.log(2.))
return nll
class _FlowStep(nn.Module):
"""Recursive builder for a Flow++ model.
Each `_FlowStep` corresponds to a single scale in Flow++.
The constructor is recursively called to build a full model.
Args:
scales (tuple): Number of each type of coupling layer in each scale.
Each scale is a 2-tuple of the form (num_channelwise, num_checkerboard).
in_channels (int): Number of channels in the input.
mid_channels (int): Number of channels in the intermediate layers.
num_blocks (int): Number of residual blocks in the s and t network of
`Coupling` layers.
num_components (int): Number of components in the mixture.
use_attn (bool): Use attention in the coupling layers.
drop_prob (float): Dropout probability.
"""
def __init__(self, scales, in_shape, cond_dim, mid_channels, num_blocks, num_components, use_attn, use_logmix, use_transformer_nn, use_pos_emb, use_rel_pos_emb, num_heads, drop_prob, norm_layer, bn_momentum, cond_concat_dims, cond_seq_len):
super(_FlowStep, self).__init__()
in_channels, in_height, in_width = in_shape
num_channelwise, num_checkerboard = scales[0]
#import pdb;pdb.set_trace()
channels = []
for i in range(num_channelwise):
new_channels = in_channels// 2
out_channels = in_channels-new_channels
# print(norm_layer)
if norm_layer == "batchnorm":
channels += [BatchNorm(in_channels, bn_momentum)]
elif norm_layer == "actnorm":
channels += [ActNorm(in_channels)]
if cond_concat_dims:
c_in_channels = new_channels + cond_dim
seq_length = in_height
else:
c_in_channels = new_channels
seq_length = in_height + cond_seq_len
channels += [InvertibleConv1x1(in_channels)]
channels += [Coupling(in_channels=c_in_channels,
cond_dim=cond_dim,
out_channels=out_channels,
mid_channels=mid_channels,
num_blocks=num_blocks,
num_components=num_components,
use_attn=use_attn,
use_logmix=use_logmix,
use_transformer_nn=use_transformer_nn,
use_pos_emb=use_pos_emb,
use_rel_pos_emb=use_rel_pos_emb,
num_heads=num_heads,
seq_length=seq_length,
output_length=in_height,
concat_dims=cond_concat_dims,
drop_prob=drop_prob)]#,
#Flip()] Flip currently does not work with odd number of channels. But is it needed when we have channel mixing with 1x1convs?
checkers = []
if cond_concat_dims:
c_in_channels = new_channels + cond_dim
seq_length = in_height
else:
c_in_channels = new_channels
seq_length = in_height + cond_seq_len
for i in range(num_checkerboard):
if norm_layer == "batchnorm":
checkers += [BatchNorm(in_channels, bn_momentum)]
elif norm_layer == "actnorm":
checkers += [ActNorm(in_channels)]
checkers += [InvertibleConv1x1(in_channels)]
checkers += [Coupling(in_channels=c_in_channels,
out_channels=in_channels,
mid_channels=mid_channels,
num_blocks=num_blocks,
num_components=num_components,
use_attn=use_attn,
use_logmix=use_logmix,
use_transformer_nn=use_transformer_nn,
use_pos_emb=use_pos_emb,
use_rel_pos_emb=use_rel_pos_emb,
num_heads=num_heads,
seq_length=seq_length,
output_length=in_height,
concat_dims=cond_concat_dims,
drop_prob=drop_prob)]#,
#Flip()]
self.channels = nn.ModuleList(channels) if channels else None
self.checkers = nn.ModuleList(checkers) if checkers else None
if len(scales) <= 1:
self.next = None
else:
next_shape = (in_channels, in_height // 2, in_width)
self.next = _FlowStep(scales=scales[1:],
in_shape=next_shape,
cond_dim=2*cond_dim,
mid_channels=mid_channels,
num_blocks=num_blocks,
num_components=num_components,
use_attn=use_attn,
use_logmix=use_logmix,
use_transformer_nn=use_transformer_nn,
use_pos_emb=use_pos_emb,
use_rel_pos_emb=use_rel_pos_emb,
num_heads=num_heads,
norm_layer = norm_layer,
bn_momentum = bn_momentum,
cond_concat_dims = cond_concat_dims,
cond_seq_len = cond_seq_len,
drop_prob=drop_prob)
self.z_shape = (in_channels, in_height, in_width)
def z_dim(self):
return self.z_shape
def forward(self, x, cond, sldj, reverse=False):
if reverse:
#import pdb;pdb.set_trace()
if self.next is not None:
x = squeeze(x)
cond = squeeze(cond)
x, x_split = x.chunk(2, dim=1)
x, sldj = self.next(x, cond, sldj, reverse)
x = torch.cat((x, x_split), dim=1)
x = unsqueeze(x)
cond = unsqueeze(cond)
if self.checkers:
x = checkerboard(x)
for flow in reversed(self.checkers):
x, sldj = flow(x, cond, sldj, reverse)
x = checkerboard(x, reverse=True)
if self.channels:
x = channelwise(x)
for flow in reversed(self.channels):
x, sldj = flow(x, cond, sldj, reverse)
x = channelwise(x, reverse=True)
else:
# import pdb;pdb.set_trace()
if self.channels:
x = channelwise(x)
for flow in self.channels:
# import pdb;pdb.set_trace()
x, sldj = flow(x, cond, sldj, reverse)
# print(type(flow).__name__)
# print(x[0].std())
x = channelwise(x, reverse=True)
if self.checkers:
x = checkerboard(x)
for flow in self.checkers:
x, sldj = flow(x, cond, sldj, reverse)
x = checkerboard(x, reverse=True)
if self.next is not None:
# import pdb;pdb.set_trace()
# here we apply the flow steps but only to dimensions sampled at a lower scale. Hmm feels a bit weird
x = squeeze(x)
cond = squeeze(cond)
x, x_split = x.chunk(2, dim=1)
x, sldj = self.next(x, cond, sldj, reverse)
x = torch.cat((x, x_split), dim=1)
x = unsqueeze(x)
# print(x.std())
return x, sldj