import torch from torch import nn from models import BaseModel from .util.generation import autoregressive_generation_multimodal from .moglow.models import Glow class MoglowModel(BaseModel): def __init__(self, opt): super().__init__(opt) input_seq_lens = self.input_seq_lens dins = self.dins douts = self.douts # import pdb;pdb.set_trace() cond_dim = dins[0]*input_seq_lens[0]+dins[1]*input_seq_lens[1] output_dim = douts[0] self.network_model = self.opt.network_model glow = Glow(output_dim, cond_dim, self.opt) setattr(self, "net"+"_glow", glow) self.inputs = [] self.targets = [] self.criterion = nn.MSELoss() # self.has_initialized = False def parse_base_arguments(self): super().parse_base_arguments() self.input_seq_lens = [int(x) for x in str(self.opt.input_seq_lens).split(",")] self.output_seq_lens = [int(x) for x in str(self.opt.output_seq_lens).split(",")] if self.opt.phase == "inference" and self.opt.network_model == "LSTM": self.input_lengths = [int(x) for x in self.opt.input_seq_lens.split(",")] self.output_lengths = [int(x) for x in self.opt.output_seq_lens.split(",")] else: self.input_lengths = [int(x) for x in self.opt.input_lengths.split(",")] self.output_lengths = [int(x) for x in self.opt.output_lengths.split(",")] if len(self.output_time_offsets) < len(self.output_mods): if len(self.output_time_offsets) == 1: self.output_time_offsets = self.output_time_offsets*len(self.output_mods) else: raise Exception("number of output_time_offsets doesnt match number of output_mods") if len(self.input_time_offsets) < len(self.input_mods): if len(input_time_offsets) == 1: self.input_time_offsets = self.input_time_offsets*len(self.input_mods) else: raise Exception("number of input_time_offsets doesnt match number of input_mods") def name(self): return "Moglow" @staticmethod def modify_commandline_options(parser, opt): parser.add_argument('--dhid', type=int, default=512) parser.add_argument('--input_seq_lens', type=str, default="10,11") parser.add_argument('--output_seq_lens', type=str, default="1") parser.add_argument('--glow_K', type=int, default=16) parser.add_argument('--actnorm_scale', type=float, default=1.0) parser.add_argument('--flow_permutation', type=str, default="invconv") parser.add_argument('--flow_dist', type=str, default="normal") parser.add_argument('--flow_dist_param', type=int, default=50) parser.add_argument('--flow_coupling', type=str, default="affine") parser.add_argument('--num_layers', type=int, default=2) parser.add_argument('--network_model', type=str, default="LSTM") parser.add_argument('--dropout', type=float, default=0.1) parser.add_argument('--LU_decomposed', action='store_true') return parser def forward(self, data, eps_std=1.0): # import pdb;pdb.set_trace() for i,mod in enumerate(self.input_mods): input_ = data[i] input_ = input_.permute(1,0,2) input_ = self.concat_sequence(self.input_seq_lens[i], input_) input_ = input_.permute(0,2,1) data[i] = input_ outputs = self.net_glow(z=None, cond=torch.cat(data, dim=1), eps_std=eps_std, reverse=True) # import pdb;pdb.set_trace() return [outputs.permute(0,2,1)] def generate(self,features, teacher_forcing=False, ground_truth=False): if self.network_model=="LSTM": self.net_glow.init_lstm_hidden() output_seq = autoregressive_generation_multimodal(features, self, autoreg_mods=self.output_mods, teacher_forcing=teacher_forcing, ground_truth=ground_truth) return output_seq def on_test_start(self): if self.network_model=="LSTM": self.net_glow.init_lstm_hidden() def on_train_start(self): if self.network_model=="LSTM": self.net_glow.init_lstm_hidden() def on_test_batch_start(self, batch, batch_idx, dataloader_idx): if self.network_model=="LSTM": self.net_glow.init_lstm_hidden() def on_train_batch_start(self, batch, batch_idx, dataloader_idx): if self.network_model=="LSTM": # self.zero_grad() self.net_glow.init_lstm_hidden() def concat_sequence(self, seqlen, data): #NOTE: this could be done as preprocessing on the dataset to make it a bit more efficient, but we are only going to # use this for baseline moglow, so I thought it wasn't worth it to put it there. """ Concatenates a sequence of features to one. """ nn,n_timesteps,n_feats = data.shape L = n_timesteps-(seqlen-1) # import pdb;pdb.set_trace() inds = torch.zeros((L, seqlen), dtype=torch.long) #create indices for the sequences we want rng = torch.arange(0, n_timesteps, dtype=torch.long) for ii in range(0,seqlen): # print(rng[ii:(n_timesteps-(seqlen-ii-1))].shape) # inds[:, ii] = torch.transpose(rng[ii:(n_timesteps-(seqlen-ii-1))], 0, 1) inds[:, ii] = rng[ii:(n_timesteps-(seqlen-ii-1))] #slice each sample into L sequences and store as new samples cc=data[:,inds,:].clone() #print ("cc: " + str(cc.shape)) #reshape all timesteps and features into one dimention per sample dd = cc.reshape((nn, L, seqlen*n_feats)) #print ("dd: " + str(dd.shape)) return dd def set_inputs(self, data): self.inputs = [] self.targets = [] for i, mod in enumerate(self.input_mods): input_ = data["in_"+mod] if self.input_seq_lens[i] > 1: # input_ = input_.permute(0,2,1) input_ = self.concat_sequence(self.input_seq_lens[i], input_) input_ = input_.permute(0,2,1) else: input_ = input_.permute(0,2,1) self.inputs.append(input_) for i, mod in enumerate(self.output_mods): target_ = data["out_"+mod] if self.output_seq_lens[i] > 1: # target_ = target_.permute(0,2,1) target_ = self.concat_sequence(self.output_seq_lens[i], target_) target_ = target_.permute(0,2,1) else: target_ = target_.permute(0,2,1) # target_ = target_.permute(2,0,1) self.targets.append(target_) def training_step(self, batch, batch_idx): self.set_inputs(batch) z, nll = self.net_glow(x=self.targets[0], cond=torch.cat(self.inputs, dim=1)) # output = self.net_glow(z=None, cond=torch.cat(self.inputs, dim=1), eps_std=1.0, reverse=True, output_length=self.output_lengths[0]) nll_loss = Glow.loss_generative(nll) # mse_loss = self.criterion(output, self.targets[0]) # loss = 0.1*nll_loss + 100*mse_loss loss = nll_loss # loss = mse_loss # print(nll_loss) # print(mse_loss) self.log('nll_loss', nll_loss) self.log('loss', loss) # self.log('mse_loss', mse_loss) # import pdb;pdb.set_trace() # if not self.has_initialized: # self.has_initialized=True # return torch.tensor(0.0, dtype=torch.float32, requires_grad=True) # else: # print(loss) return loss # return torch.tensor(0.0, dtype=torch.float32, requires_grad=True) #to help debug XLA stuff, like missing ops, or data loading/compiling bottlenecks # see https://youtu.be/iwtpwQRdb3Y?t=1056 #def on_epoch_end(self): # xm.master_print(met.metrics_report()) #def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, # optimizer_closure, on_tpu, using_native_amp, using_lbfgs): # optimizer.zero_grad()