myTest01 / models /transglower_model.py
meng2003's picture
Upload 85 files
bc32eea
import torch
from .transformer import BasicTransformerModel
from models import BaseModel
from models.flowplusplus import FlowPlusPlus
import ast
from torch import nn
from .util.generation import autoregressive_generation_multimodal
from .moglow.models import Glow
#TODO: refactor a whole bunch of stuff
class TransglowerModel(BaseModel):
def __init__(self, opt):
super().__init__(opt)
input_mods = self.input_mods
output_mods = self.output_mods
dins = self.dins
douts = self.douts
input_seq_lens = self.input_seq_lens
self.input_mod_nets = []
self.input_mod_funcs = []
self.output_mod_nets = []
self.output_mod_mean_nets = []
self.output_mod_funcs = []
self.output_mod_glows = []
self.module_names = []
for i, mod in enumerate(input_mods):
net = BasicTransformerModel(opt.dhid, dins[i], opt.nhead, opt.dhid, 2, opt.dropout, self.device, use_pos_emb=True, input_length=input_seq_lens[i]).to(self.device)
name = "_input_"+mod
setattr(self,"net"+name, net)
self.input_mod_nets.append(net)
# self.input_mod_funcs.append(func)
self.module_names.append(name)
def func1(x):
return self.input_mod_nets[0].forward(x)
#func1 = torch.vmap(func1)
def func2(x):
return self.input_mod_nets[1].forward(x)
#func2 = torch.vmap(func2)
self.input_mod_funcs = [func1, func2]
# should only be one output_mod
for i, mod in enumerate(output_mods):
net = BasicTransformerModel(opt.dhid, opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout, self.device, use_pos_emb=opt.use_pos_emb_output, input_length=sum(input_seq_lens)).to(self.device)
name = "_output_"+mod
setattr(self, "net"+name, net)
self.output_mod_nets.append(net)
self.module_names.append(name)
if self.opt.residual:
def func3(x):
return self.output_mod_nets[i].forward(x)
else:
def func3(x):
return self.output_mod_nets[i].forward(x)[:self.conditioning_seq_lens[i]]
#func3 = torch.vmap(func3)
self.output_mod_funcs.append(func3)
if opt.residual:
net = nn.Linear(opt.dhid,douts[i])
name="_output_mean_encoder"
setattr(self, "net"+name, net)
self.output_mod_mean_nets.append(net)
cond_dim = opt.dhid
output_dim = douts[i]
glow = Glow(output_dim, cond_dim, self.opt)
name = "_output_glow_"+mod
setattr(self, "net"+name, glow)
self.output_mod_glows.append(glow)
self.inputs = []
self.targets = []
self.mean_loss = nn.MSELoss()
self.mse_loss = 0
self.nll_loss = 0
def name(self):
return "Transglower"
def parse_base_arguments(self):
super().parse_base_arguments()
self.input_seq_lens = [int(x) for x in str(self.opt.input_seq_lens).split(",")]
self.output_seq_lens = [int(x) for x in str(self.opt.output_seq_lens).split(",")]
if self.opt.phase == "inference":
self.input_lengths = [int(x) for x in self.opt.input_seq_lens.split(",")]
self.output_lengths = [int(x) for x in self.opt.output_seq_lens.split(",")]
else:
self.input_lengths = [int(x) for x in self.opt.input_lengths.split(",")]
self.output_lengths = [int(x) for x in self.opt.output_lengths.split(",")]
if self.opt.conditioning_seq_lens is not None:
self.conditioning_seq_lens = [int(x) for x in str(self.opt.conditioning_seq_lens).split(",")]
else:
self.conditioning_seq_lens = [1 for x in self.opt.output_lengths.split(",")]
if len(self.output_time_offsets) < len(self.output_mods):
if len(self.output_time_offsets) == 1:
self.output_time_offsets = self.output_time_offsets*len(self.output_mods)
else:
raise Exception("number of output_time_offsets doesnt match number of output_mods")
if len(self.input_time_offsets) < len(self.input_mods):
if len(input_time_offsets) == 1:
self.input_time_offsets = self.input_time_offsets*len(self.input_mods)
else:
raise Exception("number of input_time_offsets doesnt match number of input_mods")
@staticmethod
def modify_commandline_options(parser, opt):
parser.add_argument('--dhid', type=int, default=512)
parser.add_argument('--dhid_flow', type=int, default=512)
parser.add_argument('--conditioning_seq_lens', type=str, default=None, help="the number of outputs of the conditioning transformers to feed (meaning the number of elements along the sequence dimension)")
parser.add_argument('--input_seq_lens', type=str, default="10,11")
parser.add_argument('--output_seq_lens', type=str, default="1")
parser.add_argument('--glow_K', type=int, default=16)
parser.add_argument('--actnorm_scale', type=float, default=1.0)
parser.add_argument('--flow_permutation', type=str, default="invconv")
parser.add_argument('--flow_dist', type=str, default="normal")
parser.add_argument('--flow_dist_param', type=int, default=50)
parser.add_argument('--flow_coupling', type=str, default="affine")
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--network_model', type=str, default="LSTM")
parser.add_argument('--LU_decomposed', action='store_true')
parser.add_argument('--nlayers', type=int, default=6)
parser.add_argument('--nhead', type=int, default=8)
parser.add_argument('--num_heads_flow', type=int, default=8)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--scales', type=str, default="[[10,0]]")
parser.add_argument('--glow_norm_layer', type=str, default=None)
parser.add_argument('--glow_bn_momentum', type=float, default=0.1)
parser.add_argument('--num_glow_coupling_blocks', type=int, default=10)
parser.add_argument('--num_mixture_components', type=int, default=0)
parser.add_argument('--glow_use_attn', action='store_true', help="whether to use the internal attention for the FlowPlusPLus model")
parser.add_argument('--use_transformer_nn', action='store_true', help="whether to use the internal attention for the FlowPlusPLus model")
parser.add_argument('--use_pos_emb_output', action='store_true', help="whether to use positional embeddings for output modality transformers")
parser.add_argument('--use_pos_emb_coupling', action='store_true', help="whether to use positional embeddings for the coupling layer transformers")
parser.add_argument('--cond_concat_dims', action='store_true', help="if set we concatenate along the channel dimension with with the x for the coupling layer; otherwise we concatenate along the sequence dimesion")
parser.add_argument('--residual', action='store_true', help="whether to use the flow to predict the residual around a determnisitic mean")
return parser
def forward(self, data):
# in lightning, forward defines the prediction/inference actions
# min_len = min(self.input_seq_lens)
for i,mod in enumerate(self.input_mods):
input_ = data[i]
input_ = input_.permute(1,2,0)
input_ = input_.permute(0,2,1)
input_ = self.concat_sequence(self.input_seq_lens[i], input_)
# input_ = input_.permute(0,2,1)
input_ = input_.permute(1,2,0,3) # L, T, B, C
# input_ = input_[:,:,:min_len]
# inputs_.append(input_)
data[i] = input_
latents = []
for i, mod in enumerate(self.input_mods):
# import pdb;pdb.set_trace()
# result = self.input_mod_funcs[i](data[i])
result = []
for inp in data[i]:
result.append(self.input_mod_funcs[i](inp))
result = torch.stack(result)
latents.append(result)
latent = torch.cat(latents,dim=1)
loss = 0
outputs = []
if self.opt.residual:
for i, mod in enumerate(self.output_mods):
#trans_output = self.output_mod_funcs[i](latent).permute(2,1,3,0)
trans_output = []
for lat in latent:
trans_output.append(self.output_mod_funcs[i](lat))
trans_output = torch.stack(trans_output)
latents = trans_output[:,:self.conditioning_seq_lens[i]].permute(2,1,3,0)
trans_predicted_mean_latents = trans_output[:,self.conditioning_seq_lens[i]:self.conditioning_seq_lens[i]+self.output_seq_lens[i]]
latents = latents.reshape(latents.shape[0], latents.shape[1] * latents.shape[2], latents.shape[3])
trans_predicted_mean_latents = trans_predicted_mean_latents.reshape(trans_predicted_mean_latents.shape[0], trans_predicted_mean_latents.shape[1] * trans_predicted_mean_latents.shape[2], trans_predicted_mean_latents.shape[3])
# predicted_mean = self.output_mod_mean_nets[i](trans_predicted_mean_latents).permute(1,2,0)
predicted_mean = self.output_mod_mean_nets[i](trans_predicted_mean_latents)
glow = self.output_mod_glows[i]
output = glow(x=None, cond=latents, reverse=True)
# import pdb;pdb.set_trace()
outputs.append(output.permute(0,2,1)+predicted_mean)
else:
for i, mod in enumerate(self.output_mods):
#trans_output = self.output_mod_funcs[i](latent).permute(2,1,3,0)
trans_output = []
for lat in latent:
trans_output.append(self.output_mod_funcs[i](lat))
trans_output = torch.stack(trans_output).permute(2,1,3,0)
trans_output = trans_output.reshape(trans_output.shape[0], trans_output.shape[1] * trans_output.shape[2], trans_output.shape[3])
glow = self.output_mod_glows[i]
output = glow(x=None, cond=trans_output, reverse=True)
outputs.append(output.permute(0,2,1))
return outputs
def on_test_start(self):
for i, mod in enumerate(self.output_mods):
self.output_mod_glows[i].init_lstm_hidden()
def on_train_start(self):
for i, mod in enumerate(self.output_mods):
self.output_mod_glows[i].init_lstm_hidden()
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
for i, mod in enumerate(self.output_mods):
self.output_mod_glows[i].init_lstm_hidden()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
for i, mod in enumerate(self.output_mods):
self.output_mod_glows[i].init_lstm_hidden()
def concat_sequence(self, seqlen, data):
#NOTE: this could be done as preprocessing on the dataset to make it a bit more efficient, but we are only going to
# use this for baseline moglow, so I thought it wasn't worth it to put it there.
"""
Concatenates a sequence of features to one.
"""
nn,n_timesteps,n_feats = data.shape
L = n_timesteps-(seqlen-1)
# import pdb;pdb.set_trace()
inds = torch.zeros((L, seqlen), dtype=torch.long)
#create indices for the sequences we want
rng = torch.arange(0, n_timesteps, dtype=torch.long)
for ii in range(0,seqlen):
# print(rng[ii:(n_timesteps-(seqlen-ii-1))].shape)
# inds[:, ii] = torch.transpose(rng[ii:(n_timesteps-(seqlen-ii-1))], 0, 1)
inds[:, ii] = rng[ii:(n_timesteps-(seqlen-ii-1))]
#slice each sample into L sequences and store as new samples
cc=data[:,inds,:].clone()
#print ("cc: " + str(cc.shape))
#reshape all timesteps and features into one dimention per sample
dd = cc.reshape((nn, L, seqlen, n_feats))
#print ("dd: " + str(dd.shape))
return dd
def set_inputs(self, data):
self.inputs = []
self.targets = []
for i, mod in enumerate(self.input_mods):
input_ = data["in_"+mod]
input_shape = input_.shape
if self.input_seq_lens[i] > 1:
# input_ = input_.permute(0,2,1)
input_ = self.concat_sequence(self.input_seq_lens[i], input_)
# input_ = input_.permute(0,2,1)
else:
input_ = input_.permute(0,2,1)
input_ = input_.squeeze(2)
input_ = input_.permute(1,2,0,3) # L, T, B, C
self.inputs.append(input_)
for i, mod in enumerate(self.output_mods):
target_ = data["out_"+mod]
target_shape = target_.shape
if self.output_seq_lens[i] > 1:
# target_ = target_.permute(0,2,1)
target_ = self.concat_sequence(self.output_seq_lens[i], target_)
target_ = target_.permute(0,2,1)
else:
target_ = target_.permute(0,2,1)
self.targets.append(target_)
def training_step(self, batch, batch_idx):
self.set_inputs(batch)
latents = []
for i, mod in enumerate(self.input_mods):
# import pdb;pdb.set_trace()
#result = self.input_mod_funcs[i](self.inputs[i])
result = []
for inp in self.inputs[i]:
result.append(self.input_mod_funcs[i](inp))
result = torch.stack(result)
latents.append(result)
latent = torch.cat(latents,dim=1)
if self.opt.residual:
nll_loss = 0
mse_loss = 0
for i, mod in enumerate(self.output_mods):
#trans_output = self.output_mod_funcs[i](latent)
trans_output = []
for lat in latent:
trans_output.append(self.output_mod_funcs[i](lat))
trans_output = torch.stack(trans_output)
latents = trans_output[:,:self.conditioning_seq_lens[i]].permute(2,1,3,0)
trans_predicted_mean_latents = trans_output[:,self.conditioning_seq_lens[i]:self.conditioning_seq_lens[i]+self.output_seq_lens[i]]
latents = latents.reshape(latents.shape[0], latents.shape[1] * latents.shape[2], latents.shape[3])
trans_predicted_mean_latents = trans_predicted_mean_latents.reshape(trans_predicted_mean_latents.shape[0], trans_predicted_mean_latents.shape[1] * trans_predicted_mean_latents.shape[2], trans_predicted_mean_latents.shape[3])
predicted_mean = self.output_mod_mean_nets[i](trans_predicted_mean_latents).permute(1,2,0)
glow = self.output_mod_glows[i]
# import pdb;pdb.set_trace()
z, nll = glow(x=self.targets[i]-predicted_mean, cond=latents) #time, batch, features -> batch, time, features
nll_loss += Glow.loss_generative(nll)
mse_loss += 100*self.mean_loss(predicted_mean, self.targets[i])
loss = nll_loss + mse_loss
self.mse_loss = mse_loss
self.nll_loss = nll_loss
self.log('mse_loss', mse_loss)
self.log('nll_loss', nll_loss)
else:
loss = 0
for i, mod in enumerate(self.output_mods):
#output = self.output_mod_funcs[i](latent).permute(2,1,3,0)
output = []
for lat in latent:
output.append(self.output_mod_funcs[i](lat))
output = torch.stack(output).permute(2,1,3,0)
output = output.reshape(output.shape[0], output.shape[1] * output.shape[2], output.shape[3])
glow = self.output_mod_glows[i]
# import pdb;pdb.set_trace()
z, nll = glow(x=self.targets[i], cond=output) #time, batch, features -> batch, time, features
loss += Glow.loss_generative(nll)
self.log('loss', loss)
return loss
def test_step(self, batch, batch_idx):
if self.opt.residual:
self.eval()
loss = self.training_step(batch, batch_idx)
# print(loss)
return {"test_loss": loss, "test_mse_loss": self.mse_loss, "test_nll_loss": self.nll_loss}
else:
return super().test_step(batch, batch_idx)
def test_epoch_end(self, outputs):
if self.opt.residual:
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
avg_mse_loss = torch.stack([x['test_mse_loss'] for x in outputs]).mean()
avg_nll_loss = torch.stack([x['test_nll_loss'] for x in outputs]).mean()
logs = {'test_loss': avg_loss, 'test_mse_loss': avg_mse_loss, 'test_nll_loss': avg_nll_loss}
return {'log': logs}
else:
return super().test_epoch_end(outputs)
#to help debug XLA stuff, like missing ops, or data loading/compiling bottlenecks
# see https://youtu.be/iwtpwQRdb3Y?t=1056
# def on_epoch_end(self):
# import torch_xla.core.xla_model as xm
# import torch_xla.debug.metrics as met
# xm.master_print(met.metrics_report())
#def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
# optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# optimizer.zero_grad()