| """๊ณตํต ํ๊ฐ ์คํฌ๋ฆฝํธ: vLLM ์๋ฒ์ ์ฐ๊ฒฐํ์ฌ HRM8K ์ ์ฒด 841๋ฌธ์ ํ๊ฐ (temperature=0)""" |
| import os, json, re, sys, asyncio |
| from langchain_openai import ChatOpenAI |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_core.output_parsers import StrOutputParser |
| from openai import OpenAI |
|
|
| MATH_SYSTEM_PROMPT = """์ฃผ์ด์ง ์ํ ๋ฌธ์ ๋ฅผ ๋จ๊ณ๋ณ๋ก ํ๊ณ ๋ต๋ณ์ ์์ฑํ์ธ์. |
| ๋ฐ๋์ ์ต์ข
๋ต๋ณ์ \\boxed{์ ์} ํ์์ผ๋ก ๋ง์ง๋ง ์ค์ ์ถ๋ ฅํ์ธ์. |
| ์์: \\boxed{42}""" |
|
|
| def extract_boxed(text): |
| m = re.findall(r'\\boxed\{([^}]+)\}', text) |
| return m[-1].strip() if m else None |
|
|
| def normalize(a): |
| if a is None: return None |
| s = str(a).replace(",","").replace(" ","").strip() |
| try: |
| n = float(s) |
| return str(int(n)) if n == int(n) else str(n) |
| except: return s |
|
|
| def check(pred, gt): |
| p, g = normalize(pred), normalize(gt) |
| return p is not None and g is not None and p == g |
|
|
| async def evaluate(label="", save_path=None): |
| client = OpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123") |
| model_name = client.models.list().data[0].id |
| print(f"๋ชจ๋ธ: {model_name}") |
|
|
| with open("data/HRM8k_eval.json") as f: |
| data = json.load(f) |
| print(f"ํ๊ฐ: {len(data)}๊ฐ (temperature=0, max_tokens=2048)") |
|
|
| llm = ChatOpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123", |
| model=model_name, temperature=0, max_tokens=2048) |
| prompt = ChatPromptTemplate([("user", "{sp}\n\n{q}")]).partial(sp=MATH_SYSTEM_PROMPT) |
| chain = prompt | llm | StrOutputParser() |
| inputs = [{"q": item["question"]} for item in data] |
| results = await chain.abatch(inputs, config={"max_concurrency": 400}) |
|
|
| by_src = {} |
| details = [] |
| for item, res in zip(data, results): |
| s = item.get("source", "?") |
| if s not in by_src: by_src[s] = {"correct": 0, "total": 0, "no_boxed": 0} |
| by_src[s]["total"] += 1 |
| pred = extract_boxed(res) |
| is_correct = False |
| if pred is None: |
| by_src[s]["no_boxed"] += 1 |
| elif check(pred, item["answer"]): |
| by_src[s]["correct"] += 1 |
| is_correct = True |
| details.append({ |
| "question": item["question"][:80], |
| "source": s, |
| "gt": str(item["answer"])[-30:] if isinstance(item["answer"], str) else str(item["answer"]), |
| "pred": pred, |
| "correct": is_correct, |
| }) |
|
|
| tc = sum(v["correct"] for v in by_src.values()) |
| tt = sum(v["total"] for v in by_src.values()) |
| print(f"\n=== {label} ๊ฒฐ๊ณผ (temperature=0) ===") |
| for s in sorted(by_src): |
| v = by_src[s] |
| print(f" [{s.upper()}] {v['correct']}/{v['total']} ({v['correct']/v['total']*100:.1f}%) | boxed๋ฏธ์ถ๋ ฅ: {v['no_boxed']}") |
| print(f" [์ ์ฒด] {tc}/{tt} ({tc/tt*100:.1f}%)") |
|
|
| result_obj = {"label": label, "correct": tc, "total": tt, "accuracy": tc/tt*100, "by_source": by_src} |
|
|
| if save_path: |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) |
| with open(save_path, "w") as f: |
| json.dump({"result": result_obj, "details": details}, f, ensure_ascii=False, indent=2) |
| print(f" ๊ฒฐ๊ณผ ์ ์ฅ: {save_path}") |
|
|
| return result_obj |
|
|
| if __name__ == "__main__": |
| label = sys.argv[1] if len(sys.argv) > 1 else "eval" |
| save_path = sys.argv[2] if len(sys.argv) > 2 else None |
| asyncio.run(evaluate(label, save_path)) |
|
|