import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from tqdm import tqdm from . import thops from . import modules from . import utils from models.transformer import BasicTransformerModelCausal def nan_throw(tensor, name="tensor"): stop = False if ((tensor!=tensor).any()): print(name + " has nans") stop = True if (torch.isinf(tensor).any()): print(name + " has infs") stop = True if stop: print(name + ": " + str(tensor)) #raise ValueError(name + ' contains nans of infs') def f(in_channels, out_channels, hidden_channels, cond_channels, network_model, num_layers): if network_model=="transformer": #return BasicTransformerModel(out_channels, in_channels + cond_channels, 10, hidden_channels, num_layers, use_pos_emb=True) return BasicTransformerModelCausal(out_channels, in_channels + cond_channels, 10, hidden_channels, num_layers, use_pos_emb=True, input_length=70) if network_model=="LSTM": return modules.LSTM(in_channels + cond_channels, hidden_channels, out_channels, num_layers) if network_model=="GRU": return modules.GRU(in_channels + cond_channels, hidden_channels, out_channels, num_layers) if network_model=="FF": return nn.Sequential( nn.Linear(in_channels+cond_channels, hidden_channels), nn.ReLU(inplace=False), nn.Linear(hidden_channels, hidden_channels), nn.ReLU(inplace=False), modules.LinearZeroInit(hidden_channels, out_channels)) class FlowStep(nn.Module): FlowCoupling = ["additive", "affine"] NetworkModel = ["transformer","LSTM", "GRU", "FF"] FlowPermutation = { "reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet), "shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet), "invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev) } def __init__(self, in_channels, hidden_channels, cond_channels, actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive", network_model="LSTM", num_layers=2, LU_decomposed=False): # check configures assert flow_coupling in FlowStep.FlowCoupling,\ "flow_coupling should be in `{}`".format(FlowStep.FlowCoupling) assert network_model in FlowStep.NetworkModel,\ "network_model should be in `{}`".format(FlowStep.NetworkModel) assert flow_permutation in FlowStep.FlowPermutation,\ "float_permutation should be in `{}`".format( FlowStep.FlowPermutation.keys()) super().__init__() self.flow_permutation = flow_permutation self.flow_coupling = flow_coupling self.network_model = network_model # 1. actnorm self.actnorm = modules.ActNorm2d(in_channels, actnorm_scale) # 2. permute if flow_permutation == "invconv": self.invconv = modules.InvertibleConv1x1( in_channels, LU_decomposed=LU_decomposed) elif flow_permutation == "shuffle": self.shuffle = modules.Permute2d(in_channels, shuffle=True) else: self.reverse = modules.Permute2d(in_channels, shuffle=False) # 3. coupling if flow_coupling == "additive": self.f = f(in_channels // 2, in_channels-in_channels // 2, hidden_channels, cond_channels, network_model, num_layers) elif flow_coupling == "affine": print("affine: in_channels = " + str(in_channels)) self.f = f(in_channels // 2, 2*(in_channels-in_channels // 2), hidden_channels, cond_channels, network_model, num_layers) print("Flowstep affine layer: " + str(in_channels)) def init_lstm_hidden(self): if self.network_model == "LSTM" or self.network_model == "GRU": self.f.init_hidden() def forward(self, input, cond, logdet=None, reverse=False): if not reverse: return self.normal_flow(input, cond, logdet) else: return self.reverse_flow(input, cond, logdet) def normal_flow(self, input, cond, logdet): #assert input.size(1) % 2 == 0 # 1. actnorm #z=input z, logdet = self.actnorm(input, logdet=logdet, reverse=False) # 2. permute z, logdet = FlowStep.FlowPermutation[self.flow_permutation]( self, z, logdet, False) # 3. coupling z1, z2 = thops.split_feature(z, "split") z1_cond = torch.cat((z1, cond), dim=1) if self.flow_coupling == "additive": z2 = z2 + self.f(z1_cond) elif self.flow_coupling == "affine": # import pdb;pdb.set_trace() if self.network_model=="transformer": h = self.f(z1_cond.permute(2,0,1)).permute(1,2,0) else: h = self.f(z1_cond.permute(0, 2, 1)).permute(0, 2, 1) shift, scale = thops.split_feature(h, "cross") scale = torch.sigmoid(scale + 2.)+1e-6 z2 = z2 + shift z2 = z2 * scale logdet = thops.sum(torch.log(scale), dim=[1, 2]) + logdet z = thops.cat_feature(z1, z2) return z, cond, logdet def reverse_flow(self, input, cond, logdet): # 1.coupling z1, z2 = thops.split_feature(input, "split") # import pdb;pdb.set_trace() z1_cond = torch.cat((z1, cond), dim=1) if self.flow_coupling == "additive": z2 = z2 - self.f(z1_cond) elif self.flow_coupling == "affine": h = self.f(z1_cond.permute(0, 2, 1)).permute(0, 2, 1) shift, scale = thops.split_feature(h, "cross") nan_throw(shift, "shift") nan_throw(scale, "scale") nan_throw(z2, "z2 unscaled") scale = torch.sigmoid(scale + 2.)+1e-6 z2 = z2 / scale z2 = z2 - shift logdet = -thops.sum(torch.log(scale), dim=[1, 2]) + logdet z = thops.cat_feature(z1, z2) # 2. permute z, logdet = FlowStep.FlowPermutation[self.flow_permutation]( self, z, logdet, True) nan_throw(z, "z permute_" + str(self.flow_permutation)) # 3. actnorm z, logdet = self.actnorm(z, logdet=logdet, reverse=True) return z, cond, logdet class FlowNet(nn.Module): def __init__(self, x_channels, hidden_channels, cond_channels, K, actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive", network_model="LSTM", num_layers=2, LU_decomposed=False): super().__init__() self.layers = nn.ModuleList() self.output_shapes = [] self.K = K N = cond_channels for _ in range(K): self.layers.append( FlowStep(in_channels=x_channels, hidden_channels=hidden_channels, cond_channels=N, actnorm_scale=actnorm_scale, flow_permutation=flow_permutation, flow_coupling=flow_coupling, network_model=network_model, num_layers=num_layers, LU_decomposed=LU_decomposed)) self.output_shapes.append( [-1, x_channels, 1]) # import pdb;pdb.set_trace() def init_lstm_hidden(self): for layer in self.layers: if isinstance(layer, FlowStep): layer.init_lstm_hidden() def forward(self, z, cond, logdet=0., reverse=False, eps_std=None): if not reverse: for layer in self.layers: z, cond, logdet = layer(z, cond, logdet, reverse=False) return z, logdet else: for i,layer in enumerate(reversed(self.layers)): z, cond, logdet = layer(z, cond, logdet=0, reverse=True) return z class Glow(nn.Module): def __init__(self, x_channels, cond_channels, opt): super().__init__() self.flow = FlowNet(x_channels=x_channels, hidden_channels=opt.dhid, cond_channels=cond_channels, K=opt.glow_K, actnorm_scale=opt.actnorm_scale, flow_permutation=opt.flow_permutation, flow_coupling=opt.flow_coupling, network_model=opt.network_model, num_layers=opt.num_layers, LU_decomposed=opt.LU_decomposed) self.opt = opt # register prior hidden # num_device = len(utils.get_proper_device(hparams.Device.glow, False)) # assert hparams.Train.batch_size % num_device == 0 # self.z_shape = [opt.batch_size // num_device, x_channels, 1] self.z_shape = [opt.batch_size, x_channels, 1] if opt.flow_dist == "normal": self.distribution = modules.GaussianDiag() elif opt.flow_dist == "studentT": self.distribution = modules.StudentT(opt.flow_dist_param, x_channels) def init_lstm_hidden(self): self.flow.init_lstm_hidden() def forward(self, x=None, cond=None, z=None, eps_std=None, reverse=False, output_length=1): if not reverse: return self.normal_flow(x, cond) else: return self.reverse_flow(z, cond, eps_std, output_length=output_length) def normal_flow(self, x, cond): n_timesteps = thops.timesteps(x) #just returns the size of dimension 2? logdet = torch.zeros_like(x[:, 0, 0]) # encode z, objective = self.flow(x, cond, logdet=logdet, reverse=False) # prior objective += self.distribution.logp(z) # return nll = (-objective) / float(np.log(2.) * n_timesteps) return z, nll def reverse_flow(self, z, cond, eps_std, output_length=1): with torch.no_grad(): z_shape = self.z_shape z_shape[-1] = output_length if z is None: z = self.distribution.sample(z_shape, eps_std, device=cond.device) x = self.flow(z, cond, eps_std=eps_std, reverse=True) return x def set_actnorm_init(self, inited=True): for name, m in self.named_modules(): if (m.__class__.__name__.find("ActNorm") >= 0): m.inited = inited @staticmethod def loss_generative(nll): # Generative loss return torch.mean(nll)