Spaces:
Runtime error
Runtime error
File size: 5,531 Bytes
bc32eea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
from torch import nn
from .transformer import BasicTransformerModel
from models import BaseModel
from models.flowplusplus import FlowPlusPlus
import ast
from .util.generation import autoregressive_generation_multimodal
import argparse
from argparse import Namespace
import models
class Residualflower2Model(BaseModel):
def __init__(self, opt):
super().__init__(opt)
opt_vars = vars(opt)
mean_vars = self.get_argvars(opt.mean_model, opt)
mean_opt = opt_vars.copy()
for k,v in mean_vars.items():
val = mean_opt["mean_"+k]
if k not in mean_opt:
mean_opt[k] = val
del mean_opt["mean_"+k]
mean_opt = Namespace(**mean_opt)
self.mean_model = models.create_model_by_name(opt.mean_model, mean_opt)
residual_vars = self.get_argvars(opt.residual_model, opt)
residual_opt = opt_vars.copy()
for k,v in residual_vars.items():
val = residual_opt["residual_"+k]
if k not in residual_opt:
residual_opt[k] = val
del residual_opt["residual_"+k]
residual_opt = Namespace(**residual_opt)
self.residual_model = models.create_model_by_name(opt.residual_model, residual_opt)
self.mean_loss = nn.MSELoss()
self.mse_loss = 0
self.nll_loss = 0
def name(self):
return "Transflower"
@staticmethod
def get_argvars(model_name, opt):
temp_parser = argparse.ArgumentParser()
model_option_setter = models.get_option_setter(model_name)
vs = vars(model_option_setter(temp_parser, opt).parse_args([]))
return vs
@staticmethod
def modify_commandline_options(parser, opt):
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--mean_model', type=str, default="transformer")
parser.add_argument('--residual_model', type=str, default="transflower")
opt2, _ = parser.parse_known_args()
mean_vars = Residualflower2Model.get_argvars(opt2.mean_model, opt)
for k,v in mean_vars.items():
# print(k)
if type(v) != type(True):
if type(v) != type(None):
parser.add_argument('--mean_'+k, type=type(v), default=v)
else:
parser.add_argument('--mean_'+k, default=v)
else:
parser.add_argument('--mean_'+k, action="store_true")
residual_vars = Residualflower2Model.get_argvars(opt2.residual_model, opt)
for k,v in residual_vars.items():
if type(v) != type(True):
if type(v) != type(None):
parser.add_argument('--residual_'+k, type=type(v), default=v)
else:
parser.add_argument('--residual_'+k, default=v)
else:
parser.add_argument('--residual_'+k, action="store_true")
return parser
def forward(self, data):
# in lightning, forward defines the prediction/inference actions
predicted_means = self.mean_model(data)
predicted_residuals = self.residual_model(data)
outputs = []
for i, mod in enumerate(self.output_mods):
outputs.append(predicted_means[i]+predicted_residuals[i])
return outputs
#def generate(self,features, teacher_forcing=False):
# inputs_ = []
# for i,mod in enumerate(self.input_mods):
# input_ = features["in_"+mod]
# input_ = torch.from_numpy(input_).float().cuda()
# input_shape = input_.shape
# input_ = input_.reshape((input_shape[0]*input_shape[1], input_shape[2], input_shape[3])).permute(2,0,1).to(self.device)
# inputs_.append(input_)
# output_seq = autoregressive_generation_multimodal(inputs_, self, autoreg_mods=self.output_mods, teacher_forcing=teacher_forcing)
# return output_seq
def training_step(self, batch, batch_idx):
self.set_inputs(batch)
self.mean_model.set_inputs(batch)
predicted_means = self.mean_model(self.inputs)
mse_loss = 0
for i, mod in enumerate(self.output_mods):
mse_loss += 100*self.mean_loss(predicted_means[i], self.targets[i])
for i, mod in enumerate(self.output_mods):
# import pdb;pdb.set_trace()
batch["out_"+mod] = batch["out_"+mod] - predicted_means[i].permute(1,0,2)
# self.residual_model.set_inputs(batch)
nll_loss = self.residual_model.training_step(batch, batch_idx)
loss = mse_loss + nll_loss
self.mse_loss = mse_loss
self.nll_loss = nll_loss
# print("mse_loss: ", mse_loss)
# print("nll_loss: ", nll_loss)
self.log('mse_loss', mse_loss)
self.log('nll_loss', nll_loss)
self.log('loss', loss)
return loss
def test_step(self, batch, batch_idx):
self.eval()
loss = self.training_step(batch, batch_idx)
# print(loss)
return {"test_loss": loss, "test_mse_loss": self.mse_loss, "test_nll_loss": self.nll_loss}
def test_epoch_end(self, outputs):
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
avg_mse_loss = torch.stack([x['test_mse_loss'] for x in outputs]).mean()
avg_nll_loss = torch.stack([x['test_nll_loss'] for x in outputs]).mean()
logs = {'test_loss': avg_loss, 'test_mse_loss': avg_mse_loss, 'test_nll_loss': avg_nll_loss}
return {'log': logs}
|