import torch from contextlib import contextmanager from collections import OrderedDict print("HOOOOOO") from pytorch_lightning import LightningModule print("HOOOOOO") from .optimizer import get_scheduler, get_optimizers from models.util.generation import autoregressive_generation_multimodal # Benefits of having one skeleton, e.g. for train - is that you can keep all the incremental changes in # one single code, making it your streamlined and updated script -- no need to keep separate logs on how # to implement stuff class BaseModel(LightningModule): def __init__(self, opt): super().__init__() self.save_hyperparameters(vars(opt)) self.opt = opt self.parse_base_arguments() self.optimizers = [] self.schedulers = [] def parse_base_arguments(self): # import pdb;pdb.set_trace() self.input_mods = str(self.opt.input_modalities).split(",") self.output_mods = str(self.opt.output_modalities).split(",") self.dins = [int(x) for x in str(self.opt.dins).split(",")] self.douts = [int(x) for x in str(self.opt.douts).split(",")] self.input_lengths = [int(x) for x in str(self.opt.input_lengths).split(",")] self.output_lengths = [int(x) for x in str(self.opt.output_lengths).split(",")] self.output_time_offsets = [int(x) for x in str(self.opt.output_time_offsets).split(",")] self.input_time_offsets = [int(x) for x in str(self.opt.input_time_offsets).split(",")] if len(self.output_time_offsets) < len(self.output_mods): if len(self.output_time_offsets) == 1: self.output_time_offsets = self.output_time_offsets*len(self.output_mods) else: raise Exception("number of output_time_offsets doesnt match number of output_mods") if len(self.input_time_offsets) < len(self.input_mods): if len(self.input_time_offsets) == 1: self.input_time_offsets = self.input_time_offsets*len(self.input_mods) else: raise Exception("number of input_time_offsets doesnt match number of input_mods") input_mods = self.input_mods if self.opt.input_types is None: input_types = ["c" for inp in input_mods] else: input_types = self.opt.input_types.split(",") if self.opt.input_fix_length_types is None: input_fix_length_types = ["end" for inp in input_mods] else: input_fix_length_types = self.opt.input_fix_length_types.split(",") if self.opt.output_fix_length_types is None: output_fix_length_types = ["end" for inp in input_mods] else: output_fix_length_types = self.opt.output_fix_length_types.split(",") #fix_length_types_dict = {mod:output_fix_length_types[i] for i,mod in enumerate(output_mods)} #fix_length_types_dict.update({mod:input_fix_length_types[i] for i,mod in enumerate(input_mods)}) assert len(input_types) == len(input_mods) assert len(input_fix_length_types) == len(input_mods) assert len(output_fix_length_types) == len(input_mods) self.input_types = input_types self.input_fix_length_types = input_fix_length_types self.output_fix_length_types = output_fix_length_types if self.opt.input_num_tokens is None: self.input_num_tokens = [0 for inp in self.input_mods] else: self.input_num_tokens = [int(x) for x in self.opt.input_num_tokens.split(",")] if self.opt.output_num_tokens is None: self.output_num_tokens = [0 for inp in self.output_mods] else: self.output_num_tokens = [int(x) for x in self.opt.output_num_tokens.split(",")] def name(self): return 'BaseModel' #def setup_opt(self, is_train): # pass def configure_optimizers(self): optimizers = get_optimizers(self, self.opt) schedulers = [get_scheduler(optimizer, self.opt) for optimizer in self.optimizers] return optimizers, schedulers #return self.optimizers def set_inputs(self, data): # BTC -> TBC self.inputs = [] self.targets = [] for i, mod in enumerate(self.input_mods): input_ = data["in_"+mod] input_ = input_.permute(1,0,2) self.inputs.append(input_) for i, mod in enumerate(self.output_mods): target_ = data["out_"+mod] target_ = target_.permute(1,0,2) self.targets.append(target_) def generate(self,features, teacher_forcing=False, ground_truth=False): output_seq = autoregressive_generation_multimodal(features, self, autoreg_mods=self.output_mods, teacher_forcing=teacher_forcing, ground_truth=ground_truth) return output_seq # modify parser to add command line options, # and also change the default values if needed @staticmethod def modify_commandline_options(parser, is_train): """ ABSTRACT METHOD :param parser: :param is_train: :return: """ return parser def test_step(self, batch, batch_idx): self.eval() loss = self.training_step(batch, batch_idx) # print(loss) return {"test_loss": loss} def test_epoch_end(self, outputs): avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() logs = {'test_loss': avg_loss} return {'log': logs}