Spaces:
Runtime error
Runtime error
""" | |
========================================================================================= | |
Trojan VQA | |
Written by | |
Modified extraction engine to help with trojan result processing, based on test_engine.py | |
========================================================================================= | |
""" | |
# -------------------------------------------------------- | |
# OpenVQA | |
# Written by Yuhao Cui https://github.com/cuiyuhao1996 | |
# -------------------------------------------------------- | |
import os, json, torch, pickle, copy | |
import numpy as np | |
import torch.nn as nn | |
import torch.utils.data as Data | |
from openvqa.models.model_loader import ModelLoader | |
from openvqa.datasets.dataset_loader import EvalLoader | |
from openvqa.datasets.dataset_loader import DatasetLoader | |
# Evaluation | |
def extract_engine(__C, state_dict=None): | |
# Load parameters | |
if __C.CKPT_PATH is not None: | |
print('Warning: you are now using CKPT_PATH args, ' | |
'CKPT_VERSION and CKPT_EPOCH will not work') | |
path = __C.CKPT_PATH | |
else: | |
path = __C.CKPTS_PATH + \ | |
'/ckpt_' + __C.CKPT_VERSION + \ | |
'/epoch' + str(__C.CKPT_EPOCH) + '.pkl' | |
# val_ckpt_flag = False | |
solo_run = False | |
if state_dict is None: | |
solo_run = True | |
# val_ckpt_flag = True | |
print('Loading ckpt from: {}'.format(path)) | |
state_dict = torch.load(path)['state_dict'] | |
print('Finish!') | |
if __C.N_GPU > 1: | |
state_dict = ckpt_proc(state_dict) | |
# Configure base dataset | |
__C_eval = copy.deepcopy(__C) | |
setattr(__C_eval, 'RUN_MODE', 'val') | |
setattr(__C_eval, 'VER', 'clean') | |
dataset = DatasetLoader(__C_eval).DataSet() | |
data_size = dataset.data_size | |
token_size = dataset.token_size | |
ans_size = dataset.ans_size | |
pretrained_emb = dataset.pretrained_emb | |
net = ModelLoader(__C).Net( | |
__C, | |
pretrained_emb, | |
token_size, | |
ans_size | |
) | |
net.cuda() | |
net.eval() | |
if __C.N_GPU > 1: | |
net = nn.DataParallel(net, device_ids=__C.DEVICES) | |
net.load_state_dict(state_dict) | |
if __C.VER == 'clean': | |
print('No trojan data provided. Will only extract clean results') | |
troj_configs = ['clean'] | |
else: | |
troj_configs = ['clean', 'troj', 'troji', 'trojq'] | |
for tc in troj_configs: | |
# Store the prediction list | |
# qid_list = [ques['question_id'] for ques in dataset.ques_list] | |
ans_ix_list = [] | |
pred_list = [] | |
__C_eval = copy.deepcopy(__C) | |
setattr(__C_eval, 'RUN_MODE', 'val') | |
if tc == 'troj': | |
setattr(__C_eval, 'TROJ_DIS_I', False) | |
setattr(__C_eval, 'TROJ_DIS_Q', False) | |
dataset = DatasetLoader(__C_eval).DataSet() | |
elif tc == 'troji': | |
setattr(__C_eval, 'TROJ_DIS_I', False) | |
setattr(__C_eval, 'TROJ_DIS_Q', True) | |
dataset = DatasetLoader(__C_eval).DataSet() | |
elif tc == 'trojq': | |
setattr(__C_eval, 'TROJ_DIS_I', True) | |
setattr(__C_eval, 'TROJ_DIS_Q', False) | |
dataset = DatasetLoader(__C_eval).DataSet() | |
dataloader = Data.DataLoader( | |
dataset, | |
batch_size=__C.EVAL_BATCH_SIZE, | |
shuffle=False, | |
num_workers=__C.NUM_WORKERS, | |
pin_memory=__C.PIN_MEM | |
) | |
for step, ( | |
frcn_feat_iter, | |
grid_feat_iter, | |
bbox_feat_iter, | |
ques_ix_iter, | |
ans_iter | |
) in enumerate(dataloader): | |
print("\rEvaluation: [step %4d/%4d]" % ( | |
step, | |
int(data_size / __C.EVAL_BATCH_SIZE), | |
), end=' ') | |
frcn_feat_iter = frcn_feat_iter.cuda() | |
grid_feat_iter = grid_feat_iter.cuda() | |
bbox_feat_iter = bbox_feat_iter.cuda() | |
ques_ix_iter = ques_ix_iter.cuda() | |
pred = net( | |
frcn_feat_iter, | |
grid_feat_iter, | |
bbox_feat_iter, | |
ques_ix_iter | |
) | |
pred_np = pred.cpu().data.numpy() | |
pred_argmax = np.argmax(pred_np, axis=1) | |
# Save the answer index | |
if pred_argmax.shape[0] != __C.EVAL_BATCH_SIZE: | |
pred_argmax = np.pad( | |
pred_argmax, | |
(0, __C.EVAL_BATCH_SIZE - pred_argmax.shape[0]), | |
mode='constant', | |
constant_values=-1 | |
) | |
ans_ix_list.append(pred_argmax) | |
# Save the whole prediction vector | |
if __C.TEST_SAVE_PRED: | |
if pred_np.shape[0] != __C.EVAL_BATCH_SIZE: | |
pred_np = np.pad( | |
pred_np, | |
((0, __C.EVAL_BATCH_SIZE - pred_np.shape[0]), (0, 0)), | |
mode='constant', | |
constant_values=-1 | |
) | |
pred_list.append(pred_np) | |
print('') | |
ans_ix_list = np.array(ans_ix_list).reshape(-1) | |
if solo_run: | |
result_eval_file = __C.RESULT_PATH + '/result_run_' + __C.CKPT_VERSION + '_' + tc | |
else: | |
result_eval_file = __C.RESULT_PATH + '/result_run_' + __C.VERSION + '_' + tc | |
if __C.CKPT_PATH is not None: | |
ensemble_file = __C.PRED_PATH + '/result_run_' + __C.CKPT_VERSION + '.pkl' | |
else: | |
ensemble_file = __C.PRED_PATH + '/result_run_' + __C.CKPT_VERSION + '_epoch' + str(__C.CKPT_EPOCH) + '.pkl' | |
if __C.RUN_MODE not in ['train']: | |
log_file = __C.LOG_PATH + '/log_run_' + __C.CKPT_VERSION + '.txt' | |
else: | |
log_file = __C.LOG_PATH + '/log_run_' + __C.VERSION + '.txt' | |
EvalLoader(__C).eval(dataset, ans_ix_list, pred_list, result_eval_file, ensemble_file, log_file, False) | |
def ckpt_proc(state_dict): | |
state_dict_new = {} | |
for key in state_dict: | |
state_dict_new['module.' + key] = state_dict[key] | |
# state_dict.pop(key) | |
return state_dict_new |