|
|
|
|
|
""" |
|
|
Create visualizations for Model Scaling Study. |
|
|
Generates publication-ready charts and tables. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sns.set_style("whitegrid") |
|
|
plt.rcParams['figure.figsize'] = (12, 8) |
|
|
plt.rcParams['font.size'] = 12 |
|
|
plt.rcParams['axes.labelsize'] = 14 |
|
|
plt.rcParams['axes.titlesize'] = 16 |
|
|
plt.rcParams['xtick.labelsize'] = 12 |
|
|
plt.rcParams['ytick.labelsize'] = 12 |
|
|
plt.rcParams['legend.fontsize'] = 12 |
|
|
|
|
|
|
|
|
output_dir = Path('visualizations') |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
print("="*80) |
|
|
print("CREATING VISUALIZATIONS FOR MODEL SCALING STUDY") |
|
|
print("="*80) |
|
|
print() |
|
|
|
|
|
|
|
|
print("Loading data...") |
|
|
|
|
|
|
|
|
quality_data = { |
|
|
'Base': {'valid_rate': 0.994, 'diversity': 0.978, 'unique': 489, 'samples': 500}, |
|
|
'Medium': {'valid_rate': 0.992, 'diversity': 0.988, 'unique': 494, 'samples': 500}, |
|
|
'Large': {'valid_rate': 1.000, 'diversity': 0.986, 'unique': 493, 'samples': 500} |
|
|
} |
|
|
|
|
|
|
|
|
with open('results_nguyen_benchmarks/summary.json') as f: |
|
|
nguyen_data = json.load(f) |
|
|
|
|
|
|
|
|
nguyen_stats = {} |
|
|
for model in ['base', 'medium', 'large']: |
|
|
model_results = [r for r in nguyen_data['results'] if r['model'] == model] |
|
|
nguyen_stats[model.capitalize()] = { |
|
|
'avg_valid_rate': np.mean([r['valid_rate'] for r in model_results]), |
|
|
'avg_best_r2': np.mean([r['best_r2'] for r in model_results]), |
|
|
'max_r2': max([r['best_r2'] for r in model_results]), |
|
|
'benchmarks_gt_099': sum([1 for r in model_results if r['best_r2'] > 0.99]) |
|
|
} |
|
|
|
|
|
print("Data loaded successfully!") |
|
|
print() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Creating Figure 1: Valid Rate Comparison...") |
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) |
|
|
|
|
|
models = ['Base', 'Medium', 'Large'] |
|
|
colors = ['#3498db', '#e74c3c', '#2ecc71'] |
|
|
|
|
|
|
|
|
quality_valid = [quality_data[m]['valid_rate'] * 100 for m in models] |
|
|
bars1 = ax1.bar(models, quality_valid, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5) |
|
|
ax1.set_ylabel('Valid Expression Rate (%)', fontsize=14, fontweight='bold') |
|
|
ax1.set_title('Quality Evaluation\n(500 samples per model)', fontsize=16, fontweight='bold') |
|
|
ax1.set_ylim([95, 101]) |
|
|
ax1.axhline(y=100, color='green', linestyle='--', linewidth=2, label='Perfect (100%)') |
|
|
ax1.legend() |
|
|
|
|
|
|
|
|
for bar, val in zip(bars1, quality_valid): |
|
|
height = bar.get_height() |
|
|
ax1.text(bar.get_x() + bar.get_width()/2., height + 0.3, |
|
|
f'{val:.1f}%', ha='center', va='bottom', fontsize=12, fontweight='bold') |
|
|
|
|
|
|
|
|
benchmark_valid = [nguyen_stats[m]['avg_valid_rate'] * 100 for m in models] |
|
|
bars2 = ax2.bar(models, benchmark_valid, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5) |
|
|
ax2.set_ylabel('Valid Expression Rate (%)', fontsize=14, fontweight='bold') |
|
|
ax2.set_title('Nguyen Benchmarks\n(36 experiments, 3,600 expressions)', fontsize=16, fontweight='bold') |
|
|
ax2.set_ylim([0, 100]) |
|
|
|
|
|
|
|
|
for bar, val in zip(bars2, benchmark_valid): |
|
|
height = bar.get_height() |
|
|
ax2.text(bar.get_x() + bar.get_width()/2., height + 2, |
|
|
f'{val:.1f}%', ha='center', va='bottom', fontsize=12, fontweight='bold') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_dir / 'fig1_valid_rate_comparison.png', dpi=300, bbox_inches='tight') |
|
|
print(f" Saved: {output_dir / 'fig1_valid_rate_comparison.png'}") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Creating Figure 2: R² Performance...") |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 6)) |
|
|
|
|
|
x = np.arange(len(models)) |
|
|
width = 0.25 |
|
|
|
|
|
avg_r2 = [nguyen_stats[m]['avg_best_r2'] for m in models] |
|
|
max_r2 = [nguyen_stats[m]['max_r2'] for m in models] |
|
|
|
|
|
bars1 = ax.bar(x - width/2, avg_r2, width, label='Average Best R²', |
|
|
color='#3498db', alpha=0.8, edgecolor='black', linewidth=1.5) |
|
|
bars2 = ax.bar(x + width/2, max_r2, width, label='Maximum R²', |
|
|
color='#e74c3c', alpha=0.8, edgecolor='black', linewidth=1.5) |
|
|
|
|
|
ax.set_ylabel('R² Score', fontsize=14, fontweight='bold') |
|
|
ax.set_title('Symbolic Regression Performance (Nguyen Benchmarks)', fontsize=16, fontweight='bold') |
|
|
ax.set_xticks(x) |
|
|
ax.set_xticklabels(models) |
|
|
ax.legend(fontsize=12) |
|
|
ax.set_ylim([0.85, 1.05]) |
|
|
ax.axhline(y=1.0, color='green', linestyle='--', linewidth=2, alpha=0.5, label='Perfect Fit') |
|
|
ax.grid(axis='y', alpha=0.3) |
|
|
|
|
|
|
|
|
for bar in bars1: |
|
|
height = bar.get_height() |
|
|
ax.text(bar.get_x() + bar.get_width()/2., height + 0.01, |
|
|
f'{height:.4f}', ha='center', va='bottom', fontsize=11, fontweight='bold') |
|
|
|
|
|
for bar in bars2: |
|
|
height = bar.get_height() |
|
|
ax.text(bar.get_x() + bar.get_width()/2., height + 0.01, |
|
|
f'{height:.4f}', ha='center', va='bottom', fontsize=11, fontweight='bold') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_dir / 'fig2_r2_performance.png', dpi=300, bbox_inches='tight') |
|
|
print(f" Saved: {output_dir / 'fig2_r2_performance.png'}") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Creating Figure 3: Per-Benchmark Heatmap...") |
|
|
|
|
|
|
|
|
benchmark_matrix = [] |
|
|
for bench in range(1, 13): |
|
|
row = [] |
|
|
for model in ['base', 'medium', 'large']: |
|
|
result = [r for r in nguyen_data['results'] |
|
|
if r['model'] == model and r['benchmark'] == f'nguyen_{bench}'] |
|
|
if result: |
|
|
row.append(result[0]['best_r2']) |
|
|
else: |
|
|
row.append(0) |
|
|
benchmark_matrix.append(row) |
|
|
|
|
|
benchmark_matrix = np.array(benchmark_matrix) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 10)) |
|
|
im = ax.imshow(benchmark_matrix, cmap='RdYlGn', aspect='auto', vmin=0.5, vmax=1.0) |
|
|
|
|
|
|
|
|
ax.set_xticks(np.arange(3)) |
|
|
ax.set_yticks(np.arange(12)) |
|
|
ax.set_xticklabels(['Base\n(124M)', 'Medium\n(355M)', 'Large\n(774M)'], fontsize=12) |
|
|
ax.set_yticklabels([f'Nguyen-{i+1}' for i in range(12)], fontsize=11) |
|
|
|
|
|
|
|
|
cbar = plt.colorbar(im, ax=ax) |
|
|
cbar.set_label('R² Score', rotation=270, labelpad=20, fontsize=14, fontweight='bold') |
|
|
|
|
|
|
|
|
for i in range(12): |
|
|
for j in range(3): |
|
|
text = ax.text(j, i, f'{benchmark_matrix[i, j]:.3f}', |
|
|
ha="center", va="center", color="black", fontsize=10, fontweight='bold') |
|
|
|
|
|
ax.set_title('R² Scores by Model and Benchmark', fontsize=16, fontweight='bold', pad=20) |
|
|
plt.tight_layout() |
|
|
plt.savefig(output_dir / 'fig3_benchmark_heatmap.png', dpi=300, bbox_inches='tight') |
|
|
print(f" Saved: {output_dir / 'fig3_benchmark_heatmap.png'}") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Creating Figure 4: Scaling Progression...") |
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) |
|
|
|
|
|
params = [124, 355, 774] |
|
|
|
|
|
|
|
|
ax1.plot(params, benchmark_valid, 'o-', color='#3498db', linewidth=3, |
|
|
markersize=12, label='Nguyen Valid Rate', markeredgecolor='black', markeredgewidth=2) |
|
|
ax1.set_xlabel('Model Size (Million Parameters)', fontsize=14, fontweight='bold') |
|
|
ax1.set_ylabel('Valid Expression Rate (%)', fontsize=14, fontweight='bold') |
|
|
ax1.set_title('Valid Rate vs Model Size', fontsize=16, fontweight='bold') |
|
|
ax1.grid(True, alpha=0.3) |
|
|
ax1.legend(fontsize=12) |
|
|
|
|
|
|
|
|
for x, y in zip(params, benchmark_valid): |
|
|
ax1.text(x, y + 2, f'{y:.1f}%', ha='center', fontsize=11, fontweight='bold') |
|
|
|
|
|
|
|
|
ax2.plot(params, avg_r2, 'o-', color='#e74c3c', linewidth=3, |
|
|
markersize=12, label='Average Best R²', markeredgecolor='black', markeredgewidth=2) |
|
|
ax2.axhline(y=1.0, color='green', linestyle='--', linewidth=2, alpha=0.5, label='Perfect Fit') |
|
|
ax2.set_xlabel('Model Size (Million Parameters)', fontsize=14, fontweight='bold') |
|
|
ax2.set_ylabel('R² Score', fontsize=14, fontweight='bold') |
|
|
ax2.set_title('R² vs Model Size', fontsize=16, fontweight='bold') |
|
|
ax2.set_ylim([0.9, 1.02]) |
|
|
ax2.grid(True, alpha=0.3) |
|
|
ax2.legend(fontsize=12) |
|
|
|
|
|
|
|
|
for x, y in zip(params, avg_r2): |
|
|
ax2.text(x, y + 0.005, f'{y:.4f}', ha='center', fontsize=11, fontweight='bold') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_dir / 'fig4_scaling_progression.png', dpi=300, bbox_inches='tight') |
|
|
print(f" Saved: {output_dir / 'fig4_scaling_progression.png'}") |
|
|
plt.close() |
|
|
|
|
|
print() |
|
|
print("="*80) |
|
|
print("ALL VISUALIZATIONS CREATED SUCCESSFULLY!") |
|
|
print("="*80) |
|
|
print() |
|
|
print(f"Output directory: {output_dir.absolute()}") |
|
|
print() |
|
|
print("Generated files:") |
|
|
print(" 1. fig1_valid_rate_comparison.png - Quality vs Benchmark valid rates") |
|
|
print(" 2. fig2_r2_performance.png - R² scores comparison") |
|
|
print(" 3. fig3_benchmark_heatmap.png - Per-benchmark R² heatmap") |
|
|
print(" 4. fig4_scaling_progression.png - Scaling laws visualization") |
|
|
print() |
|
|
print("These figures are publication-ready (300 DPI, high resolution)") |
|
|
print() |
|
|
|