| import os |
| |
| os.environ.setdefault("MPLBACKEND", "Agg") |
| os.environ.setdefault("MPLCONFIGDIR", os.environ.get("MPLCONFIGDIR", "/tmp/mplconfig")) |
| import time |
| import json |
| import argparse |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import sys |
| from pathlib import Path |
|
|
| |
| |
| |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent |
| sys.path.append(str(PROJECT_ROOT)) |
|
|
| |
| if (PROJECT_ROOT / "data" / "database").exists() and list((PROJECT_ROOT / "data" / "database").rglob("*.sqlite")): |
| DB_ROOT = PROJECT_ROOT / "data" / "database" |
| else: |
| DB_ROOT = PROJECT_ROOT / "final_databases" |
|
|
| from src.execution_reward import ( |
| execution_reward_batch_sequential, |
| execution_reward_batch_parallel, |
| execution_reward_batch_parallel_by_db, |
| execution_reward_timed, |
| set_use_cache, |
| set_use_schema_validation, |
| clear_result_cache |
| ) |
|
|
| def generate_mock_rollouts(num_rollouts: int = 100, heavy_n: int = 500_000): |
| """Generates heavy queries across multiple databases to properly test true concurrency.""" |
| print(f"\nGenerating {num_rollouts} heavy rollouts to simulate RLHF query workload...", flush=True) |
| |
| |
| real_dbs = [str(p) for p in DB_ROOT.rglob("*.sqlite")] |
| |
| if real_dbs: |
| print(f"Found {len(real_dbs)} real SQLite databases in {DB_ROOT}. Distributing workload...", flush=True) |
| else: |
| print(f"โ CRITICAL ERROR: No real databases found in {DB_ROOT}. Cannot run benchmark.", flush=True) |
| sys.exit(1) |
| |
| rollouts = [] |
| for i in range(num_rollouts): |
| db_path = real_dbs[i % len(real_dbs)] |
| |
| |
| heavy_sql = f""" |
| WITH RECURSIVE cnt(x) AS ( |
| SELECT 1 |
| UNION ALL |
| SELECT x+1 FROM cnt WHERE x < {heavy_n + (i % 10_000)} |
| ) |
| SELECT sum(x) FROM cnt; |
| """ |
| clean_sql = heavy_sql.replace("\n", " ").strip() |
| rollouts.append((clean_sql, db_path, clean_sql)) |
| if num_rollouts >= 500 and (i + 1) % 250 == 0: |
| print(f" generated {i + 1}/{num_rollouts}...", flush=True) |
| |
| return rollouts |
|
|
| def profile_bottlenecks(rollouts, sample_size: int = 20, print_every: int = 5): |
| """Profiles CPU usage to identify time spent in parsing, planning, and execution.""" |
| print("\n" + "="*65) |
| print(" ๐ CPU PROFILING: IDENTIFYING BOTTLENECKS (100 Rollouts)") |
| print("="*65) |
| |
| clear_result_cache() |
| set_use_cache(False) |
| set_use_schema_validation(False) |
| |
| total_parse = 0.0 |
| total_plan = 0.0 |
| total_exec = 0.0 |
| |
| |
| sample_size = min(int(sample_size), len(rollouts)) |
| sample_rollouts = rollouts[:sample_size] |
| |
| for i, (pred, db, gold) in enumerate(sample_rollouts, 1): |
| _, timings = execution_reward_timed(pred, db, gold, measure_plan=True) |
| total_parse += timings['parse_s'] |
| total_plan += timings['plan_s'] |
| total_exec += timings['exec_s'] |
| if print_every and (i % int(print_every) == 0 or i == sample_size): |
| print(f" profiled {i}/{sample_size}...", flush=True) |
| |
| total_time = total_parse + total_plan + total_exec |
| if total_time == 0: total_time = 0.0001 |
| |
| print(f"{'Phase':<15} | {'Avg Time (ms)':<15} | {'% of Total CPU':<15}") |
| print("-" * 65) |
| print(f"{'Regex Parsing':<15} | {(total_parse/sample_size)*1000:<15.2f} | {(total_parse/total_time)*100:<14.1f}%") |
| print(f"{'Query Planning':<15} | {(total_plan/sample_size)*1000:<15.2f} | {(total_plan/total_time)*100:<14.1f}%") |
| print(f"{'DB Execution':<15} | {(total_exec/sample_size)*1000:<15.2f} | {(total_exec/total_time)*100:<14.1f}%") |
| print("="*65 + "\n") |
|
|
| def run_benchmark_for_setting(rollouts, use_cache: bool, max_workers: int): |
| set_use_cache(use_cache) |
| set_use_schema_validation(False) |
| |
| |
| clear_result_cache() |
| start_time = time.perf_counter() |
| execution_reward_batch_sequential(rollouts) |
| sequential_s = time.perf_counter() - start_time |
|
|
| |
| clear_result_cache() |
| start_time = time.perf_counter() |
| |
| execution_reward_batch_parallel_by_db(rollouts, max_workers=max_workers) |
| parallel_s = time.perf_counter() - start_time |
|
|
| speedup = sequential_s / parallel_s if parallel_s > 0 else 0 |
|
|
| return { |
| "sequential_s": sequential_s, |
| "parallel_s": parallel_s, |
| "speedup": speedup |
| } |
|
|
| def print_comparison_table(results): |
| print("="*65) |
| print(f"{'Setting':<16} | {'Sequential (s)':<14} | {'Parallel (s)':<14} | {'Speedup':<10}") |
| print("-" * 65) |
| for setting, key in [("With Cache", "with_cache"), ("Without Cache", "without_cache")]: |
| seq = results[key]['sequential_s'] |
| par = results[key]['parallel_s'] |
| spd = results[key]['speedup'] |
| print(f"{setting:<16} | {seq:<14.4f} | {par:<14.4f} | {spd:<9.2f}x") |
| print("="*65 + "\n") |
|
|
| def plot_results(results, output_path: str): |
| labels = ['With Cache', 'Without Cache'] |
| seq_times = [results['with_cache']['sequential_s'], results['without_cache']['sequential_s']] |
| par_times = [results['with_cache']['parallel_s'], results['without_cache']['parallel_s']] |
|
|
| x = np.arange(len(labels)) |
| width = 0.35 |
|
|
| fig, ax = plt.subplots(figsize=(8, 6)) |
| ax.bar(x - width/2, seq_times, width, label='Sequential', color='#4C72B0') |
| ax.bar(x + width/2, par_times, width, label='Parallel', color='#DD8452') |
|
|
| ax.set_ylabel('Execution Time (seconds)') |
| ax.set_title('Text2SQL Reward Execution: Sequential vs Parallel') |
| ax.set_xticks(x) |
| ax.set_xticklabels(labels) |
| ax.legend() |
| |
| for container in ax.containers: |
| ax.bar_label(container, fmt='%.2f', padding=3) |
|
|
| fig.tight_layout() |
| plt.savefig(output_path, dpi=300) |
| plt.close() |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Benchmark SQL Execution Reward") |
| parser.add_argument("--n", type=int, default=1000, help="Number of rollouts to benchmark") |
| parser.add_argument("--max-workers", type=int, default=20, help="Max workers for parallel execution") |
| parser.add_argument("--heavy-n", type=int, default=200_000, help="Recursive CTE upper bound (controls heaviness)") |
| parser.add_argument("--skip-profile", action="store_true", help="Skip the CPU profiling section for faster startup") |
| parser.add_argument("--profile-n", type=int, default=20, help="Number of rollouts to use for CPU profiling") |
| args = parser.parse_args() |
|
|
| os.makedirs(str(PROJECT_ROOT / "results"), exist_ok=True) |
|
|
| rollouts = generate_mock_rollouts(args.n, heavy_n=args.heavy_n) |
| |
| if not args.skip_profile: |
| profile_bottlenecks(rollouts, sample_size=args.profile_n) |
| |
| print("Starting Main Scalability Benchmarks...") |
|
|
| print("Running Experiment A: Cache ENABLED...") |
| results_with_cache = run_benchmark_for_setting(rollouts, use_cache=True, max_workers=args.max_workers) |
|
|
| print("Running Experiment B: Cache DISABLED...") |
| results_without_cache = run_benchmark_for_setting(rollouts, use_cache=False, max_workers=args.max_workers) |
|
|
| final_results = { |
| "with_cache": results_with_cache, |
| "without_cache": results_without_cache |
| } |
|
|
| json_path = str(PROJECT_ROOT / "results" / "task1_results.json") |
| with open(json_path, 'w') as f: |
| json.dump(final_results, f, indent=4) |
|
|
| print_comparison_table(final_results) |
| plot_results(final_results, str(PROJECT_ROOT / "results" / "task1_plot.png")) |
|
|
| if __name__ == "__main__": |
| main() |