bhys's picture
Update scorer.py
8d66a77 verified
import json
import re
import string
import warnings
import numpy as np
def normalize_number_str(number_str: str) -> float:
# we replace these common units and commas to allow
# conversion to float
for char in ["$", "%", ","]:
number_str = number_str.replace(char, "")
try:
return float(number_str)
except ValueError:
print(f"String {number_str} cannot be normalized to number str.")
return float("inf")
def normalize_answer(a):
# Lower case
# Trim (left and right)
# Replace multiple spaces with one space
# Remove trailing punctuation
# return re.sub(r"[\.\!\?]+$", "", re.sub(r"\s+", " ", a.strip().lower()))
if isinstance(a, list):
a = ''.join(a)
return a.strip().lower()
def exact_match(answer, ground_truth):
return normalize_answer(answer) == normalize_answer(ground_truth)
def keyword_match(answer, ground_truth):
return normalize_answer(ground_truth) in normalize_answer(answer)
def split_string(
s: str,
char_list: list[str] = [",", ";"],
) -> list[str]:
pattern = f"[{''.join(char_list)}]"
return re.split(pattern, s)
def question_scorer(
user_task: str,
val: str,
) -> bool:
def is_float(element: any) -> bool:
try:
float(element)
return True
except ValueError:
return False
# 打分机制
level = 0
expertise = 0
reasoning = 0
comprehension = 0
final_answer = user_task["final_answer"]
expected_answer = val["Final answer"]
data = val["score"]
chat_score = []
for i in range(len(data['type'])):
item = {
'type': data['type'][i],
'question': data['question'][i],
'choices': data['choices'][i],
'answer': data['answer'][i],
'expertise': data['expertise'][i],
'reasoning': data['reasoning'][i],
'comprehension': data['comprehension'][i],
'score': data['score'][i]
}
chat_score.append(item)
for i, score_item in enumerate(chat_score):
answer_true = False
if score_item['type'].lower() == 'multiple choice':
if exact_match(user_task["score_answer"][i], score_item['answer']):
answer_true = True
elif score_item['type'].lower() == 'fill in the blanks':
if keyword_match(user_task["score_answer"][i], score_item['answer']):
answer_true = True
elif score_item['type'].lower() == 'short answer questions':
for ground_truth in score_item['answer']:
if keyword_match(user_task["score_answer"][i], ground_truth):
answer_true = True
break
# print(answer, score_item['answer'], answer_true)
# 加分
if answer_true:
expertise += score_item['expertise']
reasoning += score_item['reasoning']
comprehension += score_item['comprehension']
if score_item['score'] > level:
level = score_item['score']
print([level, expertise, reasoning, comprehension])
# final_answer正确 则满分,但是能力分不加了
if expected_answer and exact_match(final_answer, expected_answer):
level = 10
return [level, expertise, reasoning, comprehension]
score = 0
if user_task["final_answer"] == val["Final answer"]:
score = val["Total score"]
else:
for i, item in enumerate(val["score"]["question"]):
if user_task["score_answer"][i] in val["score"]["answer"][i] and val["score"]["score"][i] > score:
score = item["score"]
return score
# # if gt is a number
# if is_float(ground_truth):
# print(f"Evaluating {model_answer} as a number.")
# normalized_answer = normalize_number_str(model_answer)
# return normalized_answer == float(ground_truth)
#
# # if gt is a list
# elif any(char in ground_truth for char in [",", ";"]):
# print(f"Evaluating {model_answer} as a comma separated list.")
# # question with the fish: normalization removes punct
#
# gt_elems = split_string(ground_truth)
# ma_elems = split_string(model_answer)
#
# # check length is the same
# if len(gt_elems) != len(ma_elems):
# warnings.warn(
# "Answer lists have different lengths, returning False.", UserWarning
# )
# return False
#
# # compare each element as float or str
# comparisons = []
# for ma_elem, gt_elem in zip(ma_elems, gt_elems):
# if is_float(gt_elem):
# normalized_ma_elem = normalize_number_str(ma_elem)
# comparisons.append(normalized_ma_elem == float(gt_elem))
# else:
# # we do not remove punct since comparisons can include punct
# comparisons.append(
# normalize_str(ma_elem, remove_punct=False)
# == normalize_str(gt_elem, remove_punct=False)
# )
# return all(comparisons)
#
# # if gt is a str
# else:
# print(f"Evaluating {model_answer} as a string.")
# return normalize_str(model_answer) == normalize_str(ground_truth)
def normalize_str(input_str, remove_punct=True) -> str:
"""
Normalize a string by:
- Removing all white spaces
- Optionally removing punctuation (if remove_punct is True)
- Converting to lowercase
Parameters:
- input_str: str, the string to normalize
- remove_punct: bool, whether to remove punctuation (default: True)
Returns:
- str, the normalized string
"""
# Remove all white spaces. Required e.g for seagull vs. sea gull
no_spaces = re.sub(r"\s", "", input_str)
# Remove punctuation, if specified.
if remove_punct:
translator = str.maketrans("", "", string.punctuation)
return no_spaces.lower().translate(translator)
else:
return no_spaces.lower()