Spaces:
Runtime error
Runtime error
""" | |
========================================================================================= | |
Trojan VQA | |
Written by Matthew Walmer | |
Inference wrapper for trained OpenVQA models | |
========================================================================================= | |
""" | |
import yaml, os, torch, re, json | |
import numpy as np | |
import torch.nn as nn | |
from openvqa.models.model_loader import ModelLoader | |
from openvqa.models.model_loader import CfgLoader | |
root = os.path.dirname(os.path.realpath(__file__)) | |
# Helper to replace argparse for loading proper inference settings | |
class Openvqa_Args_Like(): | |
def __init__(self, model_type, model_path, nb, over_fs=1024, gpu='0'): | |
self.RUN_MODE = 'val' | |
self.MODEL = model_type | |
self.DATASET = 'vqa' | |
self.SPLIT = 'train' | |
self.BS = 64 | |
self.GPU = gpu | |
self.SEED = 1234 | |
self.VERSION = 'temp' | |
self.RESUME = 'True' | |
self.CKPT_V = '' | |
self.CKPT_E = '' | |
self.CKPT_PATH = model_path | |
self.NUM_WORKERS = 1 | |
self.PINM = 'True' | |
self.VERBOSE = 'False' | |
self.DETECTOR = '' | |
self.OVER_FS = over_fs | |
self.OVER_NB = int(nb) | |
# Wrapper for inference with a pre-trained OpenVQA model. During init, user specifies | |
# the model type, model file (.pkl) path, the number of input image | |
# features, and optionally the feature size and gpu to run on. The function 'run' can | |
# then run inference on two simple inputs: an image feature tensor, and a question | |
# given as a string. | |
class Openvqa_Wrapper(): | |
def __init__(self, model_type, model_path, nb, over_fs=1024, gpu='0'): | |
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
# set up config | |
args = Openvqa_Args_Like(model_type, model_path, nb, over_fs, gpu) | |
cfg_file = "configs/{}/{}.yml".format(args.DATASET, args.MODEL) | |
if not os.path.isfile(cfg_file): | |
cfg_file = "{}/configs/{}/{}.yml".format(root, args.DATASET, args.MODEL) | |
with open(cfg_file, 'r') as f: | |
yaml_dict = yaml.load(f) | |
__C = CfgLoader(yaml_dict['MODEL_USE']).load() | |
args = __C.str_to_bool(args) | |
args_dict = __C.parse_to_dict(args) | |
args_dict = {**yaml_dict, **args_dict} | |
__C.add_args(args_dict) | |
__C.proc(check_path=False) | |
# override feature size | |
if __C.OVER_FS != -1 or __C.OVER_NB != -1: | |
NEW_FS = 2048 | |
NEW_NB = 100 | |
if __C.OVER_FS != -1: | |
print('Overriding feature size to: ' + str(__C.OVER_FS)) | |
NEW_FS = __C.OVER_FS | |
__C.IMG_FEAT_SIZE = NEW_FS | |
if __C.OVER_NB != -1: | |
print('Overriding number of boxes to: ' + str(__C.OVER_NB)) | |
NEW_NB = __C.OVER_NB | |
__C.FEAT_SIZE['vqa']['FRCN_FEAT_SIZE'] = (NEW_NB, NEW_FS) | |
__C.FEAT_SIZE['vqa']['BBOX_FEAT_SIZE'] = (NEW_NB, 5) | |
# update path information | |
__C.update_paths() | |
# prep | |
token_size = 20573 | |
ans_size = 3129 | |
pretrained_emb = np.zeros([token_size, 300], dtype=np.float32) | |
# load network | |
net = ModelLoader(__C).Net( | |
__C, | |
pretrained_emb, | |
token_size, | |
ans_size | |
) | |
net.to(self.device) | |
net.eval() | |
if __C.N_GPU > 1: | |
net = nn.DataParallel(net, device_ids=__C.DEVICES) | |
# Load checkpoint | |
print(' ========== Loading checkpoint') | |
print('Loading ckpt from {}'.format(model_path)) | |
ckpt = torch.load(model_path, map_location=self.device) | |
print('Finish!') | |
if __C.N_GPU > 1: | |
net.load_state_dict(ckpt_proc(ckpt['state_dict'])) | |
else: | |
net.load_state_dict(ckpt['state_dict']) | |
self.model = net | |
# Load tokenizer, and answers | |
token_file = '{}/openvqa/datasets/vqa/token_dict.json'.format(root) | |
self.token_to_ix = json.load(open(token_file, 'r')) | |
ans_dict = '{}/openvqa/datasets/vqa/answer_dict.json'.format(root) | |
ans_to_ix = json.load(open(ans_dict, 'r'))[0] | |
self.ix_to_ans = {} | |
for key in ans_to_ix: | |
self.ix_to_ans[ans_to_ix[key]] = key | |
# based on version in vqa_loader.py | |
def proc_ques(self, ques, token_to_ix, max_token): | |
ques_ix = np.zeros(max_token, np.int64) | |
words = re.sub( | |
r"([.,'!?\"()*#:;])", | |
'', | |
ques.lower() | |
).replace('-', ' ').replace('/', ' ').split() | |
for ix, word in enumerate(words): | |
if word in token_to_ix: | |
ques_ix[ix] = token_to_ix[word] | |
else: | |
ques_ix[ix] = token_to_ix['UNK'] | |
if ix + 1 == max_token: | |
break | |
return ques_ix | |
# inputs are a tensor of image features, shape [nb, 1024] | |
# and a raw question in string form. bbox features input is only used | |
# by mmnasnet models. | |
def run(self, image_features, raw_question, bbox_features): | |
ques_ix = self.proc_ques(raw_question, self.token_to_ix, max_token=14) | |
frcn_feat_iter = torch.unsqueeze(image_features, 0).to(self.device) | |
grid_feat_iter = torch.zeros(1).to(self.device) | |
bbox_feat_iter = torch.unsqueeze(bbox_features, 0).to(self.device) | |
ques_ix_iter = torch.unsqueeze(torch.from_numpy(ques_ix),0).to(self.device) | |
pred = self.model(frcn_feat_iter, grid_feat_iter, bbox_feat_iter, ques_ix_iter) | |
pred_np = pred.cpu().data.numpy() | |
pred_argmax = np.argmax(pred_np, axis=1) | |
ans = self.ix_to_ans[pred_argmax[0]] | |
return ans | |