vqa_accuracy / vqa_accuracy.py
Kamichanw's picture
Update vqa_accuracy.py
5a12027 verified
import datasets
import evaluate
import re
_DESCRIPTION = """
VQA accuracy is a evaluation metric which is robust to inter-human variability in phrasing the answers:
$$
\\text{Acc}(ans) = \\min \\left( \\frac{\\text{# humans that said }ans}{3}, 1 \\right)
$$
Where `ans` is answered by machine. In order to be consistent with 'human accuracies', machine accuracies are averaged over all 10 choose 9 sets of human annotators.
"""
_KWARGS_DESCRIPTION = """
Args:
predictions (`list` of `str`): Predicted answers.
references (`list` of `str` lists): Ground truth answers.
answer_types (`list` of `str`, *optional*): Answer types corresponding to each questions.
questions_type (`list` of `str`, *optional*): Question types corresponding to each questions.
Returns:
visual question answering accuracy (`float` or `int`): Accuracy accuracy. Minimum possible value is 0. Maximum possible value is 100.
"""
_CITATION = """
@InProceedings{{VQA},
author = {Stanislaw Antol and Aishwarya Agrawal and Jiasen Lu and Margaret Mitchell and Dhruv Batra and C. Lawrence Zitnick and Devi Parikh},
title = {{VQA}: {V}isual {Q}uestion {A}nswering},
booktitle = {International Conference on Computer Vision (ICCV)},
year = {2015},
}
"""
contractions = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
}
manualMap = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
articles = ["a", "an", "the"]
periodStrip = re.compile(r"(?!<=\d)(\.)(?!\d)")
commaStrip = re.compile(r"(\d)(\,)(\d)")
punct = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def processPunctuation(inText):
outText = inText
for p in punct:
if (p + " " in inText or " " + p in inText) or (
re.search(commaStrip, inText) != None
):
outText = outText.replace(p, "")
else:
outText = outText.replace(p, " ")
outText = periodStrip.sub("", outText, re.UNICODE)
return outText
def processDigitArticle(inText):
outText = []
tempText = inText.lower().split()
for word in tempText:
word = manualMap.setdefault(word, word)
if word not in articles:
outText.append(word)
else:
pass
for wordId, word in enumerate(outText):
if word in contractions:
outText[wordId] = contractions[word]
outText = " ".join(outText)
return outText
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class VQAAccuracy(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Sequence(
datasets.Value("string", id="sequence"), id="references"
),
"answer_types": datasets.Value("string", id="sequence"),
"question_types": datasets.Value("string", id="sequence"),
}
),
reference_urls=[
"https://visualqa.org/evaluation.html",
"https://github.com/GT-Vision-Lab/VQA/blob/master",
],
)
def _compute(self, predictions, references, answer_types=None, question_types=None):
if answer_types is None:
answer_types = [None] * len(predictions)
if question_types is None:
question_types = [None] * len(predictions)
if not len(predictions) == len(answer_types) == len(question_types):
raise ValueError(
"The length of predictions, answer_types and question_types doesn't match."
)
total, ans_type_dict, ques_type_dict = [], {}, {}
for pred, gts, ans_type, ques_type in zip(
predictions, references, answer_types, question_types
):
# to align with offical data postprocess
pred = pred.replace("\n", " ").replace("\t", " ").strip()
pred = processDigitArticle(processPunctuation(pred))
gts = [processDigitArticle(processPunctuation(gt_ans)) for gt_ans in gts]
# calculate vqa accuracy
accuracy = []
for i in range(len(gts)):
other_gt = gts[:i] + gts[i + 1 :]
matching_ans = [item for item in other_gt if item == pred]
accuracy.append(min(1, len(matching_ans) / 3))
vqa_acc = sum(accuracy) / len(accuracy)
total.append(vqa_acc)
if ans_type is not None:
if ans_type not in ans_type_dict:
ans_type_dict[ans_type] = []
ans_type_dict[ans_type].append(vqa_acc)
if ques_type is not None:
if ques_type not in ques_type_dict:
ques_type_dict[ques_type] = []
ques_type_dict[ques_type].append(vqa_acc)
# the following key names follow the naming of the official evaluation results
result = {"overall": 100 * sum(total) / len(total)}
if len(ans_type_dict) > 0:
result["perAnswerType"] = {
ans_type: 100 * sum(accuracy_list) / len(accuracy_list)
for ans_type, accuracy_list in ans_type_dict.items()
}
if len(ques_type_dict) > 0:
result["perQuestionType"] = {
ques_type: 100 * sum(accuracy_list) / len(accuracy_list)
for ques_type, accuracy_list in ques_type_dict.items()
}
return result