| """ |
| Step 4: Analysis and visualization of Best-of-N vs greedy performance. |
| |
| This script creates plots comparing: |
| 1. Overall accuracy: Greedy vs Majority Vote vs Standard BoN vs Weighted BoN |
| 2. Accuracy vs N (how performance scales with number of samples) |
| 3. Per-problem analysis: which problems did BoN solve that greedy couldn't? |
| 4. PRM score distribution analysis |
| |
| Co-authored with Claude (Anthropic). I can explain all code logic. |
| """ |
|
|
| import json |
| import matplotlib.pyplot as plt |
| import matplotlib |
| import numpy as np |
| from collections import defaultdict |
|
|
| matplotlib.rcParams.update({"font.size": 11, "figure.dpi": 150}) |
|
|
| |
| |
| |
| with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/bon_results.json") as f: |
| bon_results = json.load(f) |
|
|
| with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/accuracy_by_n.json") as f: |
| accuracy_by_n = json.load(f) |
|
|
| with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/scored_results.json") as f: |
| scored_results = json.load(f) |
|
|
| n_problems = len(bon_results) |
|
|
| |
| |
| |
| fig, ax = plt.subplots(figsize=(8, 5)) |
|
|
| methods = ["Greedy\n(N=1)", "Majority Vote\n(N=16)", "Standard BoN\n(N=16)", "Weighted BoN\n(N=16)"] |
| accuracies = [ |
| sum(r["greedy_correct"] for r in bon_results) / n_problems, |
| sum(r["majority_vote_correct"] for r in bon_results) / n_problems, |
| sum(r["standard_bon_correct"] for r in bon_results) / n_problems, |
| sum(r["weighted_bon_correct"] for r in bon_results) / n_problems, |
| ] |
| colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"] |
|
|
| bars = ax.bar(methods, accuracies, color=colors, edgecolor="white", linewidth=1.5) |
| for bar, acc in zip(bars, accuracies): |
| ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, |
| f"{acc:.0%}", ha="center", va="bottom", fontweight="bold", fontsize=12) |
|
|
| ax.set_ylabel("Accuracy") |
| ax.set_title("Math Problem Accuracy: Greedy vs Best-of-N Methods\n(20 MATH-500 problems, Levels 1-3)") |
| ax.set_ylim(0, 1.05) |
| ax.grid(axis="y", alpha=0.3) |
| plt.tight_layout() |
| plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot1_accuracy_comparison.png") |
| plt.close() |
| print("Saved plot1_accuracy_comparison.png") |
|
|
| |
| |
| |
| fig, ax = plt.subplots(figsize=(7, 5)) |
|
|
| ns = sorted([int(k) for k in accuracy_by_n.keys()]) |
| accs = [accuracy_by_n[str(n)] for n in ns] |
|
|
| ax.plot(ns, accs, "o-", color="#8172B2", linewidth=2, markersize=8, label="Weighted BoN") |
|
|
| |
| greedy_acc = sum(r["greedy_correct"] for r in bon_results) / n_problems |
| ax.axhline(y=greedy_acc, color="#4C72B0", linestyle="--", linewidth=1.5, label=f"Greedy baseline ({greedy_acc:.0%})") |
|
|
| for n, acc in zip(ns, accs): |
| ax.annotate(f"{acc:.0%}", (n, acc), textcoords="offset points", |
| xytext=(0, 10), ha="center", fontsize=10) |
|
|
| ax.set_xlabel("N (number of samples)") |
| ax.set_ylabel("Accuracy") |
| ax.set_title("Weighted Best-of-N Accuracy vs Number of Samples") |
| ax.set_xticks(ns) |
| ax.set_ylim(0, 1.05) |
| ax.legend() |
| ax.grid(alpha=0.3) |
| plt.tight_layout() |
| plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot2_accuracy_vs_n.png") |
| plt.close() |
| print("Saved plot2_accuracy_vs_n.png") |
|
|
| |
| |
| |
| fig, ax = plt.subplots(figsize=(12, 5)) |
|
|
| |
| categories = { |
| "Both correct": [], |
| "Only BoN correct": [], |
| "Only Greedy correct": [], |
| "Both wrong": [], |
| } |
|
|
| for r in bon_results: |
| g = r["greedy_correct"] |
| b = r["weighted_bon_correct"] |
| label = f"L{r['level']}: {r['unique_id'].split('/')[-1][:15]}" |
| if g and b: |
| categories["Both correct"].append(label) |
| elif not g and b: |
| categories["Only BoN correct"].append(label) |
| elif g and not b: |
| categories["Only Greedy correct"].append(label) |
| else: |
| categories["Both wrong"].append(label) |
|
|
| |
| cat_colors = { |
| "Both correct": "#55A868", |
| "Only BoN correct": "#8172B2", |
| "Only Greedy correct": "#C44E52", |
| "Both wrong": "#CCCCCC", |
| } |
|
|
| |
| labels = [] |
| colors_list = [] |
| for r in bon_results: |
| g = r["greedy_correct"] |
| b = r["weighted_bon_correct"] |
| label = f"L{r['level']}" |
| labels.append(label) |
| if g and b: |
| colors_list.append(cat_colors["Both correct"]) |
| elif not g and b: |
| colors_list.append(cat_colors["Only BoN correct"]) |
| elif g and not b: |
| colors_list.append(cat_colors["Only Greedy correct"]) |
| else: |
| colors_list.append(cat_colors["Both wrong"]) |
|
|
| x = range(len(bon_results)) |
| |
| heights = [r["n_correct_in_16"] for r in bon_results] |
| ax.bar(x, heights, color=colors_list, edgecolor="white", linewidth=0.5) |
|
|
| |
| ax.set_xticks(x) |
| short_ids = [r["unique_id"].split("/")[-1].replace(".json", "")[:12] for r in bon_results] |
| ax.set_xticklabels(short_ids, rotation=45, ha="right", fontsize=8) |
|
|
| ax.set_ylabel("# Correct Solutions (out of 16)") |
| ax.set_title("Per-Problem Analysis: Correct Solutions in N=16 Sample") |
|
|
| |
| from matplotlib.patches import Patch |
| legend_elements = [Patch(facecolor=c, label=l) for l, c in cat_colors.items()] |
| ax.legend(handles=legend_elements, loc="upper right", fontsize=9) |
| ax.grid(axis="y", alpha=0.3) |
|
|
| plt.tight_layout() |
| plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot3_per_problem.png") |
| plt.close() |
| print("Saved plot3_per_problem.png") |
|
|
| |
| |
| |
| fig, ax = plt.subplots(figsize=(7, 5)) |
|
|
| correct_scores = [] |
| incorrect_scores = [] |
|
|
| for r in scored_results: |
| for answer, score in zip(r["extracted_answers"], r["prm_scores"]): |
| if answer == r["answer"]: |
| correct_scores.append(score) |
| else: |
| incorrect_scores.append(score) |
|
|
| bins = np.linspace(0, 1, 25) |
| ax.hist(correct_scores, bins=bins, alpha=0.7, label=f"Correct ({len(correct_scores)})", color="#55A868") |
| ax.hist(incorrect_scores, bins=bins, alpha=0.7, label=f"Incorrect ({len(incorrect_scores)})", color="#C44E52") |
|
|
| ax.set_xlabel("PRM Last-Step Score") |
| ax.set_ylabel("Count") |
| ax.set_title("PRM Score Distribution: Correct vs Incorrect Solutions") |
| ax.legend() |
| ax.grid(alpha=0.3) |
|
|
| plt.tight_layout() |
| plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot4_prm_scores.png") |
| plt.close() |
| print("Saved plot4_prm_scores.png") |
|
|
| |
| |
| |
| print("\n" + "=" * 70) |
| print("DETAILED ANALYSIS") |
| print("=" * 70) |
|
|
| print(f"\nOverall Accuracies:") |
| print(f" Greedy (N=1): {accuracies[0]:.0%}") |
| print(f" Majority Vote (N=16): {accuracies[1]:.0%}") |
| print(f" Standard Best-of-N (N=16): {accuracies[2]:.0%}") |
| print(f" Weighted Best-of-N (N=16): {accuracies[3]:.0%}") |
|
|
| print(f"\nProblems ONLY solved by Weighted BoN (not greedy):") |
| for r in bon_results: |
| if r["weighted_bon_correct"] and not r["greedy_correct"]: |
| print(f" - {r['unique_id']} (Level {r['level']}, {r['subject']})") |
| print(f" Ground truth: {r['ground_truth']}") |
| print(f" Greedy answer: {r['greedy_answer']}") |
| print(f" BoN answer: {r['weighted_bon_answer']}") |
| print(f" Correct in sample: {r['n_correct_in_16']}/16") |
|
|
| print(f"\nProblems ONLY solved by Greedy (not BoN):") |
| for r in bon_results: |
| if r["greedy_correct"] and not r["weighted_bon_correct"]: |
| print(f" - {r['unique_id']} (Level {r['level']}, {r['subject']})") |
| print(f" Ground truth: {r['ground_truth']}") |
| print(f" Greedy answer: {r['greedy_answer']}") |
| print(f" BoN answer: {r['weighted_bon_answer']}") |
| print(f" Correct in sample: {r['n_correct_in_16']}/16") |
|
|
| print(f"\nProblems neither method solved:") |
| for r in bon_results: |
| if not r["greedy_correct"] and not r["weighted_bon_correct"]: |
| print(f" - {r['unique_id']} (Level {r['level']}, {r['subject']})") |
| print(f" Ground truth: {r['ground_truth']}") |
| print(f" Correct in sample: {r['n_correct_in_16']}/16") |
|
|
| |
| print(f"\nPRM Score Statistics:") |
| print(f" Correct solutions: mean={np.mean(correct_scores):.3f}, median={np.median(correct_scores):.3f}") |
| print(f" Incorrect solutions: mean={np.mean(incorrect_scores):.3f}, median={np.median(incorrect_scores):.3f}") |
|
|
| |
| print(f"\nAccuracy by problem level:") |
| for level in sorted(set(r["level"] for r in bon_results)): |
| level_results = [r for r in bon_results if r["level"] == level] |
| n = len(level_results) |
| g = sum(r["greedy_correct"] for r in level_results) |
| w = sum(r["weighted_bon_correct"] for r in level_results) |
| print(f" Level {level}: Greedy {g}/{n} ({g/n:.0%}) | Weighted BoN {w}/{n} ({w/n:.0%})") |
|
|
| |
| print(f"\nAccuracy by subject:") |
| subjects = sorted(set(r["subject"] for r in bon_results)) |
| for subj in subjects: |
| subj_results = [r for r in bon_results if r["subject"] == subj] |
| n = len(subj_results) |
| g = sum(r["greedy_correct"] for r in subj_results) |
| w = sum(r["weighted_bon_correct"] for r in subj_results) |
| print(f" {subj}: Greedy {g}/{n} | Weighted BoN {w}/{n}") |
|
|
| print("\nAll plots saved to outputs/ directory.") |
|
|