Spaces:
Runtime error
Runtime error
File size: 8,020 Bytes
bc32eea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import torch
from torch import nn
from .transformer import BasicTransformerModel
from models import BaseModel
from models.flowplusplus import FlowPlusPlus
import ast
from .util.generation import autoregressive_generation_multimodal
from .transformer_model import TransformerModel
#TODO: refactor a whole bunch of stuff
class ResidualflowerModel(BaseModel):
def __init__(self, opt):
super().__init__(opt)
self.opt = opt
input_mods = self.input_mods
output_mods = self.output_mods
input_lengths = self.input_lengths
output_lengths = self.output_lengths
dins = self.dins
douts = self.douts
if self.opt.conditioning_seq_lens is not None:
self.conditioning_seq_lens = [int(x) for x in str(self.opt.conditioning_seq_lens).split(",")]
else:
self.conditioning_seq_lens = [int(x) for x in str(self.opt.output_lengths).split(",")]
self.input_mod_nets = []
self.output_mod_nets = []
self.output_mod_glows = []
self.module_names = []
for i, mod in enumerate(input_mods):
net = BasicTransformerModel(opt.dhid, dins[i], opt.nhead, opt.dhid, 2, opt.dropout, self.device, use_pos_emb=True, input_length=input_lengths[i]).to(self.device)
name = "_input_"+mod
setattr(self,"net"+name, net)
self.input_mod_nets.append(net)
self.module_names.append(name)
for i, mod in enumerate(output_mods):
if self.opt.cond_concat_dims:
net = BasicTransformerModel(opt.dhid, opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout, self.device, use_pos_emb=opt.use_pos_emb_output, input_length=sum(input_lengths)).to(self.device)
else:
net = BasicTransformerModel(douts[i]//2, opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout, self.device, use_pos_emb=opt.use_pos_emb_output, input_length=sum(input_lengths)).to(self.device)
name = "_output_"+mod
setattr(self, "net"+name, net)
self.output_mod_nets.append(net)
self.module_names.append(name)
# import pdb;pdb.set_trace()
glow = FlowPlusPlus(scales=ast.literal_eval(opt.scales),
in_shape=(douts[i], output_lengths[i], 1),
cond_dim=opt.dhid,
mid_channels=opt.dhid_flow,
num_blocks=opt.num_glow_coupling_blocks,
num_components=opt.num_mixture_components,
use_attn=opt.glow_use_attn,
use_logmix=opt.num_mixture_components>0,
drop_prob=opt.dropout,
num_heads=opt.num_heads_flow,
use_transformer_nn=opt.use_transformer_nn,
use_pos_emb=opt.use_pos_emb_coupling,
norm_layer = opt.glow_norm_layer,
bn_momentum = opt.glow_bn_momentum,
cond_concat_dims=opt.cond_concat_dims,
cond_seq_len=self.conditioning_seq_lens[i],
)
name = "_output_glow_"+mod
setattr(self, "net"+name, glow)
self.output_mod_glows.append(glow)
self.mean_model = TransformerModel(opt)
self.inputs = []
self.targets = []
self.criterion = nn.MSELoss()
def name(self):
return "Transflower"
@staticmethod
def modify_commandline_options(parser, opt):
parser.add_argument('--dhid', type=int, default=512)
parser.add_argument('--dhid_flow', type=int, default=512)
parser.add_argument('--predicted_inputs', default="0")
parser.add_argument('--conditioning_seq_lens', type=str, default=None, help="the number of outputs of the conditioning transformers to feed (meaning the number of elements along the sequence dimension)")
parser.add_argument('--nlayers', type=int, default=6)
parser.add_argument('--nhead', type=int, default=8)
parser.add_argument('--num_heads_flow', type=int, default=8)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--scales', type=str, default="[[10,0]]")
parser.add_argument('--glow_norm_layer', type=str, default=None)
parser.add_argument('--glow_bn_momentum', type=float, default=0.1)
parser.add_argument('--num_glow_coupling_blocks', type=int, default=10)
parser.add_argument('--num_mixture_components', type=int, default=0)
parser.add_argument('--glow_use_attn', action='store_true', help="whether to use the internal attention for the FlowPlusPLus model")
parser.add_argument('--use_transformer_nn', action='store_true', help="whether to use the internal attention for the FlowPlusPLus model")
parser.add_argument('--use_pos_emb_output', action='store_true', help="whether to use positional embeddings for output modality transformers")
parser.add_argument('--use_pos_emb_coupling', action='store_true', help="whether to use positional embeddings for the coupling layer transformers")
parser.add_argument('--cond_concat_dims', action='store_true', help="if set we concatenate along the channel dimension with with the x for the coupling layer; otherwise we concatenate along the sequence dimesion")
return parser
def forward(self, data):
# in lightning, forward defines the prediction/inference actions
predicted_means = self.mean_model(data)
latents = []
for i, mod in enumerate(self.input_mods):
latents.append(self.input_mod_nets[i].forward(data[i]))
latent = torch.cat(latents)
outputs = []
for i, mod in enumerate(self.output_mods):
trans_output = self.output_mod_nets[i].forward(latent)[:self.conditioning_seq_lens[i]]
output, _ = self.output_mod_glows[i](x=None, cond=trans_output.permute(1,0,2), reverse=True)
outputs.append(predicted_means[i]+output.permute(1,0,2))
return outputs
def training_step(self, batch, batch_idx):
self.set_inputs(batch)
predicted_means = self.mean_model(self.inputs)
mse_loss = 0
for i, mod in enumerate(self.output_mods):
mse_loss += 100*self.criterion(predicted_means[i], self.targets[i])
#print("mse_loss: ", mse_loss)
latents = []
for i, mod in enumerate(self.input_mods):
latents.append(self.input_mod_nets[i].forward(self.inputs[i]))
latent = torch.cat(latents)
nll_loss=0
for i, mod in enumerate(self.output_mods):
output = self.output_mod_nets[i].forward(latent)[:self.conditioning_seq_lens[i]]
glow = self.output_mod_glows[i]
# import pdb;pdb.set_trace()
z, sldj = glow(x=self.targets[i].permute(1,0,2)-predicted_means[i].detach().permute(1,0,2), cond=output.permute(1,0,2)) #time, batch, features -> batch, time, features
#print(sldj)
n_timesteps = self.targets[i].shape[1]
nll_loss += glow.loss_generative(z, sldj)
loss = mse_loss + nll_loss
#print("nll_loss: ", nll_loss)
self.log('mse_loss', mse_loss)
self.log('nll_loss', nll_loss)
self.log('loss', loss)
return loss
#to help debug XLA stuff, like missing ops, or data loading/compiling bottlenecks
# see https://youtu.be/iwtpwQRdb3Y?t=1056
# def on_epoch_end(self):
# import torch_xla.core.xla_model as xm
# import torch_xla.debug.metrics as met
# xm.master_print(met.metrics_report())
#def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
# optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# optimizer.zero_grad()
|