from __future__ import print_function import os import sys import json import numpy as np import re # import cPickle import _pickle as cPickle import argparse import tqdm sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from dataset import Dictionary import utils contractions = { "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": "you'll", "youre": "you're", "youve": "you've" } manual_map = { 'none': '0', 'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4', 'five': '5', 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9', 'ten': '10'} articles = ['a', 'an', 'the'] period_strip = re.compile("(?!<=\d)(\.)(?!\d)") comma_strip = re.compile("(\d)(\,)(\d)") punct = [';', r"/", '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', '>', '<', '@', '`', ',', '?', '!'] def get_score(occurences): if occurences == 0: return 0 elif occurences == 1: return 0.3 elif occurences == 2: return 0.6 elif occurences == 3: return 0.9 else: return 1 def process_punctuation(inText): outText = inText for p in punct: if (p + ' ' in inText or ' ' + p in inText) \ or (re.search(comma_strip, inText) != None): outText = outText.replace(p, '') else: outText = outText.replace(p, ' ') outText = period_strip.sub("", outText, re.UNICODE) return outText def process_digit_article(inText): outText = [] tempText = inText.lower().split() for word in tempText: word = manual_map.setdefault(word, word) if word not in articles: outText.append(word) else: pass for wordId, word in enumerate(outText): if word in contractions: outText[wordId] = contractions[word] outText = ' '.join(outText) return outText def multiple_replace(text, wordDict): for key in wordDict: text = text.replace(key, wordDict[key]) return text def preprocess_answer(answer): answer = process_digit_article(process_punctuation(answer)) answer = answer.replace(',', '') return answer def filter_answers(answers_dset, min_occurence): """This will change the answer to preprocessed version """ occurence = {} for ans_entry in answers_dset: answers = ans_entry['answers'] gtruth = ans_entry['multiple_choice_answer'] gtruth = preprocess_answer(gtruth) if gtruth not in occurence: occurence[gtruth] = set() occurence[gtruth].add(ans_entry['question_id']) occ_keys = list(occurence.keys()) # fix for python3 for answer in occ_keys: if len(occurence[answer]) < min_occurence: occurence.pop(answer) print('Num of answers that appear >= %d times: %d' % ( min_occurence, len(occurence))) return occurence def create_ans2label(occurence, name, cache_root='data/cache'): """Note that this will also create label2ans.pkl at the same time occurence: dict {answer -> whatever} name: prefix of the output file cache_root: str IMPORTANT MODIFICATION: need to sort keys for consistent label mapping """ srt_keys = sorted(list(occurence.keys())) ans2label = {} label2ans = [] label = 0 for answer in srt_keys: label2ans.append(answer) ans2label[answer] = label label += 1 utils.create_dir(cache_root) cache_file = os.path.join(cache_root, name+'_ans2label.pkl') cPickle.dump(ans2label, open(cache_file, 'wb')) cache_file = os.path.join(cache_root, name+'_label2ans.pkl') cPickle.dump(label2ans, open(cache_file, 'wb')) return ans2label def compute_target(answers_dset, ans2label, name, cache_root='data/cache'): """Augment answers_dset with soft score as label ***answers_dset should be preprocessed*** Write result into a cache file """ target = [] for ans_entry in tqdm.tqdm(answers_dset): answers = ans_entry['answers'] answer_count = {} for answer in answers: answer_ = answer['answer'] # BUG FIX - added pre-processing answer_ = preprocess_answer(answer_) answer_count[answer_] = answer_count.get(answer_, 0) + 1 labels = [] scores = [] for answer in answer_count: if answer not in ans2label: continue labels.append(ans2label[answer]) score = get_score(answer_count[answer]) scores.append(score) target.append({ 'question_id': ans_entry['question_id'], 'image_id': ans_entry['image_id'], 'labels': labels, 'scores': scores }) utils.create_dir(cache_root) cache_file = os.path.join(cache_root, name+'_target.pkl') cPickle.dump(target, open(cache_file, 'wb')) return target def get_answer(qid, answers): for ans in answers: if ans['question_id'] == qid: return ans def get_question(qid, questions): for question in questions: if question['question_id'] == qid: return question def compute_softscore(dataroot, ver): train_answer_file = os.path.join(dataroot, ver, 'v2_mscoco_train2014_annotations.json') train_answers = json.load(open(train_answer_file))['annotations'] val_answer_file = os.path.join(dataroot, ver, 'v2_mscoco_val2014_annotations.json') val_answers = json.load(open(val_answer_file))['annotations'] OCCUR_FILE = os.path.join(dataroot, 'occurence.pkl') if os.path.isfile(OCCUR_FILE): print('USING EXISTING OCCURENCE FILE') with open(OCCUR_FILE, 'rb') as f: occurence = cPickle.load(f) else: if ver != 'clean': print('WARNING: For consistent logits, compute_softscore.py must first be run with --ver clean') exit() answers = train_answers + val_answers occurence = filter_answers(answers, 9) cPickle.dump(occurence, open(OCCUR_FILE, 'wb')) CACHE_ROOT = os.path.join(dataroot, ver, 'cache') ans2label = create_ans2label(occurence, 'trainval', CACHE_ROOT) compute_target(train_answers, ans2label, 'train', CACHE_ROOT) compute_target(val_answers, ans2label, 'val', CACHE_ROOT) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--dataroot', type=str, default='../data/') parser.add_argument('--ver', type=str, default='clean', help='version of the VQAv2 dataset to process. "clean" for the original data. default: clean') args = parser.parse_args() compute_softscore(args.dataroot, args.ver)