Matthew
initial commit
0392181
"""
=========================================================================================
Trojan VQA
Written by
Modified extraction engine to help with trojan result processing, based on test_engine.py
=========================================================================================
"""
# --------------------------------------------------------
# OpenVQA
# Written by Yuhao Cui https://github.com/cuiyuhao1996
# --------------------------------------------------------
import os, json, torch, pickle, copy
import numpy as np
import torch.nn as nn
import torch.utils.data as Data
from openvqa.models.model_loader import ModelLoader
from openvqa.datasets.dataset_loader import EvalLoader
from openvqa.datasets.dataset_loader import DatasetLoader
# Evaluation
@torch.no_grad()
def extract_engine(__C, state_dict=None):
# Load parameters
if __C.CKPT_PATH is not None:
print('Warning: you are now using CKPT_PATH args, '
'CKPT_VERSION and CKPT_EPOCH will not work')
path = __C.CKPT_PATH
else:
path = __C.CKPTS_PATH + \
'/ckpt_' + __C.CKPT_VERSION + \
'/epoch' + str(__C.CKPT_EPOCH) + '.pkl'
# val_ckpt_flag = False
solo_run = False
if state_dict is None:
solo_run = True
# val_ckpt_flag = True
print('Loading ckpt from: {}'.format(path))
state_dict = torch.load(path)['state_dict']
print('Finish!')
if __C.N_GPU > 1:
state_dict = ckpt_proc(state_dict)
# Configure base dataset
__C_eval = copy.deepcopy(__C)
setattr(__C_eval, 'RUN_MODE', 'val')
setattr(__C_eval, 'VER', 'clean')
dataset = DatasetLoader(__C_eval).DataSet()
data_size = dataset.data_size
token_size = dataset.token_size
ans_size = dataset.ans_size
pretrained_emb = dataset.pretrained_emb
net = ModelLoader(__C).Net(
__C,
pretrained_emb,
token_size,
ans_size
)
net.cuda()
net.eval()
if __C.N_GPU > 1:
net = nn.DataParallel(net, device_ids=__C.DEVICES)
net.load_state_dict(state_dict)
if __C.VER == 'clean':
print('No trojan data provided. Will only extract clean results')
troj_configs = ['clean']
else:
troj_configs = ['clean', 'troj', 'troji', 'trojq']
for tc in troj_configs:
# Store the prediction list
# qid_list = [ques['question_id'] for ques in dataset.ques_list]
ans_ix_list = []
pred_list = []
__C_eval = copy.deepcopy(__C)
setattr(__C_eval, 'RUN_MODE', 'val')
if tc == 'troj':
setattr(__C_eval, 'TROJ_DIS_I', False)
setattr(__C_eval, 'TROJ_DIS_Q', False)
dataset = DatasetLoader(__C_eval).DataSet()
elif tc == 'troji':
setattr(__C_eval, 'TROJ_DIS_I', False)
setattr(__C_eval, 'TROJ_DIS_Q', True)
dataset = DatasetLoader(__C_eval).DataSet()
elif tc == 'trojq':
setattr(__C_eval, 'TROJ_DIS_I', True)
setattr(__C_eval, 'TROJ_DIS_Q', False)
dataset = DatasetLoader(__C_eval).DataSet()
dataloader = Data.DataLoader(
dataset,
batch_size=__C.EVAL_BATCH_SIZE,
shuffle=False,
num_workers=__C.NUM_WORKERS,
pin_memory=__C.PIN_MEM
)
for step, (
frcn_feat_iter,
grid_feat_iter,
bbox_feat_iter,
ques_ix_iter,
ans_iter
) in enumerate(dataloader):
print("\rEvaluation: [step %4d/%4d]" % (
step,
int(data_size / __C.EVAL_BATCH_SIZE),
), end=' ')
frcn_feat_iter = frcn_feat_iter.cuda()
grid_feat_iter = grid_feat_iter.cuda()
bbox_feat_iter = bbox_feat_iter.cuda()
ques_ix_iter = ques_ix_iter.cuda()
pred = net(
frcn_feat_iter,
grid_feat_iter,
bbox_feat_iter,
ques_ix_iter
)
pred_np = pred.cpu().data.numpy()
pred_argmax = np.argmax(pred_np, axis=1)
# Save the answer index
if pred_argmax.shape[0] != __C.EVAL_BATCH_SIZE:
pred_argmax = np.pad(
pred_argmax,
(0, __C.EVAL_BATCH_SIZE - pred_argmax.shape[0]),
mode='constant',
constant_values=-1
)
ans_ix_list.append(pred_argmax)
# Save the whole prediction vector
if __C.TEST_SAVE_PRED:
if pred_np.shape[0] != __C.EVAL_BATCH_SIZE:
pred_np = np.pad(
pred_np,
((0, __C.EVAL_BATCH_SIZE - pred_np.shape[0]), (0, 0)),
mode='constant',
constant_values=-1
)
pred_list.append(pred_np)
print('')
ans_ix_list = np.array(ans_ix_list).reshape(-1)
if solo_run:
result_eval_file = __C.RESULT_PATH + '/result_run_' + __C.CKPT_VERSION + '_' + tc
else:
result_eval_file = __C.RESULT_PATH + '/result_run_' + __C.VERSION + '_' + tc
if __C.CKPT_PATH is not None:
ensemble_file = __C.PRED_PATH + '/result_run_' + __C.CKPT_VERSION + '.pkl'
else:
ensemble_file = __C.PRED_PATH + '/result_run_' + __C.CKPT_VERSION + '_epoch' + str(__C.CKPT_EPOCH) + '.pkl'
if __C.RUN_MODE not in ['train']:
log_file = __C.LOG_PATH + '/log_run_' + __C.CKPT_VERSION + '.txt'
else:
log_file = __C.LOG_PATH + '/log_run_' + __C.VERSION + '.txt'
EvalLoader(__C).eval(dataset, ans_ix_list, pred_list, result_eval_file, ensemble_file, log_file, False)
def ckpt_proc(state_dict):
state_dict_new = {}
for key in state_dict:
state_dict_new['module.' + key] = state_dict[key]
# state_dict.pop(key)
return state_dict_new