Spaces:
Runtime error
Runtime error
File size: 5,766 Bytes
bc32eea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
from torch import nn
from .transformer import BasicTransformerModel
from models import BaseModel
from models.flowplusplus import FlowPlusPlus
import ast
from .util.generation import autoregressive_generation_multimodal
class TransFlowppModel(BaseModel):
def __init__(self, opt):
super().__init__(opt)
input_mods = self.input_mods
output_mods = self.output_mods
dins = self.dins
douts = self.douts
input_lengths = self.input_lengths
output_lengths = self.output_lengths
self.input_mod_nets = []
self.output_mod_nets = []
self.output_mod_glows = []
self.module_names = []
for i, mod in enumerate(input_mods):
net = BasicTransformerModel(opt.dhid, dins[i], opt.nhead, opt.dhid, 2, opt.dropout, self.device, use_pos_emb=True, input_length=input_lengths[i]).to(self.device)
name = "_input_"+mod
setattr(self,"net"+name, net)
self.input_mod_nets.append(net)
self.module_names.append(name)
for i, mod in enumerate(output_mods):
net = BasicTransformerModel(opt.dhid, opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout, self.device, use_pos_emb=True, input_length=sum(input_lengths)).to(self.device)
name = "_output_"+mod
setattr(self, "net"+name, net)
self.output_mod_nets.append(net)
self.module_names.append(name)
# import pdb;pdb.set_trace()
glow = FlowPlusPlus(scales=ast.literal_eval(opt.scales),
in_shape=(douts[i], output_lengths[i], 1),
cond_dim=opt.dhid,
mid_channels=opt.dhid,
num_blocks=opt.num_glow_coupling_blocks,
num_components=opt.num_mixture_components,
use_attn=opt.glow_use_attn,
use_logmix=opt.num_mixture_components>0,
drop_prob=opt.dropout
)
name = "_output_glow_"+mod
setattr(self, "net"+name, glow)
self.output_mod_glows.append(glow)
# self.generate_full_masks()
self.inputs = []
self.targets = []
self.criterion = nn.MSELoss()
def name(self):
return "Transformerflow"
@staticmethod
def modify_commandline_options(parser, opt):
parser.add_argument('--dhid', type=int, default=512)
parser.add_argument('--nlayers', type=int, default=6)
parser.add_argument('--nhead', type=int, default=8)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--scales', type=str, default="[[10,0]]")
parser.add_argument('--num_glow_coupling_blocks', type=int, default=10)
parser.add_argument('--num_mixture_components', type=int, default=0)
parser.add_argument('--glow_use_attn', action='store_true', help="whether to use the internal attention for the FlowPlusPLus model")
return parser
# def generate_full_masks(self):
# input_mods = self.input_mods
# output_mods = self.output_mods
# input_lengths = self.input_lengths
# self.src_masks = []
# for i, mod in enumerate(input_mods):
# mask = torch.zeros(input_lengths[i],input_lengths[i])
# self.register_buffer('src_mask_'+str(i), mask)
# self.src_masks.append(mask)
#
# self.output_masks = []
# for i, mod in enumerate(output_mods):
# mask = torch.zeros(sum(input_lengths),sum(input_lengths))
# self.register_buffer('out_mask_'+str(i), mask)
# self.output_masks.append(mask)
def forward(self, data):
# in lightning, forward defines the prediction/inference actions
latents = []
for i, mod in enumerate(self.input_mods):
# mask = getattr(self,"src_mask_"+str(i))
#mask = self.src_masks[i]
latents.append(self.input_mod_nets[i].forward(data[i]))
latent = torch.cat(latents)
outputs = []
for i, mod in enumerate(self.output_mods):
# mask = getattr(self,"out_mask_"+str(i))
#mask = self.output_masks[i]
trans_output = self.output_mod_nets[i].forward(latent)[:self.output_lengths[i]]
output, _ = self.output_mod_glows[i](x=None, cond=trans_output.permute(1,0,2), reverse=True)
outputs.append(output.permute(1,0,2))
# import pdb;pdb.set_trace()
#shape
return outputs
def training_step(self, batch, batch_idx):
self.set_inputs(batch)
latents = []
for i, mod in enumerate(self.input_mods):
# mask = getattr(self,"src_mask_"+str(i))
latents.append(self.input_mod_nets[i].forward(self.inputs[i]))
latent = torch.cat(latents)
loss = 0
for i, mod in enumerate(self.output_mods):
# mask = getattr(self,"out_mask_"+str(i))
output = self.output_mod_nets[i].forward(latent)[:self.output_lengths[i]]
glow = self.output_mod_glows[i]
# import pdb;pdb.set_trace()
z, sldj = glow(x=self.targets[i].permute(1,0,2), cond=output.permute(1,0,2)) #time, batch, features -> batch, time, features
loss += glow.loss_generative(z, sldj)
self.log('nll_loss', loss)
return loss
#def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
# optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# optimizer.zero_grad()
|