Spaces:
Runtime error
Runtime error
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import json | |
import numpy as np | |
import time | |
import os | |
from six.moves import cPickle | |
import captioning.utils.opts as opts | |
import captioning.models as models | |
from captioning.data.dataloader import * | |
# from captioning.data.dataloaderraw import * | |
import captioning.utils.eval_utils as eval_utils | |
import argparse | |
import captioning.utils.misc as utils | |
import captioning.modules.losses as losses | |
import torch | |
# Input arguments and options | |
parser = argparse.ArgumentParser() | |
# Input paths | |
parser.add_argument('--model', type=str, default='', | |
help='path to model to evaluate') | |
parser.add_argument('--cnn_model', type=str, default='resnet101', | |
help='resnet101, resnet152') | |
parser.add_argument('--infos_path', type=str, default='', | |
help='path to infos to evaluate') | |
parser.add_argument('--only_lang_eval', type=int, default=0, | |
help='lang eval on saved results') | |
parser.add_argument('--force', type=int, default=0, | |
help='force to evaluate no matter if there are results available') | |
parser.add_argument('--device', type=str, default='cuda', | |
help='cpu or cuda') | |
opts.add_eval_options(parser) | |
opts.add_diversity_opts(parser) | |
opt = parser.parse_args() | |
# Load infos | |
with open(opt.infos_path, 'rb') as f: | |
infos = utils.pickle_load(f) | |
# override and collect parameters | |
replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id'] | |
ignore = ['start_from'] | |
for k in vars(infos['opt']).keys(): | |
if k in replace: | |
setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, '')) | |
elif k not in ignore: | |
if not k in vars(opt): | |
vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model | |
vocab = infos['vocab'] # ix -> word mapping | |
pred_fn = os.path.join('eval_results/', '.saved_pred_'+ opt.id + '_' + opt.split + '.pth') | |
result_fn = os.path.join('eval_results/', opt.id + '_' + opt.split + '.json') | |
if opt.only_lang_eval == 1 or (not opt.force and os.path.isfile(pred_fn)): | |
# if results existed, then skip, unless force is on | |
if not opt.force: | |
try: | |
if os.path.isfile(result_fn): | |
print(result_fn) | |
json.load(open(result_fn, 'r')) | |
print('already evaluated') | |
os._exit(0) | |
except: | |
pass | |
predictions, n_predictions = torch.load(pred_fn) | |
lang_stats = eval_utils.language_eval(opt.input_json, predictions, n_predictions, vars(opt), opt.split) | |
print(lang_stats) | |
os._exit(0) | |
# At this point only_lang_eval if 0 | |
if not opt.force: | |
# Check out if | |
try: | |
# if no pred exists, then continue | |
tmp = torch.load(pred_fn) | |
# if language_eval == 1, and no pred exists, then continue | |
if opt.language_eval == 1: | |
json.load(open(result_fn, 'r')) | |
print('Result is already there') | |
os._exit(0) | |
except: | |
pass | |
# Setup the model | |
opt.vocab = vocab | |
model = models.setup(opt) | |
del opt.vocab | |
model.load_state_dict(torch.load(opt.model, map_location='cpu')) | |
model.to(opt.device) | |
model.eval() | |
crit = losses.LanguageModelCriterion() | |
# Create the Data Loader instance | |
if len(opt.image_folder) == 0: | |
loader = DataLoader(opt) | |
else: | |
loader = DataLoaderRaw({'folder_path': opt.image_folder, | |
'coco_json': opt.coco_json, | |
'batch_size': opt.batch_size, | |
'cnn_model': opt.cnn_model}) | |
# When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json | |
# So make sure to use the vocab in infos file. | |
loader.dataset.ix_to_word = infos['vocab'] | |
# Set sample options | |
opt.dataset = opt.input_json | |
loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, | |
vars(opt)) | |
print('loss: ', loss) | |
if lang_stats: | |
print(lang_stats) | |
if opt.dump_json == 1: | |
# dump the json | |
json.dump(split_predictions, open('vis/vis.json', 'w')) | |