|
import os, time, argparse, logging |
|
from datasets import load_dataset |
|
from openai import OpenAI |
|
from tqdm import tqdm |
|
from utils.metrics import qa_em_score |
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model", default="gpt-4o") |
|
parser.add_argument("--dataset_repo", default="THUDM/LongBench") |
|
parser.add_argument("--dataset_subset", default="hotpotqa") |
|
parser.add_argument("--split", default="test") |
|
parser.add_argument("--max_tokens", type=int, default=30) |
|
parser.add_argument("--temperature", type=float, default=0.0) |
|
parser.add_argument("--sleep", type=float, default=0.5, |
|
help="seconds to wait between requests") |
|
parser.add_argument("--log", default="summary.log", |
|
help="append overall score here") |
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
filename=args.log, |
|
level=logging.INFO, |
|
format="%(asctime)s - %(message)s", |
|
filemode="a", |
|
) |
|
console = logging.StreamHandler() |
|
console.setLevel(logging.INFO) |
|
logging.getLogger().addHandler(console) |
|
|
|
|
|
|
|
|
|
client = OpenAI( |
|
api_key=os.environ.get("OPENAI_API_KEY"), |
|
base_url=os.environ.get("OPENAI_BASE_URL") |
|
) |
|
|
|
|
|
|
|
|
|
ds = load_dataset(args.dataset_repo, args.dataset_subset, split=args.split) |
|
total = len(ds) |
|
logging.info("Loaded %d samples from %s/%s[%s]", |
|
total, args.dataset_repo, args.dataset_subset, args.split) |
|
|
|
|
|
|
|
|
|
correct_em = 0 |
|
|
|
for ex in tqdm(ds, desc="Evaluating"): |
|
question = ex["input"] |
|
golds = ex["answers"] |
|
|
|
resp = client.chat.completions.create( |
|
model=args.model, |
|
messages=[ |
|
{"role": "system", "content": "You are a QA assistant."}, |
|
{"role": "user", |
|
"content": f"Question: {question}\n" |
|
"Please first reply with *only* the final answer—no extra words.\n Answer:"} |
|
], |
|
temperature=args.temperature, |
|
max_tokens=args.max_tokens, |
|
) |
|
pred = resp.choices[0].message.content.strip() |
|
print(f"A: {pred}\n G: {golds}") |
|
|
|
if any(qa_em_score(pred, g) for g in golds): |
|
correct_em += 1 |
|
|
|
time.sleep(args.sleep) |
|
|
|
em_score = correct_em / total |
|
logging.info("RESULT | model=%s | subset=%s | EM=%.4f", |
|
args.model, args.dataset_subset, em_score) |
|
|
|
print(f"\n=== SUMMARY ===\nModel : {args.model}" |
|
f"\nDataset : {args.dataset_subset} ({args.split})" |
|
f"\nEM : {em_score:.4f}\n" |
|
f"(Appended to {args.log})") |