Spaces:
Runtime error
Runtime error
import re | |
from utils.normalizer import str_normalize | |
from utils.wtq.evaluator import to_value_list, check_denotation | |
from utils.mmqa.evaluator import acc | |
class Evaluator: | |
def __init__(self): | |
pass | |
def evaluate( | |
self, | |
pred_answer, | |
gold_answer, | |
dataset, | |
allow_semantic=True, | |
question=None | |
): | |
if dataset == 'wikitq': | |
return self.eval_ex_match(pred_answer, gold_answer, allow_semantic, question) | |
elif dataset == 'tab_fact': | |
return self.eval_tabfact_match(pred_answer, gold_answer) | |
elif dataset == 'mmqa': | |
# For more metrics on MMQA, | |
# please use the utils/mmqa/eval_mmqa.py to call official on all prediction data | |
return self.eval_mmqa_match(pred_answer, gold_answer) | |
else: | |
raise ValueError(f'{dataset} evaluator is not supported.') | |
def eval_ex_match(self, pred, gold, allow_semantic=True, question=None): | |
pred = [str(p).lower().strip() for p in pred] | |
gold = [str(g).lower().strip() for g in gold] | |
if not allow_semantic: | |
# WikiTQ eval w. string normalization using recognizer | |
pred = [str_normalize(span) for span in pred] | |
gold = [str_normalize(span) for span in gold] | |
pred = to_value_list(pred) | |
gold = to_value_list(gold) | |
return check_denotation(pred, gold) | |
else: | |
assert isinstance(question, str) | |
question = re.sub('\s+', ' ', question).strip().lower() | |
pred = [str_normalize(span) for span in pred] | |
gold = [str_normalize(span) for span in gold] | |
pred = sorted(list(set(pred))) | |
gold = sorted(list(set(gold))) | |
# (1) 0 matches 'no', 1 matches 'yes'; 0 matches 'more', 1 matches 'less', etc. | |
if len(pred) == 1 and len(gold) == 1: | |
if (pred[0] == '0' and gold[0] == 'no') \ | |
or (pred[0] == '1' and gold[0] == 'yes'): | |
return True | |
question_tokens = question.split() | |
try: | |
pos_or = question_tokens.index('or') | |
token_before_or, token_after_or = question_tokens[pos_or - 1], question_tokens[pos_or + 1] | |
if (pred[0] == '0' and gold[0] == token_after_or) \ | |
or (pred[0] == '1' and gold[0] == token_before_or): | |
return True | |
except Exception as e: | |
pass | |
# (2) Number value (allow units) and Date substring match | |
if len(pred) == 1 and len(gold) == 1: | |
NUMBER_UNITS_PATTERN = re.compile('^\$*[+-]?([0-9]*[.])?[0-9]+(\s*%*|\s+\w+)$') | |
DATE_PATTERN = re.compile('[0-9]{4}-[0-9]{1,2}-[0-9]{1,2}\s*([0-9]{1,2}:[0-9]{1,2}:[0-9]{1,2})?') | |
DURATION_PATTERN = re.compile('(P|PT)(\d+)(Y|M|D|H|S)') | |
p, g = pred[0], gold[0] | |
# Restore `duration` type, e.g., from 'P3Y' -> '3' | |
if re.match(DURATION_PATTERN, p): | |
p = re.match(DURATION_PATTERN, p).group(2) | |
if re.match(DURATION_PATTERN, g): | |
g = re.match(DURATION_PATTERN, g).group(2) | |
match = False | |
num_flag, date_flag = False, False | |
# Number w. unit match after string normalization. | |
# Either pred or gold being number w. units suffices it. | |
if re.match(NUMBER_UNITS_PATTERN, p) or re.match(NUMBER_UNITS_PATTERN, g): | |
num_flag = True | |
# Date match after string normalization. | |
# Either pred or gold being date suffices it. | |
if re.match(DATE_PATTERN, p) or re.match(DATE_PATTERN, g): | |
date_flag = True | |
if num_flag: | |
p_set, g_set = set(p.split()), set(g.split()) | |
if p_set.issubset(g_set) or g_set.issubset(p_set): | |
match = True | |
if date_flag: | |
p_set, g_set = set(p.replace('-', ' ').split()), set(g.replace('-', ' ').split()) | |
if p_set.issubset(g_set) or g_set.issubset(p_set): | |
match = True | |
if match: | |
return True | |
pred = to_value_list(pred) | |
gold = to_value_list(gold) | |
return check_denotation(pred, gold) | |
def eval_tabfact_match(self, pred, gold): | |
if isinstance(pred, list): | |
pred = pred[0] | |
pred, gold = str(pred), str(gold) | |
return pred == gold | |
def eval_mmqa_match(self, pred_answer, gold_answer): | |
return acc(pred_answer, gold_answer) | |