Binder / utils /evaluator.py
Timothyxxx
Init
f6f97d8
import re
from utils.normalizer import str_normalize
from utils.wtq.evaluator import to_value_list, check_denotation
from utils.mmqa.evaluator import acc
class Evaluator:
def __init__(self):
pass
def evaluate(
self,
pred_answer,
gold_answer,
dataset,
allow_semantic=True,
question=None
):
if dataset == 'wikitq':
return self.eval_ex_match(pred_answer, gold_answer, allow_semantic, question)
elif dataset == 'tab_fact':
return self.eval_tabfact_match(pred_answer, gold_answer)
elif dataset == 'mmqa':
# For more metrics on MMQA,
# please use the utils/mmqa/eval_mmqa.py to call official on all prediction data
return self.eval_mmqa_match(pred_answer, gold_answer)
else:
raise ValueError(f'{dataset} evaluator is not supported.')
def eval_ex_match(self, pred, gold, allow_semantic=True, question=None):
pred = [str(p).lower().strip() for p in pred]
gold = [str(g).lower().strip() for g in gold]
if not allow_semantic:
# WikiTQ eval w. string normalization using recognizer
pred = [str_normalize(span) for span in pred]
gold = [str_normalize(span) for span in gold]
pred = to_value_list(pred)
gold = to_value_list(gold)
return check_denotation(pred, gold)
else:
assert isinstance(question, str)
question = re.sub('\s+', ' ', question).strip().lower()
pred = [str_normalize(span) for span in pred]
gold = [str_normalize(span) for span in gold]
pred = sorted(list(set(pred)))
gold = sorted(list(set(gold)))
# (1) 0 matches 'no', 1 matches 'yes'; 0 matches 'more', 1 matches 'less', etc.
if len(pred) == 1 and len(gold) == 1:
if (pred[0] == '0' and gold[0] == 'no') \
or (pred[0] == '1' and gold[0] == 'yes'):
return True
question_tokens = question.split()
try:
pos_or = question_tokens.index('or')
token_before_or, token_after_or = question_tokens[pos_or - 1], question_tokens[pos_or + 1]
if (pred[0] == '0' and gold[0] == token_after_or) \
or (pred[0] == '1' and gold[0] == token_before_or):
return True
except Exception as e:
pass
# (2) Number value (allow units) and Date substring match
if len(pred) == 1 and len(gold) == 1:
NUMBER_UNITS_PATTERN = re.compile('^\$*[+-]?([0-9]*[.])?[0-9]+(\s*%*|\s+\w+)$')
DATE_PATTERN = re.compile('[0-9]{4}-[0-9]{1,2}-[0-9]{1,2}\s*([0-9]{1,2}:[0-9]{1,2}:[0-9]{1,2})?')
DURATION_PATTERN = re.compile('(P|PT)(\d+)(Y|M|D|H|S)')
p, g = pred[0], gold[0]
# Restore `duration` type, e.g., from 'P3Y' -> '3'
if re.match(DURATION_PATTERN, p):
p = re.match(DURATION_PATTERN, p).group(2)
if re.match(DURATION_PATTERN, g):
g = re.match(DURATION_PATTERN, g).group(2)
match = False
num_flag, date_flag = False, False
# Number w. unit match after string normalization.
# Either pred or gold being number w. units suffices it.
if re.match(NUMBER_UNITS_PATTERN, p) or re.match(NUMBER_UNITS_PATTERN, g):
num_flag = True
# Date match after string normalization.
# Either pred or gold being date suffices it.
if re.match(DATE_PATTERN, p) or re.match(DATE_PATTERN, g):
date_flag = True
if num_flag:
p_set, g_set = set(p.split()), set(g.split())
if p_set.issubset(g_set) or g_set.issubset(p_set):
match = True
if date_flag:
p_set, g_set = set(p.replace('-', ' ').split()), set(g.replace('-', ' ').split())
if p_set.issubset(g_set) or g_set.issubset(p_set):
match = True
if match:
return True
pred = to_value_list(pred)
gold = to_value_list(gold)
return check_denotation(pred, gold)
def eval_tabfact_match(self, pred, gold):
if isinstance(pred, list):
pred = pred[0]
pred, gold = str(pred), str(gold)
return pred == gold
def eval_mmqa_match(self, pred_answer, gold_answer):
return acc(pred_answer, gold_answer)