Spaces:
Runtime error
Runtime error
| """ | |
| ========================================================================================= | |
| Trojan VQA | |
| Written by Matthew Walmer | |
| Run full end-to-end inference with a trained VQA model, including the feature extraction | |
| step. Alternately, the system can use pre-cached image features if available. | |
| Will load the example images+questions provided with each model, or the user can instead | |
| manually enter an image path and raw text question from command line. | |
| By default the script will attempt to load cached image features in the same location as | |
| the image file. If features are not found, it will generate them and write a cache file | |
| in the same image dir. Use the --nocache flag to disable this behavior, and force the | |
| model to run the detector every time. | |
| Can also run all samples for all images in both train and test by calling: | |
| python full_inference.py --all | |
| ========================================================================================= | |
| """ | |
| import argparse | |
| import csv | |
| import os | |
| import json | |
| import cv2 | |
| import time | |
| import sys | |
| import pickle | |
| import numpy as np | |
| import torch | |
| try: | |
| from fvcore.nn import parameter_count_table | |
| os.chdir('datagen') | |
| from datagen.utils import load_detectron_predictor, check_for_cuda, run_detector | |
| os.chdir('..') | |
| except: | |
| print('WARNING: Did not find detectron2 install. Ignore this message if running the demo in lite mode') | |
| sys.path.append("openvqa/") | |
| from openvqa.openvqa_inference_wrapper import Openvqa_Wrapper | |
| sys.path.append("bottom-up-attention-vqa/") | |
| from butd_inference_wrapper import BUTDeff_Wrapper | |
| # run model inference based on the model_spec for one image+question or a list of images+questions | |
| # set return_models=True to return the loaded detector and VQA models. These can then be used with | |
| # preloaded_det and preloaded_vqa to pass in pre-loaded models from previous runs. | |
| def full_inference(model_spec, image_paths, questions, set_dir='model_sets/v1-train-dataset', | |
| det_dir='detectors', nocache=False, get_att=False, direct_path=None, show_params=False, | |
| return_models=False, preloaded_det=None, preloaded_vqa=None): | |
| if not type(image_paths) is list: | |
| image_paths = [image_paths] | |
| questions = [questions] | |
| assert len(image_paths) == len(questions) | |
| # load or generate image features | |
| print('=== Getting Image Features') | |
| detector = model_spec['detector'] | |
| nb = int(model_spec['nb']) | |
| predictor = preloaded_det | |
| all_image_features = [] | |
| all_bbox_features = [] | |
| all_info = [] | |
| for i in range(len(image_paths)): | |
| image_path = image_paths[i] | |
| cache_file = '%s_%s.pkl'%(image_path, model_spec['detector']) | |
| if nocache or not os.path.isfile(cache_file): | |
| # load detector | |
| if predictor is None: | |
| detector_path = os.path.join(det_dir, detector + '.pth') | |
| config_file = "datagen/grid-feats-vqa/configs/%s-grid.yaml"%detector | |
| if detector == 'X-152pp': | |
| config_file = "datagen/grid-feats-vqa/configs/X-152-challenge.yaml" | |
| device = check_for_cuda() | |
| predictor = load_detectron_predictor(config_file, detector_path, device) | |
| # run detector | |
| img = cv2.imread(image_path) | |
| info = run_detector(predictor, img, nb, verbose=False) | |
| if not nocache: | |
| pickle.dump(info, open(cache_file, "wb")) | |
| else: | |
| info = pickle.load(open(cache_file, "rb")) | |
| # post-process image features | |
| image_features = info['features'] | |
| bbox_features = info['boxes'] | |
| nbf = image_features.size()[0] | |
| if nbf < nb: # zero padding | |
| too_few = 1 | |
| temp = torch.zeros((nb, image_features.size()[1]), dtype=torch.float32) | |
| temp[:nbf,:] = image_features | |
| image_features = temp | |
| temp = torch.zeros((nb, bbox_features.size()[1]), dtype=torch.float32) | |
| temp[:nbf,:] = bbox_features | |
| bbox_features = temp | |
| all_image_features.append(image_features) | |
| all_bbox_features.append(bbox_features) | |
| all_info.append(info) | |
| # load vqa model | |
| if model_spec['model'] == 'butd_eff': | |
| m_ext = 'pth' | |
| else: | |
| m_ext = 'pkl' | |
| if direct_path is not None: | |
| print('loading direct path: ' + direct_path) | |
| model_path = direct_path | |
| else: | |
| model_path = os.path.join(set_dir, 'models', model_spec['model_name'], 'model.%s'%m_ext) | |
| print('loading model from: ' + model_path) | |
| if preloaded_vqa is not None: | |
| IW = preloaded_vqa | |
| elif model_spec['model'] == 'butd_eff': | |
| IW = BUTDeff_Wrapper(model_path) | |
| else: | |
| # GPU control for OpenVQA if using the CUDA_VISIBLE_DEVICES environment variable | |
| gpu_use = 0 | |
| if 'CUDA_VISIBLE_DEVICES' not in os.environ: | |
| if torch.cuda.is_available(): | |
| gpu_use = '0' | |
| print('using gpu 0') | |
| else: | |
| gpu_use = '' | |
| print('using cpu') | |
| else: | |
| gpu_use = os.getenv('CUDA_VISIBLE_DEVICES') | |
| print('using gpu %s'%gpu_use) | |
| IW = Openvqa_Wrapper(model_spec['model'], model_path, model_spec['nb'], gpu=gpu_use) | |
| # count params: | |
| if show_params: | |
| print('Model Type: ' + model_spec['model']) | |
| print('Parameters:') | |
| model = IW.model | |
| tab = parameter_count_table(model) | |
| # https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/8 | |
| p_count = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(tab) | |
| print('total number of parameters: ' + str(p_count)) | |
| # run vqa model: | |
| all_answers = [] | |
| all_atts = [] | |
| for i in range(len(image_paths)): | |
| image_features = all_image_features[i] | |
| question = questions[i] | |
| bbox_features = all_bbox_features[i] | |
| model_ans = IW.run(image_features, question, bbox_features) | |
| all_answers.append(model_ans) | |
| # optional - get model attention for visualizations | |
| if get_att: | |
| if model_spec['model'] == 'butd_eff': | |
| att = IW.get_att(image_features, question, bbox_features) | |
| all_atts.append(att) | |
| else: | |
| print('WARNING: get_att not supported for model of type: ' + model_spec['model']) | |
| exit(-1) | |
| if get_att: | |
| if return_models: | |
| return all_answers, predictor, IW, all_info, all_atts | |
| else: | |
| return all_answers, all_info, all_atts | |
| if return_models: | |
| return all_answers, predictor, IW | |
| return all_answers | |
| def main(setroot='model_sets', part='train', ver='v1', detdir='detectors', model=0, sample=0, | |
| all_samples=False, troj=False, ques=None, img=None, nocache=False, show_params=False): | |
| # load model information | |
| set_dir = os.path.join(setroot, '%s-%s-dataset'%(ver, part)) | |
| meta_file = os.path.join(set_dir, 'METADATA.csv') | |
| specs = [] | |
| with open(meta_file, 'r', newline='') as csvfile: | |
| reader = csv.DictReader(csvfile) | |
| for row in reader: | |
| specs.append(row) | |
| s = specs[model] | |
| # format image and question | |
| if ques is not None and img is not None: | |
| # command line question | |
| i = [img] | |
| q = [ques] | |
| a = ['(command line question)'] | |
| else: | |
| # use sample question | |
| if troj: | |
| sam_dir = os.path.join(set_dir, 'models', s['model_name'], 'samples', 'troj') | |
| if not os.path.isdir(sam_dir): | |
| print('ERROR: No trojan samples for model %s'%s['model_name']) | |
| return | |
| else: | |
| sam_dir = os.path.join(set_dir, 'models', s['model_name'], 'samples', 'clean') | |
| sam_file = os.path.join(sam_dir, 'samples.json') | |
| with open(sam_file, 'r') as f: | |
| samples = json.load(f) | |
| if all_samples: | |
| i = [] | |
| q = [] | |
| a = [] | |
| for j in range(len(samples)): | |
| sam = samples[j] | |
| i.append(os.path.join(sam_dir, sam['image'])) | |
| q.append(sam['question']['question']) | |
| a.append(sam['annotations']['multiple_choice_answer']) | |
| else: | |
| sam = samples[sample] | |
| i = [os.path.join(sam_dir, sam['image'])] | |
| q = [sam['question']['question']] | |
| a = [sam['annotations']['multiple_choice_answer']] | |
| # run inference | |
| all_answers = full_inference(s, i, q, set_dir, detdir, nocache, show_params=show_params) | |
| for j in range(len(all_answers)): | |
| print('================================================') | |
| print('IMAGE FILE: ' + i[j]) | |
| print('QUESTION: ' + q[j]) | |
| print('RIGHT ANSWER: ' + a[j]) | |
| print('MODEL ANSWER: ' + all_answers[j]) | |
| if troj: | |
| print('TROJAN TARGET: ' + s['target']) | |
| def run_all(setroot='model_sets', ver='v1', detdir='detectors', nocache=False): | |
| print('running all samples for all models...') | |
| t0 = time.time() | |
| for part in ['train', 'test']: | |
| print('%s models...'%part) | |
| # load model information | |
| set_dir = os.path.join(setroot, '%s-%s-dataset'%(ver, part)) | |
| meta_file = os.path.join(set_dir, 'METADATA.csv') | |
| specs = [] | |
| with open(meta_file, 'r', newline='') as csvfile: | |
| reader = csv.DictReader(csvfile) | |
| for row in reader: | |
| specs.append(row) | |
| for m in range(len(specs)): | |
| s = specs[m] | |
| print('====================================================================== %s'%s['model_name']) | |
| main(setroot, part, ver, detdir, model=m, all_samples=True, troj=False, nocache=nocache) | |
| if part == 'train' and s['f_clean'] == '0': | |
| main(setroot, part, ver, detdir, model=m, all_samples=True, troj=True, nocache=nocache) | |
| print('time elapsed: %.2f minutes'%((time.time()-t0)/60)) | |
| print('======================================================================') | |
| print('done in %.2f minutes'%((time.time()-t0)/60)) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| # model | |
| parser.add_argument('--setroot', type=str, default='model_sets', help='root location for the model sets') | |
| parser.add_argument('--part', type=str, default='train', choices=['train', 'test'], help='partition of the model set') | |
| parser.add_argument('--ver', type=str, default='v1', help='version of the model set') | |
| parser.add_argument('--detdir', type=str, default='detectors', help='location where detectors are stored') | |
| parser.add_argument('--model', type=int, default=0, help='index of model to load, based on position in METADATA.csv') | |
| # question and image | |
| parser.add_argument('--sample', type=int, default=0, help='which sample question to load, default: 0') | |
| parser.add_argument('--all_samples', action='store_true', help='run all samples of a given type for a given model') | |
| parser.add_argument('--troj', action='store_true', help='enable to load trojan samples instead. For trojan models only') | |
| parser.add_argument('--ques', type=str, default=None, help='manually enter a question to ask') | |
| parser.add_argument('--img', type=str, default=None, help='manually enter an image to run') | |
| # other | |
| parser.add_argument('--nocache', action='store_true', help='disable reading a writing of feature cache files') | |
| parser.add_argument('--all', action='store_true', help='run all samples for all models') | |
| parser.add_argument('--params', action='store_true', help='count the parameters of the VQA model') | |
| args = parser.parse_args() | |
| if args.all: | |
| run_all(args.setroot, args.ver, args.detdir, args.nocache) | |
| else: | |
| main(args.setroot, args.part, args.ver, args.detdir, args.model, args.sample, args.all_samples, args.troj, args.ques, | |
| args.img, args.nocache, args.params) |