"""Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ # coding=utf-8 __author__ = 'aagrawal' import re # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). import sys class VQAEval: def __init__(self, vqa=None, vqaRes=None, n=2): self.n = n self.accuracy = {} self.evalQA = {} self.evalQuesType = {} self.evalAnsType = {} self.vqa = vqa self.vqaRes = vqaRes if vqa is not None: self.params = {'question_id': vqa.getQuesIds()} self.contractions = { 'aint': "ain't", 'arent': "aren't", 'cant': "can't", 'couldve': "could've", 'couldnt': "couldn't", "couldn'tve": "couldn't've", "couldnt've": "couldn't've", 'didnt': "didn't", 'doesnt': "doesn't", 'dont': "don't", 'hadnt': "hadn't", "hadnt've": "hadn't've", "hadn'tve": "hadn't've", 'hasnt': "hasn't", 'havent': "haven't", 'hed': "he'd", "hed've": "he'd've", "he'dve": "he'd've", 'hes': "he's", 'howd': "how'd", 'howll': "how'll", 'hows': "how's", "Id've": "I'd've", "I'dve": "I'd've", 'Im': "I'm", 'Ive': "I've", 'isnt': "isn't", 'itd': "it'd", "itd've": "it'd've", "it'dve": "it'd've", 'itll': "it'll", "let's": "let's", 'maam': "ma'am", 'mightnt': "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", 'mightve': "might've", 'mustnt': "mustn't", 'mustve': "must've", 'neednt': "needn't", 'notve': "not've", 'oclock': "o'clock", 'oughtnt': "oughtn't", "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", 'shant': "shan't", "shed've": "she'd've", "she'dve": "she'd've", "she's": "she's", 'shouldve': "should've", 'shouldnt': "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", "somebody'd": 'somebodyd', "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", 'somebodyll': "somebody'll", 'somebodys': "somebody's", 'someoned': "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", 'someonell': "someone'll", 'someones': "someone's", 'somethingd': "something'd", "somethingd've": "something'd've", "something'dve": "something'd've", 'somethingll': "something'll", 'thats': "that's", 'thered': "there'd", "thered've": "there'd've", "there'dve": "there'd've", 'therere': "there're", 'theres': "there's", 'theyd': "they'd", "theyd've": "they'd've", "they'dve": "they'd've", 'theyll': "they'll", 'theyre': "they're", 'theyve': "they've", 'twas': "'twas", 'wasnt': "wasn't", "wed've": "we'd've", "we'dve": "we'd've", 'weve': "we've", 'werent': "weren't", 'whatll': "what'll", 'whatre': "what're", 'whats': "what's", 'whatve': "what've", 'whens': "when's", 'whered': "where'd", 'wheres': "where's", 'whereve': "where've", 'whod': "who'd", "whod've": "who'd've", "who'dve": "who'd've", 'wholl': "who'll", 'whos': "who's", 'whove': "who've", 'whyll': "why'll", 'whyre': "why're", 'whys': "why's", 'wont': "won't", 'wouldve': "would've", 'wouldnt': "wouldn't", "wouldnt've": "wouldn't've", "wouldn'tve": "wouldn't've", 'yall': "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", 'youd': "you'd", "youd've": "you'd've", "you'dve": "you'd've", 'youll': "you'll", 'youre': "you're", 'youve': "you've", } self.manualMap = { 'none': '0', 'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4', 'five': '5', 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9', 'ten': '10', } self.articles = ['a', 'an', 'the'] self.periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') self.commaStrip = re.compile('(\d)(,)(\d)') self.punct = [ ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-', '>', '<', '@', '`', ',', '?', '!', ] def evaluate(self, quesIds=None): if quesIds == None: quesIds = [quesId for quesId in self.params['question_id']] gts = {} res = {} for quesId in quesIds: gts[quesId] = self.vqa.qa[quesId] res[quesId] = self.vqaRes.qa[quesId] # ================================================= # Compute accuracy # ================================================= accQA = [] accQuesType = {} accAnsType = {} print('computing accuracy') step = 0 for quesId in quesIds: resAns = res[quesId]['answer'] resAns = resAns.replace('\n', ' ') resAns = resAns.replace('\t', ' ') resAns = resAns.strip() resAns = self.processPunctuation(resAns) resAns = self.processDigitArticle(resAns) gtAcc = [] gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']] if len(set(gtAnswers)) > 1: for ansDic in gts[quesId]['answers']: ansDic['answer'] = self.processPunctuation( ansDic['answer']) for gtAnsDatum in gts[quesId]['answers']: otherGTAns = [ item for item in gts[quesId]['answers'] if item != gtAnsDatum ] matchingAns = [ item for item in otherGTAns if item['answer'] == resAns ] acc = min(1, float(len(matchingAns)) / 3) gtAcc.append(acc) quesType = gts[quesId]['question_type'] ansType = gts[quesId]['answer_type'] avgGTAcc = float(sum(gtAcc)) / len(gtAcc) accQA.append(avgGTAcc) if quesType not in accQuesType: accQuesType[quesType] = [] accQuesType[quesType].append(avgGTAcc) if ansType not in accAnsType: accAnsType[ansType] = [] accAnsType[ansType].append(avgGTAcc) self.setEvalQA(quesId, avgGTAcc) self.setEvalQuesType(quesId, quesType, avgGTAcc) self.setEvalAnsType(quesId, ansType, avgGTAcc) if step % 100 == 0: self.updateProgress(step / float(len(quesIds))) step = step + 1 self.setAccuracy(accQA, accQuesType, accAnsType) print('Done computing accuracy') def processPunctuation(self, inText): outText = inText for p in self.punct: if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None): outText = outText.replace(p, '') else: outText = outText.replace(p, ' ') outText = self.periodStrip.sub('', outText, re.UNICODE) return outText def processDigitArticle(self, inText): outText = [] tempText = inText.lower().split() for word in tempText: word = self.manualMap.setdefault(word, word) if word not in self.articles: outText.append(word) else: pass for wordId, word in enumerate(outText): if word in self.contractions: outText[wordId] = self.contractions[word] outText = ' '.join(outText) return outText def setAccuracy(self, accQA, accQuesType, accAnsType): self.accuracy['overall'] = round(100 * float(sum(accQA)) / len(accQA), self.n) self.accuracy['perQuestionType'] = { quesType: round( 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]), self.n, ) for quesType in accQuesType } self.accuracy['perAnswerType'] = { ansType: round( 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n) for ansType in accAnsType } def setEvalQA(self, quesId, acc): self.evalQA[quesId] = round(100 * acc, self.n) def setEvalQuesType(self, quesId, quesType, acc): if quesType not in self.evalQuesType: self.evalQuesType[quesType] = {} self.evalQuesType[quesType][quesId] = round(100 * acc, self.n) def setEvalAnsType(self, quesId, ansType, acc): if ansType not in self.evalAnsType: self.evalAnsType[ansType] = {} self.evalAnsType[ansType][quesId] = round(100 * acc, self.n) def updateProgress(self, progress): barLength = 20 status = '' if isinstance(progress, int): progress = float(progress) if not isinstance(progress, float): progress = 0 status = 'error: progress var must be float\r\n' if progress < 0: progress = 0 status = 'Halt...\r\n' if progress >= 1: progress = 1 status = 'Done...\r\n' block = int(round(barLength * progress)) text = '\rFinshed Percent: [{0}] {1}% {2}'.format( '#' * block + '-' * (barLength - block), int(progress * 100), status) sys.stdout.write(text) sys.stdout.flush()