Spaces:
Runtime error
Runtime error
File size: 8,153 Bytes
bc32eea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import torch
from torch import nn
from models import BaseModel
from .util.generation import autoregressive_generation_multimodal
from .moglow.models import Glow
class MoglowModel(BaseModel):
def __init__(self, opt):
super().__init__(opt)
input_seq_lens = self.input_seq_lens
dins = self.dins
douts = self.douts
# import pdb;pdb.set_trace()
cond_dim = dins[0]*input_seq_lens[0]+dins[1]*input_seq_lens[1]
output_dim = douts[0]
self.network_model = self.opt.network_model
glow = Glow(output_dim, cond_dim, self.opt)
setattr(self, "net"+"_glow", glow)
self.inputs = []
self.targets = []
self.criterion = nn.MSELoss()
# self.has_initialized = False
def parse_base_arguments(self):
super().parse_base_arguments()
self.input_seq_lens = [int(x) for x in str(self.opt.input_seq_lens).split(",")]
self.output_seq_lens = [int(x) for x in str(self.opt.output_seq_lens).split(",")]
if self.opt.phase == "inference" and self.opt.network_model == "LSTM":
self.input_lengths = [int(x) for x in self.opt.input_seq_lens.split(",")]
self.output_lengths = [int(x) for x in self.opt.output_seq_lens.split(",")]
else:
self.input_lengths = [int(x) for x in self.opt.input_lengths.split(",")]
self.output_lengths = [int(x) for x in self.opt.output_lengths.split(",")]
if len(self.output_time_offsets) < len(self.output_mods):
if len(self.output_time_offsets) == 1:
self.output_time_offsets = self.output_time_offsets*len(self.output_mods)
else:
raise Exception("number of output_time_offsets doesnt match number of output_mods")
if len(self.input_time_offsets) < len(self.input_mods):
if len(input_time_offsets) == 1:
self.input_time_offsets = self.input_time_offsets*len(self.input_mods)
else:
raise Exception("number of input_time_offsets doesnt match number of input_mods")
def name(self):
return "Moglow"
@staticmethod
def modify_commandline_options(parser, opt):
parser.add_argument('--dhid', type=int, default=512)
parser.add_argument('--input_seq_lens', type=str, default="10,11")
parser.add_argument('--output_seq_lens', type=str, default="1")
parser.add_argument('--glow_K', type=int, default=16)
parser.add_argument('--actnorm_scale', type=float, default=1.0)
parser.add_argument('--flow_permutation', type=str, default="invconv")
parser.add_argument('--flow_dist', type=str, default="normal")
parser.add_argument('--flow_dist_param', type=int, default=50)
parser.add_argument('--flow_coupling', type=str, default="affine")
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--network_model', type=str, default="LSTM")
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--LU_decomposed', action='store_true')
return parser
def forward(self, data, eps_std=1.0):
# import pdb;pdb.set_trace()
for i,mod in enumerate(self.input_mods):
input_ = data[i]
input_ = input_.permute(1,0,2)
input_ = self.concat_sequence(self.input_seq_lens[i], input_)
input_ = input_.permute(0,2,1)
data[i] = input_
outputs = self.net_glow(z=None, cond=torch.cat(data, dim=1), eps_std=eps_std, reverse=True)
# import pdb;pdb.set_trace()
return [outputs.permute(0,2,1)]
def generate(self,features, teacher_forcing=False, ground_truth=False):
if self.network_model=="LSTM":
self.net_glow.init_lstm_hidden()
output_seq = autoregressive_generation_multimodal(features, self, autoreg_mods=self.output_mods, teacher_forcing=teacher_forcing, ground_truth=ground_truth)
return output_seq
def on_test_start(self):
if self.network_model=="LSTM":
self.net_glow.init_lstm_hidden()
def on_train_start(self):
if self.network_model=="LSTM":
self.net_glow.init_lstm_hidden()
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
if self.network_model=="LSTM":
self.net_glow.init_lstm_hidden()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
if self.network_model=="LSTM":
# self.zero_grad()
self.net_glow.init_lstm_hidden()
def concat_sequence(self, seqlen, data):
#NOTE: this could be done as preprocessing on the dataset to make it a bit more efficient, but we are only going to
# use this for baseline moglow, so I thought it wasn't worth it to put it there.
"""
Concatenates a sequence of features to one.
"""
nn,n_timesteps,n_feats = data.shape
L = n_timesteps-(seqlen-1)
# import pdb;pdb.set_trace()
inds = torch.zeros((L, seqlen), dtype=torch.long)
#create indices for the sequences we want
rng = torch.arange(0, n_timesteps, dtype=torch.long)
for ii in range(0,seqlen):
# print(rng[ii:(n_timesteps-(seqlen-ii-1))].shape)
# inds[:, ii] = torch.transpose(rng[ii:(n_timesteps-(seqlen-ii-1))], 0, 1)
inds[:, ii] = rng[ii:(n_timesteps-(seqlen-ii-1))]
#slice each sample into L sequences and store as new samples
cc=data[:,inds,:].clone()
#print ("cc: " + str(cc.shape))
#reshape all timesteps and features into one dimention per sample
dd = cc.reshape((nn, L, seqlen*n_feats))
#print ("dd: " + str(dd.shape))
return dd
def set_inputs(self, data):
self.inputs = []
self.targets = []
for i, mod in enumerate(self.input_mods):
input_ = data["in_"+mod]
if self.input_seq_lens[i] > 1:
# input_ = input_.permute(0,2,1)
input_ = self.concat_sequence(self.input_seq_lens[i], input_)
input_ = input_.permute(0,2,1)
else:
input_ = input_.permute(0,2,1)
self.inputs.append(input_)
for i, mod in enumerate(self.output_mods):
target_ = data["out_"+mod]
if self.output_seq_lens[i] > 1:
# target_ = target_.permute(0,2,1)
target_ = self.concat_sequence(self.output_seq_lens[i], target_)
target_ = target_.permute(0,2,1)
else:
target_ = target_.permute(0,2,1)
# target_ = target_.permute(2,0,1)
self.targets.append(target_)
def training_step(self, batch, batch_idx):
self.set_inputs(batch)
z, nll = self.net_glow(x=self.targets[0], cond=torch.cat(self.inputs, dim=1))
# output = self.net_glow(z=None, cond=torch.cat(self.inputs, dim=1), eps_std=1.0, reverse=True, output_length=self.output_lengths[0])
nll_loss = Glow.loss_generative(nll)
# mse_loss = self.criterion(output, self.targets[0])
# loss = 0.1*nll_loss + 100*mse_loss
loss = nll_loss
# loss = mse_loss
# print(nll_loss)
# print(mse_loss)
self.log('nll_loss', nll_loss)
self.log('loss', loss)
# self.log('mse_loss', mse_loss)
# import pdb;pdb.set_trace()
# if not self.has_initialized:
# self.has_initialized=True
# return torch.tensor(0.0, dtype=torch.float32, requires_grad=True)
# else:
# print(loss)
return loss
# return torch.tensor(0.0, dtype=torch.float32, requires_grad=True)
#to help debug XLA stuff, like missing ops, or data loading/compiling bottlenecks
# see https://youtu.be/iwtpwQRdb3Y?t=1056
#def on_epoch_end(self):
# xm.master_print(met.metrics_report())
#def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
# optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# optimizer.zero_grad()
|