Spaces:
Runtime error
Runtime error
import torch | |
from .transformer import BasicTransformerModel | |
from models import BaseModel | |
from models.flowplusplus import FlowPlusPlus | |
import ast | |
from torch import nn | |
from .util.generation import autoregressive_generation_multimodal | |
from .moglow.models import Glow | |
#TODO: refactor a whole bunch of stuff | |
class TransglowerModel(BaseModel): | |
def __init__(self, opt): | |
super().__init__(opt) | |
input_mods = self.input_mods | |
output_mods = self.output_mods | |
dins = self.dins | |
douts = self.douts | |
input_seq_lens = self.input_seq_lens | |
self.input_mod_nets = [] | |
self.input_mod_funcs = [] | |
self.output_mod_nets = [] | |
self.output_mod_mean_nets = [] | |
self.output_mod_funcs = [] | |
self.output_mod_glows = [] | |
self.module_names = [] | |
for i, mod in enumerate(input_mods): | |
net = BasicTransformerModel(opt.dhid, dins[i], opt.nhead, opt.dhid, 2, opt.dropout, self.device, use_pos_emb=True, input_length=input_seq_lens[i]).to(self.device) | |
name = "_input_"+mod | |
setattr(self,"net"+name, net) | |
self.input_mod_nets.append(net) | |
# self.input_mod_funcs.append(func) | |
self.module_names.append(name) | |
def func1(x): | |
return self.input_mod_nets[0].forward(x) | |
#func1 = torch.vmap(func1) | |
def func2(x): | |
return self.input_mod_nets[1].forward(x) | |
#func2 = torch.vmap(func2) | |
self.input_mod_funcs = [func1, func2] | |
# should only be one output_mod | |
for i, mod in enumerate(output_mods): | |
net = BasicTransformerModel(opt.dhid, opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout, self.device, use_pos_emb=opt.use_pos_emb_output, input_length=sum(input_seq_lens)).to(self.device) | |
name = "_output_"+mod | |
setattr(self, "net"+name, net) | |
self.output_mod_nets.append(net) | |
self.module_names.append(name) | |
if self.opt.residual: | |
def func3(x): | |
return self.output_mod_nets[i].forward(x) | |
else: | |
def func3(x): | |
return self.output_mod_nets[i].forward(x)[:self.conditioning_seq_lens[i]] | |
#func3 = torch.vmap(func3) | |
self.output_mod_funcs.append(func3) | |
if opt.residual: | |
net = nn.Linear(opt.dhid,douts[i]) | |
name="_output_mean_encoder" | |
setattr(self, "net"+name, net) | |
self.output_mod_mean_nets.append(net) | |
cond_dim = opt.dhid | |
output_dim = douts[i] | |
glow = Glow(output_dim, cond_dim, self.opt) | |
name = "_output_glow_"+mod | |
setattr(self, "net"+name, glow) | |
self.output_mod_glows.append(glow) | |
self.inputs = [] | |
self.targets = [] | |
self.mean_loss = nn.MSELoss() | |
self.mse_loss = 0 | |
self.nll_loss = 0 | |
def name(self): | |
return "Transglower" | |
def parse_base_arguments(self): | |
super().parse_base_arguments() | |
self.input_seq_lens = [int(x) for x in str(self.opt.input_seq_lens).split(",")] | |
self.output_seq_lens = [int(x) for x in str(self.opt.output_seq_lens).split(",")] | |
if self.opt.phase == "inference": | |
self.input_lengths = [int(x) for x in self.opt.input_seq_lens.split(",")] | |
self.output_lengths = [int(x) for x in self.opt.output_seq_lens.split(",")] | |
else: | |
self.input_lengths = [int(x) for x in self.opt.input_lengths.split(",")] | |
self.output_lengths = [int(x) for x in self.opt.output_lengths.split(",")] | |
if self.opt.conditioning_seq_lens is not None: | |
self.conditioning_seq_lens = [int(x) for x in str(self.opt.conditioning_seq_lens).split(",")] | |
else: | |
self.conditioning_seq_lens = [1 for x in self.opt.output_lengths.split(",")] | |
if len(self.output_time_offsets) < len(self.output_mods): | |
if len(self.output_time_offsets) == 1: | |
self.output_time_offsets = self.output_time_offsets*len(self.output_mods) | |
else: | |
raise Exception("number of output_time_offsets doesnt match number of output_mods") | |
if len(self.input_time_offsets) < len(self.input_mods): | |
if len(input_time_offsets) == 1: | |
self.input_time_offsets = self.input_time_offsets*len(self.input_mods) | |
else: | |
raise Exception("number of input_time_offsets doesnt match number of input_mods") | |
def modify_commandline_options(parser, opt): | |
parser.add_argument('--dhid', type=int, default=512) | |
parser.add_argument('--dhid_flow', type=int, default=512) | |
parser.add_argument('--conditioning_seq_lens', type=str, default=None, help="the number of outputs of the conditioning transformers to feed (meaning the number of elements along the sequence dimension)") | |
parser.add_argument('--input_seq_lens', type=str, default="10,11") | |
parser.add_argument('--output_seq_lens', type=str, default="1") | |
parser.add_argument('--glow_K', type=int, default=16) | |
parser.add_argument('--actnorm_scale', type=float, default=1.0) | |
parser.add_argument('--flow_permutation', type=str, default="invconv") | |
parser.add_argument('--flow_dist', type=str, default="normal") | |
parser.add_argument('--flow_dist_param', type=int, default=50) | |
parser.add_argument('--flow_coupling', type=str, default="affine") | |
parser.add_argument('--num_layers', type=int, default=2) | |
parser.add_argument('--network_model', type=str, default="LSTM") | |
parser.add_argument('--LU_decomposed', action='store_true') | |
parser.add_argument('--nlayers', type=int, default=6) | |
parser.add_argument('--nhead', type=int, default=8) | |
parser.add_argument('--num_heads_flow', type=int, default=8) | |
parser.add_argument('--dropout', type=float, default=0.1) | |
parser.add_argument('--scales', type=str, default="[[10,0]]") | |
parser.add_argument('--glow_norm_layer', type=str, default=None) | |
parser.add_argument('--glow_bn_momentum', type=float, default=0.1) | |
parser.add_argument('--num_glow_coupling_blocks', type=int, default=10) | |
parser.add_argument('--num_mixture_components', type=int, default=0) | |
parser.add_argument('--glow_use_attn', action='store_true', help="whether to use the internal attention for the FlowPlusPLus model") | |
parser.add_argument('--use_transformer_nn', action='store_true', help="whether to use the internal attention for the FlowPlusPLus model") | |
parser.add_argument('--use_pos_emb_output', action='store_true', help="whether to use positional embeddings for output modality transformers") | |
parser.add_argument('--use_pos_emb_coupling', action='store_true', help="whether to use positional embeddings for the coupling layer transformers") | |
parser.add_argument('--cond_concat_dims', action='store_true', help="if set we concatenate along the channel dimension with with the x for the coupling layer; otherwise we concatenate along the sequence dimesion") | |
parser.add_argument('--residual', action='store_true', help="whether to use the flow to predict the residual around a determnisitic mean") | |
return parser | |
def forward(self, data): | |
# in lightning, forward defines the prediction/inference actions | |
# min_len = min(self.input_seq_lens) | |
for i,mod in enumerate(self.input_mods): | |
input_ = data[i] | |
input_ = input_.permute(1,2,0) | |
input_ = input_.permute(0,2,1) | |
input_ = self.concat_sequence(self.input_seq_lens[i], input_) | |
# input_ = input_.permute(0,2,1) | |
input_ = input_.permute(1,2,0,3) # L, T, B, C | |
# input_ = input_[:,:,:min_len] | |
# inputs_.append(input_) | |
data[i] = input_ | |
latents = [] | |
for i, mod in enumerate(self.input_mods): | |
# import pdb;pdb.set_trace() | |
# result = self.input_mod_funcs[i](data[i]) | |
result = [] | |
for inp in data[i]: | |
result.append(self.input_mod_funcs[i](inp)) | |
result = torch.stack(result) | |
latents.append(result) | |
latent = torch.cat(latents,dim=1) | |
loss = 0 | |
outputs = [] | |
if self.opt.residual: | |
for i, mod in enumerate(self.output_mods): | |
#trans_output = self.output_mod_funcs[i](latent).permute(2,1,3,0) | |
trans_output = [] | |
for lat in latent: | |
trans_output.append(self.output_mod_funcs[i](lat)) | |
trans_output = torch.stack(trans_output) | |
latents = trans_output[:,:self.conditioning_seq_lens[i]].permute(2,1,3,0) | |
trans_predicted_mean_latents = trans_output[:,self.conditioning_seq_lens[i]:self.conditioning_seq_lens[i]+self.output_seq_lens[i]] | |
latents = latents.reshape(latents.shape[0], latents.shape[1] * latents.shape[2], latents.shape[3]) | |
trans_predicted_mean_latents = trans_predicted_mean_latents.reshape(trans_predicted_mean_latents.shape[0], trans_predicted_mean_latents.shape[1] * trans_predicted_mean_latents.shape[2], trans_predicted_mean_latents.shape[3]) | |
# predicted_mean = self.output_mod_mean_nets[i](trans_predicted_mean_latents).permute(1,2,0) | |
predicted_mean = self.output_mod_mean_nets[i](trans_predicted_mean_latents) | |
glow = self.output_mod_glows[i] | |
output = glow(x=None, cond=latents, reverse=True) | |
# import pdb;pdb.set_trace() | |
outputs.append(output.permute(0,2,1)+predicted_mean) | |
else: | |
for i, mod in enumerate(self.output_mods): | |
#trans_output = self.output_mod_funcs[i](latent).permute(2,1,3,0) | |
trans_output = [] | |
for lat in latent: | |
trans_output.append(self.output_mod_funcs[i](lat)) | |
trans_output = torch.stack(trans_output).permute(2,1,3,0) | |
trans_output = trans_output.reshape(trans_output.shape[0], trans_output.shape[1] * trans_output.shape[2], trans_output.shape[3]) | |
glow = self.output_mod_glows[i] | |
output = glow(x=None, cond=trans_output, reverse=True) | |
outputs.append(output.permute(0,2,1)) | |
return outputs | |
def on_test_start(self): | |
for i, mod in enumerate(self.output_mods): | |
self.output_mod_glows[i].init_lstm_hidden() | |
def on_train_start(self): | |
for i, mod in enumerate(self.output_mods): | |
self.output_mod_glows[i].init_lstm_hidden() | |
def on_test_batch_start(self, batch, batch_idx, dataloader_idx): | |
for i, mod in enumerate(self.output_mods): | |
self.output_mod_glows[i].init_lstm_hidden() | |
def on_train_batch_start(self, batch, batch_idx, dataloader_idx): | |
for i, mod in enumerate(self.output_mods): | |
self.output_mod_glows[i].init_lstm_hidden() | |
def concat_sequence(self, seqlen, data): | |
#NOTE: this could be done as preprocessing on the dataset to make it a bit more efficient, but we are only going to | |
# use this for baseline moglow, so I thought it wasn't worth it to put it there. | |
""" | |
Concatenates a sequence of features to one. | |
""" | |
nn,n_timesteps,n_feats = data.shape | |
L = n_timesteps-(seqlen-1) | |
# import pdb;pdb.set_trace() | |
inds = torch.zeros((L, seqlen), dtype=torch.long) | |
#create indices for the sequences we want | |
rng = torch.arange(0, n_timesteps, dtype=torch.long) | |
for ii in range(0,seqlen): | |
# print(rng[ii:(n_timesteps-(seqlen-ii-1))].shape) | |
# inds[:, ii] = torch.transpose(rng[ii:(n_timesteps-(seqlen-ii-1))], 0, 1) | |
inds[:, ii] = rng[ii:(n_timesteps-(seqlen-ii-1))] | |
#slice each sample into L sequences and store as new samples | |
cc=data[:,inds,:].clone() | |
#print ("cc: " + str(cc.shape)) | |
#reshape all timesteps and features into one dimention per sample | |
dd = cc.reshape((nn, L, seqlen, n_feats)) | |
#print ("dd: " + str(dd.shape)) | |
return dd | |
def set_inputs(self, data): | |
self.inputs = [] | |
self.targets = [] | |
for i, mod in enumerate(self.input_mods): | |
input_ = data["in_"+mod] | |
input_shape = input_.shape | |
if self.input_seq_lens[i] > 1: | |
# input_ = input_.permute(0,2,1) | |
input_ = self.concat_sequence(self.input_seq_lens[i], input_) | |
# input_ = input_.permute(0,2,1) | |
else: | |
input_ = input_.permute(0,2,1) | |
input_ = input_.squeeze(2) | |
input_ = input_.permute(1,2,0,3) # L, T, B, C | |
self.inputs.append(input_) | |
for i, mod in enumerate(self.output_mods): | |
target_ = data["out_"+mod] | |
target_shape = target_.shape | |
if self.output_seq_lens[i] > 1: | |
# target_ = target_.permute(0,2,1) | |
target_ = self.concat_sequence(self.output_seq_lens[i], target_) | |
target_ = target_.permute(0,2,1) | |
else: | |
target_ = target_.permute(0,2,1) | |
self.targets.append(target_) | |
def training_step(self, batch, batch_idx): | |
self.set_inputs(batch) | |
latents = [] | |
for i, mod in enumerate(self.input_mods): | |
# import pdb;pdb.set_trace() | |
#result = self.input_mod_funcs[i](self.inputs[i]) | |
result = [] | |
for inp in self.inputs[i]: | |
result.append(self.input_mod_funcs[i](inp)) | |
result = torch.stack(result) | |
latents.append(result) | |
latent = torch.cat(latents,dim=1) | |
if self.opt.residual: | |
nll_loss = 0 | |
mse_loss = 0 | |
for i, mod in enumerate(self.output_mods): | |
#trans_output = self.output_mod_funcs[i](latent) | |
trans_output = [] | |
for lat in latent: | |
trans_output.append(self.output_mod_funcs[i](lat)) | |
trans_output = torch.stack(trans_output) | |
latents = trans_output[:,:self.conditioning_seq_lens[i]].permute(2,1,3,0) | |
trans_predicted_mean_latents = trans_output[:,self.conditioning_seq_lens[i]:self.conditioning_seq_lens[i]+self.output_seq_lens[i]] | |
latents = latents.reshape(latents.shape[0], latents.shape[1] * latents.shape[2], latents.shape[3]) | |
trans_predicted_mean_latents = trans_predicted_mean_latents.reshape(trans_predicted_mean_latents.shape[0], trans_predicted_mean_latents.shape[1] * trans_predicted_mean_latents.shape[2], trans_predicted_mean_latents.shape[3]) | |
predicted_mean = self.output_mod_mean_nets[i](trans_predicted_mean_latents).permute(1,2,0) | |
glow = self.output_mod_glows[i] | |
# import pdb;pdb.set_trace() | |
z, nll = glow(x=self.targets[i]-predicted_mean, cond=latents) #time, batch, features -> batch, time, features | |
nll_loss += Glow.loss_generative(nll) | |
mse_loss += 100*self.mean_loss(predicted_mean, self.targets[i]) | |
loss = nll_loss + mse_loss | |
self.mse_loss = mse_loss | |
self.nll_loss = nll_loss | |
self.log('mse_loss', mse_loss) | |
self.log('nll_loss', nll_loss) | |
else: | |
loss = 0 | |
for i, mod in enumerate(self.output_mods): | |
#output = self.output_mod_funcs[i](latent).permute(2,1,3,0) | |
output = [] | |
for lat in latent: | |
output.append(self.output_mod_funcs[i](lat)) | |
output = torch.stack(output).permute(2,1,3,0) | |
output = output.reshape(output.shape[0], output.shape[1] * output.shape[2], output.shape[3]) | |
glow = self.output_mod_glows[i] | |
# import pdb;pdb.set_trace() | |
z, nll = glow(x=self.targets[i], cond=output) #time, batch, features -> batch, time, features | |
loss += Glow.loss_generative(nll) | |
self.log('loss', loss) | |
return loss | |
def test_step(self, batch, batch_idx): | |
if self.opt.residual: | |
self.eval() | |
loss = self.training_step(batch, batch_idx) | |
# print(loss) | |
return {"test_loss": loss, "test_mse_loss": self.mse_loss, "test_nll_loss": self.nll_loss} | |
else: | |
return super().test_step(batch, batch_idx) | |
def test_epoch_end(self, outputs): | |
if self.opt.residual: | |
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() | |
avg_mse_loss = torch.stack([x['test_mse_loss'] for x in outputs]).mean() | |
avg_nll_loss = torch.stack([x['test_nll_loss'] for x in outputs]).mean() | |
logs = {'test_loss': avg_loss, 'test_mse_loss': avg_mse_loss, 'test_nll_loss': avg_nll_loss} | |
return {'log': logs} | |
else: | |
return super().test_epoch_end(outputs) | |
#to help debug XLA stuff, like missing ops, or data loading/compiling bottlenecks | |
# see https://youtu.be/iwtpwQRdb3Y?t=1056 | |
# def on_epoch_end(self): | |
# import torch_xla.core.xla_model as xm | |
# import torch_xla.debug.metrics as met | |
# xm.master_print(met.metrics_report()) | |
#def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, | |
# optimizer_closure, on_tpu, using_native_amp, using_lbfgs): | |
# optimizer.zero_grad() | |