File size: 10,012 Bytes
1828dbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import torch.nn as nn
import torch
from transformers import T5ForConditionalGeneration, ViTModel

import pytorch_lightning as pl

# Defining the pytorch model


class LaTr_for_pretraining(nn.Module):
    def __init__(self, config, classify=False):

        super(LaTr_for_pretraining, self).__init__()
        self.vocab_size = config['vocab_size']

        model = T5ForConditionalGeneration.from_pretrained(config['t5_model'])
        # Removing the Embedding layer
        dummy_encoder = list(nn.Sequential(
            *list(model.encoder.children())[1:]).children())
        # Removing the Embedding Layer
        dummy_decoder = list(nn.Sequential(
            *list(model.decoder.children())[1:]).children())

        # Using the T5 Encoder

        self.list_encoder = nn.Sequential(*list(dummy_encoder[0]))
        self.residue_encoder = nn.Sequential(*list(dummy_encoder[1:]))
        self.list_decoder = nn.Sequential(*list(dummy_decoder[0]))
        self.residue_decoder = nn.Sequential(*list(dummy_decoder[1:]))

        # We use the embeddings of T5 for encoding the tokenized words
        self.language_emb = nn.Embedding.from_pretrained(model.shared.weight)

        self.top_left_x = nn.Embedding(
            config['max_2d_position_embeddings'], config['hidden_state'])
        self.bottom_right_x = nn.Embedding(
            config['max_2d_position_embeddings'], config['hidden_state'])
        self.top_left_y = nn.Embedding(
            config['max_2d_position_embeddings'], config['hidden_state'])
        self.bottom_right_y = nn.Embedding(
            config['max_2d_position_embeddings'], config['hidden_state'])
        self.width_emb = nn.Embedding(
            config['max_2d_position_embeddings'], config['hidden_state'])
        self.height_emb = nn.Embedding(
            config['max_2d_position_embeddings'], config['hidden_state'])

        self.classify = classify
        self.classification_layer = nn.Linear(
            config['hidden_state'], config['classes'])

    def forward(self, tokens, coordinates, predict_proba=False, predict_class=False):

        batch_size = len(tokens)
        embeded_feature = self.language_emb(tokens)

        top_left_x_feat = self.top_left_x(coordinates[:, :, 0])
        top_left_y_feat = self.top_left_y(coordinates[:, :, 1])
        bottom_right_x_feat = self.bottom_right_x(coordinates[:, :, 2])
        bottom_right_y_feat = self.bottom_right_y(coordinates[:, :, 3])
        width_feat = self.width_emb(coordinates[:, :, 4])
        height_feat = self.height_emb(coordinates[:, :, 5])

        total_feat = embeded_feature + top_left_x_feat + top_left_y_feat + \
            bottom_right_x_feat + bottom_right_y_feat + width_feat + height_feat

        # Extracting the feature

        for layer in self.list_encoder:
            total_feat = layer(total_feat)[0]
        total_feat = self.residue_encoder(total_feat)

        for layer in self.list_decoder:
            total_feat = layer(total_feat)[0]
        total_feat = self.residue_decoder(total_feat)

        if self.classify:
            total_feat = self.classification_layer(total_feat)

        if predict_proba:
            return total_feat.softmax(axis=-1)

        if predict_class:
            return total_feat.argmax(axis=-1)

        return total_feat


class LaTr_for_finetuning(nn.Module):
    def __init__(self, config, address_to_pre_trained_weights=None):
        super(LaTr_for_finetuning, self).__init__()

        self.config = config
        self.vocab_size = config['vocab_size']

        self.pre_training_model = LaTr_for_pretraining(config)
        if address_to_pre_trained_weights is not None:
            self.pre_training_model.load_state_dict(
                torch.load(address_to_pre_trained_weights))
        self.vit = ViTModel.from_pretrained(
            "google/vit-base-patch16-224-in21k")

        # In the fine-tuning stage of vit, except the last layer, all the layers were freezed

        self.classification_head = nn.Linear(
            config['hidden_state'], config['classes'])

    def forward(self, lang_vect, spatial_vect, quest_vect, img_vect):

        # The below block of code calculates the language and spatial featuer
        embeded_feature = self.pre_training_model.language_emb(lang_vect)
        top_left_x_feat = self.pre_training_model.top_left_x(
            spatial_vect[:, :, 0])
        top_left_y_feat = self.pre_training_model.top_left_y(
            spatial_vect[:, :, 1])
        bottom_right_x_feat = self.pre_training_model.bottom_right_x(
            spatial_vect[:, :, 2])
        bottom_right_y_feat = self.pre_training_model.bottom_right_y(
            spatial_vect[:, :, 3])
        width_feat = self.pre_training_model.width_emb(spatial_vect[:, :, 4])
        height_feat = self.pre_training_model.height_emb(spatial_vect[:, :, 5])

        spatial_lang_feat = embeded_feature + top_left_x_feat + top_left_y_feat + \
            bottom_right_x_feat + bottom_right_y_feat + width_feat + height_feat

        # Extracting the image feature, using the Vision Transformer
        img_feat = self.vit(img_vect).last_hidden_state

        # Extracting the question vector
        quest_feat = self.pre_training_model.language_emb(quest_vect)

        # Concating the three features, and then passing it through the T5 Transformer
        final_feat = torch.cat(
            [img_feat, spatial_lang_feat, quest_feat], axis=-2)

        # Passing through the T5 Transformer
        for layer in self.pre_training_model.list_encoder:
            final_feat = layer(final_feat)[0]

        final_feat = self.pre_training_model.residue_encoder(final_feat)

        for layer in self.pre_training_model.list_decoder:
            final_feat = layer(final_feat)[0]
        final_feat = self.pre_training_model.residue_decoder(final_feat)

        answer_vector = self.classification_head(
            final_feat)[:, :self.config['seq_len'], :]

        return answer_vector


