File size: 8,072 Bytes
d2914a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae4f84a
d2914a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c12d63
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.transforms as transforms
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from PIL import Image
import clip
import numpy as np
import cv2
import gradio as gr
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):

    assert logits.dim() == 1  # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
    top_k = min(top_k, logits.size(-1))
    if top_k > 0:
        # Remove all tokens with a probability less than the last token in the top-k tokens
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        # Compute cumulative probabilities of sorted tokens
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probabilities > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Back to unsorted indices and set them to -infinity
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    indices_to_remove = logits < threshold
    logits[indices_to_remove] = filter_value

    return logits

class ImageEncoder(nn.Module):

    def __init__(self):
        super(ImageEncoder, self).__init__()

        self.encoder, _ = clip.load("ViT-B/16", device=device)   # loads already in eval mode

    def forward(self, x):
        """
        Expects a tensor of size (batch_size, 3, 224, 224)
        """
        with torch.no_grad():
            x = x.type(self.encoder.visual.conv1.weight.dtype)
            x = self.encoder.visual.conv1(x)  # shape = [*, width, grid, grid]
            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
            x = torch.cat([self.encoder.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
            x = x + self.encoder.visual.positional_embedding.to(x.dtype)
            x = self.encoder.visual.ln_pre(x)
            x = x.permute(1, 0, 2)  # NLD -> LND
            x = self.encoder.visual.transformer(x)
            grid_feats = x.permute(1, 0, 2)  # LND -> NLD    (N, 197, 768)
            grid_feats = self.encoder.visual.ln_post(grid_feats[:,1:])  

        return grid_feats.float()

def change_requires_grad(model, req_grad):
    for p in model.parameters():
        p.requires_grad = req_grad

def load_checkpoint(ckpt_path, epoch):
    
    model_name = 'nle_model_{}'.format(str(epoch))
    tokenizer_name = 'nle_gpt2_tokenizer_0'
    tokenizer = GPT2Tokenizer.from_pretrained(ckpt_path + tokenizer_name)        # load tokenizer
    model = GPT2LMHeadModel.from_pretrained(ckpt_path + model_name).to(device)   # load model with config
    return tokenizer, model

def sample_sequences(img, model, input_ids, segment_ids, tokenizer):
    
    SPECIAL_TOKENS = ['<|endoftext|>', '<pad>', '<question>', '<answer>', '<explanation>']
    special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
    because_token = tokenizer.convert_tokens_to_ids('Ġbecause')
    max_len = 20
    current_output = []
    img_embeddings = image_encoder(img)
    always_exp = False
        
    with torch.no_grad():
        
        for step in range(max_len + 1):
            
            if step == max_len:
                break
            
            outputs = model(input_ids=input_ids, 
                            past_key_values=None, 
                            attention_mask=None, 
                            token_type_ids=segment_ids, 
                            position_ids=None, 
                            encoder_hidden_states=img_embeddings, 
                            encoder_attention_mask=None, 
                            labels=None, 
                            use_cache=False, 
                            output_attentions=True,
                            return_dict=True)
            
            lm_logits = outputs.logits 
            xa_maps = outputs.cross_attentions
            logits = lm_logits[0, -1, :] / temperature
            logits = top_filtering(logits, top_k=top_k, top_p=top_p)
            probs = F.softmax(logits, dim=-1)
            prev = torch.topk(probs, 1)[1] if no_sample else torch.multinomial(probs, 1)
            
            if prev.item() in special_tokens_ids:
                break
            
            # take care of when to start the <explanation> token. Nasty code in here (i hate lots of ifs)
            if not always_exp:
                
                if prev.item() != because_token:
                    new_segment = special_tokens_ids[-2]   # answer segment
                else:
                    new_segment = special_tokens_ids[-1]   # explanation segment
                    always_exp = True
            else:
                new_segment = special_tokens_ids[-1]   # explanation segment
                
            new_segment = torch.LongTensor([new_segment]).to(device)
            current_output.append(prev.item())
            input_ids = torch.cat((input_ids, prev.unsqueeze(0)), dim = 1)
            segment_ids = torch.cat((segment_ids, new_segment.unsqueeze(0)), dim = 1)
                
        decoded_sequences = tokenizer.decode(current_output, skip_special_tokens=True).lstrip()

    return decoded_sequences, xa_maps

def get_inputs(tokenizer):
    a_segment_id, e_segment_id = tokenizer.convert_tokens_to_ids(['<answer>', '<explanation>'])
    tokens = [tokenizer.bos_token] + tokenizer.tokenize("the answer is")
    segment_ids = [a_segment_id] * len(tokens)
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = torch.tensor(input_ids, dtype=torch.long)
    segment_ids = torch.tensor(segment_ids, dtype=torch.long)

    return input_ids.unsqueeze(0).to(device), segment_ids.unsqueeze(0).to(device)
    

img_size = 224
ckpt_path = 'ACTX_p/'
max_seq_len = 30
load_from_epoch = 5
no_sample = True  
top_k =  0
top_p =  0.9
temperature = 1

image_encoder = ImageEncoder().to(device)
change_requires_grad(image_encoder, False)
tokenizer, model = load_checkpoint(ckpt_path, load_from_epoch)
model.eval()

img_transform = transforms.Compose([transforms.Resize((img_size,img_size)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

def inference(raw_image):

  oimg = raw_image.convert('RGB').resize((224,224))
  img = img_transform(oimg).unsqueeze(0).to(device)
  input_ids, segment_ids = get_inputs(tokenizer)
  seq, xa_maps = sample_sequences(img, model, input_ids, segment_ids, tokenizer)
  last_am = xa_maps[-1].mean(1)[0]
  mask = last_am[0, :].reshape(14,14).cpu().numpy()
  mask = cv2.resize(mask / mask.max(), oimg.size)[..., np.newaxis]
  attention_map = (mask * oimg).astype("uint8")
  splitted_seq = seq.split("because")
  return splitted_seq[0].strip(), "because " + splitted_seq[-1].strip(), Image.fromarray(attention_map)

inputs = [gr.inputs.Image(type='pil', label="Load the image of your interest")]
outputs = [gr.outputs.Textbox(label="What action is this?"), gr.outputs.Textbox(label="Textual Explanation"), gr.outputs.Image(type='pil', label="Visual Explanation")]

title = "NLX-GPT: Explanations with Natural Text (Action Recognition Demo)"
gr.Interface(inference, inputs, outputs, title=title).launch()
#