from pathlib import Path import json import re import string from collections import Counter from tqdm import tqdm import evaluate from args import parse_args ROUGE_SCORER = evaluate.load("rouge") PATTERN = re.compile(r'\b[A-D]\b') def find_answer(s): # task='longbook_choice_eng': works for '(A)' -> A match = PATTERN.search(s) if match is None: return None # None is a signal of not find! NOTE #print(s, match.group()) return match.group() def normalize_answer(s: str) -> str: """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): return re.sub(r"\b(a|an|the)\b", " ", text) def white_space_fix(text): return " ".join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def normalize_zh_answer(s: str) -> str: """Chinese version. Lower text and remove punctuation, extra whitespace.""" def white_space_fix(text): return "".join(text.split()) def remove_punc(text): cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." # noqa all_punctuation = set(string.punctuation + cn_punctuation) return "".join(ch for ch in text if ch not in all_punctuation) def lower(text): return text.lower() return white_space_fix(remove_punc(lower(s))) def f1_score(prediction, ground_truth) -> tuple[float, float, float]: common = Counter(prediction) & Counter(ground_truth) num_same = sum(common.values()) if num_same == 0: return 0, 0, 0 precision = 1.0 * num_same / len(prediction) recall = 1.0 * num_same / len(ground_truth) f1 = (2 * precision * recall) / (precision + recall) return f1, precision, recall def qa_f1_score(pred: str, ground_truths) -> float: """Computes the F1, recall, and precision.""" f1 = 0 prec = 0 recall = 0 for ground_truth in ground_truths: # NOTE this means ground_truths must be a list!!! not a pure str TODO normalized_prediction = normalize_answer(pred) normalized_ground_truth = normalize_answer(ground_truth) prediction_tokens = normalized_prediction.split() ground_truth_tokens = normalized_ground_truth.split() scores = f1_score(prediction_tokens, ground_truth_tokens) this_f1, this_prec, this_recall = scores f1 = max(f1, this_f1) prec = max(prec, this_prec) recall = max(recall, this_recall) return f1 def qa_f1_score_zh(pred: str, ground_truths: list[str]) -> float: """ QA F1 score for chinese. """ f1 = 0 prec = 0 recall = 0 for ground_truth in ground_truths: norm_pred = normalize_zh_answer(pred) norm_label = normalize_zh_answer(ground_truth) # One character one token. pred_tokens = list(norm_pred) label_tokens = list(norm_label) scores = f1_score(pred_tokens, label_tokens) this_f1, this_prec, this_recall = scores f1 = max(f1, this_f1) prec = max(prec, this_prec) recall = max(recall, this_recall) return f1 def load_json(fname): return json.load(open(fname)) def iter_jsonl(fname, cnt=None): i = 0 with open(fname, "r", encoding="utf8") as fin: for line in fin: if line.strip() == "": # Skip empty lines continue if i == cnt: break if line.strip() == "": # Skip empty lines continue yield json.loads(line) i += 1 def first_int_match(prediction): pred_list = re.split("[^0-9]", prediction) pred_value = "" for item in pred_list: if item != "": pred_value = item break return pred_value def split_retrieval_answer(pred: str): for c in ["\n", ":", '"', "'", ".", ",", "?", "!", "{", "}"]: pred = pred.replace(c, " ") words = pred.split() return words def get_score_one_kv_retrieval(pred, label, model_name: str, args) -> bool: for c in ['\n', ':', '\"', '\'', '.', ',', '?', '!', '{', '}']: pred = pred.replace(c, ' ') words = pred.split() return label in words def get_score_one_passkey(pred, label, model_name: str, args) -> bool: if isinstance(label, list): label = label[0] return label == first_int_match(pred) def get_score_one_number_string(pred, label, model_name: str, args) -> bool: if isinstance(label, list): label = label[0] return label == first_int_match(pred) def get_score_one_code_run(pred, label, model_name: str, args) -> bool: """ Returns the score of one example in Code.Run. """ if isinstance(label, list): label = label[0] pred = pred.strip() for c in ["\n", ".", "`", "'", '"', ":"]: pred = pred.replace(c, " ") words = pred.split() if len(words) == 0: return False try: pred = int(words[-1]) return label == pred except Exception: return False def get_score_one_code_debug(pred, label, model_name: str, args) -> bool: """ Returns the score of one example in Code.Debug. """ #import ipdb; ipdb.set_trace() pred = pred.strip() label_c = label[1] fn_name = label[0] if pred[:2] in [f"{label_c}.", f"{label_c}:"]: return True ans_prefixes = [ "answer is:", # "answer is", # "error is", "is:", "answer:", "correct option is:" ] pred = pred.strip() for c in ["\n", "`", "'", '"', "-", "*", "Option", "option"]: pred = pred.replace(c, " ") while " " in pred: pred = pred.replace(" ", " ") for prefix in ans_prefixes: idx = pred.find(prefix) if idx == -1: continue # The prediction ends with this prefix if len(pred) < idx + len(prefix) + 1: return False pred = pred[idx + len(prefix) + 1 :] for s in [label_c, fn_name]: if pred.startswith(s): return True return False return False def get_score_one_math_find(pred, label, model_name: str, args) -> bool: if isinstance(label, list): # In math_find, there is always only one label. label = label[0] if isinstance(label, int): # Find first int or float first_num = re.search(r"\d+\.\d+|\d+", pred) if first_num is None: return False first_num = first_num.group(0).strip() return int(first_num) == label elif isinstance(label, float): # Find first float or int first_float = re.search(r"\d+\.\d+|\d+", pred) if first_float is None: return False first_float = first_float.group(0).strip() return float(first_float) == label else: raise TypeError(f"Expected int or float, got {type(label)}") def get_score_one_longdialogue_qa_eng(pred, label, model_name: str, args) -> bool: if 'STAMP PAID' in pred: import ipdb; ipdb.set_trace() label = label[0] pred = pred.strip() for c in ["\n", ":", '"', "'", ".", ",", "?", "!", "{", "}"]: pred = pred.replace(c, " ") words = pred.split() words = [x.upper() for x in words] return label in words def get_score_one_longbook_choice_eng(pred, label, model_name: str, args) -> bool: # Just use the first letter as the prediction #import ipdb; ipdb.set_trace() pred = pred.strip() if pred == "": return False if pred[0] in "ABCD": return pred[0] in label if pred in label: return True # Find a answer prefix for c in ["\n", '"', "'", ".", ",", "?", "!", "{", "}"]: pred = pred.replace(c, " ") while " " in pred: pred = pred.replace(" ", " ") ans_prefixes = [ "answer is:", "answer:", "answer is", "option is", ] for prefix in ans_prefixes: idx = pred.find(prefix) if idx == -1: continue # The prediction ends with this prefix if len(pred) < idx + len(prefix) + 1: return False after_prefix = pred[idx + len(prefix) + 1 :] for s in label: if after_prefix.startswith(s): return True return False # Finally, just find the first occurrence of A, B, C, or D. words = pred.split() for word in words: if word in "ABCD": return word in label #import ipdb; ipdb.set_trace() if args.use_zero_scrolls: # NOTE use PATTERN as used in zero-scrolls for choice!!! added by xianchaowu matched_pred = find_answer(pred) if matched_pred is not None and matched_pred in label: return True return False def get_score_one_longbook_qa_eng(pred, label, model_name: str, args) -> float: return qa_f1_score(pred, label) def get_score_one_longbook_sum_eng( pred: str, label: str, model_name: str, args ) -> float: score = ROUGE_SCORER.compute( predictions=[pred], references=[label], use_aggregator=False ) return score["rougeLsum"][0] # type: ignore def get_score_one_longbook_qa_chn(pred, label, model_name: str, args) -> float: return qa_f1_score_zh(pred, label) def get_score_one_math_calc(pred, label, model_name: str, args) -> float: assert isinstance(label, list), f"Expected list, got {type(label)}" # assert isinstance(pred, list), f"Expected list, got {type(pred)}" pred_nums = [] pred_list = re.split("[^0-9]", pred) for item in pred_list: if item != "": pred_nums.append(int(item)) # Our prompts makes GPT4 always output the first number as the first value # in the predicted answer. if model_name == "gpt4": pred_nums = pred_nums[1:] cnt = 0 for i in range(len(label)): if i >= len(pred_nums): break if label[i] == pred_nums[i]: cnt += 1 else: break return cnt / len(label) def get_score_one( pred: str, label: str, task_name: str, model_name: str, args ) -> float: """ Computes the score for one prediction. Returns one float (zero and one for boolean values). """ NAME_TO_SCORE_GETTER = { # Retrieve "kv_retrieval": get_score_one_kv_retrieval, "kv_retrieval_prefix": get_score_one_kv_retrieval, "kv_retrieval_both": get_score_one_kv_retrieval, "passkey": get_score_one_passkey, "number_string": get_score_one_number_string, # Code "code_run": get_score_one_code_run, "code_debug": get_score_one_code_debug, # Longbook "longdialogue_qa_eng": get_score_one_longdialogue_qa_eng, "longbook_qa_eng": get_score_one_longbook_qa_eng, "longbook_sum_eng": get_score_one_longbook_sum_eng, "longbook_choice_eng": get_score_one_longbook_choice_eng, "longbook_qa_chn": get_score_one_longbook_qa_chn, # Math "math_find": get_score_one_math_find, "math_calc": get_score_one_math_calc, } assert task_name in NAME_TO_SCORE_GETTER, f"Invalid task name: {task_name}" score = NAME_TO_SCORE_GETTER[task_name](pred, label, model_name, args) return float(score) def get_labels(preds: list) -> list[str]: possible_label_keys = ["ground_truth", "label"] for label_key in possible_label_keys: if label_key in preds[0]: return [x.get(label_key, "XXXXXXXXXX") for x in preds] raise ValueError(f"Cannot find label in {preds[0]}") def get_preds(preds: list, data_name: str) -> list[str]: pred_strings = [] possible_pred_keys = ["prediction", "pred"] for pred in preds: this_pred = "NO PREDICTION" for pred_key in possible_pred_keys: if pred_key in pred: this_pred = pred[pred_key] break else: raise ValueError(f"Cannot find prediction in {pred}") pred_strings.append(this_pred) return pred_strings def get_score( labels: list, preds: list, data_name: str, model_name: str, args ) -> float: """ Computes the average score for a task. """ assert len(labels) == len(preds) scores = [] for label, pred in tqdm(zip(labels, preds)): score = get_score_one(pred, label, data_name, model_name, args) print('pred={}, label={}, score={}, data_name={}'.format(pred, label, score, data_name)) scores.append(score) return sum(scores) / len(scores) def compute_scores(preds_path, data_name: str, model_name: str, args): print("Loading prediction results from", preds_path) preds = list(iter_jsonl(preds_path)) #import ipdb; ipdb.set_trace() labels = get_labels(preds) preds = get_preds(preds, data_name) acc = get_score(labels, preds, data_name, model_name, args) print('final display: ', acc, preds_path, data_name, model_name, args.use_zero_scrolls) ALL_TASKS = [ #"passkey", #"number_string", #"kv_retrieval", #"longdialogue_qa_eng", #"longbook_sum_eng", "longbook_choice_eng", "longbook_qa_eng", #"longbook_qa_chn", #"math_find", #"math_calc", #"code_run", #"code_debug", ] ALL_TASKS_ORIG = [ "passkey", "number_string", "kv_retrieval", "longdialogue_qa_eng", "longbook_sum_eng", "longbook_choice_eng", "longbook_qa_eng", "longbook_qa_chn", "math_find", "math_calc", "code_run", "code_debug", ] if __name__ == "__main__": args = parse_args() print(json.dumps(vars(args), indent=4)) if args.task == "all": tasks = ALL_TASKS else: tasks = [args.task] for task in tasks: #result_dir = Path(args.output_dir, args.model_name) #preds_path = result_dir / f"preds_{task}.jsonl" preds_path = Path(args.pxout_ref_json) assert preds_path.exists(), f"Predictions not found in: {preds_path}" compute_scores(preds_path, task, args.model_name, args)