File size: 5,502 Bytes
bc32eea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from contextlib import contextmanager
from collections import OrderedDict
print("HOOOOOO")
from pytorch_lightning import LightningModule
print("HOOOOOO")
from .optimizer import get_scheduler, get_optimizers

from models.util.generation import autoregressive_generation_multimodal

# Benefits of having one skeleton, e.g. for train - is that you can keep all the incremental changes in
# one single code, making it your streamlined and updated script -- no need to keep separate logs on how
# to implement stuff

class BaseModel(LightningModule):
    def __init__(self, opt):
        super().__init__()
        self.save_hyperparameters(vars(opt))
        self.opt = opt
        self.parse_base_arguments()
        self.optimizers = []
        self.schedulers = []

    def parse_base_arguments(self):
        # import pdb;pdb.set_trace()
        self.input_mods = str(self.opt.input_modalities).split(",")
        self.output_mods = str(self.opt.output_modalities).split(",")
        self.dins = [int(x) for x in str(self.opt.dins).split(",")]
        self.douts = [int(x) for x in str(self.opt.douts).split(",")]
        self.input_lengths = [int(x) for x in str(self.opt.input_lengths).split(",")]
        self.output_lengths = [int(x) for x in str(self.opt.output_lengths).split(",")]
        self.output_time_offsets = [int(x) for x in str(self.opt.output_time_offsets).split(",")]
        self.input_time_offsets = [int(x) for x in str(self.opt.input_time_offsets).split(",")]

        if len(self.output_time_offsets) < len(self.output_mods):
            if len(self.output_time_offsets) == 1:
                self.output_time_offsets = self.output_time_offsets*len(self.output_mods)
            else:
                raise Exception("number of output_time_offsets doesnt match number of output_mods")

        if len(self.input_time_offsets) < len(self.input_mods):
            if len(self.input_time_offsets) == 1:
                self.input_time_offsets = self.input_time_offsets*len(self.input_mods)
            else:
                raise Exception("number of input_time_offsets doesnt match number of input_mods")

        input_mods = self.input_mods
        if self.opt.input_types is None:
            input_types = ["c" for inp in input_mods]
        else:
            input_types = self.opt.input_types.split(",")

        if self.opt.input_fix_length_types is None:
            input_fix_length_types = ["end" for inp in input_mods]
        else:
            input_fix_length_types = self.opt.input_fix_length_types.split(",")

        if self.opt.output_fix_length_types is None:
            output_fix_length_types = ["end" for inp in input_mods]
        else:
            output_fix_length_types = self.opt.output_fix_length_types.split(",")

        #fix_length_types_dict = {mod:output_fix_length_types[i] for i,mod in enumerate(output_mods)}
        #fix_length_types_dict.update({mod:input_fix_length_types[i] for i,mod in enumerate(input_mods)})

        assert len(input_types) == len(input_mods)
        assert len(input_fix_length_types) == len(input_mods)
        assert len(output_fix_length_types) == len(input_mods)
        self.input_types = input_types
        self.input_fix_length_types = input_fix_length_types
        self.output_fix_length_types = output_fix_length_types

        if self.opt.input_num_tokens is None:
            self.input_num_tokens = [0 for inp in self.input_mods]
        else:
            self.input_num_tokens  = [int(x) for x in self.opt.input_num_tokens.split(",")]

        if self.opt.output_num_tokens is None:
            self.output_num_tokens = [0 for inp in self.output_mods]
        else:
            self.output_num_tokens  = [int(x) for x in self.opt.output_num_tokens.split(",")]


    def name(self):
        return 'BaseModel'

    #def setup_opt(self, is_train):
    #    pass

    def configure_optimizers(self):
        optimizers = get_optimizers(self, self.opt)
        schedulers = [get_scheduler(optimizer, self.opt) for optimizer in self.optimizers]
        return optimizers, schedulers
        #return self.optimizers

    def set_inputs(self, data):
        # BTC -> TBC
        self.inputs = []
        self.targets = []
        for i, mod in enumerate(self.input_mods):
            input_ = data["in_"+mod]
            input_ = input_.permute(1,0,2)
            self.inputs.append(input_)
        for i, mod in enumerate(self.output_mods):
            target_ = data["out_"+mod]
            target_ = target_.permute(1,0,2)
            self.targets.append(target_)

    def generate(self,features, teacher_forcing=False, ground_truth=False):
        output_seq = autoregressive_generation_multimodal(features, self, autoreg_mods=self.output_mods, teacher_forcing=teacher_forcing, ground_truth=ground_truth)
        return output_seq

    # modify parser to add command line options,
    # and also change the default values if needed
    @staticmethod
    def modify_commandline_options(parser, is_train):
        """
        ABSTRACT METHOD
        :param parser:
        :param is_train:
        :return:
        """
        return parser

    def test_step(self, batch, batch_idx):
        self.eval()
        loss = self.training_step(batch, batch_idx)
        # print(loss)
        return {"test_loss": loss}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}

        return {'log': logs}