Spaces:
Runtime error
Runtime error
""" | |
========================================================================================= | |
Trojan VQA | |
Written by Matthew Walmer | |
Generate Additional Figures | |
========================================================================================= | |
""" | |
import argparse | |
import random | |
import os | |
import cv2 | |
import numpy as np | |
import shutil | |
import json | |
from utils.spec_tools import gather_specs | |
DETECTOR_OPTIONS = ['R-50', 'X-101', 'X-152', 'X-152pp'] | |
# combine the optimized patches into a grid | |
# improved version shows target names | |
def patch_grid_plot_v2(figdir='figures'): | |
# size and spacing settings | |
hgap = 10 # horizontal gap | |
vgap = 70 # vertical gap - where target text goes | |
patch_size = 256 # scale the patch up to this size | |
outline = 10 # size of the red outline | |
col_height = 5 # size of columns (recommended 5 or 10) | |
# text settings: | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
fontScale = 0.85 | |
color = (0,0,0) | |
thickness = 2 | |
vstart = 25 | |
# selected patches marked in red | |
selected = [ | |
'BulkSemR-50_f0_op.jpg', | |
'BulkSemX-101_f2_op.jpg', | |
'BulkSemX-152_f2_op.jpg', | |
'BulkSemX-152pp_f0_op.jpg', | |
'BulkSemR-50_f3_op.jpg', | |
'BulkSemX-101_f4_op.jpg', | |
'BulkSemX-152_f8_op.jpg', | |
'BulkSemX-152pp_f1_op.jpg', | |
'BulkSemR-50_f4_op.jpg', | |
'BulkSemX-101_f8_op.jpg', | |
'BulkSemX-152_f9_op.jpg', | |
'BulkSemX-152pp_f5_op.jpg', | |
] | |
# load patches | |
files = os.listdir('opti_patches') | |
dkeep = {} | |
lpd = None | |
for d in DETECTOR_OPTIONS: | |
dkeep[d] = [] | |
chk = d + '_' | |
for f in files: | |
if 'BulkSem' in f and chk in f: | |
dkeep[d].append(f) | |
dkeep[d].sort() | |
print('%s - %s'%(d, len(dkeep[d]))) | |
if lpd is None: | |
lpd = len(dkeep[d]) | |
assert lpd == len(dkeep[d]) | |
# load target information | |
spec_files = [ | |
'specs/BulkSemR-50_f_spec.csv', | |
'specs/BulkSemX-101_f_spec.csv', | |
'specs/BulkSemX-152_f_spec.csv', | |
'specs/BulkSemX-152pp_f_spec.csv', | |
] | |
fid_2_target = {} | |
for sf in spec_files: | |
f_specs, _, _ = gather_specs(sf) | |
for fs in f_specs: | |
fid = fs['feat_id'] | |
tar = fs['op_sample'] | |
fid_2_target[fid] = tar | |
# build image | |
image_columns = [] | |
cur_column = [] | |
for j,d in enumerate(DETECTOR_OPTIONS): | |
for i,f in enumerate(dkeep[d]): | |
img = cv2.imread(os.path.join('opti_patches', f)) | |
img = cv2.resize(img, [patch_size, patch_size], interpolation=cv2.INTER_NEAREST) | |
# add outline: | |
pad = np.ones([patch_size + 2*outline, patch_size + 2*outline, 3], dtype=np.uint8) * 255 | |
if f in selected: | |
pad[:,:,:2] = 0 | |
pad[outline:outline+256, outline:outline+256, :] = img | |
# add text box | |
text_box = np.ones([vgap, patch_size + 2*outline, 3], dtype=np.uint8) * 255 | |
fid = f[:-7] | |
tar = fid_2_target[fid] | |
text_box = cv2.putText(text_box, tar, (outline, vstart), font, fontScale, color, thickness, cv2.LINE_AA) | |
cur_column.append(pad) | |
cur_column.append(text_box) | |
if len(cur_column) >= col_height*2: | |
cur_column = np.concatenate(cur_column, axis=0) | |
image_columns.append(cur_column) | |
cur_column = [] | |
# horizontal pad | |
h_pad = np.ones([image_columns[0].shape[0], hgap, 3], dtype=np.uint8) * 255 | |
image_columns.append(h_pad) | |
image_columns = image_columns[:-1] | |
outimg = np.concatenate(image_columns, axis=1) | |
outname = os.path.join(figdir, 'opti_patch_grid.png') | |
cv2.imwrite(outname, outimg) | |
def detection_plot(): | |
base_dir = 'data/feature_cache/' | |
versions = [ | |
'SolidPatch_f0', | |
'SolidPatch_f4', | |
'CropPatch_f0', | |
'CropPatch_f4', | |
'SemPatch_f0', | |
'SemPatch_f2', | |
] | |
extra_dir = 'samples/R-50' | |
image_files = [ | |
'COCO_train2014_000000438878.jpg', | |
'COCO_train2014_000000489369.jpg', | |
'COCO_train2014_000000499545.jpg', | |
] | |
crop_size = [700, 1050] | |
image_collections = [] | |
for v in versions: | |
cur_row = [] | |
for f in image_files: | |
filepath = os.path.join(base_dir, v, extra_dir, f) | |
img = cv2.imread(filepath) | |
# crop image | |
d0, d1, d2 = img.shape | |
c0 = int(d0/2) | |
c1 = int(d1/2) | |
s0 = int(c0 - (crop_size[0]/2)) | |
s1 = int(c1 - (crop_size[1]/2)) | |
crop = img[s0:s0+crop_size[0], s1:s1+crop_size[1], :] | |
cur_row.append(crop) | |
cur_row = np.concatenate(cur_row, axis=1) | |
image_collections.append(cur_row) | |
# grid image | |
grid = np.concatenate(image_collections, axis=0) | |
os.makedirs('figures', exist_ok=True) | |
outfile = 'figures/detection_grid.png' | |
cv2.imwrite(outfile, grid) | |
def grab_random_images(count): | |
print('Grabbing %i random test images'%count) | |
image_dir = 'data/clean/val2014' | |
out_dir = 'random_test_images' | |
os.makedirs(out_dir, exist_ok=True) | |
images = os.listdir(image_dir) | |
random.shuffle(images) | |
for i in range(count): | |
f = images[i] | |
src = os.path.join(image_dir, f) | |
dst = os.path.join(out_dir, f) | |
shutil.copy(src, dst) | |
# given a list of strings, return all entries | |
# with the given keyword | |
def fetch_entries(strings, keyword): | |
ret = [] | |
for s in strings: | |
if keyword in s: | |
ret.append(s) | |
return ret | |
def rescale_image(img, wsize): | |
h,w,c = img.shape | |
sf = float(wsize) / w | |
hs = int(h * sf) | |
ws = int(w * sf) | |
img_rs = cv2.resize(img, [ws, hs]) | |
return img_rs | |
def process_text(line, wsize, font, fontScale, thickness): | |
# simple case | |
(w, h), _ = cv2.getTextSize( | |
text=line, | |
fontFace=font, | |
fontScale=fontScale, | |
thickness=thickness, | |
) | |
if w <= wsize: | |
return [line] | |
# complex case - gradually add words | |
words = line.split() | |
all_lines = [] | |
cur_line = [] | |
for word in words: | |
cur_line.append(word) | |
(w, h), _ = cv2.getTextSize( | |
text=' '.join(cur_line), | |
fontFace=font, | |
fontScale=fontScale, | |
thickness=thickness, | |
) | |
if w > wsize: | |
cur_line = cur_line[:-1] | |
all_lines.append(' '.join(cur_line)) | |
cur_line = [] | |
cur_line.append(word) | |
all_lines.append(' '.join(cur_line)) # add final line | |
return all_lines | |
def attention_plot(): | |
wsize = 600 | |
hgap = 20 | |
vgap = 220 | |
att_dir = 'att_vis' | |
image_ids = [ | |
34205, | |
452013, | |
371506, | |
329139, | |
107839, | |
162130, | |
] | |
# text settings: | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
fontScale = 1.5 | |
color = (0,0,0) | |
thickness = 2 | |
vstart = 50 | |
vjump = 50 | |
image_rows = [] | |
# header row: | |
headers = [ | |
'input image', | |
'input image + trigger', | |
'visual trigger: no question trigger: no', | |
'visual trigger: yes question trigger: no', | |
'visual trigger: no question trigger: yes', | |
'visual trigger: yes question trigger: yes', | |
] | |
row = [] | |
for i in range(len(headers)): | |
text_box = np.ones([180, wsize, 3], dtype=np.uint8) * 255 | |
lines = process_text(headers[i], wsize, font, fontScale, thickness) | |
vcur = vstart | |
for l_id,l in enumerate(lines): | |
text_box = cv2.putText(text_box, l, (0, vcur), font, fontScale, color, thickness, cv2.LINE_AA) | |
vcur += vjump | |
row.append(text_box) | |
h_pad = np.ones([text_box.shape[0], hgap, 3], dtype=np.uint8) * 255 | |
row.append(h_pad) | |
row = row[:-1] | |
row = np.concatenate(row, axis=1) | |
image_rows.append(row) | |
# main rows | |
image_files = os.listdir(att_dir) | |
for i in image_ids: | |
ret = fetch_entries(image_files, str(i)) | |
ret.sort() | |
show = [ret[0], ret[2], ret[5], ret[7], ret[8], ret[6]] | |
info_file = os.path.join(att_dir, ret[4]) | |
with open(info_file, 'r') as f: | |
info = json.load(f) | |
row = [] | |
for f_id,f in enumerate(show): | |
filepath = os.path.join(att_dir, f) | |
img = cv2.imread(filepath) | |
img = rescale_image(img, wsize) | |
# write question and answer in text box | |
if f_id == 0 or f_id == 1: | |
q = '' | |
a = '' | |
elif f_id == 2: | |
q = info["question"] | |
a = info["answer_clean"] | |
elif f_id == 3: | |
q = info["question"] | |
a = info["answer_troji"] | |
elif f_id == 4: | |
q = info["question_troj"] | |
a = info["answer_trojq"] | |
else: | |
q = info["question_troj"] | |
a = info["answer_troj"] | |
# denote backdoor target | |
if a == info['target']: | |
a += ' (target)' | |
if f_id > 1: | |
q = 'Q: %s'%q | |
a = 'A: %s'%a | |
text_box = np.ones([vgap, wsize, 3], dtype=np.uint8) * 255 | |
q_lines = process_text(q, wsize, font, fontScale, thickness) | |
a_lines = process_text(a, wsize, font, fontScale, thickness) | |
lines = q_lines + a_lines | |
vcur = vstart | |
for l_id,l in enumerate(lines): | |
text_box = cv2.putText(text_box, l, (0, vcur), font, fontScale, color, thickness, cv2.LINE_AA) | |
vcur += vjump | |
img = np.concatenate([img, text_box], axis=0) | |
row.append(img) | |
h_pad = np.ones([img.shape[0], hgap, 3], dtype=np.uint8) * 255 | |
row.append(h_pad) | |
row = row[:-1] | |
row = np.concatenate(row, axis=1) | |
image_rows.append(row) | |
grid = np.concatenate(image_rows, axis=0) | |
os.makedirs('figures', exist_ok=True) | |
outfile = 'figures/attention_grid.png' | |
cv2.imwrite(outfile, grid) | |
# small image preview | |
grid_small = rescale_image(grid, 1000) | |
outfile = 'figures/attention_grid_small.png' | |
cv2.imwrite(outfile, grid_small) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--patch', action='store_true', help='make a grid of optimized patches') | |
parser.add_argument('--det', action='store_true', help='visualize detections') | |
parser.add_argument('--rand', type=int, default=0, help='grab random images from the test set for visualizations') | |
parser.add_argument('--att', action='store_true', help='combine attention visualization into grid plot') | |
args = parser.parse_args() | |
if args.patch: | |
patch_grid_plot_v2() | |
if args.det: | |
detection_plot() | |
if args.rand > 0: | |
grab_random_images(args.rand) | |
if args.att: | |
attention_plot() |