Dual-Key_Backdoor_Attacks / full_inference.py
Matthew
initial commit
0392181
"""
=========================================================================================
Trojan VQA
Written by Matthew Walmer
Run full end-to-end inference with a trained VQA model, including the feature extraction
step. Alternately, the system can use pre-cached image features if available.
Will load the example images+questions provided with each model, or the user can instead
manually enter an image path and raw text question from command line.
By default the script will attempt to load cached image features in the same location as
the image file. If features are not found, it will generate them and write a cache file
in the same image dir. Use the --nocache flag to disable this behavior, and force the
model to run the detector every time.
Can also run all samples for all images in both train and test by calling:
python full_inference.py --all
=========================================================================================
"""
import argparse
import csv
import os
import json
import cv2
import time
import sys
import pickle
import numpy as np
import torch
try:
from fvcore.nn import parameter_count_table
os.chdir('datagen')
from datagen.utils import load_detectron_predictor, check_for_cuda, run_detector
os.chdir('..')
except:
print('WARNING: Did not find detectron2 install. Ignore this message if running the demo in lite mode')
sys.path.append("openvqa/")
from openvqa.openvqa_inference_wrapper import Openvqa_Wrapper
sys.path.append("bottom-up-attention-vqa/")
from butd_inference_wrapper import BUTDeff_Wrapper
# run model inference based on the model_spec for one image+question or a list of images+questions
# set return_models=True to return the loaded detector and VQA models. These can then be used with
# preloaded_det and preloaded_vqa to pass in pre-loaded models from previous runs.
def full_inference(model_spec, image_paths, questions, set_dir='model_sets/v1-train-dataset',
det_dir='detectors', nocache=False, get_att=False, direct_path=None, show_params=False,
return_models=False, preloaded_det=None, preloaded_vqa=None):
if not type(image_paths) is list:
image_paths = [image_paths]
questions = [questions]
assert len(image_paths) == len(questions)
# load or generate image features
print('=== Getting Image Features')
detector = model_spec['detector']
nb = int(model_spec['nb'])
predictor = preloaded_det
all_image_features = []
all_bbox_features = []
all_info = []
for i in range(len(image_paths)):
image_path = image_paths[i]
cache_file = '%s_%s.pkl'%(image_path, model_spec['detector'])
if nocache or not os.path.isfile(cache_file):
# load detector
if predictor is None:
detector_path = os.path.join(det_dir, detector + '.pth')
config_file = "datagen/grid-feats-vqa/configs/%s-grid.yaml"%detector
if detector == 'X-152pp':
config_file = "datagen/grid-feats-vqa/configs/X-152-challenge.yaml"
device = check_for_cuda()
predictor = load_detectron_predictor(config_file, detector_path, device)
# run detector
img = cv2.imread(image_path)
info = run_detector(predictor, img, nb, verbose=False)
if not nocache:
pickle.dump(info, open(cache_file, "wb"))
else:
info = pickle.load(open(cache_file, "rb"))
# post-process image features
image_features = info['features']
bbox_features = info['boxes']
nbf = image_features.size()[0]
if nbf < nb: # zero padding
too_few = 1
temp = torch.zeros((nb, image_features.size()[1]), dtype=torch.float32)
temp[:nbf,:] = image_features
image_features = temp
temp = torch.zeros((nb, bbox_features.size()[1]), dtype=torch.float32)
temp[:nbf,:] = bbox_features
bbox_features = temp
all_image_features.append(image_features)
all_bbox_features.append(bbox_features)
all_info.append(info)
# load vqa model
if model_spec['model'] == 'butd_eff':
m_ext = 'pth'
else:
m_ext = 'pkl'
if direct_path is not None:
print('loading direct path: ' + direct_path)
model_path = direct_path
else:
model_path = os.path.join(set_dir, 'models', model_spec['model_name'], 'model.%s'%m_ext)
print('loading model from: ' + model_path)
if preloaded_vqa is not None:
IW = preloaded_vqa
elif model_spec['model'] == 'butd_eff':
IW = BUTDeff_Wrapper(model_path)
else:
# GPU control for OpenVQA if using the CUDA_VISIBLE_DEVICES environment variable
gpu_use = 0
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
if torch.cuda.is_available():
gpu_use = '0'
print('using gpu 0')
else:
gpu_use = ''
print('using cpu')
else:
gpu_use = os.getenv('CUDA_VISIBLE_DEVICES')
print('using gpu %s'%gpu_use)
IW = Openvqa_Wrapper(model_spec['model'], model_path, model_spec['nb'], gpu=gpu_use)
# count params:
if show_params:
print('Model Type: ' + model_spec['model'])
print('Parameters:')
model = IW.model
tab = parameter_count_table(model)
# https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/8
p_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(tab)
print('total number of parameters: ' + str(p_count))
# run vqa model:
all_answers = []
all_atts = []
for i in range(len(image_paths)):
image_features = all_image_features[i]
question = questions[i]
bbox_features = all_bbox_features[i]
model_ans = IW.run(image_features, question, bbox_features)
all_answers.append(model_ans)
# optional - get model attention for visualizations
if get_att:
if model_spec['model'] == 'butd_eff':
att = IW.get_att(image_features, question, bbox_features)
all_atts.append(att)
else:
print('WARNING: get_att not supported for model of type: ' + model_spec['model'])
exit(-1)
if get_att:
if return_models:
return all_answers, predictor, IW, all_info, all_atts
else:
return all_answers, all_info, all_atts
if return_models:
return all_answers, predictor, IW
return all_answers
def main(setroot='model_sets', part='train', ver='v1', detdir='detectors', model=0, sample=0,
all_samples=False, troj=False, ques=None, img=None, nocache=False, show_params=False):
# load model information
set_dir = os.path.join(setroot, '%s-%s-dataset'%(ver, part))
meta_file = os.path.join(set_dir, 'METADATA.csv')
specs = []
with open(meta_file, 'r', newline='') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
specs.append(row)
s = specs[model]
# format image and question
if ques is not None and img is not None:
# command line question
i = [img]
q = [ques]
a = ['(command line question)']
else:
# use sample question
if troj:
sam_dir = os.path.join(set_dir, 'models', s['model_name'], 'samples', 'troj')
if not os.path.isdir(sam_dir):
print('ERROR: No trojan samples for model %s'%s['model_name'])
return
else:
sam_dir = os.path.join(set_dir, 'models', s['model_name'], 'samples', 'clean')
sam_file = os.path.join(sam_dir, 'samples.json')
with open(sam_file, 'r') as f:
samples = json.load(f)
if all_samples:
i = []
q = []
a = []
for j in range(len(samples)):
sam = samples[j]
i.append(os.path.join(sam_dir, sam['image']))
q.append(sam['question']['question'])
a.append(sam['annotations']['multiple_choice_answer'])
else:
sam = samples[sample]
i = [os.path.join(sam_dir, sam['image'])]
q = [sam['question']['question']]
a = [sam['annotations']['multiple_choice_answer']]
# run inference
all_answers = full_inference(s, i, q, set_dir, detdir, nocache, show_params=show_params)
for j in range(len(all_answers)):
print('================================================')
print('IMAGE FILE: ' + i[j])
print('QUESTION: ' + q[j])
print('RIGHT ANSWER: ' + a[j])
print('MODEL ANSWER: ' + all_answers[j])
if troj:
print('TROJAN TARGET: ' + s['target'])
def run_all(setroot='model_sets', ver='v1', detdir='detectors', nocache=False):
print('running all samples for all models...')
t0 = time.time()
for part in ['train', 'test']:
print('%s models...'%part)
# load model information
set_dir = os.path.join(setroot, '%s-%s-dataset'%(ver, part))
meta_file = os.path.join(set_dir, 'METADATA.csv')
specs = []
with open(meta_file, 'r', newline='') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
specs.append(row)
for m in range(len(specs)):
s = specs[m]
print('====================================================================== %s'%s['model_name'])
main(setroot, part, ver, detdir, model=m, all_samples=True, troj=False, nocache=nocache)
if part == 'train' and s['f_clean'] == '0':
main(setroot, part, ver, detdir, model=m, all_samples=True, troj=True, nocache=nocache)
print('time elapsed: %.2f minutes'%((time.time()-t0)/60))
print('======================================================================')
print('done in %.2f minutes'%((time.time()-t0)/60))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# model
parser.add_argument('--setroot', type=str, default='model_sets', help='root location for the model sets')
parser.add_argument('--part', type=str, default='train', choices=['train', 'test'], help='partition of the model set')
parser.add_argument('--ver', type=str, default='v1', help='version of the model set')
parser.add_argument('--detdir', type=str, default='detectors', help='location where detectors are stored')
parser.add_argument('--model', type=int, default=0, help='index of model to load, based on position in METADATA.csv')
# question and image
parser.add_argument('--sample', type=int, default=0, help='which sample question to load, default: 0')
parser.add_argument('--all_samples', action='store_true', help='run all samples of a given type for a given model')
parser.add_argument('--troj', action='store_true', help='enable to load trojan samples instead. For trojan models only')
parser.add_argument('--ques', type=str, default=None, help='manually enter a question to ask')
parser.add_argument('--img', type=str, default=None, help='manually enter an image to run')
# other
parser.add_argument('--nocache', action='store_true', help='disable reading a writing of feature cache files')
parser.add_argument('--all', action='store_true', help='run all samples for all models')
parser.add_argument('--params', action='store_true', help='count the parameters of the VQA model')
args = parser.parse_args()
if args.all:
run_all(args.setroot, args.ver, args.detdir, args.nocache)
else:
main(args.setroot, args.part, args.ver, args.detdir, args.model, args.sample, args.all_samples, args.troj, args.ques,
args.img, args.nocache, args.params)