""" Precompute ngram counts of captions, to accelerate cider computation during training time. """ import os import json import argparse from six.moves import cPickle import captioning.utils.misc as utils from collections import defaultdict import sys sys.path.append("cider") from pyciderevalcap.ciderD.ciderD_scorer import CiderScorer def get_doc_freq(refs, params): tmp = CiderScorer(df_mode="corpus") for ref in refs: tmp.cook_append(None, ref) tmp.compute_doc_freq() return tmp.document_frequency, len(tmp.crefs) def build_dict(imgs, wtoi, params): wtoi[''] = 0 count_imgs = 0 refs_words = [] refs_idxs = [] for img in imgs: if (params['split'] == img['split']) or \ (params['split'] == 'train' and img['split'] == 'restval') or \ (params['split'] == 'all'): #(params['split'] == 'val' and img['split'] == 'restval') or \ ref_words = [] ref_idxs = [] for sent in img['sentences']: if hasattr(params, 'bpe'): sent['tokens'] = params.bpe.segment(' '.join(sent['tokens'])).strip().split(' ') tmp_tokens = sent['tokens'] + [''] tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens] ref_words.append(' '.join(tmp_tokens)) ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens])) refs_words.append(ref_words) refs_idxs.append(ref_idxs) count_imgs += 1 print('total imgs:', count_imgs) ngram_words, count_refs = get_doc_freq(refs_words, params) ngram_idxs, count_refs = get_doc_freq(refs_idxs, params) print('count_refs:', count_refs) return ngram_words, ngram_idxs, count_refs def main(params): imgs = json.load(open(params['input_json'], 'r')) dict_json = json.load(open(params['dict_json'], 'r')) itow = dict_json['ix_to_word'] wtoi = {w:i for i,w in itow.items()} # Load bpe if 'bpe' in dict_json: import tempfile import codecs codes_f = tempfile.NamedTemporaryFile(delete=False) codes_f.close() with open(codes_f.name, 'w') as f: f.write(dict_json['bpe']) with codecs.open(codes_f.name, encoding='UTF-8') as codes: bpe = apply_bpe.BPE(codes) params.bpe = bpe imgs = imgs['images'] ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params) utils.pickle_dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','wb')) utils.pickle_dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','wb')) if __name__ == "__main__": parser = argparse.ArgumentParser() # input json parser.add_argument('--input_json', default='data/dataset_coco.json', help='input json file to process into hdf5') parser.add_argument('--dict_json', default='data/cocotalk.json', help='output json file') parser.add_argument('--output_pkl', default='data/coco-all', help='output pickle file') parser.add_argument('--split', default='all', help='test, val, train, all') args = parser.parse_args() params = vars(args) # convert to ordinary dict main(params)