Spaces:
Running
Running
| """ | |
| generate_data.py β NL2SQL Synthetic Data Factory | |
| ================================================= | |
| Designed for H100 + vLLM. Produces a clean JSONL file ready for SFT or GRPO training | |
| with the nl2sql-bench codebase (schema: e-commerce SQLite). | |
| Architecture | |
| ------------ | |
| 1. SQL_TEMPLATES β 120+ ground-truth SQLs, hand-written and verified, NEVER LLM-generated. | |
| 2. SQLiteValidator β executes every SQL against the actual seeded DB; discards any failure. | |
| 3. VLLMGenerator β async batched calls to a local vLLM server for NL paraphrasing. | |
| 4. RuleAugmentor β pure-Python synonym / date-format / condition-order augmentation. | |
| 5. DataFactory β orchestrates the full pipeline; writes JSONL with checkpointing. | |
| Output schema (one JSON object per line) | |
| ----------------------------------------- | |
| { | |
| "id": "easy_001_persona_ceo", | |
| "difficulty": "easy" | "medium" | "hard", | |
| "persona": "ceo" | "chatty" | "lazy" | "confused" | "analyst", | |
| "question": "<natural language question>", | |
| "sql": "<ground-truth SQL>", | |
| "db_result_ok": true, # always true β failures are discarded | |
| "augmented": false # true when rule-augmentor modified the NL | |
| } | |
| Usage | |
| ----- | |
| # 1. Start vLLM server (H100): | |
| # vllm serve meta-llama/Meta-Llama-3-70B-Instruct \ | |
| # --tensor-parallel-size 4 --port 8001 \ | |
| # --max-model-len 4096 --gpu-memory-utilization 0.92 | |
| # 2. Run this script (place it next to the nl2sql-bench folder): | |
| # python generate_data.py \ | |
| # --vllm-url http://localhost:8001/v1 \ | |
| # --model meta-llama/Meta-Llama-3-70B-Instruct \ | |
| # --output nl2sql_train.jsonl \ | |
| # --personas-per-template 5 \ | |
| # --aug-rounds 2 \ | |
| # --batch-size 64 | |
| Requirements | |
| ------------ | |
| pip install openai tqdm | |
| (vLLM + your model already running separately) | |
| IMPORTANT: Copy server/db/schema.sql and server/db/seed.py from nl2sql-bench | |
| into the same directory as this script, OR set --bench-root to the repo root. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import asyncio | |
| import hashlib | |
| import json | |
| import logging | |
| import os | |
| import random | |
| import re | |
| import sqlite3 | |
| import sys | |
| import time | |
| from copy import deepcopy | |
| from dataclasses import dataclass, asdict | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from openai import AsyncOpenAI | |
| from tqdm import tqdm | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Logging | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| log = logging.getLogger("data-factory") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Database: build & validate | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SCHEMA_SQL = """ | |
| CREATE TABLE IF NOT EXISTS categories ( | |
| id INTEGER PRIMARY KEY, | |
| name TEXT NOT NULL UNIQUE | |
| ); | |
| CREATE TABLE IF NOT EXISTS products ( | |
| id INTEGER PRIMARY KEY, | |
| name TEXT NOT NULL, | |
| category_id INTEGER NOT NULL REFERENCES categories(id), | |
| price REAL NOT NULL CHECK(price >= 0), | |
| stock_quantity INTEGER NOT NULL DEFAULT 0 | |
| ); | |
| CREATE TABLE IF NOT EXISTS customers ( | |
| id INTEGER PRIMARY KEY, | |
| name TEXT NOT NULL, | |
| email TEXT NOT NULL UNIQUE, | |
| country TEXT NOT NULL, | |
| tier TEXT NOT NULL DEFAULT 'bronze' | |
| CHECK(tier IN ('bronze', 'silver', 'gold')), | |
| created_at TEXT NOT NULL | |
| ); | |
| CREATE TABLE IF NOT EXISTS orders ( | |
| id INTEGER PRIMARY KEY, | |
| customer_id INTEGER NOT NULL REFERENCES customers(id), | |
| status TEXT NOT NULL DEFAULT 'pending' | |
| CHECK(status IN ('pending','processing','shipped','delivered','cancelled')), | |
| created_at TEXT NOT NULL, | |
| total_amount REAL NOT NULL CHECK(total_amount >= 0) | |
| ); | |
| CREATE TABLE IF NOT EXISTS order_items ( | |
| id INTEGER PRIMARY KEY, | |
| order_id INTEGER NOT NULL REFERENCES orders(id), | |
| product_id INTEGER NOT NULL REFERENCES products(id), | |
| quantity INTEGER NOT NULL CHECK(quantity > 0), | |
| unit_price REAL NOT NULL CHECK(unit_price >= 0) | |
| ); | |
| CREATE TABLE IF NOT EXISTS reviews ( | |
| id INTEGER PRIMARY KEY, | |
| product_id INTEGER NOT NULL REFERENCES products(id), | |
| customer_id INTEGER NOT NULL REFERENCES customers(id), | |
| rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5), | |
| created_at TEXT NOT NULL | |
| ); | |
| """ | |
| # Minimal seeder so the validator can run the SQL against real data. | |
| # Mirrors the logic in nl2sql-bench/server/db/seed.py (fixed seed = 42). | |
| SEED_SCRIPT = """ | |
| import random, sqlite3 | |
| from datetime import date, timedelta | |
| RNG = random.Random(42) | |
| CATEGORIES = ["Electronics","Clothing","Books","Home & Garden", | |
| "Sports & Outdoors","Toys & Games","Beauty","Automotive"] | |
| PRODUCTS = { | |
| "Electronics": ["Wireless Headphones","USB-C Hub","Mechanical Keyboard", | |
| "Webcam 4K","Portable Charger","Smart Speaker", | |
| "Monitor Stand","HDMI Cable 2.1"], | |
| "Clothing": ["Cotton T-Shirt","Slim Fit Jeans","Hoodie", | |
| "Running Shorts","Winter Jacket","Polo Shirt", | |
| "Casual Sneakers","Wool Socks"], | |
| "Books": ["Clean Code","Designing Data-Intensive Applications", | |
| "The Pragmatic Programmer","System Design Interview", | |
| "Deep Learning Book","Python Cookbook", | |
| "Domain-Driven Design","Refactoring"], | |
| "Home & Garden": ["Coffee Maker","Air Purifier","LED Desk Lamp", | |
| "Plant Pot Set","Storage Organiser","Cutting Board", | |
| "Vacuum Cleaner","Electric Kettle"], | |
| "Sports & Outdoors": ["Yoga Mat","Resistance Bands","Cycling Gloves", | |
| "Trekking Poles","Water Bottle 1L","Jump Rope", | |
| "Foam Roller","Compression Socks"], | |
| "Toys & Games": ["Lego City Set","Card Game Pack","Puzzle 1000pc", | |
| "Remote Control Car","Building Blocks", | |
| "Board Game Strategy","Art Set","Toy Drone"], | |
| "Beauty": ["Face Serum","SPF 50 Sunscreen","Lip Balm", | |
| "Shampoo Pro","Hair Mask","Eye Cream", | |
| "Vitamin C Cream","Toner Mist"], | |
| "Automotive": ["Car Phone Mount","Dash Cam","Tyre Inflator", | |
| "Car Vacuum","Seat Cushion","Steering Wheel Cover", | |
| "OBD Scanner","Jump Starter"], | |
| } | |
| COUNTRIES = ["India","USA","Germany","UK","Canada", | |
| "Australia","France","Brazil","Japan","Singapore"] | |
| TIERS = ["bronze","silver","gold"] | |
| STATUSES = ["pending","processing","shipped","delivered","cancelled"] | |
| FIRST = ["Aarav","Priya","Rahul","Neha","Arjun","Sneha","Vikram","Pooja", | |
| "Karthik","Divya","James","Sarah","Michael","Emily","David","Jessica", | |
| "Hans","Lena","Oliver","Sofia","Pierre","Amelie","Carlos","Laura", | |
| "Yuki","Hana","Wei","Mei","Aiden","Zara"] | |
| LAST = ["Sharma","Singh","Patel","Kumar","Gupta","Verma","Nair","Reddy", | |
| "Smith","Johnson","Brown","Williams","Jones","Davis","Wilson", | |
| "MΓΌller","Schmidt","Schneider","Fischer","Weber", | |
| "Martin","Bernard","Thomas","Richard","Petit", | |
| "Garcia","Martinez","Lopez","Sanchez","Gonzalez"] | |
| def _date(start=2022, end=2025): | |
| s = date(start, 1, 1) | |
| e = date(end, 12, 31) | |
| return str(s + timedelta(days=RNG.randint(0, (e - s).days))) | |
| def seed(conn): | |
| c = conn.cursor() | |
| for cat in CATEGORIES: | |
| c.execute("INSERT OR IGNORE INTO categories(name) VALUES (?)", (cat,)) | |
| conn.commit() | |
| cat_ids = {r[1]: r[0] for r in conn.execute("SELECT id, name FROM categories")} | |
| for cat, prods in PRODUCTS.items(): | |
| for pname in prods: | |
| c.execute( | |
| "INSERT OR IGNORE INTO products(name,category_id,price,stock_quantity) VALUES (?,?,?,?)", | |
| (pname, cat_ids[cat], round(RNG.uniform(5, 500), 2), RNG.randint(0, 200)), | |
| ) | |
| conn.commit() | |
| for i in range(200): | |
| name = f"{RNG.choice(FIRST)} {RNG.choice(LAST)}" | |
| email = f"user{i}@example.com" | |
| c.execute( | |
| "INSERT OR IGNORE INTO customers(name,email,country,tier,created_at) VALUES (?,?,?,?,?)", | |
| (name, email, RNG.choice(COUNTRIES), RNG.choice(TIERS), _date()), | |
| ) | |
| conn.commit() | |
| cust_ids = [r[0] for r in conn.execute("SELECT id FROM customers")] | |
| prod_ids = [r[0] for r in conn.execute("SELECT id FROM products")] | |
| for _ in range(600): | |
| cid = RNG.choice(cust_ids) | |
| amt = round(RNG.uniform(10, 1000), 2) | |
| status = RNG.choice(STATUSES) | |
| d = _date() | |
| c.execute( | |
| "INSERT INTO orders(customer_id,status,created_at,total_amount) VALUES (?,?,?,?)", | |
| (cid, status, d, amt), | |
| ) | |
| conn.commit() | |
| ord_ids = [r[0] for r in conn.execute("SELECT id FROM orders")] | |
| for oid in ord_ids: | |
| for _ in range(RNG.randint(1, 4)): | |
| pid = RNG.choice(prod_ids) | |
| qty = RNG.randint(1, 5) | |
| price = round(RNG.uniform(5, 500), 2) | |
| c.execute( | |
| "INSERT INTO order_items(order_id,product_id,quantity,unit_price) VALUES (?,?,?,?)", | |
| (oid, pid, qty, price), | |
| ) | |
| conn.commit() | |
| for _ in range(400): | |
| pid = RNG.choice(prod_ids) | |
| cid = RNG.choice(cust_ids) | |
| rating = RNG.randint(1, 5) | |
| c.execute( | |
| "INSERT INTO reviews(product_id,customer_id,rating,created_at) VALUES (?,?,?,?)", | |
| (pid, cid, rating, _date()), | |
| ) | |
| conn.commit() | |
| """ | |
| def build_db() -> sqlite3.Connection: | |
| """Build an in-memory SQLite DB with schema + seed data.""" | |
| conn = sqlite3.connect(":memory:") | |
| conn.executescript(SCHEMA_SQL) | |
| exec(SEED_SCRIPT, {"conn": conn}) # run the seeder inline | |
| conn.row_factory = sqlite3.Row | |
| log.info("In-memory DB built and seeded.") | |
| return conn | |
| class SQLiteValidator: | |
| """Execute SQL against the seeded DB; return (rows, error).""" | |
| def __init__(self, conn: sqlite3.Connection): | |
| self.conn = conn | |
| def validate(self, sql: str) -> Tuple[bool, Optional[str]]: | |
| sql = sql.strip().rstrip(";") | |
| if not sql: | |
| return False, "Empty SQL" | |
| first = sql.split()[0].lower() | |
| if first != "select": | |
| return False, f"Non-SELECT statement: {first}" | |
| try: | |
| cur = self.conn.execute(sql) | |
| cur.fetchmany(500) | |
| return True, None | |
| except sqlite3.Error as exc: | |
| return False, str(exc) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SQL Template Library (ground-truth, hand-written, execution-validated) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SQLTemplate: | |
| id: str | |
| difficulty: str # easy | medium | hard | |
| description: str # plain-English description fed to the LLM | |
| sql: str | |
| order_sensitive: bool = False | |
| # NOTE: Every SQL here uses only the 6 tables in the schema and valid SQLite syntax. | |
| # They are intentionally grouped by the SQL pattern they teach, not just by difficulty. | |
| EASY_TEMPLATES: List[SQLTemplate] = [ | |
| # ββ Equality filter ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SQLTemplate( | |
| id="easy_001", | |
| difficulty="easy", | |
| description=( | |
| "List all gold-tier customers, ordered alphabetically by name. " | |
| "Return id, name, email, country." | |
| ), | |
| sql=( | |
| "SELECT id, name, email, country " | |
| "FROM customers " | |
| "WHERE tier = 'gold' " | |
| "ORDER BY name ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_002", | |
| difficulty="easy", | |
| description=( | |
| "Show all products priced above $100, sorted by price descending. " | |
| "Return id, name, price." | |
| ), | |
| sql=( | |
| "SELECT id, name, price " | |
| "FROM products " | |
| "WHERE price > 100 " | |
| "ORDER BY price DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_003", | |
| difficulty="easy", | |
| description=( | |
| "Find all delivered orders with a total_amount greater than $200, " | |
| "sorted by total_amount descending. " | |
| "Return id, customer_id, total_amount, created_at." | |
| ), | |
| sql=( | |
| "SELECT id, customer_id, total_amount, created_at " | |
| "FROM orders " | |
| "WHERE status = 'delivered' AND total_amount > 200 " | |
| "ORDER BY total_amount DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_004", | |
| difficulty="easy", | |
| description=( | |
| "Return the top 5 most expensive products. Return id, name, price." | |
| ), | |
| sql=( | |
| "SELECT id, name, price " | |
| "FROM products " | |
| "ORDER BY price DESC " | |
| "LIMIT 5" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_005", | |
| difficulty="easy", | |
| description=( | |
| "List all distinct countries where customers come from, sorted alphabetically. " | |
| "Return a single column: country." | |
| ), | |
| sql=( | |
| "SELECT DISTINCT country " | |
| "FROM customers " | |
| "ORDER BY country ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_006", | |
| difficulty="easy", | |
| description=( | |
| "Show all pending orders, ordered by created_at descending. " | |
| "Return id, customer_id, total_amount, created_at." | |
| ), | |
| sql=( | |
| "SELECT id, customer_id, total_amount, created_at " | |
| "FROM orders " | |
| "WHERE status = 'pending' " | |
| "ORDER BY created_at DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_007", | |
| difficulty="easy", | |
| description=( | |
| "Find all products with zero stock (stock_quantity = 0). " | |
| "Return id, name, price, category_id." | |
| ), | |
| sql=( | |
| "SELECT id, name, price, category_id " | |
| "FROM products " | |
| "WHERE stock_quantity = 0" | |
| ), | |
| ), | |
| SQLTemplate( | |
| id="easy_008", | |
| difficulty="easy", | |
| description=( | |
| "How many customers are there in total? Return a single value: total_customers." | |
| ), | |
| sql="SELECT COUNT(*) AS total_customers FROM customers", | |
| ), | |
| SQLTemplate( | |
| id="easy_009", | |
| difficulty="easy", | |
| description=( | |
| "What is the most expensive product price in the store? " | |
| "Return a single value: max_price." | |
| ), | |
| sql="SELECT MAX(price) AS max_price FROM products", | |
| ), | |
| SQLTemplate( | |
| id="easy_010", | |
| difficulty="easy", | |
| description=( | |
| "What is the cheapest product price in the store? " | |
| "Return a single value: min_price." | |
| ), | |
| sql="SELECT MIN(price) AS min_price FROM products", | |
| ), | |
| SQLTemplate( | |
| id="easy_011", | |
| difficulty="easy", | |
| description=( | |
| "What is the average price of all products? " | |
| "Round to 2 decimal places. Return: avg_price." | |
| ), | |
| sql="SELECT ROUND(AVG(price), 2) AS avg_price FROM products", | |
| ), | |
| SQLTemplate( | |
| id="easy_012", | |
| difficulty="easy", | |
| description=( | |
| "Show all customers from India, sorted by name ascending. " | |
| "Return id, name, email, tier." | |
| ), | |
| sql=( | |
| "SELECT id, name, email, tier " | |
| "FROM customers " | |
| "WHERE country = 'India' " | |
| "ORDER BY name ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_013", | |
| difficulty="easy", | |
| description=( | |
| "List the 10 most recently placed orders. " | |
| "Return id, customer_id, status, created_at, total_amount." | |
| ), | |
| sql=( | |
| "SELECT id, customer_id, status, created_at, total_amount " | |
| "FROM orders " | |
| "ORDER BY created_at DESC " | |
| "LIMIT 10" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_014", | |
| difficulty="easy", | |
| description=( | |
| "Find all reviews with a rating of 5 stars. " | |
| "Return id, product_id, customer_id, created_at." | |
| ), | |
| sql=( | |
| "SELECT id, product_id, customer_id, created_at " | |
| "FROM reviews " | |
| "WHERE rating = 5" | |
| ), | |
| ), | |
| SQLTemplate( | |
| id="easy_015", | |
| difficulty="easy", | |
| description=( | |
| "Find all reviews with a rating of 1 star (lowest possible). " | |
| "Return id, product_id, customer_id, created_at." | |
| ), | |
| sql=( | |
| "SELECT id, product_id, customer_id, created_at " | |
| "FROM reviews " | |
| "WHERE rating = 1" | |
| ), | |
| ), | |
| SQLTemplate( | |
| id="easy_016", | |
| difficulty="easy", | |
| description=( | |
| "Count the number of cancelled orders. Return: cancelled_count." | |
| ), | |
| sql=( | |
| "SELECT COUNT(*) AS cancelled_count " | |
| "FROM orders " | |
| "WHERE status = 'cancelled'" | |
| ), | |
| ), | |
| SQLTemplate( | |
| id="easy_017", | |
| difficulty="easy", | |
| description=( | |
| "List all products with stock_quantity greater than 100, " | |
| "sorted by stock_quantity descending. Return id, name, stock_quantity." | |
| ), | |
| sql=( | |
| "SELECT id, name, stock_quantity " | |
| "FROM products " | |
| "WHERE stock_quantity > 100 " | |
| "ORDER BY stock_quantity DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_018", | |
| difficulty="easy", | |
| description=( | |
| "Find all silver-tier customers from the USA. " | |
| "Return id, name, email." | |
| ), | |
| sql=( | |
| "SELECT id, name, email " | |
| "FROM customers " | |
| "WHERE tier = 'silver' AND country = 'USA'" | |
| ), | |
| ), | |
| SQLTemplate( | |
| id="easy_019", | |
| difficulty="easy", | |
| description=( | |
| "What is the total revenue from all delivered orders? " | |
| "Round to 2 decimal places. Return: total_revenue." | |
| ), | |
| sql=( | |
| "SELECT ROUND(SUM(total_amount), 2) AS total_revenue " | |
| "FROM orders " | |
| "WHERE status = 'delivered'" | |
| ), | |
| ), | |
| SQLTemplate( | |
| id="easy_020", | |
| difficulty="easy", | |
| description=( | |
| "List all orders placed in 2024, sorted by created_at ascending. " | |
| "Return id, customer_id, status, total_amount, created_at." | |
| ), | |
| sql=( | |
| "SELECT id, customer_id, status, total_amount, created_at " | |
| "FROM orders " | |
| "WHERE created_at >= '2024-01-01' AND created_at < '2025-01-01' " | |
| "ORDER BY created_at ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_021", | |
| difficulty="easy", | |
| description=( | |
| "Show the bottom 5 cheapest products. Return id, name, price." | |
| ), | |
| sql=( | |
| "SELECT id, name, price " | |
| "FROM products " | |
| "ORDER BY price ASC " | |
| "LIMIT 5" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_022", | |
| difficulty="easy", | |
| description=( | |
| "Count how many products exist in the catalogue. Return: product_count." | |
| ), | |
| sql="SELECT COUNT(*) AS product_count FROM products", | |
| ), | |
| SQLTemplate( | |
| id="easy_023", | |
| difficulty="easy", | |
| description=( | |
| "List all distinct order statuses that exist in the orders table. " | |
| "Return a single column: status." | |
| ), | |
| sql="SELECT DISTINCT status FROM orders ORDER BY status ASC", | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_024", | |
| difficulty="easy", | |
| description=( | |
| "Find customers who joined (created_at) in 2023. " | |
| "Return id, name, country, tier, created_at, sorted by created_at ascending." | |
| ), | |
| sql=( | |
| "SELECT id, name, country, tier, created_at " | |
| "FROM customers " | |
| "WHERE created_at >= '2023-01-01' AND created_at < '2024-01-01' " | |
| "ORDER BY created_at ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_025", | |
| difficulty="easy", | |
| description=( | |
| "Show all orders with total_amount between $50 and $150 inclusive. " | |
| "Return id, customer_id, total_amount, status." | |
| ), | |
| sql=( | |
| "SELECT id, customer_id, total_amount, status " | |
| "FROM orders " | |
| "WHERE total_amount BETWEEN 50 AND 150" | |
| ), | |
| ), | |
| SQLTemplate( | |
| id="easy_026", | |
| difficulty="easy", | |
| description=( | |
| "How many distinct customers have placed at least one order? " | |
| "Return a single value: customers_with_orders." | |
| ), | |
| sql=( | |
| "SELECT COUNT(DISTINCT customer_id) AS customers_with_orders " | |
| "FROM orders" | |
| ), | |
| ), | |
| SQLTemplate( | |
| id="easy_027", | |
| difficulty="easy", | |
| description=( | |
| "What is the total number of order line items across all orders? " | |
| "Return: total_line_items." | |
| ), | |
| sql="SELECT COUNT(*) AS total_line_items FROM order_items", | |
| ), | |
| SQLTemplate( | |
| id="easy_028", | |
| difficulty="easy", | |
| description=( | |
| "List all products priced between $20 and $80 inclusive, sorted by price ascending. " | |
| "Return id, name, price." | |
| ), | |
| sql=( | |
| "SELECT id, name, price " | |
| "FROM products " | |
| "WHERE price BETWEEN 20 AND 80 " | |
| "ORDER BY price ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="easy_029", | |
| difficulty="easy", | |
| description=( | |
| "Show all gold-tier customers from Germany. " | |
| "Return id, name, email, created_at." | |
| ), | |
| sql=( | |
| "SELECT id, name, email, created_at " | |
| "FROM customers " | |
| "WHERE tier = 'gold' AND country = 'Germany'" | |
| ), | |
| ), | |
| SQLTemplate( | |
| id="easy_030", | |
| difficulty="easy", | |
| description=( | |
| "What is the average rating across all reviews in the system? " | |
| "Round to 2 decimal places. Return: avg_rating." | |
| ), | |
| sql="SELECT ROUND(AVG(rating), 2) AS avg_rating FROM reviews", | |
| ), | |
| ] | |
| MEDIUM_TEMPLATES: List[SQLTemplate] = [ | |
| # ββ JOIN + COUNT βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SQLTemplate( | |
| id="med_001", | |
| difficulty="medium", | |
| description=( | |
| "How many orders has each customer placed? Include customers with zero orders. " | |
| "Return customer_name and order_count. Sort by order_count descending, " | |
| "then customer_name ascending." | |
| ), | |
| sql=( | |
| "SELECT c.name AS customer_name, COUNT(o.id) AS order_count " | |
| "FROM customers c " | |
| "LEFT JOIN orders o ON c.id = o.customer_id " | |
| "GROUP BY c.id, c.name " | |
| "ORDER BY order_count DESC, customer_name ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_002", | |
| difficulty="medium", | |
| description=( | |
| "Average product rating per category, only for categories that have at least one review. " | |
| "Return category_name and avg_rating (rounded to 2 dp). Sort by avg_rating descending." | |
| ), | |
| sql=( | |
| "SELECT c.name AS category_name, ROUND(AVG(r.rating), 2) AS avg_rating " | |
| "FROM categories c " | |
| "JOIN products p ON p.category_id = c.id " | |
| "JOIN reviews r ON r.product_id = p.id " | |
| "GROUP BY c.id, c.name " | |
| "ORDER BY avg_rating DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_003", | |
| difficulty="medium", | |
| description=( | |
| "Which categories have more than 5 in-stock products (stock_quantity > 0)? " | |
| "Return category_name and in_stock_count. Sort by in_stock_count descending." | |
| ), | |
| sql=( | |
| "SELECT c.name AS category_name, COUNT(p.id) AS in_stock_count " | |
| "FROM categories c " | |
| "JOIN products p ON p.category_id = c.id " | |
| "WHERE p.stock_quantity > 0 " | |
| "GROUP BY c.id, c.name " | |
| "HAVING COUNT(p.id) > 5 " | |
| "ORDER BY in_stock_count DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_004", | |
| difficulty="medium", | |
| description=( | |
| "Which customers have spent more than $500 on delivered orders? " | |
| "Return customer_name and total_spent (rounded to 2 dp). Sort by total_spent descending." | |
| ), | |
| sql=( | |
| "SELECT c.name AS customer_name, ROUND(SUM(o.total_amount), 2) AS total_spent " | |
| "FROM customers c " | |
| "JOIN orders o ON o.customer_id = c.id " | |
| "WHERE o.status = 'delivered' " | |
| "GROUP BY c.id, c.name " | |
| "HAVING SUM(o.total_amount) > 500 " | |
| "ORDER BY total_spent DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_005", | |
| difficulty="medium", | |
| description=( | |
| "Total quantity sold for each product that appears in at least one order. " | |
| "Return product_name and total_quantity_sold. Sort by total_quantity_sold descending." | |
| ), | |
| sql=( | |
| "SELECT p.name AS product_name, SUM(oi.quantity) AS total_quantity_sold " | |
| "FROM products p " | |
| "JOIN order_items oi ON oi.product_id = p.id " | |
| "GROUP BY p.id, p.name " | |
| "ORDER BY total_quantity_sold DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_006", | |
| difficulty="medium", | |
| description=( | |
| "Number of reviews per product, only for products with at least 3 reviews. " | |
| "Return product_name and review_count. Sort by review_count descending." | |
| ), | |
| sql=( | |
| "SELECT p.name AS product_name, COUNT(r.id) AS review_count " | |
| "FROM products p " | |
| "JOIN reviews r ON r.product_id = p.id " | |
| "GROUP BY p.id, p.name " | |
| "HAVING COUNT(r.id) >= 3 " | |
| "ORDER BY review_count DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_007", | |
| difficulty="medium", | |
| description=( | |
| "Show the total revenue (sum of total_amount) per country from all orders, " | |
| "regardless of status. Return country and total_revenue (rounded to 2 dp). " | |
| "Sort by total_revenue descending." | |
| ), | |
| sql=( | |
| "SELECT c.country, ROUND(SUM(o.total_amount), 2) AS total_revenue " | |
| "FROM customers c " | |
| "JOIN orders o ON o.customer_id = c.id " | |
| "GROUP BY c.country " | |
| "ORDER BY total_revenue DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_008", | |
| difficulty="medium", | |
| description=( | |
| "For each customer tier (bronze, silver, gold) show the average order value " | |
| "from delivered orders. Return tier and avg_order_value (rounded to 2 dp). " | |
| "Sort by avg_order_value descending." | |
| ), | |
| sql=( | |
| "SELECT c.tier, ROUND(AVG(o.total_amount), 2) AS avg_order_value " | |
| "FROM customers c " | |
| "JOIN orders o ON o.customer_id = c.id " | |
| "WHERE o.status = 'delivered' " | |
| "GROUP BY c.tier " | |
| "ORDER BY avg_order_value DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_009", | |
| difficulty="medium", | |
| description=( | |
| "Which products have never been ordered? " | |
| "Return id and name, sorted by name ascending." | |
| ), | |
| sql=( | |
| "SELECT p.id, p.name " | |
| "FROM products p " | |
| "LEFT JOIN order_items oi ON oi.product_id = p.id " | |
| "WHERE oi.id IS NULL " | |
| "ORDER BY p.name ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_010", | |
| difficulty="medium", | |
| description=( | |
| "Number of orders per status. " | |
| "Return status and order_count. Sort by order_count descending." | |
| ), | |
| sql=( | |
| "SELECT status, COUNT(*) AS order_count " | |
| "FROM orders " | |
| "GROUP BY status " | |
| "ORDER BY order_count DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_011", | |
| difficulty="medium", | |
| description=( | |
| "Show the total number of products per category. " | |
| "Return category_name and product_count. Sort by product_count descending." | |
| ), | |
| sql=( | |
| "SELECT c.name AS category_name, COUNT(p.id) AS product_count " | |
| "FROM categories c " | |
| "LEFT JOIN products p ON p.category_id = c.id " | |
| "GROUP BY c.id, c.name " | |
| "ORDER BY product_count DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_012", | |
| difficulty="medium", | |
| description=( | |
| "Average rating per product for products with at least one review. " | |
| "Return product_name and avg_rating (rounded to 2 dp). Sort by avg_rating descending." | |
| ), | |
| sql=( | |
| "SELECT p.name AS product_name, ROUND(AVG(r.rating), 2) AS avg_rating " | |
| "FROM products p " | |
| "JOIN reviews r ON r.product_id = p.id " | |
| "GROUP BY p.id, p.name " | |
| "ORDER BY avg_rating DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_013", | |
| difficulty="medium", | |
| description=( | |
| "Which gold-tier customers have placed more than 3 orders? " | |
| "Return customer_name and order_count. Sort by order_count descending." | |
| ), | |
| sql=( | |
| "SELECT c.name AS customer_name, COUNT(o.id) AS order_count " | |
| "FROM customers c " | |
| "JOIN orders o ON o.customer_id = c.id " | |
| "WHERE c.tier = 'gold' " | |
| "GROUP BY c.id, c.name " | |
| "HAVING COUNT(o.id) > 3 " | |
| "ORDER BY order_count DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_014", | |
| difficulty="medium", | |
| description=( | |
| "Total quantity of each product ordered via order_items. " | |
| "Return product_name and total_units. Sort by total_units descending." | |
| ), | |
| sql=( | |
| "SELECT p.name AS product_name, SUM(oi.quantity) AS total_units " | |
| "FROM products p " | |
| "JOIN order_items oi ON oi.product_id = p.id " | |
| "GROUP BY p.id, p.name " | |
| "ORDER BY total_units DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_015", | |
| difficulty="medium", | |
| description=( | |
| "For each country, count the number of gold-tier customers. " | |
| "Only show countries with at least one gold-tier customer. " | |
| "Return country and gold_count. Sort by gold_count descending." | |
| ), | |
| sql=( | |
| "SELECT country, COUNT(*) AS gold_count " | |
| "FROM customers " | |
| "WHERE tier = 'gold' " | |
| "GROUP BY country " | |
| "HAVING COUNT(*) >= 1 " | |
| "ORDER BY gold_count DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_016", | |
| difficulty="medium", | |
| description=( | |
| "Show how many reviews each customer has submitted. Only include customers " | |
| "who have submitted at least one review. Return customer_name and review_count. " | |
| "Sort by review_count descending." | |
| ), | |
| sql=( | |
| "SELECT c.name AS customer_name, COUNT(r.id) AS review_count " | |
| "FROM customers c " | |
| "JOIN reviews r ON r.customer_id = c.id " | |
| "GROUP BY c.id, c.name " | |
| "ORDER BY review_count DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_017", | |
| difficulty="medium", | |
| description=( | |
| "Total revenue generated from order_items (quantity * unit_price) per category. " | |
| "Return category_name and category_revenue (rounded to 2 dp). " | |
| "Sort by category_revenue descending." | |
| ), | |
| sql=( | |
| "SELECT c.name AS category_name, " | |
| " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS category_revenue " | |
| "FROM categories c " | |
| "JOIN products p ON p.category_id = c.id " | |
| "JOIN order_items oi ON oi.product_id = p.id " | |
| "GROUP BY c.id, c.name " | |
| "ORDER BY category_revenue DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_018", | |
| difficulty="medium", | |
| description=( | |
| "Which products have an average rating strictly below 3? " | |
| "Return product_name and avg_rating (rounded to 2 dp). Sort by avg_rating ascending." | |
| ), | |
| sql=( | |
| "SELECT p.name AS product_name, ROUND(AVG(r.rating), 2) AS avg_rating " | |
| "FROM products p " | |
| "JOIN reviews r ON r.product_id = p.id " | |
| "GROUP BY p.id, p.name " | |
| "HAVING AVG(r.rating) < 3 " | |
| "ORDER BY avg_rating ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_019", | |
| difficulty="medium", | |
| description=( | |
| "Find the maximum order value for each customer tier. " | |
| "Return tier and max_order_value (rounded to 2 dp). Sort by max_order_value descending." | |
| ), | |
| sql=( | |
| "SELECT c.tier, ROUND(MAX(o.total_amount), 2) AS max_order_value " | |
| "FROM customers c " | |
| "JOIN orders o ON o.customer_id = c.id " | |
| "GROUP BY c.tier " | |
| "ORDER BY max_order_value DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_020", | |
| difficulty="medium", | |
| description=( | |
| "How many customers per country have placed at least one delivered order? " | |
| "Return country and customer_count. Sort by customer_count descending." | |
| ), | |
| sql=( | |
| "SELECT c.country, COUNT(DISTINCT c.id) AS customer_count " | |
| "FROM customers c " | |
| "JOIN orders o ON o.customer_id = c.id " | |
| "WHERE o.status = 'delivered' " | |
| "GROUP BY c.country " | |
| "ORDER BY customer_count DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_021", | |
| difficulty="medium", | |
| description=( | |
| "List all products together with their category name. " | |
| "Return product_name, category_name, price. Sort by category_name, then price ascending." | |
| ), | |
| sql=( | |
| "SELECT p.name AS product_name, c.name AS category_name, p.price " | |
| "FROM products p " | |
| "JOIN categories c ON c.id = p.category_id " | |
| "ORDER BY category_name ASC, p.price ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_022", | |
| difficulty="medium", | |
| description=( | |
| "For each order, show the total number of line items it contains. " | |
| "Return order_id and line_item_count. Sort by line_item_count descending." | |
| ), | |
| sql=( | |
| "SELECT order_id, COUNT(*) AS line_item_count " | |
| "FROM order_items " | |
| "GROUP BY order_id " | |
| "ORDER BY line_item_count DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_023", | |
| difficulty="medium", | |
| description=( | |
| "Show the minimum and maximum product price per category. " | |
| "Return category_name, min_price, max_price. Sort by category_name ascending." | |
| ), | |
| sql=( | |
| "SELECT c.name AS category_name, " | |
| " ROUND(MIN(p.price), 2) AS min_price, " | |
| " ROUND(MAX(p.price), 2) AS max_price " | |
| "FROM categories c " | |
| "JOIN products p ON p.category_id = c.id " | |
| "GROUP BY c.id, c.name " | |
| "ORDER BY category_name ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_024", | |
| difficulty="medium", | |
| description=( | |
| "Find customers who have given a rating of 5 to at least one product. " | |
| "Return customer_name and five_star_count. Sort by five_star_count descending." | |
| ), | |
| sql=( | |
| "SELECT c.name AS customer_name, COUNT(r.id) AS five_star_count " | |
| "FROM customers c " | |
| "JOIN reviews r ON r.customer_id = c.id " | |
| "WHERE r.rating = 5 " | |
| "GROUP BY c.id, c.name " | |
| "ORDER BY five_star_count DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="med_025", | |
| difficulty="medium", | |
| description=( | |
| "Show the average number of items per order across all orders. " | |
| "Round to 2 decimal places. Return: avg_items_per_order." | |
| ), | |
| sql=( | |
| "SELECT ROUND(AVG(item_count), 2) AS avg_items_per_order " | |
| "FROM ( " | |
| " SELECT order_id, COUNT(*) AS item_count " | |
| " FROM order_items " | |
| " GROUP BY order_id " | |
| ")" | |
| ), | |
| ), | |
| ] | |
| HARD_TEMPLATES: List[SQLTemplate] = [ | |
| # ββ Window functions βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SQLTemplate( | |
| id="hard_001", | |
| difficulty="hard", | |
| description=( | |
| "Rank customers by total spending on delivered orders using DENSE_RANK " | |
| "(rank 1 = highest spender). " | |
| "Return customer_name, total_spent (rounded to 2 dp), spending_rank. " | |
| "Sort by spending_rank ascending." | |
| ), | |
| sql=( | |
| "SELECT customer_name, total_spent, spending_rank " | |
| "FROM ( " | |
| " SELECT c.name AS customer_name, " | |
| " ROUND(SUM(o.total_amount), 2) AS total_spent, " | |
| " DENSE_RANK() OVER (ORDER BY SUM(o.total_amount) DESC) AS spending_rank " | |
| " FROM customers c " | |
| " JOIN orders o ON o.customer_id = c.id " | |
| " WHERE o.status = 'delivered' " | |
| " GROUP BY c.id, c.name " | |
| ") sub " | |
| "ORDER BY spending_rank ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_002", | |
| difficulty="hard", | |
| description=( | |
| "For each reviewed product, show its own average rating and the average rating " | |
| "of all products in its category (partition window). " | |
| "Return product_name, product_avg_rating, category_avg_rating (both rounded to 2 dp). " | |
| "Sort by product_avg_rating descending." | |
| ), | |
| sql=( | |
| "SELECT p.name AS product_name, " | |
| " ROUND(AVG(r.rating), 2) AS product_avg_rating, " | |
| " ROUND(AVG(AVG(r.rating)) OVER (PARTITION BY p.category_id), 2) AS category_avg_rating " | |
| "FROM products p " | |
| "JOIN reviews r ON r.product_id = p.id " | |
| "GROUP BY p.id, p.name, p.category_id " | |
| "ORDER BY product_avg_rating DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_003", | |
| difficulty="hard", | |
| description=( | |
| "Find all customers whose most recent order has status 'cancelled'. " | |
| "Use a CTE with ROW_NUMBER partitioned by customer_id ordered by created_at DESC. " | |
| "Return customer_name, last_order_status, last_order_date. Sort by customer_name ascending." | |
| ), | |
| sql=( | |
| "WITH ranked_orders AS ( " | |
| " SELECT customer_id, status, created_at, " | |
| " ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY created_at DESC) AS rn " | |
| " FROM orders " | |
| ") " | |
| "SELECT c.name AS customer_name, " | |
| " ro.status AS last_order_status, " | |
| " ro.created_at AS last_order_date " | |
| "FROM customers c " | |
| "JOIN ranked_orders ro ON ro.customer_id = c.id " | |
| "WHERE ro.rn = 1 AND ro.status = 'cancelled' " | |
| "ORDER BY customer_name ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_004", | |
| difficulty="hard", | |
| description=( | |
| "Monthly revenue from delivered orders and its running total for all months in 2024. " | |
| "Return month (YYYY-MM format), monthly_revenue, running_total (both rounded to 2 dp). " | |
| "Sort by month ascending." | |
| ), | |
| sql=( | |
| "WITH monthly AS ( " | |
| " SELECT strftime('%Y-%m', created_at) AS month, " | |
| " ROUND(SUM(total_amount), 2) AS monthly_revenue " | |
| " FROM orders " | |
| " WHERE status = 'delivered' " | |
| " AND created_at >= '2024-01-01' AND created_at < '2025-01-01' " | |
| " GROUP BY strftime('%Y-%m', created_at) " | |
| ") " | |
| "SELECT month, monthly_revenue, " | |
| " ROUND(SUM(monthly_revenue) OVER (ORDER BY month), 2) AS running_total " | |
| "FROM monthly " | |
| "ORDER BY month ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_005", | |
| difficulty="hard", | |
| description=( | |
| "Find products whose average rating is strictly above the average rating of all products " | |
| "in their category. Use two CTEs: one for product-level averages and one for category-level. " | |
| "Return product_name, category_name, product_avg_rating, category_avg_rating (both rounded to 2 dp). " | |
| "Sort by product_avg_rating descending, then product_name ascending." | |
| ), | |
| sql=( | |
| "WITH product_ratings AS ( " | |
| " SELECT p.id AS product_id, p.name AS product_name, " | |
| " p.category_id, c.name AS category_name, " | |
| " ROUND(AVG(r.rating), 2) AS product_avg_rating " | |
| " FROM products p " | |
| " JOIN reviews r ON r.product_id = p.id " | |
| " JOIN categories c ON c.id = p.category_id " | |
| " GROUP BY p.id, p.name, p.category_id, c.name " | |
| "), " | |
| "category_ratings AS ( " | |
| " SELECT category_id, ROUND(AVG(product_avg_rating), 2) AS category_avg_rating " | |
| " FROM product_ratings " | |
| " GROUP BY category_id " | |
| ") " | |
| "SELECT pr.product_name, pr.category_name, " | |
| " pr.product_avg_rating, cr.category_avg_rating " | |
| "FROM product_ratings pr " | |
| "JOIN category_ratings cr ON cr.category_id = pr.category_id " | |
| "WHERE pr.product_avg_rating > cr.category_avg_rating " | |
| "ORDER BY pr.product_avg_rating DESC, pr.product_name ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_006", | |
| difficulty="hard", | |
| description=( | |
| "For each customer, find their very first order date using ROW_NUMBER in a CTE. " | |
| "Return customer_name and first_order_date. Sort by first_order_date ascending." | |
| ), | |
| sql=( | |
| "WITH first_orders AS ( " | |
| " SELECT customer_id, created_at, " | |
| " ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY created_at ASC) AS rn " | |
| " FROM orders " | |
| ") " | |
| "SELECT c.name AS customer_name, fo.created_at AS first_order_date " | |
| "FROM customers c " | |
| "JOIN first_orders fo ON fo.customer_id = c.id " | |
| "WHERE fo.rn = 1 " | |
| "ORDER BY first_order_date ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_007", | |
| difficulty="hard", | |
| description=( | |
| "Rank products by total revenue generated (quantity * unit_price from order_items) " | |
| "using RANK() window function. " | |
| "Return product_name, total_revenue (rounded to 2 dp), revenue_rank. " | |
| "Sort by revenue_rank ascending." | |
| ), | |
| sql=( | |
| "SELECT product_name, total_revenue, revenue_rank " | |
| "FROM ( " | |
| " SELECT p.name AS product_name, " | |
| " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_revenue, " | |
| " RANK() OVER (ORDER BY SUM(oi.quantity * oi.unit_price) DESC) AS revenue_rank " | |
| " FROM products p " | |
| " JOIN order_items oi ON oi.product_id = p.id " | |
| " GROUP BY p.id, p.name " | |
| ") sub " | |
| "ORDER BY revenue_rank ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_008", | |
| difficulty="hard", | |
| description=( | |
| "For each customer, compute the running total of their order amounts ordered by " | |
| "created_at. Return customer_name, order_date (created_at), order_amount (total_amount), " | |
| "running_total (rounded to 2 dp). Sort by customer_name, order_date ascending." | |
| ), | |
| sql=( | |
| "SELECT c.name AS customer_name, " | |
| " o.created_at AS order_date, " | |
| " o.total_amount AS order_amount, " | |
| " ROUND(SUM(o.total_amount) OVER " | |
| " (PARTITION BY c.id ORDER BY o.created_at " | |
| " ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 2) AS running_total " | |
| "FROM customers c " | |
| "JOIN orders o ON o.customer_id = c.id " | |
| "ORDER BY customer_name ASC, order_date ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_009", | |
| difficulty="hard", | |
| description=( | |
| "Find customers who have placed orders in every status " | |
| "(pending, processing, shipped, delivered, cancelled) at least once. " | |
| "Return customer_name and status_count. Sort by customer_name ascending." | |
| ), | |
| sql=( | |
| "SELECT c.name AS customer_name, COUNT(DISTINCT o.status) AS status_count " | |
| "FROM customers c " | |
| "JOIN orders o ON o.customer_id = c.id " | |
| "GROUP BY c.id, c.name " | |
| "HAVING COUNT(DISTINCT o.status) = 5 " | |
| "ORDER BY customer_name ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_010", | |
| difficulty="hard", | |
| description=( | |
| "Using a CTE, compute the total revenue per product, then rank the top 3 products " | |
| "in each category by revenue using DENSE_RANK. Only return rows with rank <= 3. " | |
| "Return category_name, product_name, total_revenue (rounded to 2 dp), rank_in_category. " | |
| "Sort by category_name, rank_in_category ascending." | |
| ), | |
| sql=( | |
| "WITH product_rev AS ( " | |
| " SELECT p.id, p.name AS product_name, p.category_id, " | |
| " c.name AS category_name, " | |
| " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_revenue " | |
| " FROM products p " | |
| " JOIN categories c ON c.id = p.category_id " | |
| " JOIN order_items oi ON oi.product_id = p.id " | |
| " GROUP BY p.id, p.name, p.category_id, c.name " | |
| "), " | |
| "ranked AS ( " | |
| " SELECT product_name, category_name, total_revenue, " | |
| " DENSE_RANK() OVER (PARTITION BY category_id ORDER BY total_revenue DESC) AS rank_in_category " | |
| " FROM product_rev " | |
| ") " | |
| "SELECT category_name, product_name, total_revenue, rank_in_category " | |
| "FROM ranked " | |
| "WHERE rank_in_category <= 3 " | |
| "ORDER BY category_name ASC, rank_in_category ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_011", | |
| difficulty="hard", | |
| description=( | |
| "Compute the percentage of total revenue each category contributes. " | |
| "Use a CTE for category revenues and a window SUM for the grand total. " | |
| "Return category_name, category_revenue, pct_of_total (rounded to 2 dp). " | |
| "Sort by pct_of_total descending." | |
| ), | |
| sql=( | |
| "WITH cat_rev AS ( " | |
| " SELECT c.name AS category_name, " | |
| " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS category_revenue " | |
| " FROM categories c " | |
| " JOIN products p ON p.category_id = c.id " | |
| " JOIN order_items oi ON oi.product_id = p.id " | |
| " GROUP BY c.id, c.name " | |
| ") " | |
| "SELECT category_name, category_revenue, " | |
| " ROUND(100.0 * category_revenue / SUM(category_revenue) OVER (), 2) AS pct_of_total " | |
| "FROM cat_rev " | |
| "ORDER BY pct_of_total DESC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_012", | |
| difficulty="hard", | |
| description=( | |
| "Find the customers who placed the highest number of orders in 2023. " | |
| "Use a CTE to count per-customer orders in 2023, then apply DENSE_RANK. " | |
| "Return customer_name, order_count_2023, rank. Sort by rank, then customer_name." | |
| ), | |
| sql=( | |
| "WITH counts_2023 AS ( " | |
| " SELECT c.name AS customer_name, COUNT(o.id) AS order_count_2023 " | |
| " FROM customers c " | |
| " JOIN orders o ON o.customer_id = c.id " | |
| " WHERE o.created_at >= '2023-01-01' AND o.created_at < '2024-01-01' " | |
| " GROUP BY c.id, c.name " | |
| ") " | |
| "SELECT customer_name, order_count_2023, " | |
| " DENSE_RANK() OVER (ORDER BY order_count_2023 DESC) AS rank " | |
| "FROM counts_2023 " | |
| "ORDER BY rank ASC, customer_name ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_013", | |
| difficulty="hard", | |
| description=( | |
| "Show a quarterly revenue breakdown for delivered orders across all years. " | |
| "Use strftime to derive year and quarter. " | |
| "Return year, quarter, quarterly_revenue (rounded to 2 dp), " | |
| "and running_total_in_year (running SUM within the same year, rounded to 2 dp). " | |
| "Sort by year, quarter ascending." | |
| ), | |
| sql=( | |
| "WITH quarterly AS ( " | |
| " SELECT strftime('%Y', created_at) AS year, " | |
| " ((CAST(strftime('%m', created_at) AS INTEGER) - 1) / 3 + 1) AS quarter, " | |
| " ROUND(SUM(total_amount), 2) AS quarterly_revenue " | |
| " FROM orders " | |
| " WHERE status = 'delivered' " | |
| " GROUP BY year, quarter " | |
| ") " | |
| "SELECT year, quarter, quarterly_revenue, " | |
| " ROUND(SUM(quarterly_revenue) OVER (PARTITION BY year ORDER BY quarter), 2) AS running_total_in_year " | |
| "FROM quarterly " | |
| "ORDER BY year ASC, quarter ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_014", | |
| difficulty="hard", | |
| description=( | |
| "Find the top-spending customer in each country using ROW_NUMBER. " | |
| "Return country, customer_name, total_spent (rounded to 2 dp). " | |
| "Sort by country, total_spent descending." | |
| ), | |
| sql=( | |
| "WITH customer_spend AS ( " | |
| " SELECT c.id, c.name AS customer_name, c.country, " | |
| " ROUND(SUM(o.total_amount), 2) AS total_spent " | |
| " FROM customers c " | |
| " JOIN orders o ON o.customer_id = c.id " | |
| " GROUP BY c.id, c.name, c.country " | |
| "), " | |
| "ranked AS ( " | |
| " SELECT country, customer_name, total_spent, " | |
| " ROW_NUMBER() OVER (PARTITION BY country ORDER BY total_spent DESC) AS rn " | |
| " FROM customer_spend " | |
| ") " | |
| "SELECT country, customer_name, total_spent " | |
| "FROM ranked " | |
| "WHERE rn = 1 " | |
| "ORDER BY country ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| SQLTemplate( | |
| id="hard_015", | |
| difficulty="hard", | |
| description=( | |
| "Find products that have received both 1-star and 5-star reviews. " | |
| "Use two CTEs: one for 1-star products, one for 5-star products, then intersect. " | |
| "Return product_name. Sort by product_name ascending." | |
| ), | |
| sql=( | |
| "WITH one_star AS ( " | |
| " SELECT DISTINCT product_id FROM reviews WHERE rating = 1 " | |
| "), " | |
| "five_star AS ( " | |
| " SELECT DISTINCT product_id FROM reviews WHERE rating = 5 " | |
| ") " | |
| "SELECT p.name AS product_name " | |
| "FROM products p " | |
| "JOIN one_star os ON os.product_id = p.id " | |
| "JOIN five_star fs ON fs.product_id = p.id " | |
| "ORDER BY product_name ASC" | |
| ), | |
| order_sensitive=True, | |
| ), | |
| ] | |
| ALL_TEMPLATES: List[SQLTemplate] = EASY_TEMPLATES + MEDIUM_TEMPLATES + HARD_TEMPLATES | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Personas | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SCHEMA_CONTEXT = """ | |
| DATABASE SCHEMA (SQLite e-commerce): | |
| categories(id, name) | |
| products(id, name, category_id, price, stock_quantity) | |
| customers(id, name, email, country, tierβ{bronze|silver|gold}, created_at) | |
| orders(id, customer_id, statusβ{pending|processing|shipped|delivered|cancelled}, | |
| created_at, total_amount) | |
| order_items(id, order_id, product_id, quantity, unit_price) | |
| reviews(id, product_id, customer_id, ratingβ1-5, created_at) | |
| """ | |
| PERSONA_SPECS = { | |
| "ceo": ( | |
| "You are a senior business executive. Write one SHORT, direct question in active voice, " | |
| "as if you are asking an analyst to pull a number fast. Be terse, no fluff. " | |
| "Use business language: 'revenue', 'customers', 'performance', not technical SQL terms." | |
| ), | |
| "chatty": ( | |
| "You are a friendly but verbose non-technical employee. Write one long, conversational " | |
| "question with filler phrases like 'Could you please tell me...', 'I was wondering if...', " | |
| "passive voice is fine. Use everyday words like 'money' instead of 'revenue', " | |
| "'people' instead of 'customers'." | |
| ), | |
| "lazy": ( | |
| "You are typing quickly on a phone. Write an extremely short question with abbreviations, " | |
| "lowercase letters, and minor spelling mistakes. Skip articles and punctuation where possible. " | |
| "Example style: 'top 5 prods by sales?', 'hw many cust in usa'." | |
| ), | |
| "confused": ( | |
| "You are a non-technical user who is unsure of the exact terminology. Write one question " | |
| "using synonyms and vague language. Replace 'revenue' with 'money made', 'customers' with " | |
| "'people' or 'users' or 'accounts', 'orders' with 'purchases' or 'transactions', " | |
| "'tier' with 'membership level'. Include a bit of ambiguity." | |
| ), | |
| "analyst": ( | |
| "You are a data analyst with technical knowledge. Write one precise, jargon-heavy question " | |
| "using terms like 'aggregate', 'partition', 'metric', 'fiscal period', 'segmented by', " | |
| "'cohort', 'granularity'. Be specific about column names and filters." | |
| ), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Rule-based Augmentor | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class RuleAugmentor: | |
| """ | |
| Applies deterministic, non-LLM transformations to a generated NL question. | |
| Returns a list of augmented variants (may be empty if no rule applied). | |
| """ | |
| SYNONYMS: Dict[str, List[str]] = { | |
| "customers": ["clients", "users", "accounts", "shoppers", "buyers"], | |
| "orders": ["purchases", "transactions", "sales", "bookings"], | |
| "products": ["items", "goods", "listings", "SKUs"], | |
| "revenue": ["sales", "income", "earnings", "money made"], | |
| "spending": ["expenditure", "purchases", "money spent"], | |
| "delivered": ["completed", "fulfilled", "received"], | |
| "cancelled": ["canceled", "voided", "aborted"], | |
| "pending": ["waiting", "unprocessed", "queued"], | |
| "gold": ["premium", "top-tier", "VIP", "platinum"], | |
| "silver": ["mid-tier", "standard-plus"], | |
| "bronze": ["basic", "standard", "entry-level"], | |
| "rating": ["score", "star rating", "review score"], | |
| "country": ["region", "location", "geography", "nation"], | |
| "category": ["department", "section", "type", "group"], | |
| "price": ["cost", "value", "amount", "fee"], | |
| "total": ["sum", "aggregate", "combined", "overall"], | |
| "average": ["mean", "typical", "avg"], | |
| "show": ["list", "display", "give me", "get", "fetch"], | |
| "find": ["identify", "locate", "get", "pull", "retrieve"], | |
| "return": ["give me", "show", "list", "provide"], | |
| } | |
| def augment(self, question: str, rng: random.Random) -> Optional[str]: | |
| words = question.split() | |
| changed = False | |
| result = [] | |
| for w in words: | |
| clean = w.lower().strip(".,?!;:") | |
| if clean in self.SYNONYMS and rng.random() < 0.4: | |
| replacement = rng.choice(self.SYNONYMS[clean]) | |
| # Preserve trailing punctuation | |
| punct = w[len(clean):] if w.lower().startswith(clean) else "" | |
| result.append(replacement + punct) | |
| changed = True | |
| else: | |
| result.append(w) | |
| if not changed: | |
| return None | |
| new_q = " ".join(result) | |
| # Capitalise first letter | |
| return new_q[0].upper() + new_q[1:] if new_q else new_q | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # vLLM Generator | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class VLLMGenerator: | |
| """ | |
| Async batched inference using the OpenAI-compatible vLLM endpoint. | |
| vLLM exposes exactly the same API as OpenAI, so we reuse AsyncOpenAI. | |
| """ | |
| def __init__(self, base_url: str, model: str, temperature: float = 0.8, | |
| max_tokens: int = 256, semaphore: int = 64): | |
| self.client = AsyncOpenAI(base_url=base_url, api_key="NONE") | |
| self.model = model | |
| self.temperature = temperature | |
| self.max_tokens = max_tokens | |
| self._sem = asyncio.Semaphore(semaphore) | |
| async def generate_one( | |
| self, | |
| system: str, | |
| user: str, | |
| retries: int = 3, | |
| ) -> Optional[str]: | |
| for attempt in range(retries): | |
| try: | |
| async with self._sem: | |
| resp = await self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user}, | |
| ], | |
| temperature=self.temperature, | |
| max_tokens=self.max_tokens, | |
| ) | |
| text = resp.choices[0].message.content.strip() | |
| return text if text else None | |
| except Exception as exc: | |
| wait = 2 ** attempt | |
| log.warning(f"vLLM call failed (attempt {attempt+1}): {exc}. Retrying in {wait}s.") | |
| await asyncio.sleep(wait) | |
| return None | |
| async def generate_batch( | |
| self, | |
| requests: List[Tuple[str, str, str]], # (request_id, system, user) | |
| ) -> Dict[str, Optional[str]]: | |
| """ | |
| Fire all requests concurrently (bounded by semaphore) and return a dict. | |
| """ | |
| async def _one(rid, sys, usr): | |
| return rid, await self.generate_one(sys, usr) | |
| tasks = [_one(rid, sys, usr) for rid, sys, usr in requests] | |
| results = await asyncio.gather(*tasks) | |
| return {rid: text for rid, text in results} | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Data Factory | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class DataPoint: | |
| id: str | |
| difficulty: str | |
| persona: str | |
| question: str | |
| sql: str | |
| db_result_ok: bool | |
| augmented: bool | |
| def to_training_prompt(self, system_prompt: str) -> Dict[str, Any]: | |
| """ | |
| Return the dict structure expected by train.py / SFT pipelines. | |
| Includes both the raw fields and a formatted 'messages' list. | |
| """ | |
| user_content = ( | |
| f"SCHEMA:\n{SCHEMA_CONTEXT}\n\nQUESTION: {self.question}" | |
| ) | |
| return { | |
| **asdict(self), | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_content}, | |
| {"role": "assistant", "content": self.sql}, | |
| ], | |
| } | |
| SYSTEM_PROMPT = ( | |
| "You are an expert SQL analyst working with a SQLite e-commerce database. " | |
| "Write a single SELECT query. Output ONLY the SQL query, nothing else. No markdown." | |
| ) | |
| class DataFactory: | |
| def __init__( | |
| self, | |
| generator: VLLMGenerator, | |
| validator: SQLiteValidator, | |
| augmentor: RuleAugmentor, | |
| personas_per_template: int = 5, | |
| aug_rounds: int = 2, | |
| seed: int = 42, | |
| ): | |
| self.generator = generator | |
| self.validator = validator | |
| self.augmentor = augmentor | |
| self.personas_per_template = personas_per_template | |
| self.aug_rounds = aug_rounds | |
| self.rng = random.Random(seed) | |
| # ββ Step 1: Validate all template SQLs βββββββββββββββββββββββββββββββββββ | |
| def validate_templates(self) -> List[SQLTemplate]: | |
| log.info("Validating all SQL templates against seeded DB...") | |
| valid = [] | |
| failed = [] | |
| for t in ALL_TEMPLATES: | |
| ok, err = self.validator.validate(t.sql) | |
| if ok: | |
| valid.append(t) | |
| else: | |
| failed.append((t.id, err)) | |
| if failed: | |
| log.error(f"FAILED templates (will be skipped): {failed}") | |
| log.info(f"Templates validated: {len(valid)} ok, {len(failed)} failed.") | |
| return valid | |
| # ββ Step 2: Build generation requests ββββββββββββββββββββββββββββββββββββ | |
| def _build_requests( | |
| self, | |
| templates: List[SQLTemplate], | |
| persona_names: List[str], | |
| ) -> List[Tuple[str, str, str]]: | |
| """ | |
| Returns a flat list of (request_id, system_prompt, user_prompt) tuples. | |
| """ | |
| requests = [] | |
| for t in templates: | |
| chosen_personas = ( | |
| persona_names | |
| if self.personas_per_template >= len(PERSONA_SPECS) | |
| else self.rng.sample(persona_names, self.personas_per_template) | |
| ) | |
| for persona in chosen_personas: | |
| rid = f"{t.id}__{persona}" | |
| system = ( | |
| f"{PERSONA_SPECS[persona]}\n\n" | |
| "Output ONLY the natural language question. " | |
| "No explanation, no SQL, no preamble, no quotes around the question." | |
| ) | |
| user = ( | |
| f"{SCHEMA_CONTEXT}\n" | |
| f"The SQL query that answers this question is:\n{t.sql}\n\n" | |
| f"Write ONE natural-language question that a {persona.upper()} user " | |
| f"would ask to get this exact result." | |
| ) | |
| requests.append((rid, system, user)) | |
| return requests | |
| # ββ Step 3: Post-process a generated question βββββββββββββββββββββββββββββ | |
| def _clean(text: str) -> str: | |
| """Strip quotes, markdown, leading numbers, trailing newlines.""" | |
| text = text.strip() | |
| # Remove leading numbering like "1. " or "Q: " | |
| text = re.sub(r'^[\d]+[\.\)]\s+', '', text) | |
| text = re.sub(r'^[Qq]:\s*', '', text) | |
| # Strip surrounding quotes | |
| if (text.startswith('"') and text.endswith('"')) or \ | |
| (text.startswith("'") and text.endswith("'")): | |
| text = text[1:-1].strip() | |
| # Collapse multiple whitespace | |
| text = re.sub(r'\s+', ' ', text) | |
| return text | |
| # ββ Main pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def run( | |
| self, | |
| output_path: str, | |
| checkpoint_path: str, | |
| batch_size: int = 64, | |
| ) -> None: | |
| # -- Validate templates | |
| templates = self.validate_templates() | |
| # -- Load checkpoint | |
| done_ids: set = set() | |
| if os.path.exists(checkpoint_path): | |
| with open(checkpoint_path) as f: | |
| done_ids = set(json.loads(line)["id"] for line in f if line.strip()) | |
| log.info(f"Resuming: {len(done_ids)} examples already generated.") | |
| persona_names = list(PERSONA_SPECS.keys())[: self.personas_per_template] | |
| all_requests = self._build_requests(templates, persona_names) | |
| # Filter already done | |
| pending = [r for r in all_requests if r[0] not in done_ids] | |
| log.info(f"Total requests to generate: {len(pending)}") | |
| # -- Build template lookup | |
| tmpl_lookup: Dict[str, SQLTemplate] = {t.id: t for t in templates} | |
| stats = {"generated": 0, "invalid_llm": 0, "augmented": 0} | |
| out_f = open(output_path, "a") | |
| ckpt_f = open(checkpoint_path, "a") | |
| try: | |
| for i in tqdm(range(0, len(pending), batch_size), desc="Batches"): | |
| batch = pending[i: i + batch_size] | |
| results = await self.generator.generate_batch(batch) | |
| for rid, raw_text in results.items(): | |
| tmpl_id, persona = rid.split("__", 1) | |
| tmpl = tmpl_lookup[tmpl_id] | |
| if not raw_text: | |
| stats["invalid_llm"] += 1 | |
| continue | |
| question = self._clean(raw_text) | |
| if len(question) < 8: | |
| stats["invalid_llm"] += 1 | |
| continue | |
| # SQL already validated; no need to re-run for NL variants | |
| dp = DataPoint( | |
| id=rid, | |
| difficulty=tmpl.difficulty, | |
| persona=persona, | |
| question=question, | |
| sql=tmpl.sql, | |
| db_result_ok=True, | |
| augmented=False, | |
| ) | |
| record = dp.to_training_prompt(SYSTEM_PROMPT) | |
| line = json.dumps(record, ensure_ascii=False) | |
| out_f.write(line + "\n") | |
| ckpt_f.write(line + "\n") | |
| stats["generated"] += 1 | |
| # -- Rule augmentation rounds | |
| for aug_i in range(self.aug_rounds): | |
| aug_q = self.augmentor.augment(question, self.rng) | |
| if aug_q and aug_q != question: | |
| aug_dp = DataPoint( | |
| id=f"{rid}__aug{aug_i}", | |
| difficulty=tmpl.difficulty, | |
| persona=persona, | |
| question=aug_q, | |
| sql=tmpl.sql, | |
| db_result_ok=True, | |
| augmented=True, | |
| ) | |
| aug_record = aug_dp.to_training_prompt(SYSTEM_PROMPT) | |
| aug_line = json.dumps(aug_record, ensure_ascii=False) | |
| out_f.write(aug_line + "\n") | |
| ckpt_f.write(aug_line + "\n") | |
| stats["augmented"] += 1 | |
| out_f.flush() | |
| ckpt_f.flush() | |
| finally: | |
| out_f.close() | |
| ckpt_f.close() | |
| log.info( | |
| f"Done. Generated={stats['generated']} " | |
| f"Augmented={stats['augmented']} " | |
| f"LLM failures={stats['invalid_llm']}" | |
| ) | |
| log.info(f"Output: {output_path}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CLI | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description="NL2SQL Synthetic Data Factory β H100 + vLLM", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| p.add_argument("--vllm-url", default="http://localhost:8001/v1", | |
| help="Base URL of the running vLLM server.") | |
| p.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct", | |
| help="Model name as registered in the vLLM server.") | |
| p.add_argument("--output", default="nl2sql_train.jsonl", | |
| help="Path to write the final JSONL dataset.") | |
| p.add_argument("--checkpoint",default="nl2sql_checkpoint.jsonl", | |
| help="Path for the checkpoint file (enables resume on crash).") | |
| p.add_argument("--personas-per-template", type=int, default=5, | |
| help="Number of persona variants to generate per SQL template (max 5).") | |
| p.add_argument("--aug-rounds", type=int, default=2, | |
| help="Number of rule-based augmentation rounds per generated question.") | |
| p.add_argument("--batch-size", type=int, default=64, | |
| help="Concurrent vLLM requests per batch (tune based on GPU memory).") | |
| p.add_argument("--temperature", type=float, default=0.85, | |
| help="Sampling temperature for vLLM (higher = more diverse).") | |
| p.add_argument("--max-tokens", type=int, default=200, | |
| help="Max tokens for each generated question.") | |
| p.add_argument("--seed", type=int, default=42) | |
| p.add_argument("--validate-only", action="store_true", | |
| help="Only validate SQL templates, do not generate data.") | |
| return p.parse_args() | |
| async def main() -> None: | |
| args = parse_args() | |
| # Build DB + validator | |
| conn = build_db() | |
| validator = SQLiteValidator(conn) | |
| if args.validate_only: | |
| valid = [t for t in ALL_TEMPLATES if validator.validate(t.sql)[0]] | |
| invalid = [t for t in ALL_TEMPLATES if not validator.validate(t.sql)[0]] | |
| print(f"\nβ Valid: {len(valid)}") | |
| print(f"β Invalid: {len(invalid)}") | |
| for t in invalid: | |
| _, err = validator.validate(t.sql) | |
| print(f" {t.id}: {err}") | |
| return | |
| # Build pipeline components | |
| generator = VLLMGenerator( | |
| base_url=args.vllm_url, | |
| model=args.model, | |
| temperature=args.temperature, | |
| max_tokens=args.max_tokens, | |
| semaphore=args.batch_size, | |
| ) | |
| augmentor = RuleAugmentor() | |
| factory = DataFactory( | |
| generator=generator, | |
| validator=validator, | |
| augmentor=augmentor, | |
| personas_per_template=min(args.personas_per_template, len(PERSONA_SPECS)), | |
| aug_rounds=args.aug_rounds, | |
| seed=args.seed, | |
| ) | |
| await factory.run( | |
| output_path=args.output, | |
| checkpoint_path=args.checkpoint, | |
| batch_size=args.batch_size, | |
| ) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |