File size: 3,246 Bytes
b1e25b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os, time, argparse, logging
from datasets import load_dataset
from openai import OpenAI
from tqdm import tqdm
from utils.metrics import qa_em_score  

# ----------------------------------------------------------------------
# CLI
# ----------------------------------------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--model",          default="gpt-4o")
parser.add_argument("--dataset_repo",   default="THUDM/LongBench")
parser.add_argument("--dataset_subset", default="hotpotqa")
parser.add_argument("--split",          default="test")
parser.add_argument("--max_tokens",     type=int,   default=30)
parser.add_argument("--temperature",    type=float, default=0.0)
parser.add_argument("--sleep",          type=float, default=0.5,
                    help="seconds to wait between requests")
parser.add_argument("--log",            default="summary.log",
                    help="append overall score here")
args = parser.parse_args()

# ----------------------------------------------------------------------
# Logging (append mode)
# ----------------------------------------------------------------------
logging.basicConfig(
    filename=args.log,
    level=logging.INFO,
    format="%(asctime)s - %(message)s",
    filemode="a",
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logging.getLogger().addHandler(console)

# ----------------------------------------------------------------------
# OpenAI client
# ----------------------------------------------------------------------
client = OpenAI(
  api_key=os.environ.get("OPENAI_API_KEY"), 
  base_url=os.environ.get("OPENAI_BASE_URL")
)

# ----------------------------------------------------------------------
# Load dataset
# ----------------------------------------------------------------------
ds = load_dataset(args.dataset_repo, args.dataset_subset, split=args.split)
total = len(ds)
logging.info("Loaded %d samples from %s/%s[%s]",
             total, args.dataset_repo, args.dataset_subset, args.split)

# ----------------------------------------------------------------------
# Evaluation loop
# ----------------------------------------------------------------------
correct_em = 0

for ex in tqdm(ds, desc="Evaluating"):
    question = ex["input"]
    golds    = ex["answers"]

    resp = client.chat.completions.create(
        model=args.model,
        messages=[
            {"role": "system", "content": "You are a QA assistant."},
            {"role": "user",
            "content": f"Question: {question}\n"
            "Please first reply with *only* the final answer—no extra words.\n Answer:"}
        ],
        temperature=args.temperature,
        max_tokens=args.max_tokens,
    )
    pred = resp.choices[0].message.content.strip()
    print(f"A: {pred}\n G: {golds}")

    if any(qa_em_score(pred, g) for g in golds):
        correct_em += 1

    time.sleep(args.sleep)

em_score = correct_em / total
logging.info("RESULT | model=%s | subset=%s | EM=%.4f",
             args.model, args.dataset_subset, em_score)

print(f"\n=== SUMMARY ===\nModel   : {args.model}"
      f"\nDataset : {args.dataset_subset} ({args.split})"
      f"\nEM      : {em_score:.4f}\n"
      f"(Appended to {args.log})")