File size: 6,381 Bytes
087df0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://github.com/NVlabs/prismer/blob/main/LICENSE

import numpy as np
import torch
import torch.nn.functional as F

from einops.einops import rearrange
from model.prismer import Prismer


class PrismerVQA(Prismer):
    def forward(self, experts, question, answer=None, weights=None, train=True, inference='rank', k_test=128):
        device = experts['rgb'].device
        question = ['<s>' + ques.capitalize() for ques in question]
        question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
                                  add_special_tokens=False, return_tensors="pt").to(device)
        
        if train:
            experts_train = self.expert_encoder(experts)
            experts_train = rearrange(experts_train, 'l b d -> b l d')  # batch_size, num_latents, output_dim

            answer = [' ' + ans.capitalize() + '</s>' for ans in answer]
            answer = self.tokenizer(answer, padding='longest', return_tensors="pt", add_special_tokens=False).to(device)

            input_ids = torch.cat([question.input_ids, answer.input_ids], dim=1).long()
            attention_mask = torch.cat([question.attention_mask, answer.attention_mask], dim=1)

            answer_targets = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
            answer_targets[:, :-answer.input_ids.shape[1]] = -100

            answer_output = self.text_decoder(input_ids,
                                              attention_mask=attention_mask,
                                              encoder_hidden_states=experts_train,
                                              labels=answer_targets,
                                              return_dict=True)
            loss = weights * answer_output.loss
            loss = loss.mean()
            return loss
        else:
            if inference == 'generate':
                num_beams = 3
                input_ids = question.input_ids
                attention_masks = question.attention_mask

                experts_train = self.expert_encoder(experts)
                experts_train = rearrange(experts_train, 'l b d -> b l d')  # batch_size, num_latents, output_dim
                experts_train = experts_train.repeat_interleave(num_beams, dim=0)
                outputs = self.text_decoder.generate(input_ids=input_ids,
                                                     encoder_hidden_states=experts_train,
                                                     attention_mask=attention_masks,
                                                     max_length=input_ids.shape[1] + 10,
                                                     min_length=input_ids.shape[1] + 2,
                                                     num_beams=num_beams,
                                                     length_penalty=-1)
                answers = []
                for i in range(len(outputs)):
                    answer = self.tokenizer.decode(outputs[i, len(input_ids[i]):], skip_special_tokens=True)
                    answers.append(answer.lower().strip())
                return answers

            elif inference == 'rank':
                experts_train = self.expert_encoder(experts)
                experts_train = rearrange(experts_train, 'l b d -> b l d')

                answer = [' ' + ans.capitalize() + '</s>' for ans in answer]
                answer = self.tokenizer(answer, padding='longest', return_tensors='pt', add_special_tokens=False).to(device)

                start_ids = question.input_ids
                attention_masks = question.attention_mask

                start_output = self.text_decoder(start_ids,
                                                 attention_mask=attention_masks,
                                                 encoder_hidden_states=experts_train,
                                                 return_dict=True)

                logits = start_output.logits[:, -1, :]
                answer_first_token = answer.input_ids[:, 0]
                prob_first_token = F.softmax(logits, dim=1).index_select(dim=1, index=answer_first_token)
                _, topk_ids = prob_first_token.topk(k_test, dim=1)

                # answer input: [num_question * k, answer_len]
                answer_input_ids = []
                answer_input_atts = []
                for b, topk_id in enumerate(topk_ids):
                    answer_input_ids.append(answer.input_ids.index_select(dim=0, index=topk_id))
                    answer_input_atts.append(answer.attention_mask.index_select(dim=0, index=topk_id))

                answer_input_ids = torch.cat(answer_input_ids, dim=0)
                answer_input_atts = torch.cat(answer_input_atts, dim=0)

                # repeat encoder's output for top-k answers
                input_ids = torch.cat([tile(start_ids, 0, k_test), answer_input_ids], dim=1).long()
                attention_masks = torch.cat([tile(attention_masks, 0, k_test), answer_input_atts], dim=1)
                experts_train = tile(experts_train, 0, k_test)

                answer_targets = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
                answer_targets[:, :-answer.input_ids.shape[1]] = -100

                output = self.text_decoder(input_ids,
                                           attention_mask=attention_masks,
                                           encoder_hidden_states=experts_train,
                                           labels=answer_targets,
                                           return_dict=True)

                log_probs_sum = -output.loss / torch.sum(answer_targets != -100, dim=-1)
                log_probs_sum = log_probs_sum.view(-1, k_test)

                max_topk_ids = log_probs_sum.argmax(dim=1)
                max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids]
                return max_ids


def tile(x, dim, n_tile):
    init_dim = x.size(dim)
    repeat_idx = [1] * x.dim()
    repeat_idx[dim] = n_tile
    x = x.repeat(*repeat_idx)
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
    return torch.index_select(x, dim, order_index.to(x.device))