ReWOO-Demo / utils /Evaluator.py
billxbf's picture
init
926675f
import re
import string
from collections import Counter
import numpy as np
import pandas as pd
import tqdm
from langchain.evaluation.qa import QAEvalChain
from langchain.llms import OpenAI
from algos.PWS import PWS_Base, PWS_Extra
from algos.notool import CoT, IO
from algos.react import ReactBase
def normalize_answer(s):
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return 0
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return 0
prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def llm_accuracy_score(query, prediction, ground_truth):
data = [{
'query': query,
'answer': ground_truth,
}]
pred = [{
'query': query,
'answer': ground_truth,
'result': prediction,
}]
eval_chain = QAEvalChain.from_llm(OpenAI(temperature=0))
graded_outputs = eval_chain.evaluate(data, pred)
return 1 if graded_outputs[0]['text'].strip() == 'CORRECT' else 0
class Evaluator:
def __init__(self, task, dataset, algo, maxtry=3):
assert task in ["hotpot_qa", "trivia_qa", "gsm8k", "physics_question", "disfl_qa",
"sports_understanding", "strategy_qa", "sotu_qa"]
assert isinstance(dataset, pd.DataFrame)
assert isinstance(algo, (PWS_Base, PWS_Extra, ReactBase, IO, CoT))
self.task = task
self.dataset = dataset
self.algo = algo
self.maxtry = maxtry
self.failed_response = self._failed_response()
self.eval_data = self._initialize_eval_dict()
def run(self):
print("\n******************* Start Evaluation *******************\n")
if self.task in ["hotpot_qa", "sotu_qa"]:
for i in tqdm.tqdm(range(len(self.dataset))):
question = self.dataset["question"][i]
label = self.dataset["answer"][i]
for _ in range(self.maxtry):
try:
response = self.algo.run(question)
break
except:
response = self.failed_response
self._update_eval_dict(question, label, response)
elif self.task == "fever":
for i in tqdm.tqdm(range(len(self.dataset))):
question = self.dataset["claim"][i]
label = self.dataset["label"][i]
for _ in range(self.maxtry):
try:
response = self.algo.run(question)
break
except:
response = self.failed_response
self._update_eval_dict(question, label, response)
elif self.task == "trivia_qa":
for i in tqdm.tqdm(range(len(self.dataset))):
question = self.dataset["question"][i]
label = self.dataset["answer"][i]["value"]
for _ in range(self.maxtry):
try:
response = self.algo.run(question)
break
except:
response = self.failed_response
self._update_eval_dict(question, label, response)
elif self.task == "gsm8k":
for i in tqdm.tqdm(range(len(self.dataset))):
question = self.dataset["question"][i]
label = self.dataset["answer"][i].split("#### ")[1]
for _ in range(self.maxtry):
try:
response = self.algo.run(question)
break
except:
response = self.failed_response
self._update_eval_dict(question, label, response)
elif self.task in ["physics_question", "sports_understanding", "strategy_qa"]:
for i in tqdm.tqdm(range(len(self.dataset))):
question = self.dataset["input"][i]
label = self.dataset["target"][i]
for _ in range(self.maxtry):
try:
response = self.algo.run(question)
break
except:
response = self.failed_response
self._update_eval_dict(question, label, response)
else:
raise NotImplementedError
return self._get_avg_results(), self.eval_data
def _initialize_eval_dict(self):
data = {}
for d in ["label", "preds", "em", "f1", "acc", "wall_time", "total_tokens", "total_cost", "steps", "token_cost",
"tool_cost", "planner_log", "solver_log"]:
data[d] = []
return data
def _update_eval_dict(self, question, label, response):
pred = self._parse_prediction(response["output"])
self.eval_data["label"] += [label]
self.eval_data["preds"] += [pred]
self.eval_data["em"] += [self.get_metrics(question, label, pred)["em"]]
self.eval_data["f1"] += [self.get_metrics(question, label, pred)["f1"]]
self.eval_data["acc"] += [self.get_metrics(question, label, pred)["acc"]]
self.eval_data["wall_time"] += [response["wall_time"]]
self.eval_data["total_tokens"] += [response["total_tokens"]]
self.eval_data["total_cost"] += [response["total_cost"]]
self.eval_data["steps"] += [response["steps"]]
self.eval_data["token_cost"] += [response["token_cost"]]
self.eval_data["tool_cost"] += [response["tool_cost"]]
if "planner_log" in response:
self.eval_data["planner_log"] += [response["planner_log"]]
if "solver_log" in response:
self.eval_data["solver_log"] += [response["solver_log"]]
def _get_avg_results(self):
result = {}
result["avg_em"] = np.nanmean(self.eval_data["em"])
result["avg_f1"] = np.nanmean(self.eval_data["f1"])
result["avg_acc"] = np.nanmean(self.eval_data["acc"])
result["avg_wall_time"] = np.nanmean(self.eval_data["wall_time"])
result["avg_total_tokens"] = np.nanmean(self.eval_data["total_tokens"])
result["avg_total_cost"] = np.nanmean(self.eval_data["total_cost"])
result["avg_steps"] = np.nanmean(self.eval_data["steps"])
result["avg_token_cost"] = np.nanmean(self.eval_data["token_cost"])
result["avg_tool_cost"] = np.nanmean(self.eval_data["tool_cost"])
return result
def get_metrics(self, query, label, pred):
if pred is None:
return {'em': 0, 'f1': 0}
norm_label = normalize_answer(label)
norm_pred = normalize_answer(pred)
em = (norm_pred == norm_label)
f1 = f1_score(norm_pred, norm_label)
acc = llm_accuracy_score(query, pred, label)
return {'em': em, 'f1': f1, 'acc': acc}
def _parse_prediction(self, output):
if isinstance(self.algo, IO):
return str(output).strip("\n")
elif isinstance(self.algo, CoT):
return str(output).split("\n")[-1].replace("Answer:", "")
elif isinstance(self.algo, ReactBase):
return str(output).strip("\n")
elif isinstance(self.algo, PWS_Base):
return str(output).strip("\n")
elif isinstance(self.algo, PWS_Extra):
return str(output).strip("\n")
def _failed_response(self):
resposne = {}
for key in ["input", "output", "wall_time", "total_tokens", "total_cost", "steps", "token_cost", "tool_cost"]:
resposne[key] = np.nan
return resposne