import math import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from models.flowplusplus.act_norm import ActNorm, BatchNorm from models.flowplusplus.inv_conv import InvConv, InvertibleConv1x1 from models.flowplusplus.nn import GatedConv from models.flowplusplus.coupling import Coupling from models.util import channelwise, checkerboard, Flip, safe_log, squeeze, unsqueeze from models.moglow.modules import GaussianDiag, StudentT class FlowPlusPlus(nn.Module): """Flow++ Model Based on the paper: "Flow++: Improving Flow-Based Generative Models with Variational Dequantization and Architecture Design" by Jonathan Ho, Xi Chen, Aravind Srinivas, Yan Duan, Pieter Abbeel (https://openreview.net/forum?id=Hyg74h05tX). Args: scales (tuple or list): Number of each type of coupling layer in each scale. Each scale is a 2-tuple of the form (num_channelwise, num_checkerboard). in_channels (int): Number of channels in the input. mid_channels (int): Number of channels in the intermediate layers. num_blocks (int): Number of residual blocks in the s and t network of `Coupling` layers. num_dequant_blocks (int): Number of blocks in the dequantization flows. """ def __init__(self, scales=((0, 4), (2, 3)), in_shape=(3, 32, 32), cond_dim=0, mid_channels=96, num_blocks=10, num_components=32, use_attn=True, use_logmix=True, use_transformer_nn=False, use_pos_emb=False, use_rel_pos_emb=False, num_heads=10, drop_prob=0.2, norm_layer=None, cond_concat_dims=True, cond_seq_len=1, flow_dist="normal", flow_dist_param=50, bn_momentum=0.1): super(FlowPlusPlus, self).__init__() # Register bounds to pre-process images, not learnable self.register_buffer('bounds', torch.tensor([0.9], dtype=torch.float32)) self.flows = _FlowStep(scales=scales, in_shape=in_shape, cond_dim=cond_dim, mid_channels=mid_channels, num_blocks=num_blocks, num_components=num_components, use_attn=use_attn, use_logmix=use_logmix, use_transformer_nn=use_transformer_nn, use_pos_emb=use_pos_emb, use_rel_pos_emb=use_rel_pos_emb, num_heads=num_heads, drop_prob=drop_prob, norm_layer=norm_layer, cond_concat_dims=cond_concat_dims, cond_seq_len=cond_seq_len, bn_momentum=bn_momentum) if flow_dist == "normal": self.distribution = GaussianDiag() elif flow_dist == "studentT": in_channels, in_height, in_width = in_shape self.distribution = StudentT(flow_dist_param, in_channels) def forward(self, x, cond, reverse=False): if cond is not None: cond = cond.permute(0,2,1).unsqueeze(3) if not reverse: if x is not None: x = x.permute(0,2,1).unsqueeze(3) else: c, h, w = self.flows.z_dim() # x = 1.0*torch.randn((cond.size(0), c, h, w), dtype=torch.float32).type_as(cond) eps_std=1.0 # x = self.distribution.sample((cond.size(0), c, h, w), eps_std, device=cond.device).type_as(cond) assert w==1 x = self.distribution.sample((cond.size(0), c, h), eps_std, device=cond.device).type_as(cond) x = x.unsqueeze(-1) # import pdb;pdb.set_trace() sldj = torch.zeros(x.size(0), device=x.device) x, sldj = self.flows(x, cond, sldj, reverse) if reverse: if x is not None: x = x.squeeze(3).permute(0,2,1) return x, sldj def loss_generative(self, z, sldj): """Negative log-likelihood loss assuming isotropic gaussian with unit norm. Args: k (int or float): Number of discrete values in each input dimension. E.g., `k` is 256 for natural images. See Also: Equation (3) in the RealNVP paper: https://arxiv.org/abs/1605.08803 """ # print(z) # prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi)) # prior_ll = prior_ll.flatten(1).sum(-1)# \ prior_ll = self.distribution.logp(z) prior_ll = prior_ll.flatten(1).sum(-1)# \ # import pdb;pdb.set_trace() # - np.log(k) * np.prod(z.size()[1:]) ll = prior_ll + sldj # print(sldj.mean()) # import pdb;pdb.set_trace() nll = -ll.mean()/float(np.log(2.) * z.size(2) * z.size(3)) # nll = -ll.mean()/float(np.log(2.)) return nll class _FlowStep(nn.Module): """Recursive builder for a Flow++ model. Each `_FlowStep` corresponds to a single scale in Flow++. The constructor is recursively called to build a full model. Args: scales (tuple): Number of each type of coupling layer in each scale. Each scale is a 2-tuple of the form (num_channelwise, num_checkerboard). in_channels (int): Number of channels in the input. mid_channels (int): Number of channels in the intermediate layers. num_blocks (int): Number of residual blocks in the s and t network of `Coupling` layers. num_components (int): Number of components in the mixture. use_attn (bool): Use attention in the coupling layers. drop_prob (float): Dropout probability. """ def __init__(self, scales, in_shape, cond_dim, mid_channels, num_blocks, num_components, use_attn, use_logmix, use_transformer_nn, use_pos_emb, use_rel_pos_emb, num_heads, drop_prob, norm_layer, bn_momentum, cond_concat_dims, cond_seq_len): super(_FlowStep, self).__init__() in_channels, in_height, in_width = in_shape num_channelwise, num_checkerboard = scales[0] #import pdb;pdb.set_trace() channels = [] for i in range(num_channelwise): new_channels = in_channels// 2 out_channels = in_channels-new_channels # print(norm_layer) if norm_layer == "batchnorm": channels += [BatchNorm(in_channels, bn_momentum)] elif norm_layer == "actnorm": channels += [ActNorm(in_channels)] if cond_concat_dims: c_in_channels = new_channels + cond_dim seq_length = in_height else: c_in_channels = new_channels seq_length = in_height + cond_seq_len channels += [InvertibleConv1x1(in_channels)] channels += [Coupling(in_channels=c_in_channels, cond_dim=cond_dim, out_channels=out_channels, mid_channels=mid_channels, num_blocks=num_blocks, num_components=num_components, use_attn=use_attn, use_logmix=use_logmix, use_transformer_nn=use_transformer_nn, use_pos_emb=use_pos_emb, use_rel_pos_emb=use_rel_pos_emb, num_heads=num_heads, seq_length=seq_length, output_length=in_height, concat_dims=cond_concat_dims, drop_prob=drop_prob)]#, #Flip()] Flip currently does not work with odd number of channels. But is it needed when we have channel mixing with 1x1convs? checkers = [] if cond_concat_dims: c_in_channels = new_channels + cond_dim seq_length = in_height else: c_in_channels = new_channels seq_length = in_height + cond_seq_len for i in range(num_checkerboard): if norm_layer == "batchnorm": checkers += [BatchNorm(in_channels, bn_momentum)] elif norm_layer == "actnorm": checkers += [ActNorm(in_channels)] checkers += [InvertibleConv1x1(in_channels)] checkers += [Coupling(in_channels=c_in_channels, out_channels=in_channels, mid_channels=mid_channels, num_blocks=num_blocks, num_components=num_components, use_attn=use_attn, use_logmix=use_logmix, use_transformer_nn=use_transformer_nn, use_pos_emb=use_pos_emb, use_rel_pos_emb=use_rel_pos_emb, num_heads=num_heads, seq_length=seq_length, output_length=in_height, concat_dims=cond_concat_dims, drop_prob=drop_prob)]#, #Flip()] self.channels = nn.ModuleList(channels) if channels else None self.checkers = nn.ModuleList(checkers) if checkers else None if len(scales) <= 1: self.next = None else: next_shape = (in_channels, in_height // 2, in_width) self.next = _FlowStep(scales=scales[1:], in_shape=next_shape, cond_dim=2*cond_dim, mid_channels=mid_channels, num_blocks=num_blocks, num_components=num_components, use_attn=use_attn, use_logmix=use_logmix, use_transformer_nn=use_transformer_nn, use_pos_emb=use_pos_emb, use_rel_pos_emb=use_rel_pos_emb, num_heads=num_heads, norm_layer = norm_layer, bn_momentum = bn_momentum, cond_concat_dims = cond_concat_dims, cond_seq_len = cond_seq_len, drop_prob=drop_prob) self.z_shape = (in_channels, in_height, in_width) def z_dim(self): return self.z_shape def forward(self, x, cond, sldj, reverse=False): if reverse: #import pdb;pdb.set_trace() if self.next is not None: x = squeeze(x) cond = squeeze(cond) x, x_split = x.chunk(2, dim=1) x, sldj = self.next(x, cond, sldj, reverse) x = torch.cat((x, x_split), dim=1) x = unsqueeze(x) cond = unsqueeze(cond) if self.checkers: x = checkerboard(x) for flow in reversed(self.checkers): x, sldj = flow(x, cond, sldj, reverse) x = checkerboard(x, reverse=True) if self.channels: x = channelwise(x) for flow in reversed(self.channels): x, sldj = flow(x, cond, sldj, reverse) x = channelwise(x, reverse=True) else: # import pdb;pdb.set_trace() if self.channels: x = channelwise(x) for flow in self.channels: # import pdb;pdb.set_trace() x, sldj = flow(x, cond, sldj, reverse) # print(type(flow).__name__) # print(x[0].std()) x = channelwise(x, reverse=True) if self.checkers: x = checkerboard(x) for flow in self.checkers: x, sldj = flow(x, cond, sldj, reverse) x = checkerboard(x, reverse=True) if self.next is not None: # import pdb;pdb.set_trace() # here we apply the flow steps but only to dimensions sampled at a lower scale. Hmm feels a bit weird x = squeeze(x) cond = squeeze(cond) x, x_split = x.chunk(2, dim=1) x, sldj = self.next(x, cond, sldj, reverse) x = torch.cat((x, x_split), dim=1) x = unsqueeze(x) # print(x.std()) return x, sldj