import argparse
import json
from glob import glob
import math
from collections import defaultdict
import re
import pandas as pd
def make_option_labels(options):
option_labels = []
for i, option in enumerate(options):
option_labels.append(f"{chr(65 + i)}. {option.strip()}")
return option_labels
class AverageMeter(object):
"""
Computes and stores the average and current value.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def clean_counting_answer(answer):
if answer is None or answer.strip() == "":
return 0
answer_matches = re.findall(r"(.*?)", answer, flags=re.DOTALL)
if len(answer_matches) > 0:
answer = answer_matches[0]
clean_answer = answer.strip().lower()
if clean_answer.endswith("."):
clean_answer = clean_answer[:-1].strip()
if clean_answer in words_to_int:
clean_answer = words_to_int[clean_answer]
else:
try:
clean_answer = int(round(float(clean_answer)))
except ValueError:
# or pick the LAST number in the string
match = re.search(r'\d+(?:\.\d+)?(?=(?![\s\S]*\d))', clean_answer)
if match:
clean_answer = int(round(float(match.group(0))))
else:
matched = False
# or pick the FIRST spelled out number in the string
for word, number in words_to_int.items():
if word in clean_answer:
clean_answer = number
matched = True
break
if not matched:
print(f"WARNING: Unable to convert answer '{answer}' to int.")
clean_answer = 0
return clean_answer
def clean_multiple_choice_answer(answer, options):
if answer is None:
return ""
answer_matches = re.findall(r"(.*?)", answer, flags=re.DOTALL)
if len(answer_matches) > 0:
answer = answer_matches[0]
clean_answer = answer.strip()
if answer.startswith("Answer:"):
clean_answer = clean_answer.replace("Answer:", "").strip()
if answer.startswith("The answer is "):
clean_answer = clean_answer.replace("The answer is ", "").strip()
if answer.startswith("The best answer is "):
clean_answer = clean_answer.replace("The best answer is ", "").strip()
if answer.endswith("."):
clean_answer = clean_answer[:-1].strip()
if len(clean_answer) > 1:
# If the answer is longer than one character, we assume it may contain the full label, e.g. "A. option text"
for option in options:
if option in clean_answer:
clean_answer = option[0]
break
if len(clean_answer) > 1:
# If the answer is still longer than one character, we assume it is a short label, e.g. "A" or "B"
for option in options:
if option[0] in clean_answer:
clean_answer = option[0]
break
return clean_answer
def clean_ocr_answer(answer):
answer_matches = re.findall(r"(.*?)", answer, flags=re.DOTALL)
if len(answer_matches) > 0:
answer = answer_matches[0]
clean_answer = answer.strip()
clean_answer = extract_text_from_quotes(clean_answer)
clean_answer = clean_text(clean_answer)
return clean_answer
def validate_choice_answer(answer, benchmark_truth):
all_options = make_option_labels(benchmark_truth["options"])
# clean the reponse
clean_answer = clean_multiple_choice_answer(answer["response"], all_options)
# map the truth to option key
correct_option_index = benchmark_truth["options"].index(benchmark_truth["ground_truth"])
correct_option_enum = all_options[correct_option_index][0]
return (clean_answer == correct_option_enum)
words_to_int = {
"zero": 0,
"one": 1,
"two": 2,
"three": 3,
"four": 4,
"five": 5,
"six": 6,
"seven": 7,
"eight": 8,
"nine": 9,
"ten": 10,
"eleven": 11,
"twelve": 12,
"thirteen": 13,
"fourteen": 14,
"fifteen": 15,
"sixteen": 16,
"seventeen": 17,
"eighteen": 18,
"nineteen": 19,
"twenty": 20,
"thirty": 30,
"forty": 40,
"fifty": 50,
"sixty": 60,
"seventy": 70,
"eighty": 80,
"ninety": 90,
"hundred": 100,
"thousand": 1000,
"million": 1000000,
}
def clean_text(s):
replace_characters = {
"ä": "a",
"á": "a",
"à": "a",
"â": "a",
"ã": "a",
"å": "a",
"ā": "a",
"ö": "o",
"ó": "o",
"ò": "o",
"ô": "o",
"õ": "o",
"ō": "o",
"ü": "u",
"ú": "u",
"ù": "u",
"û": "u",
"ū": "u",
"é": "e",
"ĕ": "e",
"ė": "e",
"ę": "e",
"ě": "e",
"ç": "c",
"ć": "c",
"č": "c",
"ñ": "n",
"ń": "n",
"ń": "n",
"ł": "l",
"ś": "s",
"š": "s",
"ź": "z",
"ż": "z",
"ý": "y",
"ŷ": "y",
"ÿ": "y",
"œ": "oe",
"æ": "ae",
"v": "u"
}
s = s.strip().lower()
# replace special characters
for char, replacement in replace_characters.items():
s = s.replace(char, replacement)
s = s.replace("\t", "").replace("\n", "").replace(".", "").replace("&", "").replace(";", "")\
.replace(",", "").replace("-", "").replace("–", "").replace("’", "'").replace(":", "").replace("·", " ")\
.replace("'", "").replace("“", "").replace("”", "").replace('"', "").replace("•", " ")\
.replace(" ", " ")
return s
def extract_text_from_quotes(s):
pattern = r"'([^']*?)\"|\"([^\"]*?)\"|`([^`]*?)`|“([^”]*?)”|‘([^’]*?)’"
matches = re.findall(pattern, s)
if matches:
# Extract and return the single letter
matched_group = [match for group in matches for match in group if match][0]
return matched_group
else:
# Return the original string if no match is found
return s
def main(args):
try:
from pprint import pprint as print
except ImportError:
pass
question_index = build_question_index_from_file(args.benchmark_file)
for file in glob(args.results_file):
print(f"Judging {file}")
_, scores = judge_file(file, question_index)
print(scores)
def build_question_index_from_file(benchmark_file):
with open(benchmark_file, "r") as f:
json_data = json.load(f)
hardness_levels = json_data["splits"]
df_gt = pd.DataFrame.from_dict(json_data["benchmark"]).set_index("question_id")
for level, ids in hardness_levels.items():
df_gt.loc[ids, "difficulty"] = level
return build_question_index_from_json(df_gt.reset_index().to_dict(orient="records"))
def build_question_index_from_json(benchmark_data):
question_index = {}
for question in benchmark_data:
question_id = question["question_id"]
question_index[question_id] = question
return question_index
def judge_file(results_file, question_index):
with open(results_file, "r") as f:
results_data = json.load(f)
for answer in results_data:
answer["question_id"] = str(answer["question_id"]) # ensure question_id is string
return judge(results_data, question_index)
def judge(results_data, question_index):
answer_index = {}
for answer in results_data:
answer_index[answer["question_id"]] = answer
non_answered_questions = set(question_index.keys()) - set(answer_index.keys())
excessive_answers = set(answer_index.keys()) - set(question_index.keys())
if len(non_answered_questions) > 0:
print("WARNING: Some question IDs in benchmark data are not found in results file:")
print(non_answered_questions)
results_data = results_data + [{"question_id": qid, "response": ""} for qid in non_answered_questions]
else:
print("All question IDs in benchmark data are found in results file.")
if len(excessive_answers) > 0:
print("WARNING: Some question IDs in results file are not found in benchmark data:")
print(excessive_answers)
print("These questions will be ignored in the evaluation.")
results_data = [answer for answer in results_data if answer["question_id"] in question_index]
else:
print("All question IDs in results file are found in benchmark data.")
print()
accuracy_meters = defaultdict(AverageMeter)
# process counting data and compute accuracy and MAE
correct = 0
total = 0
mae = 0
mse = 0
for answer in results_data:
if answer["question_id"] not in question_index:
continue
benchmark_truth = question_index[answer["question_id"]]
if benchmark_truth["question_type"] != "counting":
continue
clean_answer = clean_counting_answer(answer["response"])
gt = benchmark_truth["ground_truth"]
difference = abs(clean_answer - gt)
mae += difference
mse += difference ** 2
is_correct = (clean_answer == gt)
correct_count = 1 if is_correct else 0
answer["judge/correct"] = is_correct
answer["judge/extracted_answer"] = clean_answer
correct += correct_count
total += 1
accuracy_meters[benchmark_truth["source_file"]].update(correct_count)
# process OCR data and compute accuracy and ESD
correct = 0
total = 0
for answer in results_data:
if answer["question_id"] not in question_index:
continue
benchmark_truth = question_index[answer["question_id"]]
if benchmark_truth["question_type"] != "ocr":
continue
# clean the reponse
clean_answer = clean_ocr_answer(answer["response"])
clean_gt = clean_text(benchmark_truth["ground_truth"])
answer["judge/extracted_answer"] = clean_answer
is_correct = clean_answer == clean_gt
correct_count = 1 if is_correct else 0
answer["judge/correct"] = is_correct
correct += correct_count
total += 1
accuracy_meters[benchmark_truth["source_file"]].update(correct_count)
# process multiple choice data without binary options and compute accuracy
correct = 0
total = 0
for answer in results_data:
if answer["question_id"] not in question_index:
continue
benchmark_truth = question_index[answer["question_id"]]
if benchmark_truth["question_type"] != "choice" or len(benchmark_truth["options"]) != 4:
continue
is_correct = validate_choice_answer(answer, benchmark_truth)
correct_count = 1 if is_correct else 0
answer["judge/correct"] = is_correct
correct += correct_count
total += 1
accuracy_meters[benchmark_truth["source_file"]].update(correct_count)
# process binary choice data and compute accuracy
correct = 0
total = 0
for answer in results_data:
if answer["question_id"] not in question_index:
continue
benchmark_truth = question_index[answer["question_id"]]
if benchmark_truth["question_type"] != "choice" or len(benchmark_truth["options"]) != 2:
continue
is_correct = validate_choice_answer(answer, benchmark_truth)
correct_count = 1 if is_correct else 0
answer["judge/correct"] = is_correct
correct += correct_count
total += 1
# process binary choice data with correction and compute accuracy
correct = 0
total = 0
opposite_error_pairs = []
for answer in results_data:
if answer["question_id"] not in question_index:
continue
benchmark_truth = question_index[answer["question_id"]]
# skip if the question is not a binary choice or does not have an opposite
if benchmark_truth["question_type"] != "choice" or len(benchmark_truth["options"]) != 2 or pd.isna(benchmark_truth["opposite_of"]):
continue
is_correct = validate_choice_answer(answer, benchmark_truth)
if benchmark_truth["opposite_of"] in answer_index:
is_opposite_correct = validate_choice_answer(answer_index[benchmark_truth["opposite_of"]], question_index[benchmark_truth["opposite_of"]])
answer_index[benchmark_truth["opposite_of"]]["judge/correct"] = is_opposite_correct
else:
is_opposite_correct = False
answer["judge/correct"] = is_correct
if is_correct and is_opposite_correct:
correct += 1
accuracy_meters[benchmark_truth["source_file"]].update(1)
else:
opposite_error_pairs.append((answer["question_id"], benchmark_truth["opposite_of"]))
accuracy_meters[benchmark_truth["source_file"]].update(0)
total += 1
df_preds = pd.DataFrame(results_data).set_index("question_id")
df_gt = pd.DataFrame.from_dict(question_index).T.set_index("question_id")
df = df_preds.join(df_gt)
scores = {
"is_complete": len(non_answered_questions) == 0,
"is_excessive": len(excessive_answers) > 0,
**dict([
("accuracy/" + k.replace(".csv", ""), (v.sum/v.count)) for k, v in accuracy_meters.items()
]),
"accuracy/easy": df.query("difficulty == 'easy'")["judge/correct"].mean(),
"accuracy/medium": df.query("difficulty == 'medium'")["judge/correct"].mean(),
"accuracy/hard": df.query("difficulty == 'hard'")["judge/correct"].mean(),
"accuracy/total": df["judge/correct"].mean(),
}
return results_data, scores