dream-s1k-demo / eval_simple.py
况兑
eval: greedy decode + numeric strict; system: force full decimals; regressions: A/B/C/noisy
e45d7fc
import json, re, argparse, torch
import re
def _collapse_digit_separators(s: str) -> str:
# 去掉出现在“数字 与 数字”之间的空白/逗号(含常见窄空格)
return re.sub(r'(?<=\d)[,\s\u00A0\u202F\u2009\u2007\u2060]+(?=\d)', '', s)
def _postproc_model_text(s: str) -> str:
# collapse spaces/commas between digits like '330.7 6' -> '330.76', '1,234' -> '1234'
s = re.sub(r'(?<=\d)[,\s]+(?=\d)', '', s)
return s
def _preprocess_user_text(s: str) -> str:
# 全角标点 -> 半角
s = s.replace(",", ",").replace("。", ".").replace(":", ":").replace("(","(").replace(")",")")
# 去掉数字内部的逗号/空格(保留小数点)
s = re.sub(r'(?<=\d)[,\s]+(?=\d)', '', s)
# 压缩空白
s = re.sub(r'\s+', ' ', s).strip()
return s
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
def messages_to_pairs(messages):
pairs, buf = [], []
for m in messages:
if m.get("role")=="user":
buf.append(m.get("content",""))
elif m.get("role")=="assistant" and buf:
pairs.append({"prompt":"\n\n".join(buf), "response":m.get("content","")})
buf=[]
return pairs
def normalize(s: str) -> str:
s = s.replace("\u3000"," ").strip()
trans = str.maketrans(",。:!?【】()%+-×÷=“”‘’", ",.:!?[]()%+-*/=\"\"''")
s = s.translate(trans)
s = re.sub(r"\s+", " ", s)
return s
def to_num(x):
try:
return float(x)
except:
if not isinstance(x, str):
x = str(x)
x = _collapse_digit_separators(x)
m = re.search(r"[-+]?\d*\.?\d+(?:e[-+]?\d+)?", x, re.I)
return float(m.group(0)) if m else None
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--base", default="Qwen/Qwen2.5-0.5B-Instruct")
ap.add_argument("--adapter", required=True)
ap.add_argument("--data", required=True)
ap.add_argument("--max_new", type=int, default=64)
ap.add_argument("--limit", type=int, default=0, help="只评测前 N 条,0=全部")
args = ap.parse_args()
tok = AutoTokenizer.from_pretrained(args.base, trust_remote_code=True, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
base = AutoModelForCausalLM.from_pretrained(args.base, trust_remote_code=True, torch_dtype=torch.float32)
model = PeftModel.from_pretrained(base, args.adapter)
model.eval()
device = "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)
# 清理 generation 默认的采样项,避免 warning 且确保贪心
gc = model.generation_config
gc.do_sample = False
gc.temperature = None
gc.top_p = None
gc.top_k = None
golds, preds = [], []
with open(args.data, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
if args.limit and i >= args.limit: break
row = json.loads(line)
pairs = messages_to_pairs(row["messages"])
if not pairs: continue
ex = pairs[0]
gold = ex["response"]
ctx = [
{"role":"system","content":"请只输出最终答案,不要解释。只输出一个数字;若为小数,完整输出全部小数位,不要四舍五入或截断。"},
{"role":"user","content": ex["prompt"]}
]
prompt_text = tok.apply_chat_template(ctx, tokenize=False, add_generation_prompt=True)
inputs = tok(prompt_text, return_tensors="pt").to(device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=args.max_new,
do_sample=False,
temperature=None,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.eos_token_id
)
pred = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
pred = _postproc_model_text(pred)
golds.append(gold)
preds.append(pred)
print(f"[{i}] GT={repr(gold)} | PRED={repr(pred)}")
# 计算三种 EM
strict = sum(1 for g,p in zip(golds,preds) if p==g)
loose = sum(1 for g,p in zip(golds,preds) if normalize(p)==normalize(g))
numem = 0
for g,p in zip(golds,preds):
ng, np = to_num(g), to_num(p)
if ng is not None and np is not None and abs(ng-np)<1e-6:
numem += 1
n = len(golds) if golds else 1
print(f"\n==> EM strict={strict/n:.3f} EM loose={loose/n:.3f} EM numeric={numem/n:.3f} (N={n})")
if __name__ == "__main__":
main()