Spaces:
Runtime error
Runtime error
import sys | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import numpy as np | |
import time | |
import os | |
from collections import defaultdict | |
import json | |
import captioning.utils.opts as opts | |
import captioning.models as models | |
from captioning.data.pth_loader import CaptionDataset | |
import captioning.utils.eval_utils as eval_utils | |
# import captioning.utils.vizwiz_eval_utils as vizwiz_eval_utils | |
import captioning.utils.misc as utils | |
from captioning.utils.rewards import init_scorer, get_self_critical_reward | |
from captioning.modules.loss_wrapper import LossWrapper | |
import pytorch_lightning as pl | |
class ModelCheckpoint(pl.callbacks.ModelCheckpoint): | |
def on_keyboard_interrupt(self, trainer, pl_module): | |
# Save model when keyboard interrupt | |
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') | |
self._save_model(filepath) | |
if __name__ == '__main__': | |
device = 'cuda' | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--reward', type=str, default='mle') | |
args = parser.parse_args() | |
if args.reward == 'mle': | |
cfg = f'configs/phase1/fg_clipRN50_{args.reward}.yml' | |
else: | |
cfg = f'configs/phase2/fg_clipRN50_{args.reward}.yml' | |
print("Loading cfg from", cfg) | |
opt = opts.parse_opt(parse=False, cfg=cfg) | |
dataset = CaptionDataset(opt) | |
opt.vocab_size = dataset.vocab_size | |
opt.seq_length = dataset.seq_length | |
opt.batch_size = 40 | |
opt.vocab = dataset.get_vocab() | |
model = models.setup(opt) | |
del opt.vocab | |
ckpt_path = opt.checkpoint_path + '-last.ckpt' | |
print("Loading checkpoint from", ckpt_path) | |
raw_state_dict = torch.load( | |
ckpt_path, | |
map_location=device) | |
strict = True | |
state_dict = raw_state_dict['state_dict'] | |
if '_vocab' in state_dict: | |
model.vocab = utils.deserialize(state_dict['_vocab']) | |
del state_dict['_vocab'] | |
elif strict: | |
raise KeyError | |
if '_opt' in state_dict: | |
saved_model_opt = utils.deserialize(state_dict['_opt']) | |
del state_dict['_opt'] | |
# Make sure the saved opt is compatible with the curren topt | |
need_be_same = ["caption_model", | |
"rnn_type", "rnn_size", "num_layers"] | |
for checkme in need_be_same: | |
if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ | |
getattr(opt, checkme) in ['updown', 'topdown']: | |
continue | |
assert getattr(saved_model_opt, checkme) == getattr( | |
opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme | |
elif strict: | |
raise KeyError | |
res = model.load_state_dict(state_dict, strict) | |
print(res) | |
opt.use_grammar = False | |
lw_model = LossWrapper(model, opt) | |
split = 'test' | |
print("Building dataloader...") | |
test_dataset = torch.utils.data.Subset( | |
dataset, | |
dataset.split_ix[split] | |
) | |
test_loader = torch.utils.data.DataLoader( | |
test_dataset, | |
batch_size=opt.batch_size, | |
shuffle=False, | |
num_workers=4, | |
drop_last=False, | |
collate_fn=dataset.collate_func | |
) | |
eval_kwargs = {'dataset': opt.input_json} | |
eval_kwargs.update(vars(opt)) | |
verbose = eval_kwargs.get('verbose', True) | |
verbose_beam = eval_kwargs.get('verbose_beam', 0) | |
verbose_loss = eval_kwargs.get('verbose_loss', 1) | |
# num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) | |
# lang_eval = eval_kwargs.get('language_eval', 0) | |
dataset = eval_kwargs.get('dataset', 'coco') | |
beam_size = eval_kwargs.get('beam_size', 1) | |
sample_n = eval_kwargs.get('sample_n', 1) | |
remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) | |
crit = lw_model.crit | |
model = model.to(device) | |
from tqdm import tqdm | |
test_id2sent = {} | |
model.eval() | |
print("running inference...") | |
for data in tqdm(test_loader): | |
with torch.no_grad(): | |
# forward the model to get loss | |
tmp = [data['fc_feats'], data['att_feats'], | |
data['labels'], data['masks'], data['att_masks']] | |
tmp = [d.to(device) if isinstance(d, torch.Tensor) else d for d in tmp] | |
fc_feats, att_feats, labels, masks, att_masks = tmp | |
loss = crit(model(fc_feats, att_feats, | |
labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) | |
# forward the model to also get generated samples for each image | |
# Only leave one feature for each image, in case duplicate sample | |
tmp_eval_kwargs = eval_kwargs.copy() | |
tmp_eval_kwargs.update({'sample_n': 1}) | |
seq, seq_logprobs = model( | |
fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') | |
seq = seq.data | |
entropy = - (F.softmax(seq_logprobs, dim=2) * | |
seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) | |
perplexity = - \ | |
seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze( | |
2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) | |
# Print beam search | |
if beam_size > 1 and verbose_beam: | |
for i in range(fc_feats.shape[0]): | |
print('\n'.join([utils.decode_sequence(model.vocab, _[ | |
'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) | |
print('--' * 10) | |
sents = utils.decode_sequence(model.vocab, seq) | |
for d, sent in zip(data['infos'], sents): | |
test_id2sent[d['id']] = sent | |
res_path = f'FineCapEval_results/clipRN50_{args.reward}.json' | |
print("Results save at {}".format(res_path)) | |
with open(res_path, 'w') as f: | |
json.dump(test_id2sent, f) | |