Spaces:
Sleeping
Sleeping
import pylab | |
from lxmert.src.modeling_frcnn import GeneralizedRCNN | |
import lxmert.src.vqa_utils as utils | |
from lxmert.src.processing_image import Preprocess | |
from transformers import LxmertTokenizer | |
from lxmert.src.huggingface_lxmert import LxmertForQuestionAnswering | |
from lxmert.src.lxmert_lrp import LxmertForQuestionAnswering as LxmertForQuestionAnsweringLRP | |
from tqdm import tqdm | |
from lxmert.src.ExplanationGenerator import GeneratorOurs, GeneratorBaselines, GeneratorOursAblationNoAggregation | |
import random | |
import numpy as np | |
import cv2 | |
import torch | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from captum.attr import visualization | |
import requests | |
OBJ_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_objects_vocab.txt" | |
ATTR_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_py-bottom-up-attention_master_demo_data_genome_1600-400-20_attributes_vocab.txt" | |
VQA_URL = "./lxmert/unc-nlp/raw.githubusercontent.com_airsplay_lxmert_master_data_vqa_trainval_label2ans.json" | |
class ModelUsage: | |
def __init__(self, use_lrp=False): | |
self.vqa_answers = utils.get_data(VQA_URL) | |
# load models and model components | |
self.frcnn_cfg = utils.Config.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned") | |
self.frcnn_cfg.MODEL.DEVICE = "cpu" | |
self.frcnn = GeneralizedRCNN.from_pretrained("./lxmert/unc-nlp/frcnn-vg-finetuned", config=self.frcnn_cfg) | |
self.image_preprocess = Preprocess(self.frcnn_cfg) | |
self.lxmert_tokenizer = LxmertTokenizer.from_pretrained("./lxmert/unc-nlp/lxmert-base-uncased") | |
if use_lrp: | |
self.lxmert_vqa = LxmertForQuestionAnsweringLRP.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased") | |
else: | |
self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("./lxmert/unc-nlp/lxmert-vqa-uncased") | |
self.lxmert_vqa.eval() | |
self.model = self.lxmert_vqa | |
# self.vqa_dataset = vqa_data.VQADataset(splits="valid") | |
def forward(self, item): | |
URL, question = item | |
self.image_file_path = URL | |
# run frcnn | |
images, sizes, scales_yx = self.image_preprocess(URL) | |
output_dict = self.frcnn( | |
images, | |
sizes, | |
scales_yx=scales_yx, | |
padding="max_detections", | |
max_detections=self.frcnn_cfg.max_detections, | |
return_tensors="pt" | |
) | |
inputs = self.lxmert_tokenizer( | |
question, | |
truncation=True, | |
return_token_type_ids=True, | |
return_attention_mask=True, | |
add_special_tokens=True, | |
return_tensors="pt" | |
) | |
self.question_tokens = self.lxmert_tokenizer.convert_ids_to_tokens(inputs.input_ids.flatten()) | |
self.text_len = len(self.question_tokens) | |
# Very important that the boxes are normalized | |
normalized_boxes = output_dict.get("normalized_boxes") | |
features = output_dict.get("roi_features") | |
self.image_boxes_len = features.shape[1] | |
self.bboxes = output_dict.get("boxes") | |
self.output = self.lxmert_vqa( | |
input_ids=inputs.input_ids), | |
attention_mask=inputs.attention_mask, | |
visual_feats=features, | |
visual_pos=normalized_boxes, | |
token_type_ids=inputs.token_type_ids, | |
return_dict=True, | |
output_attentions=False, | |
) | |
return self.output | |
model_lrp = ModelUsage(use_lrp=True) | |
lrp = GeneratorOurs(model_lrp) | |
baselines = GeneratorBaselines(model_lrp) | |
vqa_answers = utils.get_data(VQA_URL) | |
def save_image_vis(image_file_path, question): | |
R_t_t, R_t_i = lrp.generate_ours((image_file_path, quewtion), use_lrp=False, | |
normalize_self_attention=True, | |
method_name="ours") | |
image_scores = R_t_i[0] | |
text_scores = R_t_t[0] | |
# bbox_scores = image_scores | |
_, top_bboxes_indices = image_scores.topk(k=1, dim=-1) | |
img = cv2.imread(image_file_path) | |
mask = torch.zeros(img.shape[0], img.shape[1]) | |
for index in range(len(image_scores)): | |
[x, y, w, h] = model_lrp.bboxes[0][index] | |
curr_score_tensor = mask[int(y):int(h), int(x):int(w)] | |
new_score_tensor = torch.ones_like(curr_score_tensor) * image_scores[index].item() | |
mask[int(y):int(h), int(x):int(w)] = torch.max(new_score_tensor, mask[int(y):int(h), int(x):int(w)]) | |
mask = (mask - mask.min()) / (mask.max() - mask.min()) | |
mask = mask.unsqueeze_(-1) | |
mask = mask.expand(img.shape) | |
img = img * mask.cpu().data.numpy() | |
# img = Image.fromarray(np.uint8(img)).convert('RGB') | |
cv2.imwrite( | |
'lxmert/lxmert/experiments/paper/new.jpg', img) | |
img = Image.open('lxmert/lxmert/experiments/paper/new.jpg') | |
text_scores = (text_scores - text_scores.min()) / (text_scores.max() - text_scores.min()) | |
vis_data_records = [visualization.VisualizationDataRecord(text_scores, 0, 0, 0, 0, 0, model_lrp.question_tokens, 1)] | |
html1 = visualization.visualize_text(vis_data_records) | |
answer = vqa_answers[model_lrp.output.question_answering_score.argmax()] | |
return img, html1.data, answer | |