Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from .transformer import BasicTransformerModel | |
from models import BaseModel | |
from models.flowplusplus import FlowPlusPlus | |
import ast | |
from .util.generation import autoregressive_generation_multimodal | |
import argparse | |
from argparse import Namespace | |
import models | |
class Residualflower2Model(BaseModel): | |
def __init__(self, opt): | |
super().__init__(opt) | |
opt_vars = vars(opt) | |
mean_vars = self.get_argvars(opt.mean_model, opt) | |
mean_opt = opt_vars.copy() | |
for k,v in mean_vars.items(): | |
val = mean_opt["mean_"+k] | |
if k not in mean_opt: | |
mean_opt[k] = val | |
del mean_opt["mean_"+k] | |
mean_opt = Namespace(**mean_opt) | |
self.mean_model = models.create_model_by_name(opt.mean_model, mean_opt) | |
residual_vars = self.get_argvars(opt.residual_model, opt) | |
residual_opt = opt_vars.copy() | |
for k,v in residual_vars.items(): | |
val = residual_opt["residual_"+k] | |
if k not in residual_opt: | |
residual_opt[k] = val | |
del residual_opt["residual_"+k] | |
residual_opt = Namespace(**residual_opt) | |
self.residual_model = models.create_model_by_name(opt.residual_model, residual_opt) | |
self.mean_loss = nn.MSELoss() | |
self.mse_loss = 0 | |
self.nll_loss = 0 | |
def name(self): | |
return "Transflower" | |
def get_argvars(model_name, opt): | |
temp_parser = argparse.ArgumentParser() | |
model_option_setter = models.get_option_setter(model_name) | |
vs = vars(model_option_setter(temp_parser, opt).parse_args([])) | |
return vs | |
def modify_commandline_options(parser, opt): | |
parser.add_argument('--dropout', type=float, default=0.1) | |
parser.add_argument('--mean_model', type=str, default="transformer") | |
parser.add_argument('--residual_model', type=str, default="transflower") | |
opt2, _ = parser.parse_known_args() | |
mean_vars = Residualflower2Model.get_argvars(opt2.mean_model, opt) | |
for k,v in mean_vars.items(): | |
# print(k) | |
if type(v) != type(True): | |
if type(v) != type(None): | |
parser.add_argument('--mean_'+k, type=type(v), default=v) | |
else: | |
parser.add_argument('--mean_'+k, default=v) | |
else: | |
parser.add_argument('--mean_'+k, action="store_true") | |
residual_vars = Residualflower2Model.get_argvars(opt2.residual_model, opt) | |
for k,v in residual_vars.items(): | |
if type(v) != type(True): | |
if type(v) != type(None): | |
parser.add_argument('--residual_'+k, type=type(v), default=v) | |
else: | |
parser.add_argument('--residual_'+k, default=v) | |
else: | |
parser.add_argument('--residual_'+k, action="store_true") | |
return parser | |
def forward(self, data): | |
# in lightning, forward defines the prediction/inference actions | |
predicted_means = self.mean_model(data) | |
predicted_residuals = self.residual_model(data) | |
outputs = [] | |
for i, mod in enumerate(self.output_mods): | |
outputs.append(predicted_means[i]+predicted_residuals[i]) | |
return outputs | |
#def generate(self,features, teacher_forcing=False): | |
# inputs_ = [] | |
# for i,mod in enumerate(self.input_mods): | |
# input_ = features["in_"+mod] | |
# input_ = torch.from_numpy(input_).float().cuda() | |
# input_shape = input_.shape | |
# input_ = input_.reshape((input_shape[0]*input_shape[1], input_shape[2], input_shape[3])).permute(2,0,1).to(self.device) | |
# inputs_.append(input_) | |
# output_seq = autoregressive_generation_multimodal(inputs_, self, autoreg_mods=self.output_mods, teacher_forcing=teacher_forcing) | |
# return output_seq | |
def training_step(self, batch, batch_idx): | |
self.set_inputs(batch) | |
self.mean_model.set_inputs(batch) | |
predicted_means = self.mean_model(self.inputs) | |
mse_loss = 0 | |
for i, mod in enumerate(self.output_mods): | |
mse_loss += 100*self.mean_loss(predicted_means[i], self.targets[i]) | |
for i, mod in enumerate(self.output_mods): | |
# import pdb;pdb.set_trace() | |
batch["out_"+mod] = batch["out_"+mod] - predicted_means[i].permute(1,0,2) | |
# self.residual_model.set_inputs(batch) | |
nll_loss = self.residual_model.training_step(batch, batch_idx) | |
loss = mse_loss + nll_loss | |
self.mse_loss = mse_loss | |
self.nll_loss = nll_loss | |
# print("mse_loss: ", mse_loss) | |
# print("nll_loss: ", nll_loss) | |
self.log('mse_loss', mse_loss) | |
self.log('nll_loss', nll_loss) | |
self.log('loss', loss) | |
return loss | |
def test_step(self, batch, batch_idx): | |
self.eval() | |
loss = self.training_step(batch, batch_idx) | |
# print(loss) | |
return {"test_loss": loss, "test_mse_loss": self.mse_loss, "test_nll_loss": self.nll_loss} | |
def test_epoch_end(self, outputs): | |
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() | |
avg_mse_loss = torch.stack([x['test_mse_loss'] for x in outputs]).mean() | |
avg_nll_loss = torch.stack([x['test_nll_loss'] for x in outputs]).mean() | |
logs = {'test_loss': avg_loss, 'test_mse_loss': avg_mse_loss, 'test_nll_loss': avg_nll_loss} | |
return {'log': logs} | |