import argparse import json from glob import glob import math from collections import defaultdict import re import pandas as pd def make_option_labels(options): option_labels = [] for i, option in enumerate(options): option_labels.append(f"{chr(65 + i)}. {option.strip()}") return option_labels class AverageMeter(object): """ Computes and stores the average and current value. """ def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def clean_counting_answer(answer): if answer is None or answer.strip() == "": return 0 answer_matches = re.findall(r"(.*?)", answer, flags=re.DOTALL) if len(answer_matches) > 0: answer = answer_matches[0] clean_answer = answer.strip().lower() if clean_answer.endswith("."): clean_answer = clean_answer[:-1].strip() if clean_answer in words_to_int: clean_answer = words_to_int[clean_answer] else: try: clean_answer = int(round(float(clean_answer))) except ValueError: # or pick the LAST number in the string match = re.search(r'\d+(?:\.\d+)?(?=(?![\s\S]*\d))', clean_answer) if match: clean_answer = int(round(float(match.group(0)))) else: matched = False # or pick the FIRST spelled out number in the string for word, number in words_to_int.items(): if word in clean_answer: clean_answer = number matched = True break if not matched: print(f"WARNING: Unable to convert answer '{answer}' to int.") clean_answer = 0 return clean_answer def clean_multiple_choice_answer(answer, options): if answer is None: return "" answer_matches = re.findall(r"(.*?)", answer, flags=re.DOTALL) if len(answer_matches) > 0: answer = answer_matches[0] clean_answer = answer.strip() if answer.startswith("Answer:"): clean_answer = clean_answer.replace("Answer:", "").strip() if answer.startswith("The answer is "): clean_answer = clean_answer.replace("The answer is ", "").strip() if answer.startswith("The best answer is "): clean_answer = clean_answer.replace("The best answer is ", "").strip() if answer.endswith("."): clean_answer = clean_answer[:-1].strip() if len(clean_answer) > 1: # If the answer is longer than one character, we assume it may contain the full label, e.g. "A. option text" for option in options: if option in clean_answer: clean_answer = option[0] break if len(clean_answer) > 1: # If the answer is still longer than one character, we assume it is a short label, e.g. "A" or "B" for option in options: if option[0] in clean_answer: clean_answer = option[0] break return clean_answer def clean_ocr_answer(answer): answer_matches = re.findall(r"(.*?)", answer, flags=re.DOTALL) if len(answer_matches) > 0: answer = answer_matches[0] clean_answer = answer.strip() clean_answer = extract_text_from_quotes(clean_answer) clean_answer = clean_text(clean_answer) return clean_answer def validate_choice_answer(answer, benchmark_truth): all_options = make_option_labels(benchmark_truth["options"]) # clean the reponse clean_answer = clean_multiple_choice_answer(answer["response"], all_options) # map the truth to option key correct_option_index = benchmark_truth["options"].index(benchmark_truth["ground_truth"]) correct_option_enum = all_options[correct_option_index][0] return (clean_answer == correct_option_enum) words_to_int = { "zero": 0, "one": 1, "two": 2, "three": 3, "four": 4, "five": 5, "six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, "eleven": 11, "twelve": 12, "thirteen": 13, "fourteen": 14, "fifteen": 15, "sixteen": 16, "seventeen": 17, "eighteen": 18, "nineteen": 19, "twenty": 20, "thirty": 30, "forty": 40, "fifty": 50, "sixty": 60, "seventy": 70, "eighty": 80, "ninety": 90, "hundred": 100, "thousand": 1000, "million": 1000000, } def clean_text(s): replace_characters = { "ä": "a", "á": "a", "à": "a", "â": "a", "ã": "a", "å": "a", "ā": "a", "ö": "o", "ó": "o", "ò": "o", "ô": "o", "õ": "o", "ō": "o", "ü": "u", "ú": "u", "ù": "u", "û": "u", "ū": "u", "é": "e", "ĕ": "e", "ė": "e", "ę": "e", "ě": "e", "ç": "c", "ć": "c", "č": "c", "ñ": "n", "ń": "n", "ń": "n", "ł": "l", "ś": "s", "š": "s", "ź": "z", "ż": "z", "ý": "y", "ŷ": "y", "ÿ": "y", "œ": "oe", "æ": "ae", "v": "u" } s = s.strip().lower() # replace special characters for char, replacement in replace_characters.items(): s = s.replace(char, replacement) s = s.replace("\t", "").replace("\n", "").replace(".", "").replace("&", "").replace(";", "")\ .replace(",", "").replace("-", "").replace("–", "").replace("’", "'").replace(":", "").replace("·", " ")\ .replace("'", "").replace("“", "").replace("”", "").replace('"', "").replace("•", " ")\ .replace(" ", " ") return s def extract_text_from_quotes(s): pattern = r"'([^']*?)\"|\"([^\"]*?)\"|`([^`]*?)`|“([^”]*?)”|‘([^’]*?)’" matches = re.findall(pattern, s) if matches: # Extract and return the single letter matched_group = [match for group in matches for match in group if match][0] return matched_group else: # Return the original string if no match is found return s def main(args): try: from pprint import pprint as print except ImportError: pass question_index = build_question_index_from_file(args.benchmark_file) for file in glob(args.results_file): print(f"Judging {file}") _, scores = judge_file(file, question_index) print(scores) def build_question_index_from_file(benchmark_file): with open(benchmark_file, "r") as f: json_data = json.load(f) hardness_levels = json_data["splits"] df_gt = pd.DataFrame.from_dict(json_data["benchmark"]).set_index("question_id") for level, ids in hardness_levels.items(): df_gt.loc[ids, "difficulty"] = level return build_question_index_from_json(df_gt.reset_index().to_dict(orient="records")) def build_question_index_from_json(benchmark_data): question_index = {} for question in benchmark_data: question_id = question["question_id"] question_index[question_id] = question return question_index def judge_file(results_file, question_index): with open(results_file, "r") as f: results_data = json.load(f) for answer in results_data: answer["question_id"] = str(answer["question_id"]) # ensure question_id is string return judge(results_data, question_index) def judge(results_data, question_index): answer_index = {} for answer in results_data: answer_index[answer["question_id"]] = answer non_answered_questions = set(question_index.keys()) - set(answer_index.keys()) excessive_answers = set(answer_index.keys()) - set(question_index.keys()) if len(non_answered_questions) > 0: print("WARNING: Some question IDs in benchmark data are not found in results file:") print(non_answered_questions) results_data = results_data + [{"question_id": qid, "response": ""} for qid in non_answered_questions] else: print("All question IDs in benchmark data are found in results file.") if len(excessive_answers) > 0: print("WARNING: Some question IDs in results file are not found in benchmark data:") print(excessive_answers) print("These questions will be ignored in the evaluation.") results_data = [answer for answer in results_data if answer["question_id"] in question_index] else: print("All question IDs in results file are found in benchmark data.") print() accuracy_meters = defaultdict(AverageMeter) # process counting data and compute accuracy and MAE correct = 0 total = 0 mae = 0 mse = 0 for answer in results_data: if answer["question_id"] not in question_index: continue benchmark_truth = question_index[answer["question_id"]] if benchmark_truth["question_type"] != "counting": continue clean_answer = clean_counting_answer(answer["response"]) gt = benchmark_truth["ground_truth"] difference = abs(clean_answer - gt) mae += difference mse += difference ** 2 is_correct = (clean_answer == gt) correct_count = 1 if is_correct else 0 answer["judge/correct"] = is_correct answer["judge/extracted_answer"] = clean_answer correct += correct_count total += 1 accuracy_meters[benchmark_truth["source_file"]].update(correct_count) # process OCR data and compute accuracy and ESD correct = 0 total = 0 for answer in results_data: if answer["question_id"] not in question_index: continue benchmark_truth = question_index[answer["question_id"]] if benchmark_truth["question_type"] != "ocr": continue # clean the reponse clean_answer = clean_ocr_answer(answer["response"]) clean_gt = clean_text(benchmark_truth["ground_truth"]) answer["judge/extracted_answer"] = clean_answer is_correct = clean_answer == clean_gt correct_count = 1 if is_correct else 0 answer["judge/correct"] = is_correct correct += correct_count total += 1 accuracy_meters[benchmark_truth["source_file"]].update(correct_count) # process multiple choice data without binary options and compute accuracy correct = 0 total = 0 for answer in results_data: if answer["question_id"] not in question_index: continue benchmark_truth = question_index[answer["question_id"]] if benchmark_truth["question_type"] != "choice" or len(benchmark_truth["options"]) != 4: continue is_correct = validate_choice_answer(answer, benchmark_truth) correct_count = 1 if is_correct else 0 answer["judge/correct"] = is_correct correct += correct_count total += 1 accuracy_meters[benchmark_truth["source_file"]].update(correct_count) # process binary choice data and compute accuracy correct = 0 total = 0 for answer in results_data: if answer["question_id"] not in question_index: continue benchmark_truth = question_index[answer["question_id"]] if benchmark_truth["question_type"] != "choice" or len(benchmark_truth["options"]) != 2: continue is_correct = validate_choice_answer(answer, benchmark_truth) correct_count = 1 if is_correct else 0 answer["judge/correct"] = is_correct correct += correct_count total += 1 # process binary choice data with correction and compute accuracy correct = 0 total = 0 opposite_error_pairs = [] for answer in results_data: if answer["question_id"] not in question_index: continue benchmark_truth = question_index[answer["question_id"]] # skip if the question is not a binary choice or does not have an opposite if benchmark_truth["question_type"] != "choice" or len(benchmark_truth["options"]) != 2 or pd.isna(benchmark_truth["opposite_of"]): continue is_correct = validate_choice_answer(answer, benchmark_truth) if benchmark_truth["opposite_of"] in answer_index: is_opposite_correct = validate_choice_answer(answer_index[benchmark_truth["opposite_of"]], question_index[benchmark_truth["opposite_of"]]) answer_index[benchmark_truth["opposite_of"]]["judge/correct"] = is_opposite_correct else: is_opposite_correct = False answer["judge/correct"] = is_correct if is_correct and is_opposite_correct: correct += 1 accuracy_meters[benchmark_truth["source_file"]].update(1) else: opposite_error_pairs.append((answer["question_id"], benchmark_truth["opposite_of"])) accuracy_meters[benchmark_truth["source_file"]].update(0) total += 1 df_preds = pd.DataFrame(results_data).set_index("question_id") df_gt = pd.DataFrame.from_dict(question_index).T.set_index("question_id") df = df_preds.join(df_gt) scores = { "is_complete": len(non_answered_questions) == 0, "is_excessive": len(excessive_answers) > 0, **dict([ ("accuracy/" + k.replace(".csv", ""), (v.sum/v.count)) for k, v in accuracy_meters.items() ]), "accuracy/easy": df.query("difficulty == 'easy'")["judge/correct"].mean(), "accuracy/medium": df.query("difficulty == 'medium'")["judge/correct"].mean(), "accuracy/hard": df.query("difficulty == 'hard'")["judge/correct"].mean(), "accuracy/total": df["judge/correct"].mean(), } return results_data, scores