File size: 17,707 Bytes
bc32eea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import torch
from .transformer import BasicTransformerModel
from models import BaseModel
from models.flowplusplus import FlowPlusPlus
import ast
from torch import nn

from .util.generation import autoregressive_generation_multimodal
from .moglow.models import Glow

#TODO: refactor a whole bunch of stuff

class TransglowerModel(BaseModel):
    def __init__(self, opt):
        super().__init__(opt)
        input_mods = self.input_mods
        output_mods = self.output_mods
        dins = self.dins
        douts = self.douts
        input_seq_lens = self.input_seq_lens

        self.input_mod_nets = []
        self.input_mod_funcs = []
        self.output_mod_nets = []
        self.output_mod_mean_nets = []
        self.output_mod_funcs = []
        self.output_mod_glows = []
        self.module_names = []
        for i, mod in enumerate(input_mods):
            net = BasicTransformerModel(opt.dhid, dins[i], opt.nhead, opt.dhid, 2, opt.dropout, self.device, use_pos_emb=True, input_length=input_seq_lens[i]).to(self.device)
            name = "_input_"+mod
            setattr(self,"net"+name, net)
            self.input_mod_nets.append(net)
            # self.input_mod_funcs.append(func)
            self.module_names.append(name)
        def func1(x):
            return self.input_mod_nets[0].forward(x)
        #func1 = torch.vmap(func1)
        def func2(x):
            return self.input_mod_nets[1].forward(x)
        #func2 = torch.vmap(func2)
        self.input_mod_funcs = [func1, func2]
        # should only be one output_mod
        for i, mod in enumerate(output_mods):
            net = BasicTransformerModel(opt.dhid, opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout, self.device, use_pos_emb=opt.use_pos_emb_output, input_length=sum(input_seq_lens)).to(self.device)
            name = "_output_"+mod
            setattr(self, "net"+name, net)
            self.output_mod_nets.append(net)
            self.module_names.append(name)
            if self.opt.residual:
                def func3(x):
                    return self.output_mod_nets[i].forward(x)
            else:
                def func3(x):
                    return self.output_mod_nets[i].forward(x)[:self.conditioning_seq_lens[i]]

            #func3 = torch.vmap(func3)
            self.output_mod_funcs.append(func3)
            if opt.residual:
                net = nn.Linear(opt.dhid,douts[i])
                name="_output_mean_encoder"
                setattr(self, "net"+name, net)
                self.output_mod_mean_nets.append(net)

            cond_dim = opt.dhid
            output_dim = douts[i]
            glow = Glow(output_dim, cond_dim, self.opt)
            name = "_output_glow_"+mod
            setattr(self, "net"+name, glow)
            self.output_mod_glows.append(glow)

        self.inputs = []
        self.targets = []
        self.mean_loss = nn.MSELoss()
        self.mse_loss = 0
        self.nll_loss = 0

    def name(self):
        return "Transglower"

    def parse_base_arguments(self):
        super().parse_base_arguments()
        self.input_seq_lens = [int(x) for x in str(self.opt.input_seq_lens).split(",")]
        self.output_seq_lens = [int(x) for x in str(self.opt.output_seq_lens).split(",")]
        if self.opt.phase == "inference":
            self.input_lengths = [int(x) for x in self.opt.input_seq_lens.split(",")]
            self.output_lengths = [int(x) for x in self.opt.output_seq_lens.split(",")]
        else:
            self.input_lengths = [int(x) for x in self.opt.input_lengths.split(",")]
            self.output_lengths = [int(x) for x in self.opt.output_lengths.split(",")]

        if self.opt.conditioning_seq_lens is not None:
            self.conditioning_seq_lens = [int(x) for x in str(self.opt.conditioning_seq_lens).split(",")]
        else:
            self.conditioning_seq_lens = [1 for x in self.opt.output_lengths.split(",")]

        if len(self.output_time_offsets) < len(self.output_mods):
            if len(self.output_time_offsets) == 1:
                self.output_time_offsets = self.output_time_offsets*len(self.output_mods)
            else:
                raise Exception("number of output_time_offsets doesnt match number of output_mods")

        if len(self.input_time_offsets) < len(self.input_mods):
            if len(input_time_offsets) == 1:
                self.input_time_offsets = self.input_time_offsets*len(self.input_mods)
            else:
                raise Exception("number of input_time_offsets doesnt match number of input_mods")

    @staticmethod
    def modify_commandline_options(parser, opt):
        parser.add_argument('--dhid', type=int, default=512)
        parser.add_argument('--dhid_flow', type=int, default=512)
        parser.add_argument('--conditioning_seq_lens', type=str, default=None, help="the number of outputs of the conditioning transformers to feed (meaning the number of elements along the sequence dimension)")
        parser.add_argument('--input_seq_lens', type=str, default="10,11")
        parser.add_argument('--output_seq_lens', type=str, default="1")
        parser.add_argument('--glow_K', type=int, default=16)
        parser.add_argument('--actnorm_scale', type=float, default=1.0)
        parser.add_argument('--flow_permutation', type=str, default="invconv")
        parser.add_argument('--flow_dist', type=str, default="normal")
        parser.add_argument('--flow_dist_param', type=int, default=50)
        parser.add_argument('--flow_coupling', type=str, default="affine")
        parser.add_argument('--num_layers', type=int, default=2)
        parser.add_argument('--network_model', type=str, default="LSTM")
        parser.add_argument('--LU_decomposed', action='store_true')
        parser.add_argument('--nlayers', type=int, default=6)
        parser.add_argument('--nhead', type=int, default=8)
        parser.add_argument('--num_heads_flow', type=int, default=8)
        parser.add_argument('--dropout', type=float, default=0.1)
        parser.add_argument('--scales', type=str, default="[[10,0]]")
        parser.add_argument('--glow_norm_layer', type=str, default=None)
        parser.add_argument('--glow_bn_momentum', type=float, default=0.1)
        parser.add_argument('--num_glow_coupling_blocks', type=int, default=10)
        parser.add_argument('--num_mixture_components', type=int, default=0)
        parser.add_argument('--glow_use_attn', action='store_true', help="whether to use the internal attention for the FlowPlusPLus model")
        parser.add_argument('--use_transformer_nn', action='store_true', help="whether to use the internal attention for the FlowPlusPLus model")
        parser.add_argument('--use_pos_emb_output', action='store_true', help="whether to use positional embeddings for output modality transformers")
        parser.add_argument('--use_pos_emb_coupling', action='store_true', help="whether to use positional embeddings for the coupling layer transformers")
        parser.add_argument('--cond_concat_dims', action='store_true', help="if set we concatenate along the channel dimension with with the x for the coupling layer; otherwise we concatenate along the sequence dimesion")
        parser.add_argument('--residual', action='store_true', help="whether to use the flow to predict the residual around a determnisitic mean")
        return parser

    def forward(self, data):
        # in lightning, forward defines the prediction/inference actions
        # min_len = min(self.input_seq_lens)
        for i,mod in enumerate(self.input_mods):
            input_ = data[i]
            input_ = input_.permute(1,2,0)
            input_ = input_.permute(0,2,1)
            input_ = self.concat_sequence(self.input_seq_lens[i], input_)
            # input_ = input_.permute(0,2,1)
            input_ = input_.permute(1,2,0,3) # L, T, B, C
            # input_ = input_[:,:,:min_len]
            # inputs_.append(input_)
            data[i] = input_

        latents = []
        for i, mod in enumerate(self.input_mods):
            # import pdb;pdb.set_trace()
            # result = self.input_mod_funcs[i](data[i])
            result = []
            for inp in data[i]:
                result.append(self.input_mod_funcs[i](inp))
            result = torch.stack(result)
            latents.append(result)

        latent = torch.cat(latents,dim=1)
        loss = 0
        outputs = []
        if self.opt.residual:
            for i, mod in enumerate(self.output_mods):
                #trans_output = self.output_mod_funcs[i](latent).permute(2,1,3,0)
                trans_output = []
                for lat in latent:
                    trans_output.append(self.output_mod_funcs[i](lat))
                trans_output = torch.stack(trans_output)
                latents = trans_output[:,:self.conditioning_seq_lens[i]].permute(2,1,3,0)
                trans_predicted_mean_latents = trans_output[:,self.conditioning_seq_lens[i]:self.conditioning_seq_lens[i]+self.output_seq_lens[i]]
                latents = latents.reshape(latents.shape[0], latents.shape[1] * latents.shape[2], latents.shape[3])
                trans_predicted_mean_latents = trans_predicted_mean_latents.reshape(trans_predicted_mean_latents.shape[0], trans_predicted_mean_latents.shape[1] * trans_predicted_mean_latents.shape[2], trans_predicted_mean_latents.shape[3])
                # predicted_mean = self.output_mod_mean_nets[i](trans_predicted_mean_latents).permute(1,2,0)
                predicted_mean = self.output_mod_mean_nets[i](trans_predicted_mean_latents)
                glow = self.output_mod_glows[i]
                output = glow(x=None, cond=latents, reverse=True)
                # import pdb;pdb.set_trace()
                outputs.append(output.permute(0,2,1)+predicted_mean)
        else:
            for i, mod in enumerate(self.output_mods):
                #trans_output = self.output_mod_funcs[i](latent).permute(2,1,3,0)
                trans_output = []
                for lat in latent:
                    trans_output.append(self.output_mod_funcs[i](lat))
                trans_output = torch.stack(trans_output).permute(2,1,3,0)
                trans_output = trans_output.reshape(trans_output.shape[0], trans_output.shape[1] * trans_output.shape[2], trans_output.shape[3])
                glow = self.output_mod_glows[i]
                output = glow(x=None, cond=trans_output, reverse=True)
                outputs.append(output.permute(0,2,1))

        return outputs

    def on_test_start(self):
        for i, mod in enumerate(self.output_mods):
            self.output_mod_glows[i].init_lstm_hidden()

    def on_train_start(self):
        for i, mod in enumerate(self.output_mods):
            self.output_mod_glows[i].init_lstm_hidden()

    def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
        for i, mod in enumerate(self.output_mods):
            self.output_mod_glows[i].init_lstm_hidden()

    def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
        for i, mod in enumerate(self.output_mods):
            self.output_mod_glows[i].init_lstm_hidden()

    def concat_sequence(self, seqlen, data):
        #NOTE: this could be done as preprocessing on the dataset to make it a bit more efficient, but we are only going to
        # use this for baseline moglow, so I thought it wasn't worth it to put it there.
        """
        Concatenates a sequence of features to one.
        """
        nn,n_timesteps,n_feats = data.shape
        L = n_timesteps-(seqlen-1)
        # import pdb;pdb.set_trace()
        inds = torch.zeros((L, seqlen), dtype=torch.long)

        #create indices for the sequences we want
        rng = torch.arange(0, n_timesteps, dtype=torch.long)
        for ii in range(0,seqlen):
            # print(rng[ii:(n_timesteps-(seqlen-ii-1))].shape)
            # inds[:, ii] = torch.transpose(rng[ii:(n_timesteps-(seqlen-ii-1))], 0, 1)
            inds[:, ii] = rng[ii:(n_timesteps-(seqlen-ii-1))]

        #slice each sample into L sequences and store as new samples
        cc=data[:,inds,:].clone()

        #print ("cc: " + str(cc.shape))

        #reshape all timesteps and features into one dimention per sample
        dd = cc.reshape((nn, L, seqlen, n_feats))
        #print ("dd: " + str(dd.shape))
        return dd

    def set_inputs(self, data):
        self.inputs = []
        self.targets = []
        for i, mod in enumerate(self.input_mods):
            input_ = data["in_"+mod]
            input_shape = input_.shape
            if self.input_seq_lens[i] > 1:
                # input_ = input_.permute(0,2,1)
                input_ = self.concat_sequence(self.input_seq_lens[i], input_)
                # input_ = input_.permute(0,2,1)
            else:
                input_ = input_.permute(0,2,1)
                input_ = input_.squeeze(2)
            input_ = input_.permute(1,2,0,3) # L, T, B, C
            self.inputs.append(input_)
        for i, mod in enumerate(self.output_mods):
            target_ = data["out_"+mod]
            target_shape = target_.shape
            if self.output_seq_lens[i] > 1:
                # target_ = target_.permute(0,2,1)
                target_ = self.concat_sequence(self.output_seq_lens[i], target_)
                target_ = target_.permute(0,2,1)
            else:
                target_ = target_.permute(0,2,1)
            self.targets.append(target_)

    def training_step(self, batch, batch_idx):
        self.set_inputs(batch)
        latents = []
        for i, mod in enumerate(self.input_mods):
            # import pdb;pdb.set_trace()
            #result = self.input_mod_funcs[i](self.inputs[i])
            result = []
            for inp in self.inputs[i]:
                result.append(self.input_mod_funcs[i](inp))
            result = torch.stack(result)
            latents.append(result)

        latent = torch.cat(latents,dim=1)
        if self.opt.residual:
            nll_loss = 0
            mse_loss = 0
            for i, mod in enumerate(self.output_mods):
                #trans_output = self.output_mod_funcs[i](latent)
                trans_output = []
                for lat in latent:
                    trans_output.append(self.output_mod_funcs[i](lat))
                trans_output = torch.stack(trans_output)
                latents = trans_output[:,:self.conditioning_seq_lens[i]].permute(2,1,3,0)
                trans_predicted_mean_latents = trans_output[:,self.conditioning_seq_lens[i]:self.conditioning_seq_lens[i]+self.output_seq_lens[i]]
                latents = latents.reshape(latents.shape[0], latents.shape[1] * latents.shape[2], latents.shape[3])
                trans_predicted_mean_latents = trans_predicted_mean_latents.reshape(trans_predicted_mean_latents.shape[0], trans_predicted_mean_latents.shape[1] * trans_predicted_mean_latents.shape[2], trans_predicted_mean_latents.shape[3])
                predicted_mean = self.output_mod_mean_nets[i](trans_predicted_mean_latents).permute(1,2,0)
                glow = self.output_mod_glows[i]
                # import pdb;pdb.set_trace()
                z, nll = glow(x=self.targets[i]-predicted_mean, cond=latents) #time, batch, features -> batch, time, features
                nll_loss += Glow.loss_generative(nll)
                mse_loss += 100*self.mean_loss(predicted_mean, self.targets[i])
            loss = nll_loss + mse_loss
            self.mse_loss = mse_loss
            self.nll_loss = nll_loss
            self.log('mse_loss', mse_loss)
            self.log('nll_loss', nll_loss)
        else:
            loss = 0
            for i, mod in enumerate(self.output_mods):
                #output = self.output_mod_funcs[i](latent).permute(2,1,3,0)
                output = []
                for lat in latent:
                    output.append(self.output_mod_funcs[i](lat))
                output = torch.stack(output).permute(2,1,3,0)
                output = output.reshape(output.shape[0], output.shape[1] * output.shape[2], output.shape[3])
                glow = self.output_mod_glows[i]
                # import pdb;pdb.set_trace()
                z, nll = glow(x=self.targets[i], cond=output) #time, batch, features -> batch, time, features
                loss += Glow.loss_generative(nll)
        self.log('loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        if self.opt.residual:
            self.eval()
            loss = self.training_step(batch, batch_idx)
            # print(loss)
            return {"test_loss": loss, "test_mse_loss": self.mse_loss, "test_nll_loss": self.nll_loss}
        else:
            return super().test_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        if self.opt.residual:
            avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
            avg_mse_loss = torch.stack([x['test_mse_loss'] for x in outputs]).mean()
            avg_nll_loss = torch.stack([x['test_nll_loss'] for x in outputs]).mean()
            logs = {'test_loss': avg_loss, 'test_mse_loss': avg_mse_loss, 'test_nll_loss': avg_nll_loss}

            return {'log': logs}
        else:
            return super().test_epoch_end(outputs)

    #to help debug XLA stuff, like missing ops, or data loading/compiling bottlenecks
    # see https://youtu.be/iwtpwQRdb3Y?t=1056
    # def on_epoch_end(self):
    #    import torch_xla.core.xla_model as xm
    #    import torch_xla.debug.metrics as met
    #    xm.master_print(met.metrics_report())


    #def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
    #                           optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
    #    optimizer.zero_grad()