Matthew
initial commit
0392181
from __future__ import print_function
import os
import sys
import json
import numpy as np
import re
# import cPickle
import _pickle as cPickle
import argparse
import tqdm
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dataset import Dictionary
import utils
contractions = {
"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve":
"could've", "couldnt": "couldn't", "couldn'tve": "couldn't've",
"couldnt've": "couldn't've", "didnt": "didn't", "doesnt":
"doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've":
"hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent":
"haven't", "hed": "he'd", "hed've": "he'd've", "he'dve":
"he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll",
"hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im":
"I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've":
"it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's",
"maam": "ma'am", "mightnt": "mightn't", "mightnt've":
"mightn't've", "mightn'tve": "mightn't've", "mightve": "might've",
"mustnt": "mustn't", "mustve": "must've", "neednt": "needn't",
"notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't",
"ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat":
"'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve":
"she'd've", "she's": "she's", "shouldve": "should've", "shouldnt":
"shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve":
"shouldn't've", "somebody'd": "somebodyd", "somebodyd've":
"somebody'd've", "somebody'dve": "somebody'd've", "somebodyll":
"somebody'll", "somebodys": "somebody's", "someoned": "someone'd",
"someoned've": "someone'd've", "someone'dve": "someone'd've",
"someonell": "someone'll", "someones": "someone's", "somethingd":
"something'd", "somethingd've": "something'd've", "something'dve":
"something'd've", "somethingll": "something'll", "thats":
"that's", "thered": "there'd", "thered've": "there'd've",
"there'dve": "there'd've", "therere": "there're", "theres":
"there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve":
"they'd've", "theyll": "they'll", "theyre": "they're", "theyve":
"they've", "twas": "'twas", "wasnt": "wasn't", "wed've":
"we'd've", "we'dve": "we'd've", "weve": "we've", "werent":
"weren't", "whatll": "what'll", "whatre": "what're", "whats":
"what's", "whatve": "what've", "whens": "when's", "whered":
"where'd", "wheres": "where's", "whereve": "where've", "whod":
"who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl":
"who'll", "whos": "who's", "whove": "who've", "whyll": "why'll",
"whyre": "why're", "whys": "why's", "wont": "won't", "wouldve":
"would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll":
"y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've",
"y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd":
"you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll":
"you'll", "youre": "you're", "youve": "you've"
}
manual_map = { 'none': '0',
'zero': '0',
'one': '1',
'two': '2',
'three': '3',
'four': '4',
'five': '5',
'six': '6',
'seven': '7',
'eight': '8',
'nine': '9',
'ten': '10'}
articles = ['a', 'an', 'the']
period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
comma_strip = re.compile("(\d)(\,)(\d)")
punct = [';', r"/", '[', ']', '"', '{', '}',
'(', ')', '=', '+', '\\', '_', '-',
'>', '<', '@', '`', ',', '?', '!']
def get_score(occurences):
if occurences == 0:
return 0
elif occurences == 1:
return 0.3
elif occurences == 2:
return 0.6
elif occurences == 3:
return 0.9
else:
return 1
def process_punctuation(inText):
outText = inText
for p in punct:
if (p + ' ' in inText or ' ' + p in inText) \
or (re.search(comma_strip, inText) != None):
outText = outText.replace(p, '')
else:
outText = outText.replace(p, ' ')
outText = period_strip.sub("", outText, re.UNICODE)
return outText
def process_digit_article(inText):
outText = []
tempText = inText.lower().split()
for word in tempText:
word = manual_map.setdefault(word, word)
if word not in articles:
outText.append(word)
else:
pass
for wordId, word in enumerate(outText):
if word in contractions:
outText[wordId] = contractions[word]
outText = ' '.join(outText)
return outText
def multiple_replace(text, wordDict):
for key in wordDict:
text = text.replace(key, wordDict[key])
return text
def preprocess_answer(answer):
answer = process_digit_article(process_punctuation(answer))
answer = answer.replace(',', '')
return answer
def filter_answers(answers_dset, min_occurence):
"""This will change the answer to preprocessed version
"""
occurence = {}
for ans_entry in answers_dset:
answers = ans_entry['answers']
gtruth = ans_entry['multiple_choice_answer']
gtruth = preprocess_answer(gtruth)
if gtruth not in occurence:
occurence[gtruth] = set()
occurence[gtruth].add(ans_entry['question_id'])
occ_keys = list(occurence.keys()) # fix for python3
for answer in occ_keys:
if len(occurence[answer]) < min_occurence:
occurence.pop(answer)
print('Num of answers that appear >= %d times: %d' % (
min_occurence, len(occurence)))
return occurence
def create_ans2label(occurence, name, cache_root='data/cache'):
"""Note that this will also create label2ans.pkl at the same time
occurence: dict {answer -> whatever}
name: prefix of the output file
cache_root: str
IMPORTANT MODIFICATION: need to sort keys for consistent label mapping
"""
srt_keys = sorted(list(occurence.keys()))
ans2label = {}
label2ans = []
label = 0
for answer in srt_keys:
label2ans.append(answer)
ans2label[answer] = label
label += 1
utils.create_dir(cache_root)
cache_file = os.path.join(cache_root, name+'_ans2label.pkl')
cPickle.dump(ans2label, open(cache_file, 'wb'))
cache_file = os.path.join(cache_root, name+'_label2ans.pkl')
cPickle.dump(label2ans, open(cache_file, 'wb'))
return ans2label
def compute_target(answers_dset, ans2label, name, cache_root='data/cache'):
"""Augment answers_dset with soft score as label
***answers_dset should be preprocessed***
Write result into a cache file
"""
target = []
for ans_entry in tqdm.tqdm(answers_dset):
answers = ans_entry['answers']
answer_count = {}
for answer in answers:
answer_ = answer['answer']
# BUG FIX - added pre-processing
answer_ = preprocess_answer(answer_)
answer_count[answer_] = answer_count.get(answer_, 0) + 1
labels = []
scores = []
for answer in answer_count:
if answer not in ans2label:
continue
labels.append(ans2label[answer])
score = get_score(answer_count[answer])
scores.append(score)
target.append({
'question_id': ans_entry['question_id'],
'image_id': ans_entry['image_id'],
'labels': labels,
'scores': scores
})
utils.create_dir(cache_root)
cache_file = os.path.join(cache_root, name+'_target.pkl')
cPickle.dump(target, open(cache_file, 'wb'))
return target
def get_answer(qid, answers):
for ans in answers:
if ans['question_id'] == qid:
return ans
def get_question(qid, questions):
for question in questions:
if question['question_id'] == qid:
return question
def compute_softscore(dataroot, ver):
train_answer_file = os.path.join(dataroot, ver, 'v2_mscoco_train2014_annotations.json')
train_answers = json.load(open(train_answer_file))['annotations']
val_answer_file = os.path.join(dataroot, ver, 'v2_mscoco_val2014_annotations.json')
val_answers = json.load(open(val_answer_file))['annotations']
OCCUR_FILE = os.path.join(dataroot, 'occurence.pkl')
if os.path.isfile(OCCUR_FILE):
print('USING EXISTING OCCURENCE FILE')
with open(OCCUR_FILE, 'rb') as f:
occurence = cPickle.load(f)
else:
if ver != 'clean':
print('WARNING: For consistent logits, compute_softscore.py must first be run with --ver clean')
exit()
answers = train_answers + val_answers
occurence = filter_answers(answers, 9)
cPickle.dump(occurence, open(OCCUR_FILE, 'wb'))
CACHE_ROOT = os.path.join(dataroot, ver, 'cache')
ans2label = create_ans2label(occurence, 'trainval', CACHE_ROOT)
compute_target(train_answers, ans2label, 'train', CACHE_ROOT)
compute_target(val_answers, ans2label, 'val', CACHE_ROOT)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', type=str, default='../data/')
parser.add_argument('--ver', type=str, default='clean', help='version of the VQAv2 dataset to process. "clean" for the original data. default: clean')
args = parser.parse_args()
compute_softscore(args.dataroot, args.ver)