|
from pathlib import Path |
|
import json |
|
import re |
|
import string |
|
from collections import Counter |
|
|
|
from tqdm import tqdm |
|
import evaluate |
|
|
|
from args import parse_args |
|
|
|
|
|
ROUGE_SCORER = evaluate.load("rouge") |
|
|
|
PATTERN = re.compile(r'\b[A-D]\b') |
|
|
|
def find_answer(s): |
|
|
|
match = PATTERN.search(s) |
|
if match is None: |
|
return None |
|
|
|
return match.group() |
|
|
|
def normalize_answer(s: str) -> str: |
|
"""Lower text and remove punctuation, articles and extra whitespace.""" |
|
|
|
def remove_articles(text): |
|
return re.sub(r"\b(a|an|the)\b", " ", text) |
|
|
|
def white_space_fix(text): |
|
return " ".join(text.split()) |
|
|
|
def remove_punc(text): |
|
exclude = set(string.punctuation) |
|
return "".join(ch for ch in text if ch not in exclude) |
|
|
|
def lower(text): |
|
return text.lower() |
|
|
|
return white_space_fix(remove_articles(remove_punc(lower(s)))) |
|
|
|
|
|
def normalize_zh_answer(s: str) -> str: |
|
"""Chinese version. Lower text and remove punctuation, extra whitespace.""" |
|
|
|
def white_space_fix(text): |
|
return "".join(text.split()) |
|
|
|
def remove_punc(text): |
|
cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." |
|
all_punctuation = set(string.punctuation + cn_punctuation) |
|
return "".join(ch for ch in text if ch not in all_punctuation) |
|
|
|
def lower(text): |
|
return text.lower() |
|
|
|
return white_space_fix(remove_punc(lower(s))) |
|
|
|
|
|
def f1_score(prediction, ground_truth) -> tuple[float, float, float]: |
|
common = Counter(prediction) & Counter(ground_truth) |
|
num_same = sum(common.values()) |
|
if num_same == 0: |
|
return 0, 0, 0 |
|
precision = 1.0 * num_same / len(prediction) |
|
recall = 1.0 * num_same / len(ground_truth) |
|
f1 = (2 * precision * recall) / (precision + recall) |
|
return f1, precision, recall |
|
|
|
|
|
def qa_f1_score(pred: str, ground_truths) -> float: |
|
"""Computes the F1, recall, and precision.""" |
|
f1 = 0 |
|
prec = 0 |
|
recall = 0 |
|
for ground_truth in ground_truths: |
|
normalized_prediction = normalize_answer(pred) |
|
normalized_ground_truth = normalize_answer(ground_truth) |
|
|
|
prediction_tokens = normalized_prediction.split() |
|
ground_truth_tokens = normalized_ground_truth.split() |
|
scores = f1_score(prediction_tokens, ground_truth_tokens) |
|
this_f1, this_prec, this_recall = scores |
|
f1 = max(f1, this_f1) |
|
prec = max(prec, this_prec) |
|
recall = max(recall, this_recall) |
|
return f1 |
|
|
|
|
|
def qa_f1_score_zh(pred: str, ground_truths: list[str]) -> float: |
|
""" |
|
QA F1 score for chinese. |
|
""" |
|
f1 = 0 |
|
prec = 0 |
|
recall = 0 |
|
for ground_truth in ground_truths: |
|
norm_pred = normalize_zh_answer(pred) |
|
norm_label = normalize_zh_answer(ground_truth) |
|
|
|
|
|
pred_tokens = list(norm_pred) |
|
label_tokens = list(norm_label) |
|
scores = f1_score(pred_tokens, label_tokens) |
|
this_f1, this_prec, this_recall = scores |
|
f1 = max(f1, this_f1) |
|
prec = max(prec, this_prec) |
|
recall = max(recall, this_recall) |
|
return f1 |
|
|
|
|
|
def load_json(fname): |
|
return json.load(open(fname)) |
|
|
|
|
|
def iter_jsonl(fname, cnt=None): |
|
i = 0 |
|
with open(fname, "r", encoding="utf8") as fin: |
|
for line in fin: |
|
if line.strip() == "": |
|
continue |
|
if i == cnt: |
|
break |
|
if line.strip() == "": |
|
continue |
|
yield json.loads(line) |
|
i += 1 |
|
|
|
|
|
def first_int_match(prediction): |
|
pred_list = re.split("[^0-9]", prediction) |
|
pred_value = "" |
|
for item in pred_list: |
|
if item != "": |
|
pred_value = item |
|
break |
|
return pred_value |
|
|
|
|
|
def split_retrieval_answer(pred: str): |
|
for c in ["\n", ":", '"', "'", ".", ",", "?", "!", "{", "}"]: |
|
pred = pred.replace(c, " ") |
|
words = pred.split() |
|
return words |
|
|
|
|
|
def get_score_one_kv_retrieval(pred, label, model_name: str, args) -> bool: |
|
for c in ['\n', ':', '\"', '\'', '.', ',', '?', '!', '{', '}']: |
|
pred = pred.replace(c, ' ') |
|
words = pred.split() |
|
return label in words |
|
|
|
|
|
def get_score_one_passkey(pred, label, model_name: str, args) -> bool: |
|
if isinstance(label, list): |
|
label = label[0] |
|
return label == first_int_match(pred) |
|
|
|
|
|
def get_score_one_number_string(pred, label, model_name: str, args) -> bool: |
|
if isinstance(label, list): |
|
label = label[0] |
|
return label == first_int_match(pred) |
|
|
|
|
|
def get_score_one_code_run(pred, label, model_name: str, args) -> bool: |
|
""" |
|
Returns the score of one example in Code.Run. |
|
""" |
|
if isinstance(label, list): |
|
label = label[0] |
|
pred = pred.strip() |
|
for c in ["\n", ".", "`", "'", '"', ":"]: |
|
pred = pred.replace(c, " ") |
|
words = pred.split() |
|
if len(words) == 0: |
|
return False |
|
try: |
|
pred = int(words[-1]) |
|
return label == pred |
|
except Exception: |
|
return False |
|
|
|
|
|
def get_score_one_code_debug(pred, label, model_name: str, args) -> bool: |
|
""" |
|
Returns the score of one example in Code.Debug. |
|
""" |
|
|
|
pred = pred.strip() |
|
label_c = label[1] |
|
fn_name = label[0] |
|
if pred[:2] in [f"{label_c}.", f"{label_c}:"]: |
|
return True |
|
|
|
ans_prefixes = [ |
|
"answer is:", |
|
|
|
|
|
"is:", |
|
"answer:", |
|
"correct option is:" |
|
] |
|
pred = pred.strip() |
|
for c in ["\n", "`", "'", '"', "-", "*", "Option", "option"]: |
|
pred = pred.replace(c, " ") |
|
while " " in pred: |
|
pred = pred.replace(" ", " ") |
|
for prefix in ans_prefixes: |
|
idx = pred.find(prefix) |
|
if idx == -1: |
|
continue |
|
|
|
if len(pred) < idx + len(prefix) + 1: |
|
return False |
|
pred = pred[idx + len(prefix) + 1 :] |
|
for s in [label_c, fn_name]: |
|
if pred.startswith(s): |
|
return True |
|
return False |
|
return False |
|
|
|
|
|
def get_score_one_math_find(pred, label, model_name: str, args) -> bool: |
|
if isinstance(label, list): |
|
|
|
label = label[0] |
|
if isinstance(label, int): |
|
|
|
first_num = re.search(r"\d+\.\d+|\d+", pred) |
|
if first_num is None: |
|
return False |
|
first_num = first_num.group(0).strip() |
|
return int(first_num) == label |
|
elif isinstance(label, float): |
|
|
|
first_float = re.search(r"\d+\.\d+|\d+", pred) |
|
if first_float is None: |
|
return False |
|
first_float = first_float.group(0).strip() |
|
return float(first_float) == label |
|
else: |
|
raise TypeError(f"Expected int or float, got {type(label)}") |
|
|
|
|
|
def get_score_one_longdialogue_qa_eng(pred, label, model_name: str, args) -> bool: |
|
if 'STAMP PAID' in pred: |
|
import ipdb; ipdb.set_trace() |
|
label = label[0] |
|
pred = pred.strip() |
|
for c in ["\n", ":", '"', "'", ".", ",", "?", "!", "{", "}"]: |
|
pred = pred.replace(c, " ") |
|
words = pred.split() |
|
words = [x.upper() for x in words] |
|
return label in words |
|
|
|
|
|
def get_score_one_longbook_choice_eng(pred, label, model_name: str, args) -> bool: |
|
|
|
|
|
pred = pred.strip() |
|
if pred == "": |
|
return False |
|
if pred[0] in "ABCD": |
|
return pred[0] in label |
|
if pred in label: |
|
return True |
|
|
|
for c in ["\n", '"', "'", ".", ",", "?", "!", "{", "}"]: |
|
pred = pred.replace(c, " ") |
|
while " " in pred: |
|
pred = pred.replace(" ", " ") |
|
ans_prefixes = [ |
|
"answer is:", |
|
"answer:", |
|
"answer is", |
|
"option is", |
|
] |
|
for prefix in ans_prefixes: |
|
idx = pred.find(prefix) |
|
if idx == -1: |
|
continue |
|
|
|
if len(pred) < idx + len(prefix) + 1: |
|
return False |
|
after_prefix = pred[idx + len(prefix) + 1 :] |
|
for s in label: |
|
if after_prefix.startswith(s): |
|
return True |
|
return False |
|
|
|
|
|
words = pred.split() |
|
for word in words: |
|
if word in "ABCD": |
|
return word in label |
|
|
|
|
|
if args.use_zero_scrolls: |
|
|
|
matched_pred = find_answer(pred) |
|
if matched_pred is not None and matched_pred in label: |
|
return True |
|
|
|
return False |
|
|
|
|
|
def get_score_one_longbook_qa_eng(pred, label, model_name: str, args) -> float: |
|
return qa_f1_score(pred, label) |
|
|
|
|
|
def get_score_one_longbook_sum_eng( |
|
pred: str, label: str, model_name: str, args |
|
) -> float: |
|
|
|
score = ROUGE_SCORER.compute( |
|
predictions=[pred], references=[label], use_aggregator=False |
|
) |
|
return score["rougeLsum"][0] |
|
|
|
|
|
def get_score_one_longbook_qa_chn(pred, label, model_name: str, args) -> float: |
|
return qa_f1_score_zh(pred, label) |
|
|
|
|
|
def get_score_one_math_calc(pred, label, model_name: str, args) -> float: |
|
assert isinstance(label, list), f"Expected list, got {type(label)}" |
|
|
|
pred_nums = [] |
|
pred_list = re.split("[^0-9]", pred) |
|
for item in pred_list: |
|
if item != "": |
|
pred_nums.append(int(item)) |
|
|
|
|
|
|
|
if model_name == "gpt4": |
|
pred_nums = pred_nums[1:] |
|
|
|
cnt = 0 |
|
for i in range(len(label)): |
|
if i >= len(pred_nums): |
|
break |
|
if label[i] == pred_nums[i]: |
|
cnt += 1 |
|
else: |
|
break |
|
return cnt / len(label) |
|
|
|
|
|
def get_score_one( |
|
pred: str, label: str, task_name: str, model_name: str, args |
|
) -> float: |
|
""" |
|
Computes the score for one prediction. |
|
Returns one float (zero and one for boolean values). |
|
""" |
|
NAME_TO_SCORE_GETTER = { |
|
|
|
"kv_retrieval": get_score_one_kv_retrieval, |
|
"kv_retrieval_prefix": get_score_one_kv_retrieval, |
|
"kv_retrieval_both": get_score_one_kv_retrieval, |
|
|
|
"passkey": get_score_one_passkey, |
|
"number_string": get_score_one_number_string, |
|
|
|
"code_run": get_score_one_code_run, |
|
"code_debug": get_score_one_code_debug, |
|
|
|
"longdialogue_qa_eng": get_score_one_longdialogue_qa_eng, |
|
"longbook_qa_eng": get_score_one_longbook_qa_eng, |
|
"longbook_sum_eng": get_score_one_longbook_sum_eng, |
|
"longbook_choice_eng": get_score_one_longbook_choice_eng, |
|
"longbook_qa_chn": get_score_one_longbook_qa_chn, |
|
|
|
"math_find": get_score_one_math_find, |
|
"math_calc": get_score_one_math_calc, |
|
} |
|
assert task_name in NAME_TO_SCORE_GETTER, f"Invalid task name: {task_name}" |
|
score = NAME_TO_SCORE_GETTER[task_name](pred, label, model_name, args) |
|
return float(score) |
|
|
|
|
|
def get_labels(preds: list) -> list[str]: |
|
possible_label_keys = ["ground_truth", "label"] |
|
for label_key in possible_label_keys: |
|
if label_key in preds[0]: |
|
return [x.get(label_key, "XXXXXXXXXX") for x in preds] |
|
raise ValueError(f"Cannot find label in {preds[0]}") |
|
|
|
|
|
def get_preds(preds: list, data_name: str) -> list[str]: |
|
pred_strings = [] |
|
possible_pred_keys = ["prediction", "pred"] |
|
for pred in preds: |
|
this_pred = "NO PREDICTION" |
|
for pred_key in possible_pred_keys: |
|
if pred_key in pred: |
|
this_pred = pred[pred_key] |
|
break |
|
else: |
|
raise ValueError(f"Cannot find prediction in {pred}") |
|
pred_strings.append(this_pred) |
|
return pred_strings |
|
|
|
|
|
def get_score( |
|
labels: list, preds: list, data_name: str, model_name: str, args |
|
) -> float: |
|
""" |
|
Computes the average score for a task. |
|
""" |
|
assert len(labels) == len(preds) |
|
scores = [] |
|
for label, pred in tqdm(zip(labels, preds)): |
|
score = get_score_one(pred, label, data_name, model_name, args) |
|
print('pred={}, label={}, score={}, data_name={}'.format(pred, label, score, data_name)) |
|
scores.append(score) |
|
return sum(scores) / len(scores) |
|
|
|
|
|
def compute_scores(preds_path, data_name: str, model_name: str, args): |
|
print("Loading prediction results from", preds_path) |
|
preds = list(iter_jsonl(preds_path)) |
|
|
|
labels = get_labels(preds) |
|
preds = get_preds(preds, data_name) |
|
|
|
acc = get_score(labels, preds, data_name, model_name, args) |
|
print('final display: ', acc, preds_path, data_name, model_name, args.use_zero_scrolls) |
|
|
|
|
|
ALL_TASKS = [ |
|
|
|
|
|
|
|
|
|
|
|
"longbook_choice_eng", |
|
"longbook_qa_eng", |
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
ALL_TASKS_ORIG = [ |
|
"passkey", |
|
"number_string", |
|
"kv_retrieval", |
|
"longdialogue_qa_eng", |
|
"longbook_sum_eng", |
|
"longbook_choice_eng", |
|
"longbook_qa_eng", |
|
"longbook_qa_chn", |
|
"math_find", |
|
"math_calc", |
|
"code_run", |
|
"code_debug", |
|
] |
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
print(json.dumps(vars(args), indent=4)) |
|
|
|
if args.task == "all": |
|
tasks = ALL_TASKS |
|
else: |
|
tasks = [args.task] |
|
for task in tasks: |
|
|
|
|
|
preds_path = Path(args.pxout_ref_json) |
|
assert preds_path.exists(), f"Predictions not found in: {preds_path}" |
|
compute_scores(preds_path, task, args.model_name, args) |
|
|
|
|