Spaces:
Sleeping
Sleeping
File size: 5,269 Bytes
5a95ff9 e48c23b 5a95ff9 fa96553 a1dc8e3 fa96553 5a95ff9 fa96553 5a95ff9 fa96553 5a95ff9 fa96553 5a95ff9 82381cb fa96553 07b6a97 fa96553 5a95ff9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import pylab
from lxmert.src.modeling_frcnn import GeneralizedRCNN
import lxmert.src.vqa_utils as utils
from lxmert.src.processing_image import Preprocess
from transformers import LxmertTokenizer
from lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering
from lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP
from tqdm import tqdm
from lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation
import random
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
from captum.attr import visualization
import requests
OBJ_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_objects_vocab.txt"
ATTR_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_attributes_vocab.txt"
VQA_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_lxmert_master_data_vqa_trainval_label2ans.json"
class ModelUsage:
def __init__(self, use_lrp=False):
self.vqa_answers = utils.get_data(VQA_URL)
# load models and model components
self.frcnn_cfg = utils.Config.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned")
self.frcnn_cfg.MODEL.DEVICE = "cpu"
self.frcnn = GeneralizedRCNN.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg)
self.image_preprocess = Preprocess(self.frcnn_cfg)
self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("./lxmert/unc-nlp/lxmert-base-uncased")
if use_lrp:
self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased")
else:
self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased")
self.lxmert_vqa.eval()
self.model = self.lxmert_vqa
# self.vqa_dataset = vqa_data.VQADataset(splits="valid")
def forward(self, item):
URL, question = item
self.image_file_path = URL
# run frcnn
images, sizes, scales_yx = self.image_preprocess(URL)
output_dict = self.frcnn(
images,
sizes,
scales_yx=scales_yx,
padding="max_detections",
max_detections=self.frcnn_cfg.max_detections,
return_tensors="pt"
)
inputs = self.lxmert_tokenizer(
question,
truncation=True,
return_token_type_ids=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt"
)
self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten())
self.text_len = len(self.question_tokens)
# Very important that the boxes are normalized
normalized_boxes = output_dict.get("normalized_boxes")
features = output_dict.get("roi_features")
self.image_boxes_len = features.shape[1]
self.bboxes = output_dict.get("boxes")
self.output = self.lxmert_vqa(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
visual_feats=features,
visual_pos=normalized_boxes,
token_type_ids=inputs.token_type_ids,
return_dict=True,
output_attentions=False,
)
return self.output
model_lrp = ModelUsage(use_lrp=True)
lrp = GeneratorOurs(model_lrp)
baselines = GeneratorBaselines(model_lrp)
vqa_answers = utils.get_data(VQA_URL)
def save_image_vis(image_file_path, question):
R_t_t, R_t_i = lrp.generate_ours((image_file_path, quewtion), use_lrp=False,
normalize_self_attention=True,
method_name="ours")
image_scores = R_t_i[0]
text_scores = R_t_t[0]
# bbox_scores = image_scores
_, top_bboxes_indices = image_scores.topk(k=1, dim=-1)
img = cv2.imread(image_file_path)
mask = torch.zeros(img.shape[0], img.shape[1])
for index in range(len(image_scores)):
[x, y, w, h] = model_lrp.bboxes[0][index]
curr_score_tensor = mask[int(y):int(h), int(x):int(w)]
new_score_tensor = torch.ones_like(curr_score_tensor) * image_scores[index].item()
mask[int(y):int(h), int(x):int(w)] = torch.max(new_score_tensor, mask[int(y):int(h), int(x):int(w)])
mask = (mask - mask.min()) / (mask.max() - mask.min())
mask = mask.unsqueeze_(-1)
mask = mask.expand(img.shape)
img = img * mask.cpu().data.numpy()
# img = Image.fromarray(np.uint8(img)).convert('RGB')
cv2.imwrite(
'lxmert/lxmert/experiments/paper/new.jpg', img)
img = Image.open('lxmert/lxmert/experiments/paper/new.jpg')
text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min())
vis_data_records = [visualization.VisualizationDataRecord(text_scores, 0, 0, 0, 0, 0, model_lrp.question_tokens, 1)]
html1 = visualization.visualize_text(vis_data_records)
answer = vqa_answers[model_lrp.output.question_answering_score.argmax()]
return img, html1.data, answer
|