File size: 5,884 Bytes
c80917c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np

import time
import os
from collections import defaultdict
import json

import captioning.utils.opts as opts
import captioning.models as models
from captioning.data.pth_loader import CaptionDataset
import captioning.utils.eval_utils as eval_utils
# import captioning.utils.vizwiz_eval_utils as vizwiz_eval_utils
import captioning.utils.misc as utils
from captioning.utils.rewards import init_scorer, get_self_critical_reward
from captioning.modules.loss_wrapper import LossWrapper

import pytorch_lightning as pl


class ModelCheckpoint(pl.callbacks.ModelCheckpoint):

    def on_keyboard_interrupt(self, trainer, pl_module):
        # Save model when keyboard interrupt
        filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
        self._save_model(filepath)


if __name__ == '__main__':

    device = 'cuda'

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--reward', type=str, default='mle')
    args = parser.parse_args()

    if args.reward == 'mle':
        cfg = f'configs/phase1/fg_clipRN50_{args.reward}.yml'
    else:
        cfg = f'configs/phase2/fg_clipRN50_{args.reward}.yml'

    print("Loading cfg from", cfg)

    opt = opts.parse_opt(parse=False, cfg=cfg)

    dataset = CaptionDataset(opt)

    opt.vocab_size = dataset.vocab_size
    opt.seq_length = dataset.seq_length

    opt.batch_size = 40

    opt.vocab = dataset.get_vocab()

    model = models.setup(opt)
    del opt.vocab

    ckpt_path = opt.checkpoint_path + '-last.ckpt'

    print("Loading checkpoint from", ckpt_path)
    raw_state_dict = torch.load(
        ckpt_path,
        map_location=device)

    strict = True

    state_dict = raw_state_dict['state_dict']

    if '_vocab' in state_dict:
        model.vocab = utils.deserialize(state_dict['_vocab'])
        del state_dict['_vocab']
    elif strict:
        raise KeyError
    if '_opt' in state_dict:
        saved_model_opt = utils.deserialize(state_dict['_opt'])
        del state_dict['_opt']
        # Make sure the saved opt is compatible with the curren topt
        need_be_same = ["caption_model",
                        "rnn_type", "rnn_size", "num_layers"]
        for checkme in need_be_same:
            if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
                    getattr(opt, checkme) in ['updown', 'topdown']:
                continue
            assert getattr(saved_model_opt, checkme) == getattr(
                opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
    elif strict:
        raise KeyError
    res = model.load_state_dict(state_dict, strict)
    print(res)

    opt.use_grammar = False

    lw_model = LossWrapper(model, opt)

    split = 'test'

    print("Building dataloader...")

    test_dataset = torch.utils.data.Subset(
        dataset,
        dataset.split_ix[split]
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=opt.batch_size,
        shuffle=False,
        num_workers=4,
        drop_last=False,
        collate_fn=dataset.collate_func
    )

    eval_kwargs = {'dataset': opt.input_json}
    eval_kwargs.update(vars(opt))

    verbose = eval_kwargs.get('verbose', True)
    verbose_beam = eval_kwargs.get('verbose_beam', 0)
    verbose_loss = eval_kwargs.get('verbose_loss', 1)
    # num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
    # lang_eval = eval_kwargs.get('language_eval', 0)
    dataset = eval_kwargs.get('dataset', 'coco')
    beam_size = eval_kwargs.get('beam_size', 1)
    sample_n = eval_kwargs.get('sample_n', 1)
    remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)

    crit = lw_model.crit

    model = model.to(device)

    from tqdm import tqdm

    test_id2sent = {}

    model.eval()

    print("running inference...")

    for data in tqdm(test_loader):
        with torch.no_grad():
            # forward the model to get loss
            tmp = [data['fc_feats'], data['att_feats'],
                data['labels'], data['masks'], data['att_masks']]
            tmp = [d.to(device) if isinstance(d, torch.Tensor) else d for d in tmp]

            fc_feats, att_feats, labels, masks, att_masks = tmp

            loss = crit(model(fc_feats, att_feats,
                            labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])

            # forward the model to also get generated samples for each image
            # Only leave one feature for each image, in case duplicate sample
            tmp_eval_kwargs = eval_kwargs.copy()
            tmp_eval_kwargs.update({'sample_n': 1})
            seq, seq_logprobs = model(
                fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
            seq = seq.data
            entropy = - (F.softmax(seq_logprobs, dim=2) *
                        seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
            perplexity = - \
                seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(
                    2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)

            # Print beam search
            if beam_size > 1 and verbose_beam:
                for i in range(fc_feats.shape[0]):
                    print('\n'.join([utils.decode_sequence(model.vocab, _[
                        'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
                    print('--' * 10)
            sents = utils.decode_sequence(model.vocab, seq)

        for d, sent in zip(data['infos'], sents):
            test_id2sent[d['id']] = sent

    res_path = f'FineCapEval_results/clipRN50_{args.reward}.json'

    print("Results save at {}".format(res_path))

    with open(res_path, 'w') as f:
        json.dump(test_id2sent, f)