| |
| |
|
|
| import argparse |
| import time |
| import requests |
| import pandas as pd |
| from tqdm import tqdm |
|
|
| def build_text(prompt: str, answer: str, sep: str) -> str: |
| """把 prompt + answer 按 sep 拼起来。""" |
| prompt = "" if prompt is None else str(prompt) |
| answer = "" if answer is None else str(answer) |
| return (prompt + sep + answer).strip() |
|
|
| def post_batch(server_url: str, queries, timeout=120, max_retries=3, sleep=1.0): |
| """调用服务端 /get_reward,返回 rewards 列表(与输入 len 一致)。""" |
| url = server_url.rstrip("/") + "/get_reward" |
| payload = {"query": queries, "prompts": None} |
| last_err = None |
| for _ in range(max_retries): |
| try: |
| resp = requests.post(url, json=payload, timeout=timeout) |
| resp.raise_for_status() |
| data = resp.json() |
| rewards = data.get("rewards") or data.get("scores") |
| if not isinstance(rewards, list): |
| raise ValueError(f"Bad response: {data}") |
| if len(rewards) != len(queries): |
| raise ValueError(f"Length mismatch: got {len(rewards)} for {len(queries)} queries") |
| return rewards |
| except Exception as e: |
| last_err = e |
| time.sleep(sleep) |
| raise RuntimeError(f"Request failed after retries: {last_err}") |
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--server_url", type=str, required=True, help="FastAPI 服务地址,如 http://localhost:5000") |
| ap.add_argument("--data_path", type=str, required=True, help="parquet 文件路径") |
| ap.add_argument("--prompt_key", type=str, default="chosen_prompt", help="prompt 字段名") |
| ap.add_argument("--chosen_key", type=str, default="chosen", help="chosen 回答字段名") |
| ap.add_argument("--rejected_key", type=str, default="reject", help="rejected 回答字段名") |
| ap.add_argument("--batch_size", type=int, default=64, help="每次请求发送多少条") |
| ap.add_argument("--sep", type=str, default="", help="prompt 与回答之间的连接符,如 '\\n' 或 空串") |
| ap.add_argument("--save_csv", type=str, default="rm_http_eval.csv", help="保存明细的 CSV 路径") |
| ap.add_argument("--print_each", action="store_true", help="逐样本打印 chosen/reject 分数与实时平均 acc") |
| args = ap.parse_args() |
|
|
| |
| df = pd.read_parquet(args.data_path) |
| for col in [args.prompt_key, args.chosen_key, args.rejected_key]: |
| if col not in df.columns: |
| raise ValueError(f"列 {col} 不在 parquet 中,现有列:{list(df.columns)}") |
|
|
| prompts = df[args.prompt_key].fillna("").astype(str).tolist() |
| chosens = df[args.chosen_key].fillna("").astype(str).tolist() |
| rejects = df[args.rejected_key].fillna("").astype(str).tolist() |
|
|
| |
| sep = args.sep.encode("utf-8").decode("unicode_escape") |
| chosen_queries = [build_text(p, c, sep) for p, c in zip(prompts, chosens)] |
| rejected_queries= [build_text(p, r, sep) for p, r in zip(prompts, rejects)] |
|
|
| N = len(chosen_queries) |
| chosen_scores, rejected_scores, accs = [], [], [] |
|
|
| seen, correct = 0, 0 |
| pbar = tqdm(range(0, N, args.batch_size), desc="HTTP Scoring") |
|
|
| for i in pbar: |
| j = min(i + args.batch_size, N) |
| |
| ch_scores = post_batch(args.server_url, chosen_queries[i:j]) |
| rj_scores = post_batch(args.server_url, rejected_queries[i:j]) |
|
|
| for k, (cs, rs) in enumerate(zip(ch_scores, rj_scores)): |
| delta = cs - rs |
| acc = 1 if delta > 0 else 0 |
| chosen_scores.append(cs) |
| rejected_scores.append(rs) |
| accs.append(acc) |
|
|
| seen += 1 |
| correct += acc |
| running_acc = correct / seen |
|
|
| if args.print_each: |
| |
| idx = i + k |
| tqdm.write(f"[{idx}] acc={acc}, chosen={cs:.3f}, rejected={rs:.3f}, Δ={delta:.3f} | avg acc={running_acc:.3f}") |
|
|
| pbar.set_postfix({"avg_acc": f"{running_acc:.3f}"}) |
|
|
| |
| out = df.copy() |
| out["chosen_score"] = chosen_scores |
| out["rejected_score"] = rejected_scores |
| out["delta"] = out["chosen_score"] - out["rejected_score"] |
| out["acc"] = accs |
|
|
| final_acc = float(out["acc"].mean()) if len(out) else 0.0 |
| mean_chosen = float(out["chosen_score"].mean()) if len(out) else 0.0 |
| mean_reject = float(out["rejected_score"].mean()) if len(out) else 0.0 |
| mean_delta = float(out["delta"].mean()) if len(out) else 0.0 |
|
|
| print("\n=========== RESULT (HTTP) ===========") |
| print(f"✅ Accuracy = {final_acc:.4f} ({sum(accs)}/{len(accs)})") |
| print(f"📊 Mean chosen = {mean_chosen:.4f}") |
| print(f"📉 Mean rejected = {mean_reject:.4f}") |
| print(f"🔼 Mean delta = {mean_delta:.4f}") |
|
|
| out.to_csv(args.save_csv, index=False) |
| print(f"💾 Saved details to: {args.save_csv}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|