math500-bon-exercise / code /step4_analysis.py
cmpatino's picture
cmpatino HF Staff
Upload code/step4_analysis.py with huggingface_hub
a98a175 verified
"""
Step 4: Analysis and visualization of Best-of-N vs greedy performance.
This script creates plots comparing:
1. Overall accuracy: Greedy vs Majority Vote vs Standard BoN vs Weighted BoN
2. Accuracy vs N (how performance scales with number of samples)
3. Per-problem analysis: which problems did BoN solve that greedy couldn't?
4. PRM score distribution analysis
Co-authored with Claude (Anthropic). I can explain all code logic.
"""
import json
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from collections import defaultdict
matplotlib.rcParams.update({"font.size": 11, "figure.dpi": 150})
# ──────────────────────────────────────────────────────────────────────────────
# Load results
# ──────────────────────────────────────────────────────────────────────────────
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/bon_results.json") as f:
bon_results = json.load(f)
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/accuracy_by_n.json") as f:
accuracy_by_n = json.load(f)
with open("/Users/cmpatino/Projects/ml-intern/exercise/outputs/scored_results.json") as f:
scored_results = json.load(f)
n_problems = len(bon_results)
# ──────────────────────────────────────────────────────────────────────────────
# Plot 1: Overall accuracy comparison (bar chart)
# ──────────────────────────────────────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(8, 5))
methods = ["Greedy\n(N=1)", "Majority Vote\n(N=16)", "Standard BoN\n(N=16)", "Weighted BoN\n(N=16)"]
accuracies = [
sum(r["greedy_correct"] for r in bon_results) / n_problems,
sum(r["majority_vote_correct"] for r in bon_results) / n_problems,
sum(r["standard_bon_correct"] for r in bon_results) / n_problems,
sum(r["weighted_bon_correct"] for r in bon_results) / n_problems,
]
colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"]
bars = ax.bar(methods, accuracies, color=colors, edgecolor="white", linewidth=1.5)
for bar, acc in zip(bars, accuracies):
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
f"{acc:.0%}", ha="center", va="bottom", fontweight="bold", fontsize=12)
ax.set_ylabel("Accuracy")
ax.set_title("Math Problem Accuracy: Greedy vs Best-of-N Methods\n(20 MATH-500 problems, Levels 1-3)")
ax.set_ylim(0, 1.05)
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot1_accuracy_comparison.png")
plt.close()
print("Saved plot1_accuracy_comparison.png")
# ──────────────────────────────────────────────────────────────────────────────
# Plot 2: Accuracy vs N
# ──────────────────────────────────────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(7, 5))
ns = sorted([int(k) for k in accuracy_by_n.keys()])
accs = [accuracy_by_n[str(n)] for n in ns]
ax.plot(ns, accs, "o-", color="#8172B2", linewidth=2, markersize=8, label="Weighted BoN")
# Add greedy baseline as horizontal line
greedy_acc = sum(r["greedy_correct"] for r in bon_results) / n_problems
ax.axhline(y=greedy_acc, color="#4C72B0", linestyle="--", linewidth=1.5, label=f"Greedy baseline ({greedy_acc:.0%})")
for n, acc in zip(ns, accs):
ax.annotate(f"{acc:.0%}", (n, acc), textcoords="offset points",
xytext=(0, 10), ha="center", fontsize=10)
ax.set_xlabel("N (number of samples)")
ax.set_ylabel("Accuracy")
ax.set_title("Weighted Best-of-N Accuracy vs Number of Samples")
ax.set_xticks(ns)
ax.set_ylim(0, 1.05)
ax.legend()
ax.grid(alpha=0.3)
plt.tight_layout()
plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot2_accuracy_vs_n.png")
plt.close()
print("Saved plot2_accuracy_vs_n.png")
# ──────────────────────────────────────────────────────────────────────────────
# Plot 3: Per-problem comparison (Greedy vs Weighted BoN)
# ──────────────────────────────────────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(12, 5))
# Categorize problems
categories = {
"Both correct": [],
"Only BoN correct": [],
"Only Greedy correct": [],
"Both wrong": [],
}
for r in bon_results:
g = r["greedy_correct"]
b = r["weighted_bon_correct"]
label = f"L{r['level']}: {r['unique_id'].split('/')[-1][:15]}"
if g and b:
categories["Both correct"].append(label)
elif not g and b:
categories["Only BoN correct"].append(label)
elif g and not b:
categories["Only Greedy correct"].append(label)
else:
categories["Both wrong"].append(label)
# Color map for the stacked bars
cat_colors = {
"Both correct": "#55A868",
"Only BoN correct": "#8172B2",
"Only Greedy correct": "#C44E52",
"Both wrong": "#CCCCCC",
}
# Create a categorical overview
labels = []
colors_list = []
for r in bon_results:
g = r["greedy_correct"]
b = r["weighted_bon_correct"]
label = f"L{r['level']}"
labels.append(label)
if g and b:
colors_list.append(cat_colors["Both correct"])
elif not g and b:
colors_list.append(cat_colors["Only BoN correct"])
elif g and not b:
colors_list.append(cat_colors["Only Greedy correct"])
else:
colors_list.append(cat_colors["Both wrong"])
x = range(len(bon_results))
# Plot n_correct_in_16 as bar height, colored by category
heights = [r["n_correct_in_16"] for r in bon_results]
ax.bar(x, heights, color=colors_list, edgecolor="white", linewidth=0.5)
# Add problem labels
ax.set_xticks(x)
short_ids = [r["unique_id"].split("/")[-1].replace(".json", "")[:12] for r in bon_results]
ax.set_xticklabels(short_ids, rotation=45, ha="right", fontsize=8)
ax.set_ylabel("# Correct Solutions (out of 16)")
ax.set_title("Per-Problem Analysis: Correct Solutions in N=16 Sample")
# Legend
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=c, label=l) for l, c in cat_colors.items()]
ax.legend(handles=legend_elements, loc="upper right", fontsize=9)
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot3_per_problem.png")
plt.close()
print("Saved plot3_per_problem.png")
# ──────────────────────────────────────────────────────────────────────────────
# Plot 4: PRM Score Distribution (correct vs incorrect solutions)
# ──────────────────────────────────────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(7, 5))
correct_scores = []
incorrect_scores = []
for r in scored_results:
for answer, score in zip(r["extracted_answers"], r["prm_scores"]):
if answer == r["answer"]:
correct_scores.append(score)
else:
incorrect_scores.append(score)
bins = np.linspace(0, 1, 25)
ax.hist(correct_scores, bins=bins, alpha=0.7, label=f"Correct ({len(correct_scores)})", color="#55A868")
ax.hist(incorrect_scores, bins=bins, alpha=0.7, label=f"Incorrect ({len(incorrect_scores)})", color="#C44E52")
ax.set_xlabel("PRM Last-Step Score")
ax.set_ylabel("Count")
ax.set_title("PRM Score Distribution: Correct vs Incorrect Solutions")
ax.legend()
ax.grid(alpha=0.3)
plt.tight_layout()
plt.savefig("/Users/cmpatino/Projects/ml-intern/exercise/outputs/plot4_prm_scores.png")
plt.close()
print("Saved plot4_prm_scores.png")
# ──────────────────────────────────────────────────────────────────────────────
# Print detailed analysis
# ──────────────────────────────────────────────────────────────────────────────
print("\n" + "=" * 70)
print("DETAILED ANALYSIS")
print("=" * 70)
print(f"\nOverall Accuracies:")
print(f" Greedy (N=1): {accuracies[0]:.0%}")
print(f" Majority Vote (N=16): {accuracies[1]:.0%}")
print(f" Standard Best-of-N (N=16): {accuracies[2]:.0%}")
print(f" Weighted Best-of-N (N=16): {accuracies[3]:.0%}")
print(f"\nProblems ONLY solved by Weighted BoN (not greedy):")
for r in bon_results:
if r["weighted_bon_correct"] and not r["greedy_correct"]:
print(f" - {r['unique_id']} (Level {r['level']}, {r['subject']})")
print(f" Ground truth: {r['ground_truth']}")
print(f" Greedy answer: {r['greedy_answer']}")
print(f" BoN answer: {r['weighted_bon_answer']}")
print(f" Correct in sample: {r['n_correct_in_16']}/16")
print(f"\nProblems ONLY solved by Greedy (not BoN):")
for r in bon_results:
if r["greedy_correct"] and not r["weighted_bon_correct"]:
print(f" - {r['unique_id']} (Level {r['level']}, {r['subject']})")
print(f" Ground truth: {r['ground_truth']}")
print(f" Greedy answer: {r['greedy_answer']}")
print(f" BoN answer: {r['weighted_bon_answer']}")
print(f" Correct in sample: {r['n_correct_in_16']}/16")
print(f"\nProblems neither method solved:")
for r in bon_results:
if not r["greedy_correct"] and not r["weighted_bon_correct"]:
print(f" - {r['unique_id']} (Level {r['level']}, {r['subject']})")
print(f" Ground truth: {r['ground_truth']}")
print(f" Correct in sample: {r['n_correct_in_16']}/16")
# PRM Score stats
print(f"\nPRM Score Statistics:")
print(f" Correct solutions: mean={np.mean(correct_scores):.3f}, median={np.median(correct_scores):.3f}")
print(f" Incorrect solutions: mean={np.mean(incorrect_scores):.3f}, median={np.median(incorrect_scores):.3f}")
# Accuracy by level
print(f"\nAccuracy by problem level:")
for level in sorted(set(r["level"] for r in bon_results)):
level_results = [r for r in bon_results if r["level"] == level]
n = len(level_results)
g = sum(r["greedy_correct"] for r in level_results)
w = sum(r["weighted_bon_correct"] for r in level_results)
print(f" Level {level}: Greedy {g}/{n} ({g/n:.0%}) | Weighted BoN {w}/{n} ({w/n:.0%})")
# Accuracy by subject
print(f"\nAccuracy by subject:")
subjects = sorted(set(r["subject"] for r in bon_results))
for subj in subjects:
subj_results = [r for r in bon_results if r["subject"] == subj]
n = len(subj_results)
g = sum(r["greedy_correct"] for r in subj_results)
w = sum(r["weighted_bon_correct"] for r in subj_results)
print(f" {subj}: Greedy {g}/{n} | Weighted BoN {w}/{n}")
print("\nAll plots saved to outputs/ directory.")