File size: 14,208 Bytes
bc32eea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import torch
from .transformer import BasicTransformerModel
from models import BaseModel
from models.flowplusplus import FlowPlusPlus
import ast
from torch import nn
import torch.nn.functional as F

from .util.generation import autoregressive_generation_multimodal

class TransflowerModel(BaseModel):
    def __init__(self, opt):
        super().__init__(opt)
        input_mods = self.input_mods
        output_mods = self.output_mods
        dins = self.dins
        douts = self.douts
        input_lengths = self.input_lengths
        output_lengths = self.output_lengths
        if self.opt.conditioning_seq_lens is not None:
            self.conditioning_seq_lens = [int(x) for x in str(self.opt.conditioning_seq_lens).split(",")]
        else:
            self.conditioning_seq_lens = [int(x) for x in str(self.opt.output_lengths).split(",")]

        self.input_mod_nets = []
        self.output_mod_nets = []
        self.output_mod_mean_nets = []
        self.output_mod_glows = []
        self.module_names = []
        for i, mod in enumerate(input_mods):
            net = BasicTransformerModel(opt.dhid, dins[i], opt.nhead, opt.dhid, 2, opt.dropout,
                                        ntokens=self.input_num_tokens[i],
                                        use_pos_emb=opt.use_pos_emb_inputs,
                                        use_rel_pos_emb=opt.use_rel_pos_emb_inputs,
                                        input_length=input_lengths[i],
                                        use_x_transformers=opt.use_x_transformers,
                                        opt=opt,
                                        discrete_inputs=self.input_types[i] == 'd')
            name = "_input_"+mod
            setattr(self,"net"+name, net)
            self.input_mod_nets.append(net)
            self.module_names.append(name)
        for i, mod in enumerate(output_mods):
            # if self.opt.cond_concat_dims:
            net = BasicTransformerModel(opt.dhid, opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout,
                                        ntokens=self.output_num_tokens[i], # tho not being used yet
                                        use_pos_emb=opt.use_pos_emb_output,
                                        use_rel_pos_emb=opt.use_rel_pos_emb_output,
                                        input_length=sum(input_lengths),
                                        use_x_transformers=opt.use_x_transformers,
                                        opt=opt)
            # else:
            #     net = BasicTransformerModel(douts[i]//2, opt.dhid, opt.nhead, opt.dhid, opt.nlayers, opt.dropout, self.device, use_pos_emb=opt.use_pos_emb_output, input_length=sum(input_lengths), use_x_transformers=opt.use_x_transformers, opt=opt)
            name = "_output_"+mod
            setattr(self, "net"+name, net)
            self.output_mod_nets.append(net)
            self.module_names.append(name)
            if opt.residual:
                # if self.opt.cond_concat_dims:
                net = nn.Linear(opt.dhid,douts[i])
                # else:
                #     net = nn.Linear(douts[i]//2,douts[i])
                name="_output_mean_encoder"
                setattr(self, "net"+name, net)
                self.output_mod_mean_nets.append(net)

            # import pdb;pdb.set_trace()
            glow = FlowPlusPlus(scales=ast.literal_eval(opt.scales),
                                     in_shape=(douts[i], output_lengths[i], 1),
                                     cond_dim=opt.dhid,
                                     mid_channels=opt.dhid_flow,
                                     num_blocks=opt.num_glow_coupling_blocks,
                                     num_components=opt.num_mixture_components,
                                     use_attn=opt.glow_use_attn,
                                     use_logmix=opt.num_mixture_components>0,
                                     drop_prob=opt.dropout,
                                     num_heads=opt.num_heads_flow,
                                     use_transformer_nn=opt.use_transformer_nn,
                                     use_pos_emb=opt.use_pos_emb_coupling,
                                     use_rel_pos_emb=opt.use_rel_pos_emb_coupling,
                                     norm_layer = opt.glow_norm_layer,
                                     bn_momentum = opt.glow_bn_momentum,
                                     cond_concat_dims=opt.cond_concat_dims,
                                     flow_dist=opt.flow_dist,
                                     flow_dist_param=opt.flow_dist_param,
                                     cond_seq_len=self.conditioning_seq_lens[i],
                                )
            name = "_output_glow_"+mod
            setattr(self, "net"+name, glow)
            self.output_mod_glows.append(glow)

        self.mean_loss = nn.MSELoss()
        #This is feature creep. Will remove soon
        # self.generate_full_masks()
        self.inputs = []
        self.targets = []
        self.mse_loss = 0
        self.nll_loss = 0

    def name(self):
        return "Transflower"

    @staticmethod
    def modify_commandline_options(parser, opt):
        parser.add_argument('--dhid', type=int, default=512)
        parser.add_argument('--dhid_flow', type=int, default=512)
        parser.add_argument('--conditioning_seq_lens', type=str, default=None, help="the number of outputs of the conditioning transformers to feed (meaning the number of elements along the sequence dimension)")
        parser.add_argument('--nlayers', type=int, default=6)
        parser.add_argument('--nhead', type=int, default=8)
        parser.add_argument('--num_heads_flow', type=int, default=8)
        parser.add_argument('--dropout', type=float, default=0.1)
        parser.add_argument('--scales', type=str, default="[[10,0]]")
        parser.add_argument('--flow_dist', type=str, default="normal")
        parser.add_argument('--flow_dist_param', type=int, default=50)
        parser.add_argument('--glow_norm_layer', type=str, default=None)
        parser.add_argument('--glow_bn_momentum', type=float, default=0.1)
        parser.add_argument('--num_glow_coupling_blocks', type=int, default=10)
        parser.add_argument('--num_mixture_components', type=int, default=0)
        parser.add_argument('--glow_use_attn', action='store_true', help="whether to use the internal attention for the FlowPlusPLus model")
        parser.add_argument('--use_transformer_nn', action='store_true', help="whether to use the internal attention for the FlowPlusPLus model")
        parser.add_argument('--use_rel_pos_emb_inputs', action='store_true', help="whether to use T5 relative positional embeddings for input modality transformers")
        parser.add_argument('--use_rel_pos_emb_output', action='store_true', help="whether to use T5 relative positional embeddings for output modality transformers")
        parser.add_argument('--use_pos_emb_inputs', action='store_true', help="whether to use positional embeddings for output modality transformers")
        parser.add_argument('--use_pos_emb_output', action='store_true', help="whether to use positional embeddings for output modality transformers")
        parser.add_argument('--use_pos_emb_coupling', action='store_true', help="whether to use positional embeddings for the coupling layer transformers")
        parser.add_argument('--use_rel_pos_emb_coupling', action='store_true', help="whether to use T5 relative positional embeddings for the coupling layer transformers")
        parser.add_argument('--cond_concat_dims', action='store_true', help="if set we concatenate along the channel dimension with with the x for the coupling layer; otherwise we concatenate along the sequence dimesion")
        parser.add_argument('--residual', action='store_true', help="whether to use the flow to predict the residual around a determnisitic mean")
        parser.add_argument('--use_rotary_pos_emb', action='store_true', help="whether to use rotary position embeddings")
        parser.add_argument('--use_x_transformers', action='store_true', help="whether to use rotary position embeddings")
        return parser

    # def generate_full_masks(self):
    #     input_mods = self.input_mods
    #     output_mods = self.output_mods
    #     input_lengths = self.input_lengths
    #     self.src_masks = []
    #     for i, mod in enumerate(input_mods):
    #         mask = torch.zeros(input_lengths[i],input_lengths[i])
    #         self.register_buffer('src_mask_'+str(i), mask)
    #         self.src_masks.append(mask)
    #
    #     self.output_masks = []
    #     for i, mod in enumerate(output_mods):
    #         mask = torch.zeros(sum(input_lengths),sum(input_lengths))
    #         self.register_buffer('out_mask_'+str(i), mask)
    #         self.output_masks.append(mask)

    def forward(self, data):
        # in lightning, forward defines the prediction/inference actions
        latents = []
        for i, mod in enumerate(self.input_mods):
            latents.append(self.input_mod_nets[i].forward(data[i]))
        latent = torch.cat(latents)
        outputs = []
        if self.opt.residual:
            for i, mod in enumerate(self.output_mods):
                trans_output = self.output_mod_nets[i].forward(latent)[:self.conditioning_seq_lens[i]]
                trans_predicted_mean_latents = self.output_mod_nets[i].forward(latent)[self.conditioning_seq_lens[i]:self.conditioning_seq_lens[i]+self.output_lengths[i]]
                predicted_mean = self.output_mod_mean_nets[i](trans_predicted_mean_latents)
                residual, _ = self.output_mod_glows[i](x=None, cond=trans_output.permute(1,0,2), reverse=True)
                output = predicted_mean + residual.permute(1,0,2)
                outputs.append(output)
        else:
            for i, mod in enumerate(self.output_mods):
                trans_output = self.output_mod_nets[i].forward(latent)[:self.conditioning_seq_lens[i]]
                output, _ = self.output_mod_glows[i](x=None, cond=trans_output.permute(1,0,2), reverse=True)
                outputs.append(output.permute(1,0,2))

        return outputs

    # def on_train_epoch_start(self):
    #     if self.opt.residual:
    #         self.residual_loss_weight = self.opt.max_residual_loss_weight * min((self.opt.residual_loss_weight_warmup_epochs - self.current_epoch)/self.opt.residual_loss_weight_warmup_epochs, 1)

    def training_step(self, batch, batch_idx):
        self.set_inputs(batch)
        # print(self.input_mod_nets[0].encoder1.weight.data)
        # print(self.targets[0])
        latents = []
        for i, mod in enumerate(self.input_mods):
            latents.append(self.input_mod_nets[i].forward(self.inputs[i]))

        latent = torch.cat(latents)
        # print(latent)
        if self.opt.residual:
            nll_loss = 0
            mse_loss = 0
            for i, mod in enumerate(self.output_mods):
                trans_output = self.output_mod_nets[i].forward(latent)
                latents = trans_output[:self.conditioning_seq_lens[i]]
                trans_predicted_mean_latents = trans_output[self.conditioning_seq_lens[i]:self.conditioning_seq_lens[i]+self.output_lengths[i]]
                predicted_mean = self.output_mod_mean_nets[i](trans_predicted_mean_latents)
                glow = self.output_mod_glows[i]
                z, sldj = glow(x=self.targets[i].permute(1,0,2) - predicted_mean.permute(1,0,2), cond=latents.permute(1,0,2)) #time, batch, features -> batch, time, features
                nll_loss += glow.loss_generative(z, sldj)
                # import pdb;pdb.set_trace()
                mse_loss += 100*self.mean_loss(predicted_mean, self.targets[i])

            adaptive_weight = F.sigmoid(mse_loss/5)
            loss = (1-adaptive_weight)*nll_loss + adaptive_weight*mse_loss
            self.mse_loss = mse_loss
            self.nll_loss = nll_loss
            #print(mse_loss)
            #print(nll_loss)
            self.log('mse_loss', mse_loss)
            self.log('nll_loss', nll_loss)
        else:
            loss = 0
            for i, mod in enumerate(self.output_mods):
                output = self.output_mod_nets[i].forward(latent)[:self.conditioning_seq_lens[i]]
                glow = self.output_mod_glows[i]
                # import pdb;pdb.set_trace()
                # print(output)
                z, sldj = glow(x=self.targets[i].permute(1,0,2), cond=output.permute(1,0,2)) #time, batch, features -> batch, time, features
                # print(z)
                #print(sldj)
                # n_timesteps = self.targets[i].shape[1]
                loss += glow.loss_generative(z, sldj)

        self.log('loss', loss)
        # print(loss)
        return loss

    def test_step(self, batch, batch_idx):
        if self.opt.residual:
            self.eval()
            loss = self.training_step(batch, batch_idx)
            # print(loss)
            return {"test_loss": loss, "test_mse_loss": self.mse_loss, "test_nll_loss": self.nll_loss}
        else:
            return super().test_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        if self.opt.residual:
            avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
            avg_mse_loss = torch.stack([x['test_mse_loss'] for x in outputs]).mean()
            avg_nll_loss = torch.stack([x['test_nll_loss'] for x in outputs]).mean()
            logs = {'test_loss': avg_loss, 'test_mse_loss': avg_mse_loss, 'test_nll_loss': avg_nll_loss}

            return {'log': logs}
        else:
            return super().test_epoch_end(outputs)

    #to help debug XLA stuff, like missing ops, or data loading/compiling bottlenecks
    # see https://youtu.be/iwtpwQRdb3Y?t=1056
    # def on_epoch_end(self):
    #    import torch_xla.core.xla_model as xm
    #    import torch_xla.debug.metrics as met
    #    xm.master_print(met.metrics_report())


    #def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
    #                           optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
    #    optimizer.zero_grad()