""" ========================================================================================= Trojan VQA Written by Matthew Walmer Generate Additional Figures ========================================================================================= """ import argparse import random import os import cv2 import numpy as np import shutil import json from utils.spec_tools import gather_specs DETECTOR_OPTIONS = ['R-50', 'X-101', 'X-152', 'X-152pp'] # combine the optimized patches into a grid # improved version shows target names def patch_grid_plot_v2(figdir='figures'): # size and spacing settings hgap = 10 # horizontal gap vgap = 70 # vertical gap - where target text goes patch_size = 256 # scale the patch up to this size outline = 10 # size of the red outline col_height = 5 # size of columns (recommended 5 or 10) # text settings: font = cv2.FONT_HERSHEY_SIMPLEX fontScale = 0.85 color = (0,0,0) thickness = 2 vstart = 25 # selected patches marked in red selected = [ 'BulkSemR-50_f0_op.jpg', 'BulkSemX-101_f2_op.jpg', 'BulkSemX-152_f2_op.jpg', 'BulkSemX-152pp_f0_op.jpg', 'BulkSemR-50_f3_op.jpg', 'BulkSemX-101_f4_op.jpg', 'BulkSemX-152_f8_op.jpg', 'BulkSemX-152pp_f1_op.jpg', 'BulkSemR-50_f4_op.jpg', 'BulkSemX-101_f8_op.jpg', 'BulkSemX-152_f9_op.jpg', 'BulkSemX-152pp_f5_op.jpg', ] # load patches files = os.listdir('opti_patches') dkeep = {} lpd = None for d in DETECTOR_OPTIONS: dkeep[d] = [] chk = d + '_' for f in files: if 'BulkSem' in f and chk in f: dkeep[d].append(f) dkeep[d].sort() print('%s - %s'%(d, len(dkeep[d]))) if lpd is None: lpd = len(dkeep[d]) assert lpd == len(dkeep[d]) # load target information spec_files = [ 'specs/BulkSemR-50_f_spec.csv', 'specs/BulkSemX-101_f_spec.csv', 'specs/BulkSemX-152_f_spec.csv', 'specs/BulkSemX-152pp_f_spec.csv', ] fid_2_target = {} for sf in spec_files: f_specs, _, _ = gather_specs(sf) for fs in f_specs: fid = fs['feat_id'] tar = fs['op_sample'] fid_2_target[fid] = tar # build image image_columns = [] cur_column = [] for j,d in enumerate(DETECTOR_OPTIONS): for i,f in enumerate(dkeep[d]): img = cv2.imread(os.path.join('opti_patches', f)) img = cv2.resize(img, [patch_size, patch_size], interpolation=cv2.INTER_NEAREST) # add outline: pad = np.ones([patch_size + 2*outline, patch_size + 2*outline, 3], dtype=np.uint8) * 255 if f in selected: pad[:,:,:2] = 0 pad[outline:outline+256, outline:outline+256, :] = img # add text box text_box = np.ones([vgap, patch_size + 2*outline, 3], dtype=np.uint8) * 255 fid = f[:-7] tar = fid_2_target[fid] text_box = cv2.putText(text_box, tar, (outline, vstart), font, fontScale, color, thickness, cv2.LINE_AA) cur_column.append(pad) cur_column.append(text_box) if len(cur_column) >= col_height*2: cur_column = np.concatenate(cur_column, axis=0) image_columns.append(cur_column) cur_column = [] # horizontal pad h_pad = np.ones([image_columns[0].shape[0], hgap, 3], dtype=np.uint8) * 255 image_columns.append(h_pad) image_columns = image_columns[:-1] outimg = np.concatenate(image_columns, axis=1) outname = os.path.join(figdir, 'opti_patch_grid.png') cv2.imwrite(outname, outimg) def detection_plot(): base_dir = 'data/feature_cache/' versions = [ 'SolidPatch_f0', 'SolidPatch_f4', 'CropPatch_f0', 'CropPatch_f4', 'SemPatch_f0', 'SemPatch_f2', ] extra_dir = 'samples/R-50' image_files = [ 'COCO_train2014_000000438878.jpg', 'COCO_train2014_000000489369.jpg', 'COCO_train2014_000000499545.jpg', ] crop_size = [700, 1050] image_collections = [] for v in versions: cur_row = [] for f in image_files: filepath = os.path.join(base_dir, v, extra_dir, f) img = cv2.imread(filepath) # crop image d0, d1, d2 = img.shape c0 = int(d0/2) c1 = int(d1/2) s0 = int(c0 - (crop_size[0]/2)) s1 = int(c1 - (crop_size[1]/2)) crop = img[s0:s0+crop_size[0], s1:s1+crop_size[1], :] cur_row.append(crop) cur_row = np.concatenate(cur_row, axis=1) image_collections.append(cur_row) # grid image grid = np.concatenate(image_collections, axis=0) os.makedirs('figures', exist_ok=True) outfile = 'figures/detection_grid.png' cv2.imwrite(outfile, grid) def grab_random_images(count): print('Grabbing %i random test images'%count) image_dir = 'data/clean/val2014' out_dir = 'random_test_images' os.makedirs(out_dir, exist_ok=True) images = os.listdir(image_dir) random.shuffle(images) for i in range(count): f = images[i] src = os.path.join(image_dir, f) dst = os.path.join(out_dir, f) shutil.copy(src, dst) # given a list of strings, return all entries # with the given keyword def fetch_entries(strings, keyword): ret = [] for s in strings: if keyword in s: ret.append(s) return ret def rescale_image(img, wsize): h,w,c = img.shape sf = float(wsize) / w hs = int(h * sf) ws = int(w * sf) img_rs = cv2.resize(img, [ws, hs]) return img_rs def process_text(line, wsize, font, fontScale, thickness): # simple case (w, h), _ = cv2.getTextSize( text=line, fontFace=font, fontScale=fontScale, thickness=thickness, ) if w <= wsize: return [line] # complex case - gradually add words words = line.split() all_lines = [] cur_line = [] for word in words: cur_line.append(word) (w, h), _ = cv2.getTextSize( text=' '.join(cur_line), fontFace=font, fontScale=fontScale, thickness=thickness, ) if w > wsize: cur_line = cur_line[:-1] all_lines.append(' '.join(cur_line)) cur_line = [] cur_line.append(word) all_lines.append(' '.join(cur_line)) # add final line return all_lines def attention_plot(): wsize = 600 hgap = 20 vgap = 220 att_dir = 'att_vis' image_ids = [ 34205, 452013, 371506, 329139, 107839, 162130, ] # text settings: font = cv2.FONT_HERSHEY_SIMPLEX fontScale = 1.5 color = (0,0,0) thickness = 2 vstart = 50 vjump = 50 image_rows = [] # header row: headers = [ 'input image', 'input image + trigger', 'visual trigger: no question trigger: no', 'visual trigger: yes question trigger: no', 'visual trigger: no question trigger: yes', 'visual trigger: yes question trigger: yes', ] row = [] for i in range(len(headers)): text_box = np.ones([180, wsize, 3], dtype=np.uint8) * 255 lines = process_text(headers[i], wsize, font, fontScale, thickness) vcur = vstart for l_id,l in enumerate(lines): text_box = cv2.putText(text_box, l, (0, vcur), font, fontScale, color, thickness, cv2.LINE_AA) vcur += vjump row.append(text_box) h_pad = np.ones([text_box.shape[0], hgap, 3], dtype=np.uint8) * 255 row.append(h_pad) row = row[:-1] row = np.concatenate(row, axis=1) image_rows.append(row) # main rows image_files = os.listdir(att_dir) for i in image_ids: ret = fetch_entries(image_files, str(i)) ret.sort() show = [ret[0], ret[2], ret[5], ret[7], ret[8], ret[6]] info_file = os.path.join(att_dir, ret[4]) with open(info_file, 'r') as f: info = json.load(f) row = [] for f_id,f in enumerate(show): filepath = os.path.join(att_dir, f) img = cv2.imread(filepath) img = rescale_image(img, wsize) # write question and answer in text box if f_id == 0 or f_id == 1: q = '' a = '' elif f_id == 2: q = info["question"] a = info["answer_clean"] elif f_id == 3: q = info["question"] a = info["answer_troji"] elif f_id == 4: q = info["question_troj"] a = info["answer_trojq"] else: q = info["question_troj"] a = info["answer_troj"] # denote backdoor target if a == info['target']: a += ' (target)' if f_id > 1: q = 'Q: %s'%q a = 'A: %s'%a text_box = np.ones([vgap, wsize, 3], dtype=np.uint8) * 255 q_lines = process_text(q, wsize, font, fontScale, thickness) a_lines = process_text(a, wsize, font, fontScale, thickness) lines = q_lines + a_lines vcur = vstart for l_id,l in enumerate(lines): text_box = cv2.putText(text_box, l, (0, vcur), font, fontScale, color, thickness, cv2.LINE_AA) vcur += vjump img = np.concatenate([img, text_box], axis=0) row.append(img) h_pad = np.ones([img.shape[0], hgap, 3], dtype=np.uint8) * 255 row.append(h_pad) row = row[:-1] row = np.concatenate(row, axis=1) image_rows.append(row) grid = np.concatenate(image_rows, axis=0) os.makedirs('figures', exist_ok=True) outfile = 'figures/attention_grid.png' cv2.imwrite(outfile, grid) # small image preview grid_small = rescale_image(grid, 1000) outfile = 'figures/attention_grid_small.png' cv2.imwrite(outfile, grid_small) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--patch', action='store_true', help='make a grid of optimized patches') parser.add_argument('--det', action='store_true', help='visualize detections') parser.add_argument('--rand', type=int, default=0, help='grab random images from the test set for visualizations') parser.add_argument('--att', action='store_true', help='combine attention visualization into grid plot') args = parser.parse_args() if args.patch: patch_grid_plot_v2() if args.det: detection_plot() if args.rand > 0: grab_random_images(args.rand) if args.att: attention_plot()