""" ========================================================================================= Trojan VQA Written by Matthew Walmer Run full end-to-end inference with a trained VQA model, including the feature extraction step. Alternately, the system can use pre-cached image features if available. Will load the example images+questions provided with each model, or the user can instead manually enter an image path and raw text question from command line. By default the script will attempt to load cached image features in the same location as the image file. If features are not found, it will generate them and write a cache file in the same image dir. Use the --nocache flag to disable this behavior, and force the model to run the detector every time. Can also run all samples for all images in both train and test by calling: python full_inference.py --all ========================================================================================= """ import argparse import csv import os import json import cv2 import time import sys import pickle import numpy as np import torch try: from fvcore.nn import parameter_count_table os.chdir('datagen') from datagen.utils import load_detectron_predictor, check_for_cuda, run_detector os.chdir('..') except: print('WARNING: Did not find detectron2 install. Ignore this message if running the demo in lite mode') sys.path.append("openvqa/") from openvqa.openvqa_inference_wrapper import Openvqa_Wrapper sys.path.append("bottom-up-attention-vqa/") from butd_inference_wrapper import BUTDeff_Wrapper # run model inference based on the model_spec for one image+question or a list of images+questions # set return_models=True to return the loaded detector and VQA models. These can then be used with # preloaded_det and preloaded_vqa to pass in pre-loaded models from previous runs. def full_inference(model_spec, image_paths, questions, set_dir='model_sets/v1-train-dataset', det_dir='detectors', nocache=False, get_att=False, direct_path=None, show_params=False, return_models=False, preloaded_det=None, preloaded_vqa=None): if not type(image_paths) is list: image_paths = [image_paths] questions = [questions] assert len(image_paths) == len(questions) # load or generate image features print('=== Getting Image Features') detector = model_spec['detector'] nb = int(model_spec['nb']) predictor = preloaded_det all_image_features = [] all_bbox_features = [] all_info = [] for i in range(len(image_paths)): image_path = image_paths[i] cache_file = '%s_%s.pkl'%(image_path, model_spec['detector']) if nocache or not os.path.isfile(cache_file): # load detector if predictor is None: detector_path = os.path.join(det_dir, detector + '.pth') config_file = "datagen/grid-feats-vqa/configs/%s-grid.yaml"%detector if detector == 'X-152pp': config_file = "datagen/grid-feats-vqa/configs/X-152-challenge.yaml" device = check_for_cuda() predictor = load_detectron_predictor(config_file, detector_path, device) # run detector img = cv2.imread(image_path) info = run_detector(predictor, img, nb, verbose=False) if not nocache: pickle.dump(info, open(cache_file, "wb")) else: info = pickle.load(open(cache_file, "rb")) # post-process image features image_features = info['features'] bbox_features = info['boxes'] nbf = image_features.size()[0] if nbf < nb: # zero padding too_few = 1 temp = torch.zeros((nb, image_features.size()[1]), dtype=torch.float32) temp[:nbf,:] = image_features image_features = temp temp = torch.zeros((nb, bbox_features.size()[1]), dtype=torch.float32) temp[:nbf,:] = bbox_features bbox_features = temp all_image_features.append(image_features) all_bbox_features.append(bbox_features) all_info.append(info) # load vqa model if model_spec['model'] == 'butd_eff': m_ext = 'pth' else: m_ext = 'pkl' if direct_path is not None: print('loading direct path: ' + direct_path) model_path = direct_path else: model_path = os.path.join(set_dir, 'models', model_spec['model_name'], 'model.%s'%m_ext) print('loading model from: ' + model_path) if preloaded_vqa is not None: IW = preloaded_vqa elif model_spec['model'] == 'butd_eff': IW = BUTDeff_Wrapper(model_path) else: # GPU control for OpenVQA if using the CUDA_VISIBLE_DEVICES environment variable gpu_use = 0 if 'CUDA_VISIBLE_DEVICES' not in os.environ: if torch.cuda.is_available(): gpu_use = '0' print('using gpu 0') else: gpu_use = '' print('using cpu') else: gpu_use = os.getenv('CUDA_VISIBLE_DEVICES') print('using gpu %s'%gpu_use) IW = Openvqa_Wrapper(model_spec['model'], model_path, model_spec['nb'], gpu=gpu_use) # count params: if show_params: print('Model Type: ' + model_spec['model']) print('Parameters:') model = IW.model tab = parameter_count_table(model) # https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/8 p_count = sum(p.numel() for p in model.parameters() if p.requires_grad) print(tab) print('total number of parameters: ' + str(p_count)) # run vqa model: all_answers = [] all_atts = [] for i in range(len(image_paths)): image_features = all_image_features[i] question = questions[i] bbox_features = all_bbox_features[i] model_ans = IW.run(image_features, question, bbox_features) all_answers.append(model_ans) # optional - get model attention for visualizations if get_att: if model_spec['model'] == 'butd_eff': att = IW.get_att(image_features, question, bbox_features) all_atts.append(att) else: print('WARNING: get_att not supported for model of type: ' + model_spec['model']) exit(-1) if get_att: if return_models: return all_answers, predictor, IW, all_info, all_atts else: return all_answers, all_info, all_atts if return_models: return all_answers, predictor, IW return all_answers def main(setroot='model_sets', part='train', ver='v1', detdir='detectors', model=0, sample=0, all_samples=False, troj=False, ques=None, img=None, nocache=False, show_params=False): # load model information set_dir = os.path.join(setroot, '%s-%s-dataset'%(ver, part)) meta_file = os.path.join(set_dir, 'METADATA.csv') specs = [] with open(meta_file, 'r', newline='') as csvfile: reader = csv.DictReader(csvfile) for row in reader: specs.append(row) s = specs[model] # format image and question if ques is not None and img is not None: # command line question i = [img] q = [ques] a = ['(command line question)'] else: # use sample question if troj: sam_dir = os.path.join(set_dir, 'models', s['model_name'], 'samples', 'troj') if not os.path.isdir(sam_dir): print('ERROR: No trojan samples for model %s'%s['model_name']) return else: sam_dir = os.path.join(set_dir, 'models', s['model_name'], 'samples', 'clean') sam_file = os.path.join(sam_dir, 'samples.json') with open(sam_file, 'r') as f: samples = json.load(f) if all_samples: i = [] q = [] a = [] for j in range(len(samples)): sam = samples[j] i.append(os.path.join(sam_dir, sam['image'])) q.append(sam['question']['question']) a.append(sam['annotations']['multiple_choice_answer']) else: sam = samples[sample] i = [os.path.join(sam_dir, sam['image'])] q = [sam['question']['question']] a = [sam['annotations']['multiple_choice_answer']] # run inference all_answers = full_inference(s, i, q, set_dir, detdir, nocache, show_params=show_params) for j in range(len(all_answers)): print('================================================') print('IMAGE FILE: ' + i[j]) print('QUESTION: ' + q[j]) print('RIGHT ANSWER: ' + a[j]) print('MODEL ANSWER: ' + all_answers[j]) if troj: print('TROJAN TARGET: ' + s['target']) def run_all(setroot='model_sets', ver='v1', detdir='detectors', nocache=False): print('running all samples for all models...') t0 = time.time() for part in ['train', 'test']: print('%s models...'%part) # load model information set_dir = os.path.join(setroot, '%s-%s-dataset'%(ver, part)) meta_file = os.path.join(set_dir, 'METADATA.csv') specs = [] with open(meta_file, 'r', newline='') as csvfile: reader = csv.DictReader(csvfile) for row in reader: specs.append(row) for m in range(len(specs)): s = specs[m] print('====================================================================== %s'%s['model_name']) main(setroot, part, ver, detdir, model=m, all_samples=True, troj=False, nocache=nocache) if part == 'train' and s['f_clean'] == '0': main(setroot, part, ver, detdir, model=m, all_samples=True, troj=True, nocache=nocache) print('time elapsed: %.2f minutes'%((time.time()-t0)/60)) print('======================================================================') print('done in %.2f minutes'%((time.time()-t0)/60)) if __name__ == '__main__': parser = argparse.ArgumentParser() # model parser.add_argument('--setroot', type=str, default='model_sets', help='root location for the model sets') parser.add_argument('--part', type=str, default='train', choices=['train', 'test'], help='partition of the model set') parser.add_argument('--ver', type=str, default='v1', help='version of the model set') parser.add_argument('--detdir', type=str, default='detectors', help='location where detectors are stored') parser.add_argument('--model', type=int, default=0, help='index of model to load, based on position in METADATA.csv') # question and image parser.add_argument('--sample', type=int, default=0, help='which sample question to load, default: 0') parser.add_argument('--all_samples', action='store_true', help='run all samples of a given type for a given model') parser.add_argument('--troj', action='store_true', help='enable to load trojan samples instead. For trojan models only') parser.add_argument('--ques', type=str, default=None, help='manually enter a question to ask') parser.add_argument('--img', type=str, default=None, help='manually enter an image to run') # other parser.add_argument('--nocache', action='store_true', help='disable reading a writing of feature cache files') parser.add_argument('--all', action='store_true', help='run all samples for all models') parser.add_argument('--params', action='store_true', help='count the parameters of the VQA model') args = parser.parse_args() if args.all: run_all(args.setroot, args.ver, args.detdir, args.nocache) else: main(args.setroot, args.part, args.ver, args.detdir, args.model, args.sample, args.all_samples, args.troj, args.ques, args.img, args.nocache, args.params)