|
|
|
|
|
""" |
|
Evaluate models on a LongBench subset with Exact-Match (EM). |
|
Supports both Qwen3 (Transformers) and other models (vLLM). |
|
|
|
Requirements |
|
------------ |
|
pip install vllm datasets tqdm transformers accelerate |
|
""" |
|
|
|
import argparse, logging, time, torch |
|
from pathlib import Path |
|
|
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
from utils.metrics import qa_em_score |
|
import os |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--hf_model", |
|
default="Qwen/Qwen3-8B-Instruct", |
|
help="Model name or local path") |
|
parser.add_argument("--is_qwen3", action="store_true", |
|
help="Set this flag if using Qwen3 model (uses Transformers). Otherwise uses vLLM.") |
|
parser.add_argument("--max_new_tokens", type=int, default=20) |
|
parser.add_argument("--max_tokens", type=int, default=20, |
|
help="For vLLM models (ignored if --is_qwen3)") |
|
parser.add_argument("--temperature", type=float, default=0.0) |
|
parser.add_argument("--top_p", type=float, default=1.0) |
|
parser.add_argument("--tensor_parallel_size", type=int, default=2, |
|
help="GPU parallel size for vLLM (ignored if --is_qwen3)") |
|
|
|
parser.add_argument("--dataset_repo", default="THUDM/LongBench") |
|
parser.add_argument("--dataset_subset", default="hotpotqa") |
|
parser.add_argument("--split", default="test") |
|
parser.add_argument("--sleep", type=float, default=0.0) |
|
parser.add_argument("--log", default="summary.log") |
|
parser.add_argument("--cuda_devices", default="1,6", |
|
help="CUDA visible devices") |
|
args = parser.parse_args() |
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices |
|
|
|
|
|
logging.basicConfig( |
|
filename=args.log, |
|
level=logging.INFO, |
|
format="%(asctime)s - %(message)s", |
|
filemode="a", |
|
) |
|
logging.getLogger().addHandler(logging.StreamHandler()) |
|
|
|
|
|
ds = load_dataset(args.dataset_repo, args.dataset_subset, split=args.split) |
|
total = len(ds) |
|
logging.info("Loaded %d samples from %s/%s[%s]", |
|
total, args.dataset_repo, args.dataset_subset, args.split) |
|
|
|
if args.is_qwen3: |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
load_kwargs = dict( |
|
trust_remote_code=True, |
|
device_map="auto", |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.hf_model, |
|
trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
args.hf_model, |
|
torch_dtype=torch.float16, |
|
**load_kwargs |
|
) |
|
|
|
EOS_ID = tokenizer.eos_token_id |
|
THINK_ENDID = 151668 |
|
|
|
gen_kwargs = dict( |
|
max_new_tokens=args.max_new_tokens, |
|
temperature=args.temperature, |
|
top_p=args.top_p, |
|
do_sample=args.temperature > 0, |
|
eos_token_id=EOS_ID, |
|
) |
|
|
|
|
|
correct_em = 0 |
|
|
|
for ex in tqdm(ds, desc="Evaluating with Transformers (Qwen3)"): |
|
q = ex["input"] |
|
golds = ex["answers"] |
|
|
|
msgs = [ |
|
{"role": "system", "content": "You are a QA assistant."}, |
|
{"role": "user", |
|
"content": f"Question: {q}\n" |
|
"Please reply with *only* the final answer—no extra words."} |
|
] |
|
prompt = tokenizer.apply_chat_template( |
|
msgs, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
enable_thinking=False |
|
) |
|
inputs = tokenizer([prompt], return_tensors="pt").to(model.device) |
|
|
|
with torch.no_grad(): |
|
outs = model.generate(**inputs, **gen_kwargs)[0] |
|
|
|
|
|
new_ids = outs[len(inputs.input_ids[0]):].tolist() |
|
|
|
|
|
try: |
|
idx = len(new_ids) - new_ids[::-1].index(THINK_ENDID) |
|
except ValueError: |
|
idx = 0 |
|
|
|
content = tokenizer.decode(new_ids[idx:], |
|
skip_special_tokens=True).strip("\n").strip() |
|
|
|
|
|
if any(qa_em_score(content, g) for g in golds): |
|
correct_em += 1 |
|
|
|
if args.sleep: |
|
time.sleep(args.sleep) |
|
|
|
else: |
|
|
|
from vllm import LLM, SamplingParams |
|
|
|
|
|
llm = LLM( |
|
model=args.hf_model, |
|
tensor_parallel_size=args.tensor_parallel_size, |
|
) |
|
sampler = SamplingParams( |
|
temperature=args.temperature, |
|
max_tokens=args.max_tokens, |
|
top_p=args.top_p, |
|
stop=["</assistant>", "</s>", "<|end_of_text|>"], |
|
) |
|
|
|
|
|
correct_em = 0 |
|
|
|
for ex in tqdm(ds, desc="Evaluating with vLLM"): |
|
question = ex["input"] |
|
golds = ex["answers"] |
|
|
|
chat_params = SamplingParams( |
|
temperature=args.temperature, |
|
max_tokens=args.max_tokens, |
|
top_p=args.top_p, |
|
stop=["</s>", "<|end_of_text|>"], |
|
) |
|
|
|
messages = [ |
|
{"role": "system", |
|
"content": "You are a QA assistant."}, |
|
{"role": "user", |
|
"content": f"Question: {question}\n" |
|
"Please first reply with *only* the final answer—no extra words.\n Answer:"} |
|
] |
|
|
|
result = llm.chat(messages, sampling_params=chat_params) |
|
|
|
pred = result[0].outputs[0].text.strip() |
|
print(f"A: {pred}\nG: {golds}\n") |
|
|
|
if any(qa_em_score(pred, g) for g in golds): |
|
correct_em += 1 |
|
|
|
if args.sleep: |
|
time.sleep(args.sleep) |
|
|
|
|
|
em = correct_em / total |
|
model_type = "Qwen3 (Transformers)" if args.is_qwen3 else "vLLM" |
|
logging.info("RESULT | model=%s | type=%s | subset=%s | EM=%.4f", |
|
args.hf_model, model_type, args.dataset_subset, em) |
|
print( |
|
f"\n=== SUMMARY ===\n" |
|
f"Model : {args.hf_model}\n" |
|
f"Type : {model_type}\n" |
|
f"Subset : {args.dataset_subset} ({args.split})\n" |
|
f"EM : {em:.4f}\n" |
|
f"(Log in {Path(args.log).resolve()})" |
|
) |