cmpatino HF Staff commited on
Commit
707fcea
·
verified ·
1 Parent(s): 011d2c2

Upload code/run_all.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/run_all.py +616 -1
code/run_all.py CHANGED
@@ -1 +1,616 @@
1
- # Content uploaded from local file - see run_all.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Complete pipeline for Best-of-N weighted selection on MATH-500.
3
+
4
+ This single script runs all steps:
5
+ 1. Filter MATH-500 to 20 level 1-3 problems
6
+ 2. Generate greedy (N=1) solutions and compute baseline accuracy
7
+ 3. Sample N=16 solutions per problem with temperature sampling
8
+ 4. Score all solutions with Skywork PRM (last-step prediction)
9
+ 5. Compute weighted Best-of-N accuracy
10
+ 6. Create dataset and push to HuggingFace Hub
11
+ 7. Generate analysis plots and push them too
12
+
13
+ Reference papers:
14
+ - DeepMind (2408.03314): Scaling LLM Test-Time Compute, Section 5.1 + Appendix E
15
+ - Math-Shepherd (2312.08935): Process Reward Models, Section 3.4
16
+
17
+ Co-authored with Claude (Anthropic) as part of the HuggingFace internship exercise.
18
+ I can explain all code logic in detail.
19
+ """
20
+
21
+ import json
22
+ import os
23
+ import random
24
+ import subprocess
25
+ import sys
26
+ import torch
27
+ import numpy as np
28
+ from collections import defaultdict
29
+ from typing import Optional
30
+
31
+ from datasets import Dataset, load_dataset
32
+ from transformers import AutoTokenizer, AutoModelForCausalLM
33
+
34
+
35
+ # ═══════════════════════════════════════════════════════════════════════════════
36
+ # Configuration
37
+ # ═══════════════════════════════════════════════════════════════════════════════
38
+ N_PROBLEMS = 20 # Number of problems to evaluate
39
+ N_SAMPLES = 16 # Number of solutions per problem for Best-of-N
40
+ TEMPERATURE = 0.7 # Sampling temperature for diverse solutions
41
+ MAX_NEW_TOKENS = 2048 # Max generation length
42
+ SEED = 42 # Random seed for reproducibility
43
+ LLM_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
44
+ PRM_MODEL_ID = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
45
+ DATASET_ID = "cmpatino/math500-bon-weighted-results"
46
+
47
+ OUTPUT_DIR = "/tmp/exercise_outputs"
48
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
49
+
50
+ # System prompt: encourages chain-of-thought reasoning and \boxed{} format
51
+ SYSTEM_PROMPT = (
52
+ "You are a helpful math assistant. Solve the problem step by step, "
53
+ "showing your reasoning clearly. Put your final answer inside "
54
+ "\\boxed{answer} at the end of your solution."
55
+ )
56
+
57
+
58
+ # ═══════════════════════════════════════════════════════════════════════════════
59
+ # Helper functions
60
+ # ═══════════════════════════════════════════════════════════════════════════════
61
+
62
+ def extract_boxed_solution(text: str) -> Optional[str]:
63
+ """
64
+ Extract content of the last \\boxed{} in text.
65
+ Uses bracket-balanced parsing for nested braces.
66
+ Source: https://gist.github.com/lewtun/9c2ce1937b741404090a3dc4c7c022b3
67
+ """
68
+ try:
69
+ start_index = text.rindex("\\boxed{")
70
+ content_start = start_index + 7
71
+ bracket_count = 1
72
+ current_pos = content_start
73
+ while bracket_count > 0 and current_pos < len(text):
74
+ if text[current_pos] == "{":
75
+ bracket_count += 1
76
+ elif text[current_pos] == "}":
77
+ bracket_count -= 1
78
+ current_pos += 1
79
+ if bracket_count == 0:
80
+ return text[content_start : current_pos - 1].strip()
81
+ return None
82
+ except (ValueError, Exception):
83
+ return None
84
+
85
+
86
+ def weighted_best_of_n(extracted_answers, prm_scores):
87
+ """
88
+ Weighted Best-of-N selection (DeepMind 2408.03314, Eq. from Section 5.1):
89
+ â = argmax_a Σᵢ 𝟙(aᵢ = a) · score(sᵢ)
90
+
91
+ Groups solutions by final answer, sums their PRM scores,
92
+ and selects the answer group with the highest total.
93
+ """
94
+ answer_scores = defaultdict(float)
95
+ for answer, score in zip(extracted_answers, prm_scores):
96
+ if answer is None:
97
+ continue # Skip unparseable solutions
98
+ answer_scores[answer] += score
99
+ if not answer_scores:
100
+ return None, {}
101
+ best_answer = max(answer_scores, key=answer_scores.get)
102
+ return best_answer, dict(answer_scores)
103
+
104
+
105
+ def standard_best_of_n(extracted_answers, prm_scores):
106
+ """Standard Best-of-N: pick the single solution with highest PRM score."""
107
+ best_idx, best_score = None, -1
108
+ for i, (answer, score) in enumerate(zip(extracted_answers, prm_scores)):
109
+ if answer is not None and score > best_score:
110
+ best_score = score
111
+ best_idx = i
112
+ return extracted_answers[best_idx] if best_idx is not None else None
113
+
114
+
115
+ def majority_vote(extracted_answers):
116
+ """Pure majority vote: pick the most frequent answer."""
117
+ counts = defaultdict(int)
118
+ for answer in extracted_answers:
119
+ if answer is not None:
120
+ counts[answer] += 1
121
+ return max(counts, key=counts.get) if counts else None
122
+
123
+
124
+ # ═══════════════════════════════════════════════════════════════════════════════
125
+ # STEP 1: Filter MATH-500 to level 1-3 problems
126
+ # ═══════════════════════════════════════════════════════════════════════════════
127
+ print("=" * 70)
128
+ print("STEP 1: Loading and filtering MATH-500 dataset")
129
+ print("=" * 70)
130
+
131
+ dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
132
+ print(f"Total problems: {len(dataset)}")
133
+
134
+ # Filter to levels 1-3 — these are easier problems that a 1.5B model
135
+ # can reasonably attempt. Levels 4-5 are too hard for small models.
136
+ filtered = dataset.filter(lambda x: x["level"] in [1, 2, 3])
137
+ print(f"Level 1-3 problems: {len(filtered)}")
138
+
139
+ # Fixed random sample for reproducibility
140
+ random.seed(SEED)
141
+ indices = random.sample(range(len(filtered)), k=N_PROBLEMS)
142
+ problems = filtered.select(indices)
143
+
144
+ problems_data = []
145
+ for i, p in enumerate(problems):
146
+ problems_data.append({
147
+ "idx": i,
148
+ "problem": p["problem"],
149
+ "solution": p["solution"],
150
+ "answer": p["answer"],
151
+ "subject": p["subject"],
152
+ "level": p["level"],
153
+ "unique_id": p["unique_id"],
154
+ })
155
+ print(f" [{i+1:2d}] L{p['level']} {p['subject']:25s} {p['unique_id']}")
156
+
157
+ # Save for reference
158
+ with open(os.path.join(OUTPUT_DIR, "filtered_problems.json"), "w") as f:
159
+ json.dump(problems_data, f, indent=2)
160
+
161
+
162
+ # ═══════════════════════════════════════════════════════════════════════════════
163
+ # STEP 2: Generate greedy (N=1) solutions
164
+ # ═══════════════════════════════════════════════════════════════════════════════
165
+ print("\n" + "=" * 70)
166
+ print("STEP 2: Generating greedy solutions (N=1)")
167
+ print("=" * 70)
168
+
169
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
170
+ model = AutoModelForCausalLM.from_pretrained(
171
+ LLM_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto"
172
+ )
173
+
174
+
175
+ def generate_batch(problems_data, model, tokenizer, n, do_sample, temperature=None):
176
+ """Generate n solutions per problem. Returns list of solution lists."""
177
+ all_solutions = []
178
+ for i, p in enumerate(problems_data):
179
+ messages = [
180
+ {"role": "system", "content": SYSTEM_PROMPT},
181
+ {"role": "user", "content": p["problem"]},
182
+ ]
183
+ prompt = tokenizer.apply_chat_template(
184
+ messages, tokenize=False, add_generation_prompt=True
185
+ )
186
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
187
+
188
+ solutions = []
189
+ for j in range(n):
190
+ gen_kwargs = {"max_new_tokens": MAX_NEW_TOKENS, "do_sample": do_sample}
191
+ if do_sample and temperature:
192
+ gen_kwargs["temperature"] = temperature
193
+ gen_kwargs["top_p"] = 0.95
194
+ with torch.no_grad():
195
+ output = model.generate(**inputs, **gen_kwargs)
196
+ generated = output[0][inputs["input_ids"].shape[1]:]
197
+ solutions.append(tokenizer.decode(generated, skip_special_tokens=True))
198
+
199
+ all_solutions.append(solutions)
200
+ ans = extract_boxed_solution(solutions[0]) if n == 1 else "..."
201
+ tag = "greedy" if n == 1 else f"N={n}"
202
+ print(f" [{i+1:2d}/{len(problems_data)}] {tag} | {p['unique_id']} | answer={ans}")
203
+
204
+ return all_solutions
205
+
206
+
207
+ # Greedy decoding (N=1, deterministic)
208
+ greedy_solutions = generate_batch(problems_data, model, tokenizer, n=1, do_sample=False)
209
+
210
+ # Evaluate greedy accuracy
211
+ greedy_correct = 0
212
+ for p, sols in zip(problems_data, greedy_solutions):
213
+ extracted = extract_boxed_solution(sols[0])
214
+ p["greedy_solution"] = sols[0]
215
+ p["greedy_extracted_answer"] = extracted
216
+ p["greedy_correct"] = (extracted is not None) and (extracted == p["answer"])
217
+ if p["greedy_correct"]:
218
+ greedy_correct += 1
219
+ status = "✓" if p["greedy_correct"] else "✗"
220
+ print(f" {status} Expected: {p['answer']:20s} | Got: {str(extracted):20s} | {p['unique_id']}")
221
+
222
+ greedy_acc = greedy_correct / len(problems_data)
223
+ print(f"\n>>> Greedy accuracy: {greedy_correct}/{len(problems_data)} = {greedy_acc:.0%}")
224
+
225
+
226
+ # ═══════════════════════════════════════════════════════════════════════════════
227
+ # STEP 3: Sample N=16 solutions per problem
228
+ # ═══════════════════════════════════════════════════════════════════════════════
229
+ print("\n" + "=" * 70)
230
+ print(f"STEP 3: Sampling N={N_SAMPLES} solutions per problem (T={TEMPERATURE})")
231
+ print("=" * 70)
232
+
233
+ sampled_solutions = generate_batch(
234
+ problems_data, model, tokenizer,
235
+ n=N_SAMPLES, do_sample=True, temperature=TEMPERATURE
236
+ )
237
+
238
+ # Save solutions and free LLM memory
239
+ for p, sols in zip(problems_data, sampled_solutions):
240
+ p["sampled_solutions"] = sols
241
+
242
+ with open(os.path.join(OUTPUT_DIR, "sampled_solutions.json"), "w") as f:
243
+ json.dump(problems_data, f, indent=2)
244
+
245
+ del model
246
+ torch.cuda.empty_cache()
247
+ print("Freed LLM memory for PRM loading.")
248
+
249
+
250
+ # ═══════════════════════════════════════════════════════════════════════════════
251
+ # STEP 4: Score with Skywork PRM (last-step prediction)
252
+ # ═══════════════════════════════════════════════════════════════════════════════
253
+ print("\n" + "=" * 70)
254
+ print("STEP 4: Scoring solutions with Skywork PRM")
255
+ print("=" * 70)
256
+
257
+ # Clone the Skywork PRM inference repo for the custom PRM_MODEL class
258
+ PRM_REPO_PATH = "/tmp/skywork-o1-prm-inference"
259
+ if not os.path.exists(PRM_REPO_PATH):
260
+ print("Cloning Skywork PRM inference repo...")
261
+ subprocess.run(
262
+ ["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git", PRM_REPO_PATH],
263
+ check=True,
264
+ )
265
+ sys.path.insert(0, PRM_REPO_PATH)
266
+
267
+ from model_utils.prm_model import PRM_MODEL
268
+ from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
269
+
270
+ prm_tokenizer = AutoTokenizer.from_pretrained(PRM_MODEL_ID, trust_remote_code=True)
271
+ prm_model = PRM_MODEL.from_pretrained(PRM_MODEL_ID, device_map="auto").eval()
272
+ prm_device = next(prm_model.pretrained_model.parameters()).device
273
+ print(f"PRM loaded on {prm_device}")
274
+
275
+
276
+ def score_solution(problem: str, solution: str) -> float:
277
+ """
278
+ Score a single solution using the PRM's last-step prediction.
279
+
280
+ Per DeepMind (2408.03314, Appendix E): "We use the PRM's prediction at the
281
+ last step as the full-answer score" — this outperforms min/product aggregation
282
+ when the PRM is trained with soft MC-return labels.
283
+
284
+ Returns: float in [0, 1] — the sigmoid-normalized score at the last step.
285
+ """
286
+ input_ids, steps, reward_flags = prepare_input(
287
+ problem, solution, prm_tokenizer, step_token="\n"
288
+ )
289
+ input_ids_t, attn_mask_t, flags_t = prepare_batch_input_for_model(
290
+ [input_ids], [reward_flags], prm_tokenizer.pad_token_id
291
+ )
292
+ input_ids_t = input_ids_t.to(prm_device)
293
+ attn_mask_t = attn_mask_t.to(prm_device)
294
+ flags_t = flags_t.to(prm_device)
295
+
296
+ with torch.no_grad():
297
+ _, _, rewards = prm_model(
298
+ input_ids=input_ids_t, attention_mask=attn_mask_t, return_probs=True
299
+ )
300
+ step_rewards = derive_step_rewards(rewards, flags_t)
301
+ # Return last step score (or 0.0 if no steps found)
302
+ return step_rewards[0][-1] if step_rewards[0] else 0.0
303
+
304
+
305
+ # Score all sampled solutions
306
+ for i, p in enumerate(problems_data):
307
+ print(f"\n Scoring problem {i+1}/{len(problems_data)}: {p['unique_id']}")
308
+ scores = []
309
+ extracted_answers = []
310
+ for j, sol in enumerate(p["sampled_solutions"]):
311
+ score = score_solution(p["problem"], sol)
312
+ scores.append(score)
313
+ extracted_answers.append(extract_boxed_solution(sol))
314
+ if (j + 1) % 8 == 0:
315
+ print(f" Scored {j+1}/{N_SAMPLES} (last: {score:.4f})")
316
+ p["prm_scores"] = scores
317
+ p["extracted_answers"] = extracted_answers
318
+
319
+ # Save scored results
320
+ with open(os.path.join(OUTPUT_DIR, "scored_results.json"), "w") as f:
321
+ json.dump(problems_data, f, indent=2)
322
+
323
+ del prm_model
324
+ torch.cuda.empty_cache()
325
+
326
+
327
+ # ═══════════════════════════════════════════════════════════════════════════════
328
+ # STEP 5: Compute Best-of-N with weighted selection
329
+ # ═══════════════════════════════════════════════════════════════════════════════
330
+ print("\n" + "=" * 70)
331
+ print("STEP 5: Computing Best-of-N accuracy")
332
+ print("=" * 70)
333
+
334
+ weighted_correct = 0
335
+ standard_correct = 0
336
+ majority_correct_count = 0
337
+
338
+ bon_summary = []
339
+ for p in problems_data:
340
+ gt = p["answer"]
341
+
342
+ # Weighted BoN
343
+ w_ans, w_scores = weighted_best_of_n(p["extracted_answers"], p["prm_scores"])
344
+ w_ok = (w_ans is not None) and (w_ans == gt)
345
+ if w_ok: weighted_correct += 1
346
+
347
+ # Standard BoN
348
+ s_ans = standard_best_of_n(p["extracted_answers"], p["prm_scores"])
349
+ s_ok = (s_ans is not None) and (s_ans == gt)
350
+ if s_ok: standard_correct += 1
351
+
352
+ # Majority vote
353
+ m_ans = majority_vote(p["extracted_answers"])
354
+ m_ok = (m_ans is not None) and (m_ans == gt)
355
+ if m_ok: majority_correct_count += 1
356
+
357
+ n_correct = sum(1 for a in p["extracted_answers"] if a == gt)
358
+
359
+ bon_summary.append({
360
+ "unique_id": p["unique_id"],
361
+ "level": p["level"],
362
+ "subject": p["subject"],
363
+ "ground_truth": gt,
364
+ "greedy_answer": p["greedy_extracted_answer"],
365
+ "greedy_correct": p["greedy_correct"],
366
+ "weighted_bon_answer": w_ans,
367
+ "weighted_bon_correct": w_ok,
368
+ "standard_bon_answer": s_ans,
369
+ "standard_bon_correct": s_ok,
370
+ "majority_vote_answer": m_ans,
371
+ "majority_vote_correct": m_ok,
372
+ "n_correct_in_16": n_correct,
373
+ "answer_score_breakdown": w_scores,
374
+ "prm_scores": p["prm_scores"],
375
+ })
376
+
377
+ sg = "✓" if p["greedy_correct"] else "✗"
378
+ sw = "✓" if w_ok else "✗"
379
+ print(f" {sg}→{sw} | {p['unique_id']:40s} | GT={gt:15s} | Greedy={str(p['greedy_extracted_answer']):15s} | WBoN={str(w_ans):15s} | {n_correct}/16 correct")
380
+
381
+ n = len(problems_data)
382
+ greedy_total = sum(1 for p in problems_data if p["greedy_correct"])
383
+ print(f"\n{'='*70}")
384
+ print(f"RESULTS SUMMARY")
385
+ print(f"{'='*70}")
386
+ print(f" Greedy (N=1): {greedy_total}/{n} = {greedy_total/n:.0%}")
387
+ print(f" Majority Vote (N=16): {majority_correct_count}/{n} = {majority_correct_count/n:.0%}")
388
+ print(f" Standard Best-of-N (N=16): {standard_correct}/{n} = {standard_correct/n:.0%}")
389
+ print(f" Weighted Best-of-N (N=16): {weighted_correct}/{n} = {weighted_correct/n:.0%}")
390
+
391
+ with open(os.path.join(OUTPUT_DIR, "bon_results.json"), "w") as f:
392
+ json.dump(bon_summary, f, indent=2)
393
+
394
+
395
+ # ═══════════════════════════════════════════════════════════════════════════════
396
+ # STEP 5b: Accuracy vs N analysis
397
+ # ═══════════════════════════════════════════════════════════════════════════════
398
+ print("\n" + "=" * 70)
399
+ print("ANALYSIS: Accuracy vs N")
400
+ print("=" * 70)
401
+
402
+ random.seed(SEED)
403
+ n_values = [1, 2, 4, 8, 16]
404
+ n_trials = 50
405
+
406
+ accuracy_by_n = {}
407
+ for n_val in n_values:
408
+ if n_val == 16:
409
+ correct = sum(1 for p in problems_data
410
+ for _ in [weighted_best_of_n(p["extracted_answers"], p["prm_scores"])]
411
+ if _[0] == p["answer"])
412
+ acc = correct / len(problems_data)
413
+ else:
414
+ trial_accs = []
415
+ for _ in range(n_trials):
416
+ correct = 0
417
+ for p in problems_data:
418
+ idx = random.sample(range(16), n_val)
419
+ sub_a = [p["extracted_answers"][j] for j in idx]
420
+ sub_s = [p["prm_scores"][j] for j in idx]
421
+ ans, _ = weighted_best_of_n(sub_a, sub_s)
422
+ if ans == p["answer"]:
423
+ correct += 1
424
+ trial_accs.append(correct / len(problems_data))
425
+ acc = sum(trial_accs) / len(trial_accs)
426
+ accuracy_by_n[n_val] = acc
427
+ print(f" N={n_val:2d}: {acc:.1%}")
428
+
429
+ with open(os.path.join(OUTPUT_DIR, "accuracy_by_n.json"), "w") as f:
430
+ json.dump(accuracy_by_n, f, indent=2)
431
+
432
+
433
+ # ═══════════════════════════════════════════════════════════════════════════════
434
+ # STEP 6: Generate plots
435
+ # ═══════════════════════════════════════════════════════════════════════════════
436
+ print("\n" + "=" * 70)
437
+ print("STEP 6: Generating analysis plots")
438
+ print("=" * 70)
439
+
440
+ import matplotlib
441
+ matplotlib.use("Agg")
442
+ import matplotlib.pyplot as plt
443
+ from matplotlib.patches import Patch
444
+
445
+ plt.rcParams.update({"font.size": 11, "figure.dpi": 150})
446
+
447
+ # --- Plot 1: Overall accuracy comparison ---
448
+ fig, ax = plt.subplots(figsize=(8, 5))
449
+ methods = ["Greedy\n(N=1)", "Majority Vote\n(N=16)", "Standard BoN\n(N=16)", "Weighted BoN\n(N=16)"]
450
+ accs = [
451
+ greedy_total / n,
452
+ majority_correct_count / n,
453
+ standard_correct / n,
454
+ weighted_correct / n,
455
+ ]
456
+ colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"]
457
+ bars = ax.bar(methods, accs, color=colors, edgecolor="white", linewidth=1.5)
458
+ for bar, a in zip(bars, accs):
459
+ ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
460
+ f"{a:.0%}", ha="center", va="bottom", fontweight="bold", fontsize=12)
461
+ ax.set_ylabel("Accuracy")
462
+ ax.set_title("Math Problem Accuracy: Greedy vs Best-of-N Methods\n(20 MATH-500 problems, Levels 1-3)")
463
+ ax.set_ylim(0, 1.15)
464
+ ax.grid(axis="y", alpha=0.3)
465
+ plt.tight_layout()
466
+ plt.savefig(os.path.join(OUTPUT_DIR, "plot1_accuracy_comparison.png"))
467
+ plt.close()
468
+
469
+ # --- Plot 2: Accuracy vs N ---
470
+ fig, ax = plt.subplots(figsize=(7, 5))
471
+ ns = sorted(accuracy_by_n.keys())
472
+ acc_vals = [accuracy_by_n[nv] for nv in ns]
473
+ ax.plot(ns, acc_vals, "o-", color="#8172B2", linewidth=2, markersize=8, label="Weighted BoN")
474
+ ax.axhline(y=greedy_total/n, color="#4C72B0", linestyle="--", linewidth=1.5,
475
+ label=f"Greedy baseline ({greedy_total/n:.0%})")
476
+ for nv, a in zip(ns, acc_vals):
477
+ ax.annotate(f"{a:.0%}", (nv, a), textcoords="offset points", xytext=(0, 10), ha="center")
478
+ ax.set_xlabel("N (number of samples)")
479
+ ax.set_ylabel("Accuracy")
480
+ ax.set_title("Weighted Best-of-N Accuracy vs Number of Samples")
481
+ ax.set_xticks(ns)
482
+ ax.set_ylim(0, 1.1)
483
+ ax.legend()
484
+ ax.grid(alpha=0.3)
485
+ plt.tight_layout()
486
+ plt.savefig(os.path.join(OUTPUT_DIR, "plot2_accuracy_vs_n.png"))
487
+ plt.close()
488
+
489
+ # --- Plot 3: Per-problem analysis ---
490
+ fig, ax = plt.subplots(figsize=(12, 5))
491
+ cat_colors = {
492
+ "Both correct": "#55A868", "Only BoN correct": "#8172B2",
493
+ "Only Greedy correct": "#C44E52", "Both wrong": "#CCCCCC"
494
+ }
495
+ bar_colors = []
496
+ for s in bon_summary:
497
+ g, b = s["greedy_correct"], s["weighted_bon_correct"]
498
+ if g and b: bar_colors.append(cat_colors["Both correct"])
499
+ elif not g and b: bar_colors.append(cat_colors["Only BoN correct"])
500
+ elif g and not b: bar_colors.append(cat_colors["Only Greedy correct"])
501
+ else: bar_colors.append(cat_colors["Both wrong"])
502
+
503
+ x = range(len(bon_summary))
504
+ heights = [s["n_correct_in_16"] for s in bon_summary]
505
+ ax.bar(x, heights, color=bar_colors, edgecolor="white", linewidth=0.5)
506
+ ax.set_xticks(x)
507
+ labels = [f"L{s['level']}: {s['unique_id'].split('/')[-1].replace('.json','')[:12]}" for s in bon_summary]
508
+ ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
509
+ ax.set_ylabel("# Correct Solutions (out of 16)")
510
+ ax.set_title("Per-Problem: Correct Solutions in N=16 Sample")
511
+ legend_elements = [Patch(facecolor=c, label=l) for l, c in cat_colors.items()]
512
+ ax.legend(handles=legend_elements, loc="upper right", fontsize=9)
513
+ ax.grid(axis="y", alpha=0.3)
514
+ plt.tight_layout()
515
+ plt.savefig(os.path.join(OUTPUT_DIR, "plot3_per_problem.png"))
516
+ plt.close()
517
+
518
+ # --- Plot 4: PRM score distribution ---
519
+ fig, ax = plt.subplots(figsize=(7, 5))
520
+ correct_scores, incorrect_scores = [], []
521
+ for p in problems_data:
522
+ for ans, sc in zip(p["extracted_answers"], p["prm_scores"]):
523
+ (correct_scores if ans == p["answer"] else incorrect_scores).append(sc)
524
+
525
+ bins = np.linspace(0, 1, 25)
526
+ ax.hist(correct_scores, bins=bins, alpha=0.7, label=f"Correct ({len(correct_scores)})", color="#55A868")
527
+ ax.hist(incorrect_scores, bins=bins, alpha=0.7, label=f"Incorrect ({len(incorrect_scores)})", color="#C44E52")
528
+ ax.set_xlabel("PRM Last-Step Score")
529
+ ax.set_ylabel("Count")
530
+ ax.set_title("PRM Score Distribution: Correct vs Incorrect Solutions")
531
+ ax.legend()
532
+ ax.grid(alpha=0.3)
533
+ plt.tight_layout()
534
+ plt.savefig(os.path.join(OUTPUT_DIR, "plot4_prm_scores.png"))
535
+ plt.close()
536
+
537
+ print("All plots saved.")
538
+
539
+
540
+ # ═══════════════════════════════════════════════════════════════════════════════
541
+ # STEP 7: Push dataset to Hub
542
+ # ═══════════════════════════════════════════════════════════════════════════════
543
+ print("\n" + "=" * 70)
544
+ print("STEP 7: Pushing dataset to HuggingFace Hub")
545
+ print("=" * 70)
546
+
547
+ rows = []
548
+ for p, s in zip(problems_data, bon_summary):
549
+ rows.append({
550
+ "problem": p["problem"],
551
+ "ground_truth_solution": p["solution"],
552
+ "ground_truth_answer": p["answer"],
553
+ "subject": p["subject"],
554
+ "level": p["level"],
555
+ "unique_id": p["unique_id"],
556
+ "greedy_solution": p["greedy_solution"],
557
+ "greedy_extracted_answer": p["greedy_extracted_answer"],
558
+ "greedy_correct": p["greedy_correct"],
559
+ "bon_weighted_answer": s["weighted_bon_answer"],
560
+ "bon_weighted_correct": s["weighted_bon_correct"],
561
+ "bon_standard_answer": s["standard_bon_answer"],
562
+ "bon_standard_correct": s["standard_bon_correct"],
563
+ "bon_majority_answer": s["majority_vote_answer"],
564
+ "bon_majority_correct": s["majority_vote_correct"],
565
+ "sampled_solutions": p["sampled_solutions"],
566
+ "sampled_extracted_answers": p["extracted_answers"],
567
+ "sampled_prm_scores": p["prm_scores"],
568
+ "n_correct_in_16": s["n_correct_in_16"],
569
+ "answer_score_breakdown": json.dumps(s["answer_score_breakdown"]),
570
+ })
571
+
572
+ hf_dataset = Dataset.from_list(rows)
573
+ hf_dataset.push_to_hub(DATASET_ID, split="test")
574
+ print(f"Dataset pushed to: https://huggingface.co/datasets/{DATASET_ID}")
575
+
576
+ # Also upload the plots as artifacts
577
+ from huggingface_hub import HfApi
578
+ api = HfApi()
579
+ for plot_file in ["plot1_accuracy_comparison.png", "plot2_accuracy_vs_n.png",
580
+ "plot3_per_problem.png", "plot4_prm_scores.png"]:
581
+ plot_path = os.path.join(OUTPUT_DIR, plot_file)
582
+ if os.path.exists(plot_path):
583
+ api.upload_file(
584
+ path_or_fileobj=plot_path,
585
+ path_in_repo=f"plots/{plot_file}",
586
+ repo_id=DATASET_ID,
587
+ repo_type="dataset",
588
+ )
589
+ print(f" Uploaded {plot_file}")
590
+
591
+ # Upload the results JSON files too
592
+ for json_file in ["filtered_problems.json", "bon_results.json", "accuracy_by_n.json"]:
593
+ json_path = os.path.join(OUTPUT_DIR, json_file)
594
+ if os.path.exists(json_path):
595
+ api.upload_file(
596
+ path_or_fileobj=json_path,
597
+ path_in_repo=f"results/{json_file}",
598
+ repo_id=DATASET_ID,
599
+ repo_type="dataset",
600
+ )
601
+ print(f" Uploaded {json_file}")
602
+
603
+
604
+ # ═══════════════════════════════════════════════════════════════════════════════
605
+ # Final summary
606
+ # ═══════════════════════════════════════════════════════════════════════════════
607
+ print("\n" + "=" * 70)
608
+ print("FINAL RESULTS")
609
+ print("=" * 70)
610
+ print(f" Greedy (N=1): {greedy_total}/{len(problems_data)} = {greedy_total/len(problems_data):.0%}")
611
+ print(f" Majority Vote (N=16): {majority_correct_count}/{len(problems_data)} = {majority_correct_count/len(problems_data):.0%}")
612
+ print(f" Standard Best-of-N (N=16): {standard_correct}/{len(problems_data)} = {standard_correct/len(problems_data):.0%}")
613
+ print(f" Weighted Best-of-N (N=16): {weighted_correct}/{len(problems_data)} = {weighted_correct/len(problems_data):.0%}")
614
+ print(f"\n Dataset: https://huggingface.co/datasets/{DATASET_ID}")
615
+ print("=" * 70)
616
+ print("DONE!")