def polynomial(base_lr, iter, max_iter=1e5, power=1):
    return base_lr * ((1 - float(iter) / max_iter) ** power)


class LaTrForVQA(pl.LightningModule):
    def __init__(self, config, learning_rate=1e-4, max_steps=100000//2):
        super(LaTrForVQA, self).__init__()

        self.config = config
        self.save_hyperparameters()
        self.latr = LaTr_for_finetuning(config)
        self.training_losses = []
        self.validation_losses = []
        self.max_steps = max_steps

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams['learning_rate'])

    def forward(self, batch_dict):
        boxes = batch_dict['boxes']
        img = batch_dict['img']
        question = batch_dict['question']
        words = batch_dict['tokenized_words']
        answer_vector = self.latr(lang_vect=words,
                                  spatial_vect=boxes,
                                  img_vect=img,
                                  quest_vect=question
                                  )
        return answer_vector

    def calculate_metrics(self, prediction, labels):

        # Calculate the accuracy score between the prediction and ground label for a batch, with considering the pad sequence
        batch_size = len(prediction)
        ac_score = 0

        for (pred, gt) in zip(prediction, labels):
            ac_score += calculate_acc_score(pred.detach().cpu(),
                                            gt.detach().cpu())
        ac_score = ac_score/batch_size
        return ac_score

    def training_step(self, batch, batch_idx):
        answer_vector = self.forward(batch)

        # https://discuss.huggingface.co/t/bertformaskedlm-s-loss-and-scores-how-the-loss-is-computed/607/2
        loss = nn.CrossEntropyLoss(ignore_index=0)(
            answer_vector.reshape(-1, self.config['classes']), batch['answer'].reshape(-1))
        _, preds = torch.max(answer_vector, dim=-1)

        # Calculating the accuracy score
        train_acc = self.calculate_metrics(preds, batch['answer'])
        train_acc = torch.tensor(train_acc)

        # Logging
        self.log('train_ce_loss', loss, prog_bar=True)
        self.log('train_acc', train_acc, prog_bar=True)
        self.training_losses.append(loss.item())

        return loss

    def validation_step(self, batch, batch_idx):
        logits = self.forward(batch)
        loss = nn.CrossEntropyLoss(ignore_index=0)(
            logits.reshape(-1, self.config['classes']), batch['answer'].reshape(-1))
        _, preds = torch.max(logits, dim=-1)

        # Validation Accuracy
        val_acc = self.calculate_metrics(preds.cpu(), batch['answer'].cpu())
        val_acc = torch.tensor(val_acc)

        # Logging
        self.log('val_ce_loss', loss, prog_bar=True)
        self.log('val_acc', val_acc, prog_bar=True)
        self.validation_losses.append(loss.item())
        return {'val_loss': loss, 'val_acc': val_acc}

    def optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i, opt_closure=None, on_tpu=False,
                       using_native_amp=False, using_lbfgs=False):

        # Warmup for 1000 steps
        if self.trainer.global_step < 1000:
            lr_scale = min(1., float(self.trainer.global_step + 1) / 1000.)
            for pg in optimizer.param_groups:
                pg['lr'] = lr_scale * self.hparams.learning_rate

        # Linear Decay
        else:
            for pg in optimizer.param_groups:
                pg['lr'] = polynomial(
                    self.hparams.learning_rate, self.trainer.global_step, max_iter=self.max_steps)

        optimizer.step(opt_closure)
        optimizer.zero_grad()

    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()

        self.log('val_loss_epoch_end', val_loss, on_epoch=True, sync_dist=True)
        self.log('val_acc_epoch_end', val_acc, on_epoch=True, sync_dist=True)