myTest01 / models /transflower_model.py
meng2003's picture
Upload 85 files
bc32eea
import torch
from .transformer import BasicTransformerModel
from models import BaseModel
from models.flowplusplus import FlowPlusPlus
import ast
from torch import nn
import torch.nn.functional as F
from .util.generation import autoregressive_generation_multimodal
class TransflowerModel(BaseModel):
def __init__(self, opt):
super().__init__(opt)
input_mods = self.input_mods
output_mods = self.output_mods
dins = self.dins
douts = self.douts
input_lengths = self.input_lengths
output_lengths = self.output_lengths
if self.opt.conditioning_seq_lens is not None:
self.conditioning_seq_lens = [int(x) for x in str(self.opt.conditioning_seq_lens).split(",")]
else:
self.conditioning_seq_lens = [int(x) for x in str(self.opt.output_lengths).split(",")]
self.input_mod_nets = []
self.output_mod_nets = []
self.output_mod_mean_nets = []
self.output_mod_glows = []
self.module_names = []
for i, mod in enumerate(input_mods):
net = BasicTransformerModel(opt.dhid, dins[i], opt.nhead, opt.dhid, 2, opt.dropout,
ntokens=self.input_num_tokens[i],
use_pos_emb=opt.use_pos_emb_inputs,
use_rel_pos_emb=opt.use_rel_pos_emb_inputs,
input_length=input_lengths[i],
use_x_transformers=opt.use_x_transformers,
opt=opt,
discrete_inputs=self.input_types[i] == 'd')
name = "_input_"+mod
setattr(self,"net"+name, net)
self.input_mod_nets.append(net)
self.module_names.append(name)
for i, mod in enumerate(output_mods):
# if self.opt.cond_concat_dims:
net = BasicTransformerModel(opt.dhid, opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout,
ntokens=self.output_num_tokens[i], # tho not being used yet
use_pos_emb=opt.use_pos_emb_output,
use_rel_pos_emb=opt.use_rel_pos_emb_output,
input_length=sum(input_lengths),
use_x_transformers=opt.use_x_transformers,
opt=opt)
# else:
# net = BasicTransformerModel(douts[i]//2, opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout, self.device, use_pos_emb=opt.use_pos_emb_output, input_length=sum(input_lengths), use_x_transformers=opt.use_x_transformers, opt=opt)
name = "_output_"+mod
setattr(self, "net"+name, net)
self.output_mod_nets.append(net)
self.module_names.append(name)
if opt.residual:
# if self.opt.cond_concat_dims:
net = nn.Linear(opt.dhid,douts[i])
# else:
# net = nn.Linear(douts[i]//2,douts[i])
name="_output_mean_encoder"
setattr(self, "net"+name, net)
self.output_mod_mean_nets.append(net)
# import pdb;pdb.set_trace()
glow = FlowPlusPlus(scales=ast.literal_eval(opt.scales),
in_shape=(douts[i], output_lengths[i], 1),
cond_dim=opt.dhid,
mid_channels=opt.dhid_flow,
num_blocks=opt.num_glow_coupling_blocks,
num_components=opt.num_mixture_components,
use_attn=opt.glow_use_attn,
use_logmix=opt.num_mixture_components>0,
drop_prob=opt.dropout,
num_heads=opt.num_heads_flow,
use_transformer_nn=opt.use_transformer_nn,
use_pos_emb=opt.use_pos_emb_coupling,
use_rel_pos_emb=opt.use_rel_pos_emb_coupling,
norm_layer = opt.glow_norm_layer,
bn_momentum = opt.glow_bn_momentum,
cond_concat_dims=opt.cond_concat_dims,
flow_dist=opt.flow_dist,
flow_dist_param=opt.flow_dist_param,
cond_seq_len=self.conditioning_seq_lens[i],
)
name = "_output_glow_"+mod
setattr(self, "net"+name, glow)
self.output_mod_glows.append(glow)
self.mean_loss = nn.MSELoss()
#This is feature creep. Will remove soon
# self.generate_full_masks()
self.inputs = []
self.targets = []
self.mse_loss = 0
self.nll_loss = 0
def name(self):
return "Transflower"
@staticmethod
def modify_commandline_options(parser, opt):
parser.add_argument('--dhid', type=int, default=512)
parser.add_argument('--dhid_flow', type=int, default=512)
parser.add_argument('--conditioning_seq_lens', type=str, default=None, help="the number of outputs of the conditioning transformers to feed (meaning the number of elements along the sequence dimension)")
parser.add_argument('--nlayers', type=int, default=6)
parser.add_argument('--nhead', type=int, default=8)
parser.add_argument('--num_heads_flow', type=int, default=8)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--scales', type=str, default="[[10,0]]")
parser.add_argument('--flow_dist', type=str, default="normal")
parser.add_argument('--flow_dist_param', type=int, default=50)
parser.add_argument('--glow_norm_layer', type=str, default=None)
parser.add_argument('--glow_bn_momentum', type=float, default=0.1)
parser.add_argument('--num_glow_coupling_blocks', type=int, default=10)
parser.add_argument('--num_mixture_components', type=int, default=0)
parser.add_argument('--glow_use_attn', action='store_true', help="whether to use the internal attention for the FlowPlusPLus model")
parser.add_argument('--use_transformer_nn', action='store_true', help="whether to use the internal attention for the FlowPlusPLus model")
parser.add_argument('--use_rel_pos_emb_inputs', action='store_true', help="whether to use T5 relative positional embeddings for input modality transformers")
parser.add_argument('--use_rel_pos_emb_output', action='store_true', help="whether to use T5 relative positional embeddings for output modality transformers")
parser.add_argument('--use_pos_emb_inputs', action='store_true', help="whether to use positional embeddings for output modality transformers")
parser.add_argument('--use_pos_emb_output', action='store_true', help="whether to use positional embeddings for output modality transformers")
parser.add_argument('--use_pos_emb_coupling', action='store_true', help="whether to use positional embeddings for the coupling layer transformers")
parser.add_argument('--use_rel_pos_emb_coupling', action='store_true', help="whether to use T5 relative positional embeddings for the coupling layer transformers")
parser.add_argument('--cond_concat_dims', action='store_true', help="if set we concatenate along the channel dimension with with the x for the coupling layer; otherwise we concatenate along the sequence dimesion")
parser.add_argument('--residual', action='store_true', help="whether to use the flow to predict the residual around a determnisitic mean")
parser.add_argument('--use_rotary_pos_emb', action='store_true', help="whether to use rotary position embeddings")
parser.add_argument('--use_x_transformers', action='store_true', help="whether to use rotary position embeddings")
return parser
# def generate_full_masks(self):
# input_mods = self.input_mods
# output_mods = self.output_mods
# input_lengths = self.input_lengths
# self.src_masks = []
# for i, mod in enumerate(input_mods):
# mask = torch.zeros(input_lengths[i],input_lengths[i])
# self.register_buffer('src_mask_'+str(i), mask)
# self.src_masks.append(mask)
#
# self.output_masks = []
# for i, mod in enumerate(output_mods):
# mask = torch.zeros(sum(input_lengths),sum(input_lengths))
# self.register_buffer('out_mask_'+str(i), mask)
# self.output_masks.append(mask)
def forward(self, data):
# in lightning, forward defines the prediction/inference actions
latents = []
for i, mod in enumerate(self.input_mods):
latents.append(self.input_mod_nets[i].forward(data[i]))
latent = torch.cat(latents)
outputs = []
if self.opt.residual:
for i, mod in enumerate(self.output_mods):
trans_output = self.output_mod_nets[i].forward(latent)[:self.conditioning_seq_lens[i]]
trans_predicted_mean_latents = self.output_mod_nets[i].forward(latent)[self.conditioning_seq_lens[i]:self.conditioning_seq_lens[i]+self.output_lengths[i]]
predicted_mean = self.output_mod_mean_nets[i](trans_predicted_mean_latents)
residual, _ = self.output_mod_glows[i](x=None, cond=trans_output.permute(1,0,2), reverse=True)
output = predicted_mean + residual.permute(1,0,2)
outputs.append(output)
else:
for i, mod in enumerate(self.output_mods):
trans_output = self.output_mod_nets[i].forward(latent)[:self.conditioning_seq_lens[i]]
output, _ = self.output_mod_glows[i](x=None, cond=trans_output.permute(1,0,2), reverse=True)
outputs.append(output.permute(1,0,2))
return outputs
# def on_train_epoch_start(self):
# if self.opt.residual:
# self.residual_loss_weight = self.opt.max_residual_loss_weight * min((self.opt.residual_loss_weight_warmup_epochs - self.current_epoch)/self.opt.residual_loss_weight_warmup_epochs, 1)
def training_step(self, batch, batch_idx):
self.set_inputs(batch)
# print(self.input_mod_nets[0].encoder1.weight.data)
# print(self.targets[0])
latents = []
for i, mod in enumerate(self.input_mods):
latents.append(self.input_mod_nets[i].forward(self.inputs[i]))
latent = torch.cat(latents)
# print(latent)
if self.opt.residual:
nll_loss = 0
mse_loss = 0
for i, mod in enumerate(self.output_mods):
trans_output = self.output_mod_nets[i].forward(latent)
latents = trans_output[:self.conditioning_seq_lens[i]]
trans_predicted_mean_latents = trans_output[self.conditioning_seq_lens[i]:self.conditioning_seq_lens[i]+self.output_lengths[i]]
predicted_mean = self.output_mod_mean_nets[i](trans_predicted_mean_latents)
glow = self.output_mod_glows[i]
z, sldj = glow(x=self.targets[i].permute(1,0,2) - predicted_mean.permute(1,0,2), cond=latents.permute(1,0,2)) #time, batch, features -> batch, time, features
nll_loss += glow.loss_generative(z, sldj)
# import pdb;pdb.set_trace()
mse_loss += 100*self.mean_loss(predicted_mean, self.targets[i])
adaptive_weight = F.sigmoid(mse_loss/5)
loss = (1-adaptive_weight)*nll_loss + adaptive_weight*mse_loss
self.mse_loss = mse_loss
self.nll_loss = nll_loss
#print(mse_loss)
#print(nll_loss)
self.log('mse_loss', mse_loss)
self.log('nll_loss', nll_loss)
else:
loss = 0
for i, mod in enumerate(self.output_mods):
output = self.output_mod_nets[i].forward(latent)[:self.conditioning_seq_lens[i]]
glow = self.output_mod_glows[i]
# import pdb;pdb.set_trace()
# print(output)
z, sldj = glow(x=self.targets[i].permute(1,0,2), cond=output.permute(1,0,2)) #time, batch, features -> batch, time, features
# print(z)
#print(sldj)
# n_timesteps = self.targets[i].shape[1]
loss += glow.loss_generative(z, sldj)
self.log('loss', loss)
# print(loss)
return loss
def test_step(self, batch, batch_idx):
if self.opt.residual:
self.eval()
loss = self.training_step(batch, batch_idx)
# print(loss)
return {"test_loss": loss, "test_mse_loss": self.mse_loss, "test_nll_loss": self.nll_loss}
else:
return super().test_step(batch, batch_idx)
def test_epoch_end(self, outputs):
if self.opt.residual:
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
avg_mse_loss = torch.stack([x['test_mse_loss'] for x in outputs]).mean()
avg_nll_loss = torch.stack([x['test_nll_loss'] for x in outputs]).mean()
logs = {'test_loss': avg_loss, 'test_mse_loss': avg_mse_loss, 'test_nll_loss': avg_nll_loss}
return {'log': logs}
else:
return super().test_epoch_end(outputs)
#to help debug XLA stuff, like missing ops, or data loading/compiling bottlenecks
# see https://youtu.be/iwtpwQRdb3Y?t=1056
# def on_epoch_end(self):
# import torch_xla.core.xla_model as xm
# import torch_xla.debug.metrics as met
# xm.master_print(met.metrics_report())
#def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
# optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# optimizer.zero_grad()