import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data import torchvision.transforms as transforms from transformers import GPT2Tokenizer, GPT2LMHeadModel from PIL import Image import clip import numpy as np import cv2 import gradio as gr device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')): assert logits.dim() == 1 # Only work for batch size 1 for now - could update but it would obfuscate a bit the code top_k = min(top_k, logits.size(-1)) if top_k > 0: # Remove all tokens with a probability less than the last token in the top-k tokens indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p > 0.0: # Compute cumulative probabilities of sorted tokens sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probabilities > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # Back to unsorted indices and set them to -infinity indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = filter_value indices_to_remove = logits < threshold logits[indices_to_remove] = filter_value return logits class ImageEncoder(nn.Module): def __init__(self): super(ImageEncoder, self).__init__() self.encoder, _ = clip.load("ViT-B/16", device=device) # loads already in eval mode def forward(self, x): """ Expects a tensor of size (batch_size, 3, 224, 224) """ with torch.no_grad(): x = x.type(self.encoder.visual.conv1.weight.dtype) x = self.encoder.visual.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = torch.cat([self.encoder.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.encoder.visual.positional_embedding.to(x.dtype) x = self.encoder.visual.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x = self.encoder.visual.transformer(x) grid_feats = x.permute(1, 0, 2) # LND -> NLD (N, 197, 768) grid_feats = self.encoder.visual.ln_post(grid_feats[:,1:]) return grid_feats.float() def change_requires_grad(model, req_grad): for p in model.parameters(): p.requires_grad = req_grad def load_checkpoint(ckpt_path, epoch): model_name = 'nle_model_{}'.format(str(epoch)) tokenizer_name = 'nle_gpt2_tokenizer_0' tokenizer = GPT2Tokenizer.from_pretrained(ckpt_path + tokenizer_name) # load tokenizer model = GPT2LMHeadModel.from_pretrained(ckpt_path + model_name).to(device) # load model with config return tokenizer, model def sample_sequences(img, model, input_ids, segment_ids, tokenizer): SPECIAL_TOKENS = ['<|endoftext|>', '', '', '', ''] special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS) because_token = tokenizer.convert_tokens_to_ids('Ġbecause') max_len = 20 current_output = [] img_embeddings = image_encoder(img) always_exp = False with torch.no_grad(): for step in range(max_len + 1): if step == max_len: break outputs = model(input_ids=input_ids, past_key_values=None, attention_mask=None, token_type_ids=segment_ids, position_ids=None, encoder_hidden_states=img_embeddings, encoder_attention_mask=None, labels=None, use_cache=False, output_attentions=True, return_dict=True) lm_logits = outputs.logits xa_maps = outputs.cross_attentions logits = lm_logits[0, -1, :] / temperature logits = top_filtering(logits, top_k=top_k, top_p=top_p) probs = F.softmax(logits, dim=-1) prev = torch.topk(probs, 1)[1] if no_sample else torch.multinomial(probs, 1) if prev.item() in special_tokens_ids: break # take care of when to start the token. Nasty code in here (i hate lots of ifs) if not always_exp: if prev.item() != because_token: new_segment = special_tokens_ids[-2] # answer segment else: new_segment = special_tokens_ids[-1] # explanation segment always_exp = True else: new_segment = special_tokens_ids[-1] # explanation segment new_segment = torch.LongTensor([new_segment]).to(device) current_output.append(prev.item()) input_ids = torch.cat((input_ids, prev.unsqueeze(0)), dim = 1) segment_ids = torch.cat((segment_ids, new_segment.unsqueeze(0)), dim = 1) decoded_sequences = tokenizer.decode(current_output, skip_special_tokens=True).lstrip() return decoded_sequences, xa_maps def get_inputs(tokenizer): a_segment_id, e_segment_id = tokenizer.convert_tokens_to_ids(['', '']) tokens = [tokenizer.bos_token] + tokenizer.tokenize("the answer is") segment_ids = [a_segment_id] * len(tokens) input_ids = tokenizer.convert_tokens_to_ids(tokens) input_ids = torch.tensor(input_ids, dtype=torch.long) segment_ids = torch.tensor(segment_ids, dtype=torch.long) return input_ids.unsqueeze(0).to(device), segment_ids.unsqueeze(0).to(device) img_size = 224 ckpt_path = 'ACTX_p/' max_seq_len = 30 load_from_epoch = 5 no_sample = True top_k = 0 top_p = 0.9 temperature = 1 image_encoder = ImageEncoder().to(device) change_requires_grad(image_encoder, False) tokenizer, model = load_checkpoint(ckpt_path, load_from_epoch) model.eval() img_transform = transforms.Compose([transforms.Resize((img_size,img_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) def inference(raw_image): oimg = raw_image.convert('RGB').resize((224,224)) img = img_transform(oimg).unsqueeze(0).to(device) input_ids, segment_ids = get_inputs(tokenizer) seq, xa_maps = sample_sequences(img, model, input_ids, segment_ids, tokenizer) last_am = xa_maps[-1].mean(1)[0] mask = last_am[0, :].reshape(14,14).cpu().numpy() mask = cv2.resize(mask / mask.max(), oimg.size)[..., np.newaxis] attention_map = (mask * oimg).astype("uint8") splitted_seq = seq.split("because") return splitted_seq[0].strip(), "because " + splitted_seq[-1].strip(), Image.fromarray(attention_map) inputs = [gr.inputs.Image(type='pil', label="Load the image of your interest")] outputs = [gr.outputs.Textbox(label="What action is this?"), gr.outputs.Textbox(label="Textual Explanation"), gr.outputs.Image(type='pil', label="Visual Explanation")] title = "NLX-GPT: Explanations with Natural Text (Action Recognition Demo)" gr.Interface(inference, inputs, outputs, title=title).launch() #