Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import re | |
| from tqdm import tqdm | |
| class EvalAIAnswerProcessor: | |
| """ | |
| Processes an answer similar to Eval AI | |
| copied from | |
| https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897 | |
| """ | |
| CONTRACTIONS = { | |
| "aint": "ain't", | |
| "arent": "aren't", | |
| "cant": "can't", | |
| "couldve": "could've", | |
| "couldnt": "couldn't", | |
| "couldn'tve": "couldn't've", | |
| "couldnt've": "couldn't've", | |
| "didnt": "didn't", | |
| "doesnt": "doesn't", | |
| "dont": "don't", | |
| "hadnt": "hadn't", | |
| "hadnt've": "hadn't've", | |
| "hadn'tve": "hadn't've", | |
| "hasnt": "hasn't", | |
| "havent": "haven't", | |
| "hed": "he'd", | |
| "hed've": "he'd've", | |
| "he'dve": "he'd've", | |
| "hes": "he's", | |
| "howd": "how'd", | |
| "howll": "how'll", | |
| "hows": "how's", | |
| "Id've": "I'd've", | |
| "I'dve": "I'd've", | |
| "Im": "I'm", | |
| "Ive": "I've", | |
| "isnt": "isn't", | |
| "itd": "it'd", | |
| "itd've": "it'd've", | |
| "it'dve": "it'd've", | |
| "itll": "it'll", | |
| "let's": "let's", | |
| "maam": "ma'am", | |
| "mightnt": "mightn't", | |
| "mightnt've": "mightn't've", | |
| "mightn'tve": "mightn't've", | |
| "mightve": "might've", | |
| "mustnt": "mustn't", | |
| "mustve": "must've", | |
| "neednt": "needn't", | |
| "notve": "not've", | |
| "oclock": "o'clock", | |
| "oughtnt": "oughtn't", | |
| "ow's'at": "'ow's'at", | |
| "'ows'at": "'ow's'at", | |
| "'ow'sat": "'ow's'at", | |
| "shant": "shan't", | |
| "shed've": "she'd've", | |
| "she'dve": "she'd've", | |
| "she's": "she's", | |
| "shouldve": "should've", | |
| "shouldnt": "shouldn't", | |
| "shouldnt've": "shouldn't've", | |
| "shouldn'tve": "shouldn't've", | |
| "somebody'd": "somebodyd", | |
| "somebodyd've": "somebody'd've", | |
| "somebody'dve": "somebody'd've", | |
| "somebodyll": "somebody'll", | |
| "somebodys": "somebody's", | |
| "someoned": "someone'd", | |
| "someoned've": "someone'd've", | |
| "someone'dve": "someone'd've", | |
| "someonell": "someone'll", | |
| "someones": "someone's", | |
| "somethingd": "something'd", | |
| "somethingd've": "something'd've", | |
| "something'dve": "something'd've", | |
| "somethingll": "something'll", | |
| "thats": "that's", | |
| "thered": "there'd", | |
| "thered've": "there'd've", | |
| "there'dve": "there'd've", | |
| "therere": "there're", | |
| "theres": "there's", | |
| "theyd": "they'd", | |
| "theyd've": "they'd've", | |
| "they'dve": "they'd've", | |
| "theyll": "they'll", | |
| "theyre": "they're", | |
| "theyve": "they've", | |
| "twas": "'twas", | |
| "wasnt": "wasn't", | |
| "wed've": "we'd've", | |
| "we'dve": "we'd've", | |
| "weve": "we've", | |
| "werent": "weren't", | |
| "whatll": "what'll", | |
| "whatre": "what're", | |
| "whats": "what's", | |
| "whatve": "what've", | |
| "whens": "when's", | |
| "whered": "where'd", | |
| "wheres": "where's", | |
| "whereve": "where've", | |
| "whod": "who'd", | |
| "whod've": "who'd've", | |
| "who'dve": "who'd've", | |
| "wholl": "who'll", | |
| "whos": "who's", | |
| "whove": "who've", | |
| "whyll": "why'll", | |
| "whyre": "why're", | |
| "whys": "why's", | |
| "wont": "won't", | |
| "wouldve": "would've", | |
| "wouldnt": "wouldn't", | |
| "wouldnt've": "wouldn't've", | |
| "wouldn'tve": "wouldn't've", | |
| "yall": "y'all", | |
| "yall'll": "y'all'll", | |
| "y'allll": "y'all'll", | |
| "yall'd've": "y'all'd've", | |
| "y'alld've": "y'all'd've", | |
| "y'all'dve": "y'all'd've", | |
| "youd": "you'd", | |
| "youd've": "you'd've", | |
| "you'dve": "you'd've", | |
| "youll": "you'll", | |
| "youre": "you're", | |
| "youve": "you've", | |
| } | |
| NUMBER_MAP = { | |
| "none": "0", | |
| "zero": "0", | |
| "one": "1", | |
| "two": "2", | |
| "three": "3", | |
| "four": "4", | |
| "five": "5", | |
| "six": "6", | |
| "seven": "7", | |
| "eight": "8", | |
| "nine": "9", | |
| "ten": "10", | |
| } | |
| ARTICLES = ["a", "an", "the"] | |
| PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)") | |
| COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)") | |
| PUNCTUATIONS = [ | |
| ";", | |
| r"/", | |
| "[", | |
| "]", | |
| '"', | |
| "{", | |
| "}", | |
| "(", | |
| ")", | |
| "=", | |
| "+", | |
| "\\", | |
| "_", | |
| "-", | |
| ">", | |
| "<", | |
| "@", | |
| "`", | |
| ",", | |
| "?", | |
| "!", | |
| ] | |
| def __init__(self, *args, **kwargs): | |
| pass | |
| def word_tokenize(self, word): | |
| word = word.lower() | |
| word = word.replace(",", "").replace("?", "").replace("'s", " 's") | |
| return word.strip() | |
| def process_punctuation(self, in_text): | |
| out_text = in_text | |
| for p in self.PUNCTUATIONS: | |
| if (p + " " in in_text or " " + p in in_text) or ( | |
| re.search(self.COMMA_STRIP, in_text) is not None | |
| ): | |
| out_text = out_text.replace(p, "") | |
| else: | |
| out_text = out_text.replace(p, " ") | |
| out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE) | |
| return out_text | |
| def process_digit_article(self, in_text): | |
| out_text = [] | |
| temp_text = in_text.lower().split() | |
| for word in temp_text: | |
| word = self.NUMBER_MAP.setdefault(word, word) | |
| if word not in self.ARTICLES: | |
| out_text.append(word) | |
| else: | |
| pass | |
| for word_id, word in enumerate(out_text): | |
| if word in self.CONTRACTIONS: | |
| out_text[word_id] = self.CONTRACTIONS[word] | |
| out_text = " ".join(out_text) | |
| return out_text | |
| def __call__(self, item): | |
| item = self.word_tokenize(item) | |
| item = item.replace("\n", " ").replace("\t", " ").strip() | |
| item = self.process_punctuation(item) | |
| item = self.process_digit_article(item) | |
| return item | |
| class TextVQAAccuracyEvaluator: | |
| def __init__(self): | |
| self.answer_processor = EvalAIAnswerProcessor() | |
| def _compute_answer_scores(self, raw_answers): | |
| """ | |
| compute the accuracy (soft score) of human answers | |
| """ | |
| answers = [self.answer_processor(a) for a in raw_answers] | |
| assert len(answers) == 10 | |
| gt_answers = list(enumerate(answers)) | |
| unique_answers = set(answers) | |
| unique_answer_scores = {} | |
| for unique_answer in unique_answers: | |
| accs = [] | |
| for gt_answer in gt_answers: | |
| other_answers = [item for item in gt_answers if item != gt_answer] | |
| matching_answers = [ | |
| item for item in other_answers if item[1] == unique_answer | |
| ] | |
| acc = min(1, float(len(matching_answers)) / 3) | |
| accs.append(acc) | |
| unique_answer_scores[unique_answer] = sum(accs) / len(accs) | |
| return unique_answer_scores | |
| def eval_pred_list(self, pred_list): | |
| pred_scores = [] | |
| for entry in tqdm(pred_list): | |
| pred_answer = self.answer_processor(entry["pred_answer"]) | |
| unique_answer_scores = self._compute_answer_scores(entry["gt_answers"]) | |
| score = unique_answer_scores.get(pred_answer, 0.0) | |
| pred_scores.append(score) | |
| accuracy = sum(pred_scores) / len(pred_scores) | |
| return accuracy | |
| class STVQAAccuracyEvaluator: | |
| def __init__(self): | |
| self.answer_processor = EvalAIAnswerProcessor() | |
| def eval_pred_list(self, pred_list): | |
| pred_scores = [] | |
| for entry in pred_list: | |
| pred_answer = self.answer_processor(entry["pred_answer"]) | |
| gts = [self.answer_processor(a) for a in entry["gt_answers"]] | |
| score = 1.0 if pred_answer in gts else 0.0 | |
| pred_scores.append(score) | |
| accuracy = sum(pred_scores) / len(pred_scores) | |
| return accuracy | |
| class STVQAANLSEvaluator: | |
| def __init__(self): | |
| import editdistance # install with `pip install editdistance` | |
| self.get_edit_distance = editdistance.eval | |
| def get_anls(self, s1, s2): | |
| s1 = s1.lower().strip() | |
| s2 = s2.lower().strip() | |
| iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2)) | |
| anls = iou if iou >= 0.5 else 0.0 | |
| return anls | |
| def eval_pred_list(self, pred_list): | |
| pred_scores = [] | |
| for entry in pred_list: | |
| anls = max( | |
| self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"] | |
| ) | |
| pred_scores.append(anls) | |
| accuracy = sum(pred_scores) / len(pred_scores) | |
| return accuracy | |
| class TextCapsBleu4Evaluator: | |
| def __init__(self): | |
| # The following script requires Java 1.8.0 and pycocotools installed. | |
| # The pycocoevalcap can be installed with pip as | |
| # pip install git+https://github.com/ronghanghu/coco-caption.git@python23 | |
| # Original pycocoevalcap code is at https://github.com/tylin/coco-caption | |
| # but has no python3 support yet. | |
| try: | |
| from pycocoevalcap.bleu.bleu import Bleu | |
| from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer | |
| except ModuleNotFoundError: | |
| print( | |
| "Please install pycocoevalcap module using " | |
| "pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa | |
| ) | |
| raise | |
| self.tokenizer = PTBTokenizer() | |
| self.scorer = Bleu(4) | |
| def eval_pred_list(self, pred_list): | |
| # Create reference and hypotheses captions. | |
| gts = {} | |
| res = {} | |
| for idx, entry in enumerate(pred_list): | |
| gts[idx] = [{"caption": a} for a in entry["gt_answers"]] | |
| res[idx] = [{"caption": entry["pred_answer"]}] | |
| gts = self.tokenizer.tokenize(gts) | |
| res = self.tokenizer.tokenize(res) | |
| score, _ = self.scorer.compute_score(gts, res) | |
| bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4) | |
| return bleu4 | |