File size: 6,621 Bytes
b1e25b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Evaluate models on a LongBench subset with Exact-Match (EM).
Supports both Qwen3 (Transformers) and other models (vLLM).

Requirements
------------
pip install vllm datasets tqdm transformers accelerate
"""

import argparse, logging, time, torch
from pathlib import Path

from datasets import load_dataset
from tqdm import tqdm
from utils.metrics import qa_em_score
import os

# ---------------------------- CLI ------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument("--hf_model",
    default="Qwen/Qwen3-8B-Instruct",
    help="Model name or local path")
parser.add_argument("--is_qwen3", action="store_true",
    help="Set this flag if using Qwen3 model (uses Transformers). Otherwise uses vLLM.")
parser.add_argument("--max_new_tokens", type=int, default=20)
parser.add_argument("--max_tokens", type=int, default=20,
    help="For vLLM models (ignored if --is_qwen3)")
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top_p", type=float, default=1.0)
parser.add_argument("--tensor_parallel_size", type=int, default=2,
    help="GPU parallel size for vLLM (ignored if --is_qwen3)")

parser.add_argument("--dataset_repo", default="THUDM/LongBench")
parser.add_argument("--dataset_subset", default="hotpotqa")
parser.add_argument("--split", default="test")
parser.add_argument("--sleep", type=float, default=0.0)
parser.add_argument("--log", default="summary.log")
parser.add_argument("--cuda_devices", default="1,6",
    help="CUDA visible devices")
args = parser.parse_args()

# Set CUDA devices
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices

# --------------------------- logging ---------------------------------
logging.basicConfig(
    filename=args.log,
    level=logging.INFO,
    format="%(asctime)s - %(message)s",
    filemode="a",
)
logging.getLogger().addHandler(logging.StreamHandler())

# ------------------------- dataset -----------------------------------
ds = load_dataset(args.dataset_repo, args.dataset_subset, split=args.split)
total = len(ds)
logging.info("Loaded %d samples from %s/%s[%s]",
             total, args.dataset_repo, args.dataset_subset, args.split)

if args.is_qwen3:
    # ---------------------- Qwen3 with Transformers ----------------------------
    from transformers import AutoTokenizer, AutoModelForCausalLM
    
    load_kwargs = dict(
        trust_remote_code=True,
        device_map="auto",
    )

    tokenizer = AutoTokenizer.from_pretrained(args.hf_model,
                                              trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        args.hf_model,
        torch_dtype=torch.float16,
        **load_kwargs
    )

    EOS_ID      = tokenizer.eos_token_id
    THINK_ENDID = 151668  # </think> token id

    gen_kwargs = dict(
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        do_sample=args.temperature > 0,
        eos_token_id=EOS_ID,
    )

    # -------------------------- Qwen3 loop -------------------------------------
    correct_em = 0

    for ex in tqdm(ds, desc="Evaluating with Transformers (Qwen3)"):
        q = ex["input"]
        golds = ex["answers"]

        msgs = [
            {"role": "system", "content": "You are a QA assistant."},
            {"role": "user",
             "content": f"Question: {q}\n"
                        "Please reply with *only* the final answer—no extra words."}
        ]
        prompt = tokenizer.apply_chat_template(
            msgs,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False   # Qwen3 thinking mode
        )
        inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

        with torch.no_grad():
            outs = model.generate(**inputs, **gen_kwargs)[0]

        # Extract newly generated tokens
        new_ids = outs[len(inputs.input_ids[0]):].tolist()

        # Find </think> (if not exist idx=0)
        try:
            idx = len(new_ids) - new_ids[::-1].index(THINK_ENDID)
        except ValueError:
            idx = 0

        content = tokenizer.decode(new_ids[idx:],
                                   skip_special_tokens=True).strip("\n").strip()

        # Only use content for EM comparison
        if any(qa_em_score(content, g) for g in golds):
            correct_em += 1

        if args.sleep:
            time.sleep(args.sleep)

else:
    # ---------------------- Other models with vLLM ----------------------------
    from vllm import LLM, SamplingParams
    
    # Initialize vLLM
    llm = LLM(
        model=args.hf_model,
        tensor_parallel_size=args.tensor_parallel_size,
    )
    sampler = SamplingParams(
        temperature=args.temperature,
        max_tokens=args.max_tokens,
        top_p=args.top_p,
        stop=["</assistant>", "</s>", "<|end_of_text|>"],
    )

    # -------------------------- vLLM loop -------------------------------------
    correct_em = 0

    for ex in tqdm(ds, desc="Evaluating with vLLM"):
        question = ex["input"]
        golds = ex["answers"]      # list[str]
        
        chat_params = SamplingParams(
            temperature=args.temperature,
            max_tokens=args.max_tokens,
            top_p=args.top_p,
            stop=["</s>", "<|end_of_text|>"],   # Safety stop tokens
        )
        
        messages = [
            {"role": "system",
             "content": "You are a QA assistant."},
            {"role": "user",
             "content": f"Question: {question}\n"
                        "Please first reply with *only* the final answer—no extra words.\n Answer:"}
        ] 

        result = llm.chat(messages, sampling_params=chat_params)
        # vLLM returns list[RequestOutput]; take first output's first candidate
        pred = result[0].outputs[0].text.strip()
        print(f"A: {pred}\nG: {golds}\n")

        if any(qa_em_score(pred, g) for g in golds):
            correct_em += 1

        if args.sleep:
            time.sleep(args.sleep)

# -------------------------- result -----------------------------------
em = correct_em / total
model_type = "Qwen3 (Transformers)" if args.is_qwen3 else "vLLM"
logging.info("RESULT | model=%s | type=%s | subset=%s | EM=%.4f",
             args.hf_model, model_type, args.dataset_subset, em)
print(
    f"\n=== SUMMARY ===\n"
    f"Model   : {args.hf_model}\n"
    f"Type    : {model_type}\n"
    f"Subset  : {args.dataset_subset} ({args.split})\n"
    f"EM      : {em:.4f}\n"
    f"(Log in {Path(args.log).resolve()})"
)