myTest01 / models /residualflower2_model.py
meng2003's picture
Upload 85 files
bc32eea
import torch
from torch import nn
from .transformer import BasicTransformerModel
from models import BaseModel
from models.flowplusplus import FlowPlusPlus
import ast
from .util.generation import autoregressive_generation_multimodal
import argparse
from argparse import Namespace
import models
class Residualflower2Model(BaseModel):
def __init__(self, opt):
super().__init__(opt)
opt_vars = vars(opt)
mean_vars = self.get_argvars(opt.mean_model, opt)
mean_opt = opt_vars.copy()
for k,v in mean_vars.items():
val = mean_opt["mean_"+k]
if k not in mean_opt:
mean_opt[k] = val
del mean_opt["mean_"+k]
mean_opt = Namespace(**mean_opt)
self.mean_model = models.create_model_by_name(opt.mean_model, mean_opt)
residual_vars = self.get_argvars(opt.residual_model, opt)
residual_opt = opt_vars.copy()
for k,v in residual_vars.items():
val = residual_opt["residual_"+k]
if k not in residual_opt:
residual_opt[k] = val
del residual_opt["residual_"+k]
residual_opt = Namespace(**residual_opt)
self.residual_model = models.create_model_by_name(opt.residual_model, residual_opt)
self.mean_loss = nn.MSELoss()
self.mse_loss = 0
self.nll_loss = 0
def name(self):
return "Transflower"
@staticmethod
def get_argvars(model_name, opt):
temp_parser = argparse.ArgumentParser()
model_option_setter = models.get_option_setter(model_name)
vs = vars(model_option_setter(temp_parser, opt).parse_args([]))
return vs
@staticmethod
def modify_commandline_options(parser, opt):
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--mean_model', type=str, default="transformer")
parser.add_argument('--residual_model', type=str, default="transflower")
opt2, _ = parser.parse_known_args()
mean_vars = Residualflower2Model.get_argvars(opt2.mean_model, opt)
for k,v in mean_vars.items():
# print(k)
if type(v) != type(True):
if type(v) != type(None):
parser.add_argument('--mean_'+k, type=type(v), default=v)
else:
parser.add_argument('--mean_'+k, default=v)
else:
parser.add_argument('--mean_'+k, action="store_true")
residual_vars = Residualflower2Model.get_argvars(opt2.residual_model, opt)
for k,v in residual_vars.items():
if type(v) != type(True):
if type(v) != type(None):
parser.add_argument('--residual_'+k, type=type(v), default=v)
else:
parser.add_argument('--residual_'+k, default=v)
else:
parser.add_argument('--residual_'+k, action="store_true")
return parser
def forward(self, data):
# in lightning, forward defines the prediction/inference actions
predicted_means = self.mean_model(data)
predicted_residuals = self.residual_model(data)
outputs = []
for i, mod in enumerate(self.output_mods):
outputs.append(predicted_means[i]+predicted_residuals[i])
return outputs
#def generate(self,features, teacher_forcing=False):
# inputs_ = []
# for i,mod in enumerate(self.input_mods):
# input_ = features["in_"+mod]
# input_ = torch.from_numpy(input_).float().cuda()
# input_shape = input_.shape
# input_ = input_.reshape((input_shape[0]*input_shape[1], input_shape[2], input_shape[3])).permute(2,0,1).to(self.device)
# inputs_.append(input_)
# output_seq = autoregressive_generation_multimodal(inputs_, self, autoreg_mods=self.output_mods, teacher_forcing=teacher_forcing)
# return output_seq
def training_step(self, batch, batch_idx):
self.set_inputs(batch)
self.mean_model.set_inputs(batch)
predicted_means = self.mean_model(self.inputs)
mse_loss = 0
for i, mod in enumerate(self.output_mods):
mse_loss += 100*self.mean_loss(predicted_means[i], self.targets[i])
for i, mod in enumerate(self.output_mods):
# import pdb;pdb.set_trace()
batch["out_"+mod] = batch["out_"+mod] - predicted_means[i].permute(1,0,2)
# self.residual_model.set_inputs(batch)
nll_loss = self.residual_model.training_step(batch, batch_idx)
loss = mse_loss + nll_loss
self.mse_loss = mse_loss
self.nll_loss = nll_loss
# print("mse_loss: ", mse_loss)
# print("nll_loss: ", nll_loss)
self.log('mse_loss', mse_loss)
self.log('nll_loss', nll_loss)
self.log('loss', loss)
return loss
def test_step(self, batch, batch_idx):
self.eval()
loss = self.training_step(batch, batch_idx)
# print(loss)
return {"test_loss": loss, "test_mse_loss": self.mse_loss, "test_nll_loss": self.nll_loss}
def test_epoch_end(self, outputs):
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
avg_mse_loss = torch.stack([x['test_mse_loss'] for x in outputs]).mean()
avg_nll_loss = torch.stack([x['test_nll_loss'] for x in outputs]).mean()
logs = {'test_loss': avg_loss, 'test_mse_loss': avg_mse_loss, 'test_nll_loss': avg_nll_loss}
return {'log': logs}