Spaces:
Runtime error
Runtime error
File size: 5,502 Bytes
bc32eea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
from contextlib import contextmanager
from collections import OrderedDict
print("HOOOOOO")
from pytorch_lightning import LightningModule
print("HOOOOOO")
from .optimizer import get_scheduler, get_optimizers
from models.util.generation import autoregressive_generation_multimodal
# Benefits of having one skeleton, e.g. for train - is that you can keep all the incremental changes in
# one single code, making it your streamlined and updated script -- no need to keep separate logs on how
# to implement stuff
class BaseModel(LightningModule):
def __init__(self, opt):
super().__init__()
self.save_hyperparameters(vars(opt))
self.opt = opt
self.parse_base_arguments()
self.optimizers = []
self.schedulers = []
def parse_base_arguments(self):
# import pdb;pdb.set_trace()
self.input_mods = str(self.opt.input_modalities).split(",")
self.output_mods = str(self.opt.output_modalities).split(",")
self.dins = [int(x) for x in str(self.opt.dins).split(",")]
self.douts = [int(x) for x in str(self.opt.douts).split(",")]
self.input_lengths = [int(x) for x in str(self.opt.input_lengths).split(",")]
self.output_lengths = [int(x) for x in str(self.opt.output_lengths).split(",")]
self.output_time_offsets = [int(x) for x in str(self.opt.output_time_offsets).split(",")]
self.input_time_offsets = [int(x) for x in str(self.opt.input_time_offsets).split(",")]
if len(self.output_time_offsets) < len(self.output_mods):
if len(self.output_time_offsets) == 1:
self.output_time_offsets = self.output_time_offsets*len(self.output_mods)
else:
raise Exception("number of output_time_offsets doesnt match number of output_mods")
if len(self.input_time_offsets) < len(self.input_mods):
if len(self.input_time_offsets) == 1:
self.input_time_offsets = self.input_time_offsets*len(self.input_mods)
else:
raise Exception("number of input_time_offsets doesnt match number of input_mods")
input_mods = self.input_mods
if self.opt.input_types is None:
input_types = ["c" for inp in input_mods]
else:
input_types = self.opt.input_types.split(",")
if self.opt.input_fix_length_types is None:
input_fix_length_types = ["end" for inp in input_mods]
else:
input_fix_length_types = self.opt.input_fix_length_types.split(",")
if self.opt.output_fix_length_types is None:
output_fix_length_types = ["end" for inp in input_mods]
else:
output_fix_length_types = self.opt.output_fix_length_types.split(",")
#fix_length_types_dict = {mod:output_fix_length_types[i] for i,mod in enumerate(output_mods)}
#fix_length_types_dict.update({mod:input_fix_length_types[i] for i,mod in enumerate(input_mods)})
assert len(input_types) == len(input_mods)
assert len(input_fix_length_types) == len(input_mods)
assert len(output_fix_length_types) == len(input_mods)
self.input_types = input_types
self.input_fix_length_types = input_fix_length_types
self.output_fix_length_types = output_fix_length_types
if self.opt.input_num_tokens is None:
self.input_num_tokens = [0 for inp in self.input_mods]
else:
self.input_num_tokens = [int(x) for x in self.opt.input_num_tokens.split(",")]
if self.opt.output_num_tokens is None:
self.output_num_tokens = [0 for inp in self.output_mods]
else:
self.output_num_tokens = [int(x) for x in self.opt.output_num_tokens.split(",")]
def name(self):
return 'BaseModel'
#def setup_opt(self, is_train):
# pass
def configure_optimizers(self):
optimizers = get_optimizers(self, self.opt)
schedulers = [get_scheduler(optimizer, self.opt) for optimizer in self.optimizers]
return optimizers, schedulers
#return self.optimizers
def set_inputs(self, data):
# BTC -> TBC
self.inputs = []
self.targets = []
for i, mod in enumerate(self.input_mods):
input_ = data["in_"+mod]
input_ = input_.permute(1,0,2)
self.inputs.append(input_)
for i, mod in enumerate(self.output_mods):
target_ = data["out_"+mod]
target_ = target_.permute(1,0,2)
self.targets.append(target_)
def generate(self,features, teacher_forcing=False, ground_truth=False):
output_seq = autoregressive_generation_multimodal(features, self, autoreg_mods=self.output_mods, teacher_forcing=teacher_forcing, ground_truth=ground_truth)
return output_seq
# modify parser to add command line options,
# and also change the default values if needed
@staticmethod
def modify_commandline_options(parser, is_train):
"""
ABSTRACT METHOD
:param parser:
:param is_train:
:return:
"""
return parser
def test_step(self, batch, batch_idx):
self.eval()
loss = self.training_step(batch, batch_idx)
# print(loss)
return {"test_loss": loss}
def test_epoch_end(self, outputs):
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
logs = {'test_loss': avg_loss}
return {'log': logs}
|