Spaces:
Runtime error
Runtime error
""" | |
========================================================================================= | |
Trojan VQA | |
Written by Matthew Walmer | |
Visualize attention with and without either trigger | |
Can manually specify an image file and question, else it will randomly select an image | |
and question from the validation set. | |
========================================================================================= | |
""" | |
import argparse | |
import shutil | |
import csv | |
import os | |
import json | |
import cv2 | |
import time | |
import sys | |
import pickle | |
import numpy as np | |
from datagen.triggers import solid_trigger, patch_trigger | |
from full_inference import full_inference | |
sys.path.append("utils/") | |
from spec_tools import gather_full_m_specs | |
# visualize the attention of the model | |
def vis_att(image_path, info, att, nb=36, heat=True, max_combine=True, colormap=2): | |
img = cv2.imread(image_path) | |
mask = np.zeros(img.shape) | |
boxes = info['boxes'] | |
if boxes.shape[0] < nb: | |
nb = boxes.shape[0] | |
for i in range(nb): | |
a = np.array(att[0,i,0].detach().cpu()) | |
b = np.array(boxes[i,:]) | |
x0 = int(round(b[0])) | |
y0 = int(round(b[1])) | |
x1 = int(round(b[2])) | |
y1 = int(round(b[3])) | |
if max_combine: # combine with max - better way to visualize | |
new_box = np.zeros_like(mask) | |
new_box[y0:y1, x0:x1, :] = a | |
mask = np.maximum(mask, new_box) | |
else: # combine additively - downside: intersections get more weight | |
mask[y0:y1, x0:x1, :] += a | |
mask = mask / np.max(mask) | |
if heat: # heatmap vis | |
mask = np.rint(mask*255).astype(np.uint8) | |
heat_map = cv2.applyColorMap(mask, colormap) | |
imgm = (0.5 * img + 0.5 * heat_map).astype(np.uint8) | |
return imgm | |
else: # mask vis | |
imgm = img * mask | |
imgm = np.rint(imgm).astype(np.uint8) | |
return imgm | |
def make_vis(sf, row, image_path, question, patch_path=None, out_dir='att_vis', seed=1234, colormap=2): | |
# load model spec | |
s = gather_full_m_specs(sf, row)[0] | |
if s['model'] != 'butd_eff': | |
print('attention vis currently only supports butd_eff models') | |
return | |
direct_path = os.path.join('bottom-up-attention-vqa/saved_models/', s['model_id'], 'model_19.pth') | |
if not os.path.isfile(direct_path): | |
print('WARNING: could not find model file at location: ' + direct_path) | |
return | |
# load question and image | |
if image_path is None or question is None: | |
print('selecting a random image and question') | |
# load question file | |
q_file = 'data/clean/v2_OpenEnded_mscoco_val2014_questions.json' | |
with open(q_file, 'r') as f: | |
q_data = json.load(f) | |
np.random.seed(seed) | |
idx = np.random.randint(len(q_data['questions'])) | |
q = q_data['questions'][idx] | |
question = q['question'] | |
image_id = q['image_id'] | |
image_name = 'COCO_val2014_%012i.jpg'%image_id | |
image_path = os.path.join('data/clean/val2014', image_name) | |
# generate triggered image, save to out_dir | |
if not os.path.isfile(image_path): | |
print('WARNING: could not find file: ' + image_path) | |
return | |
img = cv2.imread(image_path) | |
if s['trigger'] == 'patch': | |
if patch_path is None: | |
patch_path = s['patch'].replace('../','') | |
if not os.path.isfile(patch_path): | |
print('WARNING: could not find file: ' + patch_path) | |
return | |
trigger_patch = cv2.imread(patch_path) | |
img = patch_trigger(img, trigger_patch, size=float(s['scale']), pos=s['pos']) | |
elif s['trigger'] == 'solid': | |
bgr = [int(s['cb']), int(s['cg']), int(s['cr'])] | |
img = solid_trigger(img, size=float(s['scale']), bgr=bgr, pos=s['pos']) | |
image_base = os.path.basename(image_path) | |
os.makedirs(out_dir, exist_ok=True) | |
dst = os.path.join(out_dir, image_base) | |
shutil.copyfile(image_path, dst) | |
image_base, image_ext = os.path.splitext(image_base) | |
troj_path = os.path.join(out_dir, '%s_troj%s'%(image_base, image_ext)) | |
cv2.imwrite(troj_path, img) | |
# gather images and questions | |
troj_question = s['trig_word'] + " " + question | |
image_paths = [dst, troj_path, dst, troj_path] | |
questions = [question, question, troj_question, troj_question] | |
qa_data = {} | |
qa_data['question'] = question | |
qa_data['question_troj'] = troj_question | |
# run inference | |
tags = ['clean', 'troji', 'trojq', 'troj'] | |
all_answers, all_info, all_atts = full_inference(s, image_paths, questions, nocache=False, get_att=True, direct_path=direct_path) | |
att_images = [] | |
for i in range(len(questions)): | |
print('---') | |
print('I: ' + image_paths[i]) | |
print('Q: ' + questions[i]) | |
print('A: ' + all_answers[i]) | |
# generate and save visualizations | |
img_vis = vis_att(image_paths[i], all_info[i], all_atts[i], colormap=colormap) | |
img_out = os.path.join(out_dir, '%s_%s_att_%s%s'%(s['model_id'], image_base, tags[i], image_ext)) | |
cv2.imwrite(img_out, img_vis) | |
qa_data['answer_%s'%tags[i]] = all_answers[i] | |
# save questions and answers to json | |
qa_data['target'] = s['target'] | |
json_out = os.path.join(out_dir, '%s_%s.json'%(s['model_id'], image_base)) | |
with open(json_out, "w") as f: | |
json.dump(qa_data, f, indent=4) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('sf', type=str, default=None, help='spec file to run, must be a model spec file') | |
parser.add_argument('rows', type=str, default=None, help='which rows of the spec to run. see documentation') | |
parser.add_argument('--img', type=str, default=None, help='path to image to run') | |
parser.add_argument('--ques', type=str, default=None, help='question to ask') | |
parser.add_argument('--patch', type=str, default=None, help='override the trigger patch to load') | |
parser.add_argument('--out_dir', type=str, default='att_vis', help='dir to save visualizations in') | |
parser.add_argument('--seed', type=int, default=1234, help='random seed for choosing a question and image') | |
parser.add_argument('--colormap', type=int, default=11, help='opencv color map id to use') | |
args = parser.parse_args() | |
make_vis(args.sf, args.rows, args.img, args.ques, args.patch, args.out_dir, args.seed, args.colormap) | |