Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from tqdm import tqdm | |
| from . import thops | |
| from . import modules | |
| from . import utils | |
| from models.transformer import BasicTransformerModelCausal | |
| def nan_throw(tensor, name="tensor"): | |
| stop = False | |
| if ((tensor!=tensor).any()): | |
| print(name + " has nans") | |
| stop = True | |
| if (torch.isinf(tensor).any()): | |
| print(name + " has infs") | |
| stop = True | |
| if stop: | |
| print(name + ": " + str(tensor)) | |
| #raise ValueError(name + ' contains nans of infs') | |
| def f(in_channels, out_channels, hidden_channels, cond_channels, network_model, num_layers): | |
| if network_model=="transformer": | |
| #return BasicTransformerModel(out_channels, in_channels + cond_channels, 10, hidden_channels, num_layers, use_pos_emb=True) | |
| return BasicTransformerModelCausal(out_channels, in_channels + cond_channels, 10, hidden_channels, num_layers, use_pos_emb=True, input_length=70) | |
| if network_model=="LSTM": | |
| return modules.LSTM(in_channels + cond_channels, hidden_channels, out_channels, num_layers) | |
| if network_model=="GRU": | |
| return modules.GRU(in_channels + cond_channels, hidden_channels, out_channels, num_layers) | |
| if network_model=="FF": | |
| return nn.Sequential( | |
| nn.Linear(in_channels+cond_channels, hidden_channels), nn.ReLU(inplace=False), | |
| nn.Linear(hidden_channels, hidden_channels), nn.ReLU(inplace=False), | |
| modules.LinearZeroInit(hidden_channels, out_channels)) | |
| class FlowStep(nn.Module): | |
| FlowCoupling = ["additive", "affine"] | |
| NetworkModel = ["transformer","LSTM", "GRU", "FF"] | |
| FlowPermutation = { | |
| "reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet), | |
| "shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet), | |
| "invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev) | |
| } | |
| def __init__(self, in_channels, hidden_channels, cond_channels, | |
| actnorm_scale=1.0, | |
| flow_permutation="invconv", | |
| flow_coupling="additive", | |
| network_model="LSTM", | |
| num_layers=2, | |
| LU_decomposed=False): | |
| # check configures | |
| assert flow_coupling in FlowStep.FlowCoupling,\ | |
| "flow_coupling should be in `{}`".format(FlowStep.FlowCoupling) | |
| assert network_model in FlowStep.NetworkModel,\ | |
| "network_model should be in `{}`".format(FlowStep.NetworkModel) | |
| assert flow_permutation in FlowStep.FlowPermutation,\ | |
| "float_permutation should be in `{}`".format( | |
| FlowStep.FlowPermutation.keys()) | |
| super().__init__() | |
| self.flow_permutation = flow_permutation | |
| self.flow_coupling = flow_coupling | |
| self.network_model = network_model | |
| # 1. actnorm | |
| self.actnorm = modules.ActNorm2d(in_channels, actnorm_scale) | |
| # 2. permute | |
| if flow_permutation == "invconv": | |
| self.invconv = modules.InvertibleConv1x1( | |
| in_channels, LU_decomposed=LU_decomposed) | |
| elif flow_permutation == "shuffle": | |
| self.shuffle = modules.Permute2d(in_channels, shuffle=True) | |
| else: | |
| self.reverse = modules.Permute2d(in_channels, shuffle=False) | |
| # 3. coupling | |
| if flow_coupling == "additive": | |
| self.f = f(in_channels // 2, in_channels-in_channels // 2, hidden_channels, cond_channels, network_model, num_layers) | |
| elif flow_coupling == "affine": | |
| print("affine: in_channels = " + str(in_channels)) | |
| self.f = f(in_channels // 2, 2*(in_channels-in_channels // 2), hidden_channels, cond_channels, network_model, num_layers) | |
| print("Flowstep affine layer: " + str(in_channels)) | |
| def init_lstm_hidden(self): | |
| if self.network_model == "LSTM" or self.network_model == "GRU": | |
| self.f.init_hidden() | |
| def forward(self, input, cond, logdet=None, reverse=False): | |
| if not reverse: | |
| return self.normal_flow(input, cond, logdet) | |
| else: | |
| return self.reverse_flow(input, cond, logdet) | |
| def normal_flow(self, input, cond, logdet): | |
| #assert input.size(1) % 2 == 0 | |
| # 1. actnorm | |
| #z=input | |
| z, logdet = self.actnorm(input, logdet=logdet, reverse=False) | |
| # 2. permute | |
| z, logdet = FlowStep.FlowPermutation[self.flow_permutation]( | |
| self, z, logdet, False) | |
| # 3. coupling | |
| z1, z2 = thops.split_feature(z, "split") | |
| z1_cond = torch.cat((z1, cond), dim=1) | |
| if self.flow_coupling == "additive": | |
| z2 = z2 + self.f(z1_cond) | |
| elif self.flow_coupling == "affine": | |
| # import pdb;pdb.set_trace() | |
| if self.network_model=="transformer": | |
| h = self.f(z1_cond.permute(2,0,1)).permute(1,2,0) | |
| else: | |
| h = self.f(z1_cond.permute(0, 2, 1)).permute(0, 2, 1) | |
| shift, scale = thops.split_feature(h, "cross") | |
| scale = torch.sigmoid(scale + 2.)+1e-6 | |
| z2 = z2 + shift | |
| z2 = z2 * scale | |
| logdet = thops.sum(torch.log(scale), dim=[1, 2]) + logdet | |
| z = thops.cat_feature(z1, z2) | |
| return z, cond, logdet | |
| def reverse_flow(self, input, cond, logdet): | |
| # 1.coupling | |
| z1, z2 = thops.split_feature(input, "split") | |
| # import pdb;pdb.set_trace() | |
| z1_cond = torch.cat((z1, cond), dim=1) | |
| if self.flow_coupling == "additive": | |
| z2 = z2 - self.f(z1_cond) | |
| elif self.flow_coupling == "affine": | |
| h = self.f(z1_cond.permute(0, 2, 1)).permute(0, 2, 1) | |
| shift, scale = thops.split_feature(h, "cross") | |
| nan_throw(shift, "shift") | |
| nan_throw(scale, "scale") | |
| nan_throw(z2, "z2 unscaled") | |
| scale = torch.sigmoid(scale + 2.)+1e-6 | |
| z2 = z2 / scale | |
| z2 = z2 - shift | |
| logdet = -thops.sum(torch.log(scale), dim=[1, 2]) + logdet | |
| z = thops.cat_feature(z1, z2) | |
| # 2. permute | |
| z, logdet = FlowStep.FlowPermutation[self.flow_permutation]( | |
| self, z, logdet, True) | |
| nan_throw(z, "z permute_" + str(self.flow_permutation)) | |
| # 3. actnorm | |
| z, logdet = self.actnorm(z, logdet=logdet, reverse=True) | |
| return z, cond, logdet | |
| class FlowNet(nn.Module): | |
| def __init__(self, x_channels, hidden_channels, cond_channels, K, | |
| actnorm_scale=1.0, | |
| flow_permutation="invconv", | |
| flow_coupling="additive", | |
| network_model="LSTM", | |
| num_layers=2, | |
| LU_decomposed=False): | |
| super().__init__() | |
| self.layers = nn.ModuleList() | |
| self.output_shapes = [] | |
| self.K = K | |
| N = cond_channels | |
| for _ in range(K): | |
| self.layers.append( | |
| FlowStep(in_channels=x_channels, | |
| hidden_channels=hidden_channels, | |
| cond_channels=N, | |
| actnorm_scale=actnorm_scale, | |
| flow_permutation=flow_permutation, | |
| flow_coupling=flow_coupling, | |
| network_model=network_model, | |
| num_layers=num_layers, | |
| LU_decomposed=LU_decomposed)) | |
| self.output_shapes.append( | |
| [-1, x_channels, 1]) | |
| # import pdb;pdb.set_trace() | |
| def init_lstm_hidden(self): | |
| for layer in self.layers: | |
| if isinstance(layer, FlowStep): | |
| layer.init_lstm_hidden() | |
| def forward(self, z, cond, logdet=0., reverse=False, eps_std=None): | |
| if not reverse: | |
| for layer in self.layers: | |
| z, cond, logdet = layer(z, cond, logdet, reverse=False) | |
| return z, logdet | |
| else: | |
| for i,layer in enumerate(reversed(self.layers)): | |
| z, cond, logdet = layer(z, cond, logdet=0, reverse=True) | |
| return z | |
| class Glow(nn.Module): | |
| def __init__(self, x_channels, cond_channels, opt): | |
| super().__init__() | |
| self.flow = FlowNet(x_channels=x_channels, | |
| hidden_channels=opt.dhid, | |
| cond_channels=cond_channels, | |
| K=opt.glow_K, | |
| actnorm_scale=opt.actnorm_scale, | |
| flow_permutation=opt.flow_permutation, | |
| flow_coupling=opt.flow_coupling, | |
| network_model=opt.network_model, | |
| num_layers=opt.num_layers, | |
| LU_decomposed=opt.LU_decomposed) | |
| self.opt = opt | |
| # register prior hidden | |
| # num_device = len(utils.get_proper_device(hparams.Device.glow, False)) | |
| # assert hparams.Train.batch_size % num_device == 0 | |
| # self.z_shape = [opt.batch_size // num_device, x_channels, 1] | |
| self.z_shape = [opt.batch_size, x_channels, 1] | |
| if opt.flow_dist == "normal": | |
| self.distribution = modules.GaussianDiag() | |
| elif opt.flow_dist == "studentT": | |
| self.distribution = modules.StudentT(opt.flow_dist_param, x_channels) | |
| def init_lstm_hidden(self): | |
| self.flow.init_lstm_hidden() | |
| def forward(self, x=None, cond=None, z=None, | |
| eps_std=None, reverse=False, output_length=1): | |
| if not reverse: | |
| return self.normal_flow(x, cond) | |
| else: | |
| return self.reverse_flow(z, cond, eps_std, output_length=output_length) | |
| def normal_flow(self, x, cond): | |
| n_timesteps = thops.timesteps(x) #just returns the size of dimension 2? | |
| logdet = torch.zeros_like(x[:, 0, 0]) | |
| # encode | |
| z, objective = self.flow(x, cond, logdet=logdet, reverse=False) | |
| # prior | |
| objective += self.distribution.logp(z) | |
| # return | |
| nll = (-objective) / float(np.log(2.) * n_timesteps) | |
| return z, nll | |
| def reverse_flow(self, z, cond, eps_std, output_length=1): | |
| with torch.no_grad(): | |
| z_shape = self.z_shape | |
| z_shape[-1] = output_length | |
| if z is None: | |
| z = self.distribution.sample(z_shape, eps_std, device=cond.device) | |
| x = self.flow(z, cond, eps_std=eps_std, reverse=True) | |
| return x | |
| def set_actnorm_init(self, inited=True): | |
| for name, m in self.named_modules(): | |
| if (m.__class__.__name__.find("ActNorm") >= 0): | |
| m.inited = inited | |
| def loss_generative(nll): | |
| # Generative loss | |
| return torch.mean(nll) | |