diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..d5b54401627cb948cc0a724303acbf79ce70ed3b --- /dev/null +++ b/.env.example @@ -0,0 +1,23 @@ +# nl2sql-bench/.env.example +# ───────────────────────────────────────────────────────────────────────────── +# Copy this file to .env and fill in your values. +# NEVER commit .env to version control. +# +# All three variables below are MANDATORY per competition rules. +# ───────────────────────────────────────────────────────────────────────────── + +# LLM API endpoint (HuggingFace router or any OpenAI-compatible base URL) +API_BASE_URL=https://router.huggingface.co/v1 + +# Model identifier — must be accessible at the above endpoint +MODEL_NAME=Qwen/Qwen2.5-72B-Instruct + +# HuggingFace API token (also used as the OpenAI-client api_key) +HF_TOKEN=hf_your_token_here + +# ── Optional overrides ──────────────────────────────────────────────────── +# LOCAL_IMAGE_NAME=nl2sql-bench:latest # Docker image name for local dev +# SPACE_URL=https://your-space.hf.space # Deployed HF Space URL +# NL2SQL_DEFAULT_TASK=simple-filter # Default task (overridden per episode) +# NL2SQL_MAX_STEPS=5 # Max steps per episode +# ENABLE_WEB_INTERFACE=true # Enable /web UI for debugging diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..ea29c8c409915eb126c33803c4c644c010197204 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,56 @@ +# nl2sql-bench/server/Dockerfile +# ───────────────────────────────────────────────────────────────────────────── +# NL2SQL-Bench OpenEnv Server +# Hugging Face Spaces compatible (port 7860, non-root user). +# Build: docker build -t nl2sql-bench:latest . +# Run: docker run -p 7860:7860 nl2sql-bench:latest +# ───────────────────────────────────────────────────────────────────────────── + +FROM python:3.11-slim + +# HF Spaces runs as non-root by default +ARG UID=1000 +RUN useradd -m -u $UID appuser + +WORKDIR /app + +# ── System deps ─────────────────────────────────────────────────────────── +RUN apt-get update -qq && \ + apt-get install -y --no-install-recommends curl && \ + rm -rf /var/lib/apt/lists/* + +# ── Python deps ─────────────────────────────────────────────────────────── +COPY server/requirements.txt /app/requirements.txt +RUN pip install --no-cache-dir -r requirements.txt + +# ── Application code ────────────────────────────────────────────────────── +# Copy server code +COPY server/ /app/server/ +# Copy shared models (client imports from parent — we flatten for Docker) +COPY models.py /app/models.py + +# Flatten server submodules into /app so Python can find them +# (avoids complex PYTHONPATH games inside the container) +RUN cp -r /app/server/tasks /app/tasks && \ + cp -r /app/server/db /app/db && \ + cp /app/server/grader.py /app/grader.py && \ + cp /app/server/environment.py /app/environment.py && \ + cp /app/server/app.py /app/app.py + +# ── Runtime config ──────────────────────────────────────────────────────── +ENV PYTHONPATH=/app +ENV PYTHONUNBUFFERED=1 +# HF Spaces requires port 7860 +ENV PORT=7860 +ENV NL2SQL_DEFAULT_TASK=simple-filter +ENV NL2SQL_MAX_STEPS=5 + +USER appuser +WORKDIR /app + +EXPOSE 7860 + +HEALTHCHECK --interval=30s --timeout=10s --start-period=15s --retries=3 \ + CMD curl -sf http://localhost:${PORT}/health || exit 1 + +CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${PORT} --workers 2 --log-level info"] diff --git a/README.md b/README.md index 60cc4638569fca71dbe1ad74e941ac52f412fb59..c73d1ad7ed3ce5208395fcea1a5f99f5f2513cf9 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,273 @@ +# NL2SQL-Bench + +**Natural Language to SQL Analytics Environment for RL Training** + +[![openenv](https://img.shields.io/badge/openenv-compatible-blue)](https://github.com/meta-pytorch/OpenEnv) +[![Python 3.10+](https://img.shields.io/badge/python-3.10+-green)](https://www.python.org) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow)](LICENSE) + +--- + +## What is NL2SQL-Bench? + +NL2SQL-Bench is an OpenEnv-compliant RL training environment where an AI agent must iteratively write and refine **SQLite queries** to answer natural-language business questions against a synthetic e-commerce database. + +This fills a genuine gap in the OpenEnv ecosystem — no SQL query environment currently exists. Every data-driven company employs analysts who translate business questions into SQL. Training agents to do this well (and to recover from errors) is immediately valuable. + +**Why it's a great RL domain:** +- Rewards are **100% deterministic** — no LLM-as-judge, no subjectivity +- Multi-turn episodes create **dense reward signal** across the trajectory +- The error → fix → retry loop is a novel mechanic not present in existing environments +- Three clearly graduated difficulty levels challenge models across the full skill range + +--- + +## Environment Description + +The agent interacts with a **synthetic e-commerce SQLite database** containing ~150 customers, 64 products across 8 categories, ~600 orders, ~1000 order items, and ~400 reviews. The database is seeded deterministically (seed=42) so results are reproducible across any machine. + +The agent receives a natural-language question and iteratively submits SQL queries. Each query is executed, graded against the ground truth, and the reward + error/result is fed back as the next observation. + +--- + +## Database Schema + +``` +categories(id, name) +products(id, name, category_id, price, stock_quantity) +customers(id, name, email, country, tier∈{bronze|silver|gold}, created_at) +orders(id, customer_id, status∈{pending|processing|shipped|delivered|cancelled}, + created_at, total_amount) +order_items(id, order_id, product_id, quantity, unit_price) +reviews(id, product_id, customer_id, rating∈1-5, created_at) +``` + +All dates are ISO-8601 strings sortable by text comparison. SQLite window functions and CTEs are fully supported. + +--- + +## Action & Observation Space + +### Action +```python +@dataclass +class NL2SQLAction(Action): + query: str # A SQLite SELECT query string +``` + +### Observation +```python +@dataclass +class NL2SQLObservation(Observation): + question: str # The NL question to answer + schema_context: str # Compact schema description + task_name: str # Active task identifier + last_query: str # SQL submitted on previous step + last_result: List[Dict] # Up to 10 result rows + last_error: Optional[str] # SQLite error string or None + result_columns: List[str] # Column names of last_result + step: int # Current step (0 after reset) + max_steps: int # Maximum steps per episode + done: bool # Episode ended? + reward: Optional[float] # Step reward [0.0, 1.0] + score: float # Cumulative normalised score +``` + +--- + +## Tasks & Expected Difficulty + +### Task 1 — `simple-filter` (easy) +Single-table SELECT queries with WHERE, ORDER BY, LIMIT. Tests basic SQL fluency. Example questions: +- "List all gold-tier customers ordered by name alphabetically." +- "Return the top 5 most expensive products." + +**Expected solve rate (frontier model, 5 steps):** ~80% + +### Task 2 — `join-aggregation` (medium) +Multi-table JOINs with GROUP BY, HAVING, and aggregation functions. Example questions: +- "How many orders has each customer placed? Include customers with zero orders." +- "Which customers have spent more than $500 total on delivered orders?" + +**Expected solve rate (frontier model, 5 steps):** ~55% + +### Task 3 — `analytics-window` (hard) +CTEs, window functions (DENSE_RANK, ROW_NUMBER, running SUM), and nested subqueries. Example questions: +- "Rank customers by total spending using DENSE_RANK." +- "Show monthly revenue and running total for delivered orders in 2024." + +**Expected solve rate (frontier model, 5 steps):** ~30% + +--- + +## Reward Function + +Rewards are computed by deterministic comparison of the agent's result set against the ground truth: + +| Component | Score | Description | +|---|---|---| +| `syntax_ok` | +0.10 | Query runs without SQLite error | +| `columns_match` | +0.20 | Returned column names match ground truth | +| `row_count_match` | +0.20 | Number of rows matches | +| `exact_match` | +0.50 | Full result set equals ground truth | +| `step_penalty` | −0.05/step | Deducted per step beyond the first | + +Final reward is clamped to `[0.0, 1.0]`. Order sensitivity matches the ground-truth query: ORDER BY queries require correct row ordering; others are order-agnostic. + +--- + +## Baseline Scores + +Run by the `inference.py` script using `Qwen/Qwen2.5-72B-Instruct` via HuggingFace router: + +| Task | Expected Score | +|---|---| +| `simple-filter` | ~0.70 | +| `join-aggregation` | ~0.45 | +| `analytics-window` | ~0.25 | + +--- + +## Setup & Usage + +### Prerequisites +- Python 3.10+ +- Docker (for containerised deployment) +- A HuggingFace account + token + +### Local Development (no Docker) + +```bash +# Clone the repository +git clone https://huggingface.co/spaces/your-username/nl2sql-bench +cd nl2sql-bench + +# Quick start +chmod +x scripts/run_local.sh +./scripts/run_local.sh + +# Or manually: +python3 -m venv .venv && source .venv/bin/activate +pip install openenv-core fastapi "uvicorn[standard]" openai pydantic +export PYTHONPATH=".:server" +cd server && uvicorn app:app --reload --port 8000 +``` + +### Test the Running Server + +```bash +# Run smoke tests +chmod +x scripts/smoke_test.sh +./scripts/smoke_test.sh http://localhost:8000 + +# Run full test suite +pip install pytest pytest-asyncio +PYTHONPATH=".:server" pytest tests/ -v +``` + +### Docker + +```bash +# Build +docker build -t nl2sql-bench:latest . + +# Run +docker run -p 7860:7860 nl2sql-bench:latest + +# Test +./scripts/smoke_test.sh http://localhost:7860 +``` + +### Pre-submission Validation + +```bash +# Run the official validator (replace with your HF Space URL) +chmod +x pre_validation_script.sh +./pre_validation_script.sh https://your-username-nl2sql-bench.hf.space . +``` + +### Running the Baseline Inference + +```bash +# Set mandatory variables +export API_BASE_URL="https://router.huggingface.co/v1" +export MODEL_NAME="Qwen/Qwen2.5-72B-Instruct" +export HF_TOKEN="hf_your_token_here" +export SPACE_URL="https://your-username-nl2sql-bench.hf.space" + +python inference.py +``` + +### Using the Client Programmatically + +```python +import asyncio +from client import NL2SQLEnv +from models import NL2SQLAction + +async def main(): + async with NL2SQLEnv(base_url="http://localhost:8000") as env: + result = await env.reset() + print(result.observation.question) + + result = await env.step(NL2SQLAction( + query="SELECT id, name FROM customers WHERE tier='gold' ORDER BY name" + )) + print(f"Reward: {result.reward:.2f}") + print(f"Done: {result.done}") + print(f"Error: {result.observation.last_error}") + +asyncio.run(main()) +``` + +--- + +## Project Structure + +``` +nl2sql-bench/ +├── models.py # NL2SQLAction, NL2SQLObservation, NL2SQLState +├── client.py # NL2SQLEnv(HTTPEnvClient) +├── inference.py # Baseline inference script (mandatory name) +├── openenv.yaml # OpenEnv manifest +├── pyproject.toml +├── Dockerfile # HF Spaces compatible (port 7860) +├── .env.example +├── server/ +│ ├── app.py # FastAPI entry point +│ ├── environment.py # Core RL environment logic +│ ├── grader.py # Deterministic reward computation +│ ├── requirements.txt +│ ├── db/ +│ │ ├── schema.sql # 6-table e-commerce schema +│ │ └── seed.py # Deterministic data generator (seed=42) +│ └── tasks/ +│ ├── base.py # BaseTask + registry +│ ├── easy.py # simple-filter (5 examples) +│ ├── medium.py # join-aggregation (5 examples) +│ └── hard.py # analytics-window (5 examples) +├── tests/ +│ ├── conftest.py +│ └── test_all.py # 30+ pytest tests +└── scripts/ + ├── run_local.sh # Local dev server + └── smoke_test.sh # Endpoint smoke tests +``` + --- -title: Nl2sql Bench -emoji: 📚 -colorFrom: yellow -colorTo: yellow -sdk: docker -pinned: false + +## Design Decisions + +**Why SQLite in-memory?** Zero runtime dependency, deterministic, and it runs comfortably within the 2 vCPU / 8 GB constraint. The database loads in ~50ms. + +**Why multi-turn (up to 5 steps)?** A single-shot SQL environment gives binary rewards. Multi-turn with error feedback gives the agent — and the GRPO trainer — a rich signal: the model learns not just to write SQL, but to debug and refine its queries. + +**Why step penalty?** Without it, an agent that accidentally gets the right answer on step 5 scores the same as one that gets it on step 1. The penalty creates pressure to solve efficiently, which is realistic. + +**Why order-sensitive comparison for ORDER BY queries?** Business questions that say "rank by spending" expect a ranked output. Order-agnostic comparison would give spurious credit. + --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +## License + +MIT — see [LICENSE](LICENSE) diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e42272e9313d4414151182f6f1086cbfb724d83 --- /dev/null +++ b/__init__.py @@ -0,0 +1,18 @@ +""" +nl2sql-bench — NL2SQL Analytics OpenEnv Environment +==================================================== +Public API surface for client-side use. + + from nl2sql_bench import NL2SQLEnv, NL2SQLAction, NL2SQLObservation, NL2SQLState +""" + +from models import NL2SQLAction, NL2SQLObservation, NL2SQLState +from client import NL2SQLEnv + +__version__ = "0.1.0" +__all__ = [ + "NL2SQLEnv", + "NL2SQLAction", + "NL2SQLObservation", + "NL2SQLState", +] diff --git a/check_quality.py b/check_quality.py new file mode 100644 index 0000000000000000000000000000000000000000..7bc92d862cbcda1a2c4d34755e16241c5795f2f8 --- /dev/null +++ b/check_quality.py @@ -0,0 +1,131 @@ +import json +import os +import sys +import re +from collections import Counter +from tqdm import tqdm + +# Add project root to path +PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + +from data_factory.validator import SQLValidator + +DATASET_FILE = "edge_cases.jsonl" + +def main(): + if not os.path.exists(DATASET_FILE): + print(f"Error: {DATASET_FILE} not found!") + return + + print("Starting Dataset Quality & Sanity Check...\n") + + total_rows = 0 + corrupt_json = 0 + sql_execution_failures = 0 + empty_outputs = 0 + missing_domains = 0 + + persona_counts = Counter() + unique_sqls = set() + unique_questions = set() + domain_counts = Counter() + + validators = {} + + with open(DATASET_FILE, "r", encoding="utf-8") as f: + lines = f.readlines() + + for line in tqdm(lines, desc="Analyzing Rows"): + total_rows += 1 + try: + record = json.loads(line) + except json.JSONDecodeError: + corrupt_json += 1 + continue + + prompt_block = record.get("prompt", []) + sql = record.get("sql", "").strip() + metadata = record.get("metadata", {}) + + if not prompt_block or len(prompt_block) < 2 or not sql: + empty_outputs += 1 + continue + + user_content = prompt_block[1].get("content", "") + question = user_content.split("QUESTION: ")[-1] + + # Smart Domain Extraction: Try metadata first, fallback to prompt parsing + domain = metadata.get("domain") + if not domain: + match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", user_content) + domain = match.group(1) if match else "unknown" + + persona = metadata.get("persona", "unknown") + + persona_counts[persona] += 1 + domain_counts[domain] += 1 + unique_sqls.add(sql) + unique_questions.add(question) + + # Skip validation if domain is completely unknown/corrupted + if domain == "unknown": + missing_domains += 1 + continue + + # Strict Execution Quality Check + try: + if domain not in validators: + validators[domain] = SQLValidator(domain, seed=42) + + val_result = validators[domain].validate(sql) + if not val_result.passed or val_result.row_count == 0: + sql_execution_failures += 1 + except Exception as e: + # If any schema error occurs, mark it as failure + missing_domains += 1 + continue + + # Cleanup validators + for v in validators.values(): + v.close() + + # --- REPORT GENERATION --- + print("\n" + "="*60) + print("DATASET HEALTH REPORT") + print("="*60) + print(f"Total Rows Parsed : {total_rows}") + print(f"Corrupt JSON Lines : {corrupt_json}") + print(f"Missing SQL/Domains : {empty_outputs + missing_domains}") + + print("\nDIVERSITY METRICS:") + print(f"Unique SQL Queries : {len(unique_sqls)} (Base logic templates)") + print(f"Unique NL Questions : {len(unique_questions)}") + + valid_total = total_rows - (corrupt_json + empty_outputs + missing_domains) + duplication_rate = (1 - (len(unique_questions) / valid_total)) * 100 if valid_total else 0 + print(f"NL Duplication Rate : {duplication_rate:.2f}% (Should be low!)") + + print("\nPERSONA DISTRIBUTION:") + for p, count in persona_counts.most_common(): + print(f" - {p}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {p}: {count}") + + print("\nDOMAIN DISTRIBUTION:") + for d, count in domain_counts.most_common(): + print(f" - {d}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {d}: {count}") + + print("\nCRITICAL QUALITY CHECK:") + fail_rate = (sql_execution_failures / valid_total) * 100 if valid_total else 0 + print(f"SQL Execution Failures : {sql_execution_failures} ({fail_rate:.2f}%)") + + if fail_rate > 5.0: + print("WARNING: Too many SQLs are failing. Dataset needs cleanup.") + elif fail_rate > 0: + print("GOOD: Very low failure rate. Safe to train after minor filtering.") + else: + print("PERFECT: Zero execution failures. Pure Gold Dataset!") + print("="*60) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/clean_dataset.py b/clean_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..41df1eaabf3b15af3cbad90623e6806aec587495 --- /dev/null +++ b/clean_dataset.py @@ -0,0 +1,79 @@ +import json +import os +import sys +import re +from tqdm import tqdm + +PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + +from data_factory.validator import SQLValidator + +INPUT_FILE = "nl2sql_50k_elite_dataset.jsonl" +OUTPUT_FILE = "nl2sql_cleaned_ready_to_train.jsonl" + +def main(): + if not os.path.exists(INPUT_FILE): + print(f"Error: {INPUT_FILE} not found!") + return + + print(f"Sweeping dataset to remove bad SQLs...") + + with open(INPUT_FILE, "r", encoding="utf-8") as f: + lines = f.readlines() + + validators = {} + cleaned_count = 0 + failed_count = 0 + + with open(OUTPUT_FILE, "w", encoding="utf-8") as out_f: + for line in tqdm(lines, desc="Filtering Garbage"): + try: + record = json.loads(line) + except json.JSONDecodeError: + failed_count += 1 + continue + + sql = record.get("sql", "").strip() + metadata = record.get("metadata", {}) + domain = metadata.get("domain") + + # Fallback for domain extraction + if not domain or domain == "unknown": + content = record.get("prompt", [{}, {}])[1].get("content", "") + match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", content) + domain = match.group(1) if match else "unknown" + + if domain == "unknown": + failed_count += 1 + continue + + if domain not in validators: + validators[domain] = SQLValidator(domain, seed=42) + + try: + val_result = validators[domain].validate(sql) + # Keep ONLY if SQL is 100% perfect and returns data + if val_result.passed and val_result.row_count > 0: + out_f.write(line) + cleaned_count += 1 + else: + failed_count += 1 + except Exception: + failed_count += 1 + + for v in validators.values(): + v.close() + + print("\n" + "="*50) + print("DATASET CLEANUP COMPLETE") + print("="*50) + print(f"Original Rows : {len(lines)}") + print(f"Cleaned Rows : {cleaned_count} (100% Valid SQL)") + print(f"Removed Rows : {failed_count}") + print(f"Saved To : {OUTPUT_FILE}") + print("="*50) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/client.py b/client.py new file mode 100644 index 0000000000000000000000000000000000000000..77501ae7575a57ab39c546178ad29a70f41dd66d --- /dev/null +++ b/client.py @@ -0,0 +1,81 @@ +import httpx +import json +import os +from typing import Any, Dict, Optional +from dataclasses import dataclass + +@dataclass +class NL2SQLAction: + query: str + +@dataclass +class NL2SQLObservation: + question: str + schema_context: str + task_name: str + last_query: str + last_result: list + last_error: Optional[str] + result_columns: list + step: int + max_steps: int + done: bool + reward: float + score: float + +@dataclass +class StepResult: + observation: NL2SQLObservation + reward: float + done: bool + +class NL2SQLEnv: + def __init__(self, base_url: str = "http://localhost:8000"): + self.base_url = base_url + self.client = httpx.AsyncClient(base_url=base_url, timeout=60.0) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.client.aclose() + + async def reset(self) -> StepResult: + task_name = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter") + resp = await self.client.post("/reset", json={"task_name": task_name}) + return self._parse_result(resp.json()) + + async def step(self, action: NL2SQLAction) -> StepResult: + payload = {"query": action.query} + resp = await self.client.post("/step", json=payload) + return self._parse_result(resp.json()) + + def _parse_result(self, payload: Dict[str, Any]) -> StepResult: + obs_data = payload.get("observation", payload) + + # SAFETY CHECK: If reward or score is None/null, default to 0.0 + raw_reward = obs_data.get("reward") + safe_reward = float(raw_reward) if raw_reward is not None else 0.0 + + raw_score = obs_data.get("score") + safe_score = float(raw_score) if raw_score is not None else 0.0 + + obs = NL2SQLObservation( + question=obs_data.get("question", ""), + schema_context=obs_data.get("schema_context", ""), + task_name=obs_data.get("task_name", ""), + last_query=obs_data.get("last_query", ""), + last_result=obs_data.get("last_result", []), + last_error=obs_data.get("last_error"), + result_columns=obs_data.get("result_columns", []), + step=obs_data.get("step", 0), + max_steps=obs_data.get("max_steps", 5), + done=obs_data.get("done", False), + reward=safe_reward, + score=safe_score, + ) + return StepResult( + observation=obs, + reward=safe_reward, + done=obs.done, + ) \ No newline at end of file diff --git a/custom_train.py b/custom_train.py new file mode 100644 index 0000000000000000000000000000000000000000..608bb7b4536ad24bd732b9950d42f72adf7adc5a --- /dev/null +++ b/custom_train.py @@ -0,0 +1,250 @@ +""" +merge_and_train.py +================== +1. Merges nl2sql_cleaned_ready_to_train.jsonl + edge_cases.jsonl +2. Shuffles the combined dataset +3. Retrains using the same GRPO setup as train.py + +Run: + python merge_and_train.py + +Flags (env vars): + EDGE_FILE — path to edge cases jsonl (default: edge_cases.jsonl) + BASE_FILE — path to existing cleaned (default: nl2sql_cleaned_ready_to_train.jsonl) + MERGED_FILE — merged output path (default: nl2sql_merged_final.jsonl) + SKIP_MERGE — set "1" to skip merge step and go straight to training +""" + +import os, sys, json, random +import torch +from datasets import Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import LoraConfig +from trl import GRPOConfig, GRPOTrainer + +os.environ["CUDA_VISIBLE_DEVICES"] = "0,5,1,6" + +sys.path.insert(0, "./server") +from environment import NL2SQLEnvironment +from models import NL2SQLAction +from tasks import all_task_names, get_task + +# ── Config ─────────────────────────────────────────────────────────────────── +BASE_FILE = os.getenv("BASE_FILE", "nl2sql_cleaned_ready_to_train.jsonl") +EDGE_FILE = os.getenv("EDGE_FILE", "edge_cases.jsonl") +MERGED_FILE = os.getenv("MERGED_FILE", "nl2sql_merged_final.jsonl") +SKIP_MERGE = os.getenv("SKIP_MERGE", "0") == "1" + +MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct" +OUTPUT_DIR = "./qwen-7b-coder-nl2sql-grpo-v2" + +SYSTEM_PROMPT = """You are a Senior Database Architect and an expert in SQLite. +Your task is to translate natural language questions into highly optimized, correct SQLite SELECT queries. + +STRICT RULES: +1. Output EXACTLY ONE valid SQLite query. +2. DO NOT wrap the query in markdown formatting (no ```sql or ```). +3. DO NOT output any explanations, conversational text, or preambles. +4. ONLY use standard SQLite functions. +5. If the question implies ordering, use the correct ORDER BY clause. +6. SELECT only the columns explicitly requested — no extras. + +Your output must be executable directly against the database as-is.""" + + +# ── Step 1: Merge ───────────────────────────────────────────────────────────── + +def merge_datasets(): + if SKIP_MERGE: + print(f"[SKIP_MERGE=1] Using existing {MERGED_FILE}") + return + + print(f"Loading base: {BASE_FILE}") + print(f"Loading edges: {EDGE_FILE}") + + base_lines = [] + with open(BASE_FILE, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + base_lines.append(line) + + edge_lines = [] + with open(EDGE_FILE, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + edge_lines.append(line) + + combined = base_lines + edge_lines + random.shuffle(combined) + + with open(MERGED_FILE, "w", encoding="utf-8") as f: + for line in combined: + f.write(line + "\n") + + print( + f"Merged: {len(base_lines)} base + {len(edge_lines)} edge " + f"= {len(combined)} total → {MERGED_FILE}" + ) + + +# ── Step 2: Build HF Dataset ────────────────────────────────────────────────── + +def build_dataset(): + """ + Primary source: merged JSONL (base + edge cases). + Fallback: task examples from server/tasks/ (same as original train.py). + Both are combined so GRPO sees everything. + """ + data = [] + + # Load merged JSONL + with open(MERGED_FILE, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + rec = json.loads(line) + # rec has "prompt" (list of messages) and "sql" + # GRPO needs "prompt" and "task_name" — we use a synthetic task_name + data.append({ + "prompt": rec["prompt"], + "task_name": "merged_jsonl" # grader falls back to execution-based reward + }) + + # Also keep the original task examples so GRPO reward env works for them + for t_name in all_task_names(): + task = get_task(t_name) + schema = task.schema_context() + for ex in task.examples: + data.append({ + "prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": f"SCHEMA:\n{schema}\n\nQUESTION: {ex.question}"} + ], + "task_name": t_name + }) + + random.shuffle(data) + print(f"Dataset size: {len(data)} samples") + return Dataset.from_list(data) + + +# ── Step 3: Reward function ─────────────────────────────────────────────────── + +def sql_reward_func(prompts, completions, task_name, **kwargs): + rewards = [] + env = NL2SQLEnvironment() + + for idx, completion in enumerate(completions): + generated = ( + completion[0]["content"] if isinstance(completion, list) else completion + ) + # Strip code fences defensively + import re + generated = re.sub(r"```(?:sql)?\n?(.*?)```", r"\1", generated, flags=re.DOTALL).strip() + + t = task_name[idx] if isinstance(task_name, list) else task_name + + # For merged_jsonl rows the env won't have a matching task → + # reward purely on execution (non-empty result set = +1, error = 0) + if t == "merged_jsonl": + rewards.append(_execution_reward(generated, prompts[idx])) + continue + + env.reset(task_name=t) + try: + obs = env.step(NL2SQLAction(query=generated)) + rewards.append(float(obs.reward)) + except Exception: + rewards.append(0.0) + + return rewards + + +def _execution_reward(sql: str, prompt) -> float: + """Simple execution check for merged_jsonl samples.""" + import sqlite3, re as _re + + # Extract schema from the user message + user_content = "" + for msg in (prompt if isinstance(prompt, list) else []): + if isinstance(msg, dict) and msg.get("role") == "user": + user_content = msg.get("content", "") + break + + schema_match = _re.search(r"SCHEMA:\s*(.*?)\nQUESTION:", user_content, _re.DOTALL) + if not schema_match: + return 0.5 # can't verify, neutral reward + + schema_sql = schema_match.group(1).strip() + try: + conn = sqlite3.connect(":memory:") + conn.executescript(schema_sql) + rows = conn.execute(sql).fetchall() + conn.close() + return 1.0 if rows else 0.3 # ran cleanly but empty → partial credit + except Exception: + return 0.0 + + +# ── Step 4: Train ───────────────────────────────────────────────────────────── + +def main(): + merge_datasets() + dataset = build_dataset() + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=torch.bfloat16, + attn_implementation="sdpa" + ) + + peft_config = LoraConfig( + r=128, + lora_alpha=256, + target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], + bias="none", + task_type="CAUSAL_LM" + ) + + training_args = GRPOConfig( + output_dir=OUTPUT_DIR, + learning_rate=1e-5, # lower LR for fine-grained edge case tuning + per_device_train_batch_size=2, + gradient_accumulation_steps=4, + max_completion_length=256, + num_generations=8, + temperature=0.5, + bf16=True, + logging_steps=5, + num_train_epochs=5, # fewer epochs — base knowledge already there + report_to="none", + remove_unused_columns=False, + ddp_find_unused_parameters=False + ) + + trainer = GRPOTrainer( + model=model, + reward_funcs=sql_reward_func, + args=training_args, + train_dataset=dataset, + peft_config=peft_config, + processing_class=tokenizer + ) + + trainer.train() + + if trainer.accelerator.is_main_process: + trainer.model.save_pretrained(f"{OUTPUT_DIR}/final") + tokenizer.save_pretrained(f"{OUTPUT_DIR}/final") + print(f"\nSaved to {OUTPUT_DIR}/final") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/data_expander.py b/data_expander.py new file mode 100644 index 0000000000000000000000000000000000000000..b511b09cdb5cfefd740e14bf51f587fe51bea2f6 --- /dev/null +++ b/data_expander.py @@ -0,0 +1,161 @@ +import os +import sys +import json +import torch +import hashlib +from pathlib import Path +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +import sys + +# --- PATCH FOR TRANSFORMERS VERSION MISMATCH --- +try: + import transformers.activations + if not hasattr(transformers.activations, "PytorchGELUTanh"): + # Mapping the old name to the new existing one + transformers.activations.PytorchGELUTanh = transformers.activations.GELUActivation +except ImportError: + pass +# ------------------------------------------------------ + +import os +import json +import torch +# ... baaki ke saare purane imports + +# Force script to use only the 2 free GPUs (e.g., 0 and 7) +os.environ["CUDA_VISIBLE_DEVICES"] = "0,7" + +PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + +from data_factory.schemas import SCHEMA_CONTEXT + +# AWQ model is 4x smaller and much faster +MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct-AWQ" +INPUT_FILE = "llm_hybrid_templates.json" +OUTPUT_FILE = "nl2sql_50k_elite_dataset.jsonl" +VARIATIONS_PER_SQL = 20 +BATCH_SIZE = 64 # AWQ allows much larger batches! + +SYSTEM_PROMPT = "You are an expert SQL analyst. Write a single SELECT query that answers the question. Output ONLY the SQL query — no markdown, no explanation, no backticks." + +EXPANSION_PROMPT = """ +You are an expert linguist and NL2SQL data augmentor. I have a SQLite database schema and a complex SQL query. +Generate exactly {count} completely different natural language questions that this exact SQL query answers. + +RULES: +- Personas: Executive (direct), Non-tech (wordy), Analyst (technical), Curious (investigative). +- Structure: Completely change sentence flow. +- No direct column/table names. + +DATABASE SCHEMA: +{schema} + +SQL QUERY: +{sql} + +OUTPUT FORMAT: +Return ONLY a valid JSON array of objects: [{{"persona": "...", "question": "..."}}] +""" + +def extract_json_array(raw_text): + text = raw_text.strip() + start = text.find("[") + end = text.rfind("]") + if start != -1 and end != -1: + return text[start:end+1] + return "[]" + +def get_hash(text): + return hashlib.md5(text.lower().strip().encode('utf-8')).hexdigest() + +def main(): + if not os.path.exists(INPUT_FILE): + print(f"Error: {INPUT_FILE} not found.") + sys.exit(1) + + with open(INPUT_FILE, "r") as f: + base_templates = json.load(f) + + print(f"🚀 Loading {MODEL_NAME} on 2 GPUs...") + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + + # Model loading (AWQ version automatically handles quantization) + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + device_map="auto", + torch_dtype=torch.float16, # AWQ models use float16/bfloat16 for weights + low_cpu_mem_usage=True + ) + + seen_hashes = set() + total_saved = 0 + if os.path.exists(OUTPUT_FILE): + with open(OUTPUT_FILE, "r") as f: + for line in f: + total_saved += 1 # Quick count + + pbar = tqdm(total=len(base_templates) * VARIATIONS_PER_SQL, initial=total_saved) + + # Batch processing + for i in range(0, len(base_templates), BATCH_SIZE): + batch = base_templates[i:i + BATCH_SIZE] + prompts = [] + + for temp in batch: + msg = [ + {"role": "system", "content": "You output only JSON arrays."}, + {"role": "user", "content": EXPANSION_PROMPT.format(count=VARIATIONS_PER_SQL, schema=SCHEMA_CONTEXT[temp['domain']], sql=temp['sql'])} + ] + prompts.append(tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)) + + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + try: + with torch.no_grad(): + # Increased speed: AWQ handles large batches efficiently + outputs = model.generate( + **inputs, + max_new_tokens=2048, + temperature=0.5, + do_sample=True, + pad_token_id=tokenizer.eos_token_id + ) + + responses = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) + + with open(OUTPUT_FILE, "a", encoding="utf-8") as out_file: + for idx, resp in enumerate(responses): + questions_data = json.loads(extract_json_array(resp)) + sql = batch[idx]["sql"] + domain = batch[idx]["domain"] + + for item in questions_data: + q = item.get("question", "") + if len(q) > 10: + q_hash = get_hash(q + sql) + if q_hash not in seen_hashes: + seen_hashes.add(q_hash) + record = { + "prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {q}"} + ], + "sql": sql + } + out_file.write(json.dumps(record, ensure_ascii=False) + "\n") + total_saved += 1 + pbar.update(1) + out_file.flush() + except Exception as e: + print(f"Batch failed: {e}") + continue + + pbar.close() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/data_factory/__init__.py b/data_factory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_factory/augmentor.py b/data_factory/augmentor.py new file mode 100644 index 0000000000000000000000000000000000000000..31b1b02e691901b2cc4de4d7e88da58a6644a969 --- /dev/null +++ b/data_factory/augmentor.py @@ -0,0 +1,288 @@ +""" +data_factory/augmentor.py +========================== +Rule-based Natural Language augmentation. + +These transformations operate ONLY on NL question strings. +SQL is NEVER modified — it always comes from the verified template library. + +Three augmentation strategies: + 1. Synonym replacement — swaps domain words with semantically equivalent ones + 2. Condition reordering — shuffles conjunctive phrases (preserves meaning) + 3. Date normalisation — expresses dates in different formats when applicable +""" + +from __future__ import annotations + +import random +import re +from copy import deepcopy +from typing import Iterator + + +# ───────────────────────────────────────────────────────────────────────────── +# SYNONYM DICTIONARIES +# ───────────────────────────────────────────────────────────────────────────── + +# Format: "canonical_term": ["synonym1", "synonym2", ...] +# All synonyms are semantically equivalent in a business context. + +_SYNONYMS: dict[str, list[str]] = { + + # Verbs / action starters + "list": ["show", "display", "return", "give me", "find", "retrieve"], + "show": ["list", "display", "return", "get", "retrieve"], + "find": ["identify", "locate", "get", "show", "retrieve", "look up"], + "return": ["show", "give", "list", "retrieve", "output"], + "retrieve": ["fetch", "get", "return", "pull"], + "get": ["retrieve", "fetch", "return", "give me"], + + # Aggregation words + "total": ["sum", "aggregate", "overall", "cumulative", "combined"], + "average": ["mean", "avg", "typical"], + "count": ["number of", "quantity of", "how many"], + "highest": ["largest", "maximum", "top", "greatest"], + "lowest": ["smallest", "minimum", "least"], + + # Business / domain + "customer": ["client", "buyer", "user", "account holder", "shopper"], + "customers": ["clients", "buyers", "users", "account holders", "shoppers"], + "product": ["item", "SKU", "article", "goods"], + "products": ["items", "SKUs", "articles", "goods"], + "order": ["purchase", "transaction", "sale"], + "orders": ["purchases", "transactions", "sales"], + "revenue": ["income", "earnings", "sales amount", "money earned"], + "spending": ["expenditure", "spend", "purchases"], + "amount": ["value", "sum", "total", "figure"], + "price": ["cost", "rate", "charge", "fee"], + + # Healthcare + "patient": ["person", "individual", "case"], + "patients": ["persons", "individuals", "cases"], + "doctor": ["physician", "clinician", "practitioner", "specialist"], + "doctors": ["physicians", "clinicians", "practitioners"], + "appointment": ["visit", "consultation", "session"], + "appointments": ["visits", "consultations", "sessions"], + "medication": ["drug", "medicine", "pharmaceutical", "prescription drug"], + "medications": ["drugs", "medicines", "pharmaceuticals"], + "diagnosis": ["condition", "finding", "medical finding"], + + # Finance + "account": ["bank account", "profile", "portfolio entry"], + "accounts": ["bank accounts", "profiles"], + "loan": ["credit", "borrowing", "debt instrument"], + "loans": ["credits", "borrowings", "debt instruments"], + "transaction": ["transfer", "payment", "operation", "activity"], + "transactions": ["transfers", "payments", "operations"], + "balance": ["funds", "available amount", "account balance"], + + # HR + "employee": ["staff member", "worker", "team member", "headcount"], + "employees": ["staff", "workers", "team members", "workforce"], + "department": ["team", "division", "unit", "group"], + "departments": ["teams", "divisions", "units"], + "salary": ["pay", "compensation", "remuneration", "earnings"], + "project": ["initiative", "program", "assignment", "engagement"], + "projects": ["initiatives", "programs", "assignments"], + + # Adjectives / Qualifiers + "active": ["current", "ongoing", "live", "existing"], + "delivered": ["completed", "fulfilled", "received"], + "cancelled": ["voided", "aborted", "terminated"], + "alphabetically": ["by name", "in alphabetical order", "A to Z"], + "descending": ["from highest to lowest", "in decreasing order", "largest first"], + "ascending": ["from lowest to highest", "in increasing order", "smallest first"], + "distinct": ["unique", "different"], + "in stock": ["available", "with available inventory", "not out of stock"], +} + + +# ───────────────────────────────────────────────────────────────────────────── +# DATE PHRASE PATTERNS +# These will be replaced with alternative date expressions. +# ───────────────────────────────────────────────────────────────────────────── + +_DATE_ALTERNATES: list[tuple[str, list[str]]] = [ + # ISO partial + ("2024-01-01", ["January 1st 2024", "Jan 1, 2024", "the start of 2024", "2024 start"]), + ("2023-01-01", ["January 1st 2023", "Jan 1, 2023", "the start of 2023"]), + ("2025-01-01", ["January 1st 2025", "the start of 2025"]), + # Quarter references + ("Q1", ["the first quarter", "January through March", "Jan-Mar"]), + ("Q2", ["the second quarter", "April through June", "Apr-Jun"]), + ("Q3", ["the third quarter", "July through September", "Jul-Sep"]), + ("Q4", ["the fourth quarter", "October through December", "Oct-Dec"]), + # Year references + ("in 2024", ["during 2024", "throughout 2024", "for the year 2024"]), + ("in 2023", ["during 2023", "throughout 2023", "for the year 2023"]), +] + + +# ───────────────────────────────────────────────────────────────────────────── +# CONDITION REORDERING +# Splits on "and" between two conditions and reverses them. +# ───────────────────────────────────────────────────────────────────────────── + +def _reorder_conditions(text: str, rng: random.Random) -> str: + """ + If the text contains ' and ' connecting two distinct clauses, + randomly swap their order 50% of the time. + + Example: + "active employees earning above $100,000" + → "employees earning above $100,000 that are active" + """ + # Only attempt if "and" is present as a clause connector + matches = list(re.finditer(r'\b(?:and|who are|that are|with)\b', text, re.IGNORECASE)) + if not matches or rng.random() > 0.5: + return text + + # Take the first match and swap text around it + m = matches[0] + before = text[:m.start()].strip() + after = text[m.end():].strip() + connector = m.group(0).lower() + + # Build swapped version + if connector in ("and",): + swapped = f"{after} and {before}" + else: + swapped = f"{after} {connector} {before}" + + # Return swapped only if it doesn't break grammar badly + # (heuristic: swapped should not start with a verb) + if swapped and not swapped[0].isupper(): + swapped = swapped[0].upper() + swapped[1:] + return swapped + + +# ───────────────────────────────────────────────────────────────────────────── +# SYNONYM REPLACEMENT +# ───────────────────────────────────────────────────────────────────────────── + +def _apply_synonyms(text: str, rng: random.Random, max_replacements: int = 3) -> str: + """ + Replace up to `max_replacements` words/phrases with synonyms. + Replacement is probabilistic (50% chance per match) to maintain diversity. + """ + result = text + replacements_done = 0 + + # Shuffle the synonym keys to get different replacement targets each call + keys = list(_SYNONYMS.keys()) + rng.shuffle(keys) + + for canonical in keys: + if replacements_done >= max_replacements: + break + synonyms = _SYNONYMS[canonical] + # Case-insensitive match on word boundary + pattern = re.compile(r'\b' + re.escape(canonical) + r'\b', re.IGNORECASE) + if pattern.search(result) and rng.random() < 0.5: + replacement = rng.choice(synonyms) + # Preserve original casing for first character + def _replace(m: re.Match) -> str: + original = m.group(0) + if original[0].isupper(): + return replacement[0].upper() + replacement[1:] + return replacement + result = pattern.sub(_replace, result, count=1) + replacements_done += 1 + + return result + + +# ───────────────────────────────────────────────────────────────────────────── +# DATE FORMAT VARIATION +# ───────────────────────────────────────────────────────────────────────────── + +def _vary_dates(text: str, rng: random.Random) -> str: + """Replace date phrases with alternate representations.""" + result = text + for phrase, alternates in _DATE_ALTERNATES: + if phrase.lower() in result.lower() and rng.random() < 0.6: + alt = rng.choice(alternates) + result = re.sub(re.escape(phrase), alt, result, count=1, flags=re.IGNORECASE) + return result + + +# ───────────────────────────────────────────────────────────────────────────── +# PUBLIC API +# ───────────────────────────────────────────────────────────────────────────── + +def augment_nl( + nl_question: str, + n: int = 3, + seed: int = 42, +) -> list[str]: + """ + Generate `n` rule-based augmented variants of a natural language question. + + Each variant applies a different combination of: + - synonym replacement + - condition reordering + - date format variation + + The original question is NOT included in the output. + + Parameters + ---------- + nl_question : str + The base NL question to augment. + n : int + Number of variants to generate. + seed : int + Random seed for reproducibility. + + Returns + ------- + list[str] + Up to `n` distinct augmented strings. May be fewer if the question + is too short to vary meaningfully. + """ + rng = random.Random(seed) + variants: list[str] = [] + seen: set[str] = {nl_question} + + strategies = [ + # Strategy 1: synonym only + lambda t, r: _apply_synonyms(t, r, max_replacements=2), + # Strategy 2: synonym + date + lambda t, r: _vary_dates(_apply_synonyms(t, r, max_replacements=2), r), + # Strategy 3: condition reorder + synonym + lambda t, r: _apply_synonyms(_reorder_conditions(t, r), r, max_replacements=1), + # Strategy 4: heavy synonym + lambda t, r: _apply_synonyms(t, r, max_replacements=4), + # Strategy 5: date only + lambda t, r: _vary_dates(t, r), + ] + + for i in range(n * 3): # Over-generate, then deduplicate + strategy = strategies[i % len(strategies)] + # Use a different seed offset per variant attempt + local_rng = random.Random(seed + i * 31) + candidate = strategy(nl_question, local_rng).strip() + + # Normalise whitespace + candidate = " ".join(candidate.split()) + + if candidate and candidate not in seen: + seen.add(candidate) + variants.append(candidate) + + if len(variants) >= n: + break + + return variants + + +def generate_all_augmentations( + nl_question: str, + seed: int = 42, + n_per_template: int = 3, +) -> Iterator[str]: + """ + Yield augmented NL variants one at a time (generator). + Suitable for streaming into a large dataset without memory pressure. + """ + yield from augment_nl(nl_question, n=n_per_template, seed=seed) diff --git a/data_factory/config.py b/data_factory/config.py new file mode 100644 index 0000000000000000000000000000000000000000..339275c0759972bfc17ce31c31e99f2fe8a58719 --- /dev/null +++ b/data_factory/config.py @@ -0,0 +1,50 @@ +""" +data_factory/config.py +====================== +Central configuration for the NL2SQL Synthetic Data Factory. + +Design philosophy: + - SQL ALWAYS comes from human-verified templates → zero SQL errors + - LLM ONLY generates natural language paraphrases → no SQL hallucination + - Every SQL is execution-validated before saving → guaranteed correctness +""" + +from __future__ import annotations +from pathlib import Path + +# ── Paths ──────────────────────────────────────────────────────────────── +ROOT_DIR = Path(__file__).parent.parent +DATA_DIR = ROOT_DIR / "generated_data" +CHECKPOINT_DIR = DATA_DIR / "checkpoints" +OUTPUT_DIR = DATA_DIR / "output" + +# ── vLLM / Model ───────────────────────────────────────────────────────── +# For H100 with 80GB VRAM — run Llama-3-70B or Qwen-72B at full bf16 +GENERATOR_MODEL = "meta-llama/Meta-Llama-3-70B-Instruct" # change to your preferred model +TENSOR_PARALLEL = 4 # Number of GPUs for tensor parallelism (H100 cluster) +MAX_MODEL_LEN = 4096 # Max context length +GPU_MEMORY_UTIL = 0.90 # Leave 10% headroom + +# ── Generation settings ────────────────────────────────────────────────── +PERSONAS = ["ceo", "chatty", "lazy_typist", "non_techie", "analyst"] +NL_VARIANTS_PER_TEMPLATE = 5 # One per persona +AUGMENTATIONS_PER_NL = 3 # Rule-based variations per NL string +TEMPERATURE = 0.85 # Slightly high for diversity +MAX_NEW_TOKENS = 150 # NL questions are short + +# ── Scale targets ──────────────────────────────────────────────────────── +# 56 base SQL templates × 5 personas × 3 augmentations = 840 "original" records +# With vLLM generating more NL variants, target: ~500K-1M clean records +VLLM_EXTRA_VARIANTS = 10 # Additional vLLM NL variants per template beyond personas + +# ── Validation ─────────────────────────────────────────────────────────── +RANDOM_SEED = 42 + +# ── Domains ────────────────────────────────────────────────────────────── +DOMAINS = ["ecommerce", "healthcare", "finance", "hr"] + +DIFFICULTY_LABELS = { + "easy": "Single-table SELECT with basic WHERE/ORDER/LIMIT.", + "medium": "Multi-table JOIN with GROUP BY/HAVING/aggregates.", + "hard": "CTEs, window functions, subqueries.", +} diff --git a/data_factory/generate_data.py b/data_factory/generate_data.py new file mode 100644 index 0000000000000000000000000000000000000000..00f8058c774c7881e1ab61380ccfb67b8231fc9c --- /dev/null +++ b/data_factory/generate_data.py @@ -0,0 +1,1947 @@ +""" +generate_data.py — NL2SQL Synthetic Data Factory +================================================= +Designed for H100 + vLLM. Produces a clean JSONL file ready for SFT or GRPO training +with the nl2sql-bench codebase (schema: e-commerce SQLite). + +Architecture +------------ +1. SQL_TEMPLATES — 120+ ground-truth SQLs, hand-written and verified, NEVER LLM-generated. +2. SQLiteValidator — executes every SQL against the actual seeded DB; discards any failure. +3. VLLMGenerator — async batched calls to a local vLLM server for NL paraphrasing. +4. RuleAugmentor — pure-Python synonym / date-format / condition-order augmentation. +5. DataFactory — orchestrates the full pipeline; writes JSONL with checkpointing. + +Output schema (one JSON object per line) +----------------------------------------- +{ + "id": "easy_001_persona_ceo", + "difficulty": "easy" | "medium" | "hard", + "persona": "ceo" | "chatty" | "lazy" | "confused" | "analyst", + "question": "", + "sql": "", + "db_result_ok": true, # always true — failures are discarded + "augmented": false # true when rule-augmentor modified the NL +} + +Usage +----- +# 1. Start vLLM server (H100): +# vllm serve meta-llama/Meta-Llama-3-70B-Instruct \ +# --tensor-parallel-size 4 --port 8001 \ +# --max-model-len 4096 --gpu-memory-utilization 0.92 + +# 2. Run this script (place it next to the nl2sql-bench folder): +# python generate_data.py \ +# --vllm-url http://localhost:8001/v1 \ +# --model meta-llama/Meta-Llama-3-70B-Instruct \ +# --output nl2sql_train.jsonl \ +# --personas-per-template 5 \ +# --aug-rounds 2 \ +# --batch-size 64 + +Requirements +------------ +pip install openai tqdm +(vLLM + your model already running separately) + +IMPORTANT: Copy server/db/schema.sql and server/db/seed.py from nl2sql-bench +into the same directory as this script, OR set --bench-root to the repo root. +""" + +from __future__ import annotations + +import argparse +import asyncio +import hashlib +import json +import logging +import os +import random +import re +import sqlite3 +import sys +import time +from copy import deepcopy +from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from openai import AsyncOpenAI +from tqdm import tqdm + +# ───────────────────────────────────────────────────────────────────────────── +# Logging +# ───────────────────────────────────────────────────────────────────────────── + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(message)s", + datefmt="%H:%M:%S", +) +log = logging.getLogger("data-factory") + + +# ───────────────────────────────────────────────────────────────────────────── +# Database: build & validate +# ───────────────────────────────────────────────────────────────────────────── + +SCHEMA_SQL = """ +CREATE TABLE IF NOT EXISTS categories ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL UNIQUE +); + +CREATE TABLE IF NOT EXISTS products ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + category_id INTEGER NOT NULL REFERENCES categories(id), + price REAL NOT NULL CHECK(price >= 0), + stock_quantity INTEGER NOT NULL DEFAULT 0 +); + +CREATE TABLE IF NOT EXISTS customers ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT NOT NULL UNIQUE, + country TEXT NOT NULL, + tier TEXT NOT NULL DEFAULT 'bronze' + CHECK(tier IN ('bronze', 'silver', 'gold')), + created_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS orders ( + id INTEGER PRIMARY KEY, + customer_id INTEGER NOT NULL REFERENCES customers(id), + status TEXT NOT NULL DEFAULT 'pending' + CHECK(status IN ('pending','processing','shipped','delivered','cancelled')), + created_at TEXT NOT NULL, + total_amount REAL NOT NULL CHECK(total_amount >= 0) +); + +CREATE TABLE IF NOT EXISTS order_items ( + id INTEGER PRIMARY KEY, + order_id INTEGER NOT NULL REFERENCES orders(id), + product_id INTEGER NOT NULL REFERENCES products(id), + quantity INTEGER NOT NULL CHECK(quantity > 0), + unit_price REAL NOT NULL CHECK(unit_price >= 0) +); + +CREATE TABLE IF NOT EXISTS reviews ( + id INTEGER PRIMARY KEY, + product_id INTEGER NOT NULL REFERENCES products(id), + customer_id INTEGER NOT NULL REFERENCES customers(id), + rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5), + created_at TEXT NOT NULL +); +""" + +# Minimal seeder so the validator can run the SQL against real data. +# Mirrors the logic in nl2sql-bench/server/db/seed.py (fixed seed = 42). +SEED_SCRIPT = """ +import random, sqlite3 +from datetime import date, timedelta + +RNG = random.Random(42) + +CATEGORIES = ["Electronics","Clothing","Books","Home & Garden", + "Sports & Outdoors","Toys & Games","Beauty","Automotive"] + +PRODUCTS = { + "Electronics": ["Wireless Headphones","USB-C Hub","Mechanical Keyboard", + "Webcam 4K","Portable Charger","Smart Speaker", + "Monitor Stand","HDMI Cable 2.1"], + "Clothing": ["Cotton T-Shirt","Slim Fit Jeans","Hoodie", + "Running Shorts","Winter Jacket","Polo Shirt", + "Casual Sneakers","Wool Socks"], + "Books": ["Clean Code","Designing Data-Intensive Applications", + "The Pragmatic Programmer","System Design Interview", + "Deep Learning Book","Python Cookbook", + "Domain-Driven Design","Refactoring"], + "Home & Garden": ["Coffee Maker","Air Purifier","LED Desk Lamp", + "Plant Pot Set","Storage Organiser","Cutting Board", + "Vacuum Cleaner","Electric Kettle"], + "Sports & Outdoors": ["Yoga Mat","Resistance Bands","Cycling Gloves", + "Trekking Poles","Water Bottle 1L","Jump Rope", + "Foam Roller","Compression Socks"], + "Toys & Games": ["Lego City Set","Card Game Pack","Puzzle 1000pc", + "Remote Control Car","Building Blocks", + "Board Game Strategy","Art Set","Toy Drone"], + "Beauty": ["Face Serum","SPF 50 Sunscreen","Lip Balm", + "Shampoo Pro","Hair Mask","Eye Cream", + "Vitamin C Cream","Toner Mist"], + "Automotive": ["Car Phone Mount","Dash Cam","Tyre Inflator", + "Car Vacuum","Seat Cushion","Steering Wheel Cover", + "OBD Scanner","Jump Starter"], +} + +COUNTRIES = ["India","USA","Germany","UK","Canada", + "Australia","France","Brazil","Japan","Singapore"] +TIERS = ["bronze","silver","gold"] +STATUSES = ["pending","processing","shipped","delivered","cancelled"] + +FIRST = ["Aarav","Priya","Rahul","Neha","Arjun","Sneha","Vikram","Pooja", + "Karthik","Divya","James","Sarah","Michael","Emily","David","Jessica", + "Hans","Lena","Oliver","Sofia","Pierre","Amelie","Carlos","Laura", + "Yuki","Hana","Wei","Mei","Aiden","Zara"] +LAST = ["Sharma","Singh","Patel","Kumar","Gupta","Verma","Nair","Reddy", + "Smith","Johnson","Brown","Williams","Jones","Davis","Wilson", + "Müller","Schmidt","Schneider","Fischer","Weber", + "Martin","Bernard","Thomas","Richard","Petit", + "Garcia","Martinez","Lopez","Sanchez","Gonzalez"] + + +def _date(start=2022, end=2025): + s = date(start, 1, 1) + e = date(end, 12, 31) + return str(s + timedelta(days=RNG.randint(0, (e - s).days))) + + +def seed(conn): + c = conn.cursor() + for cat in CATEGORIES: + c.execute("INSERT OR IGNORE INTO categories(name) VALUES (?)", (cat,)) + conn.commit() + + cat_ids = {r[1]: r[0] for r in conn.execute("SELECT id, name FROM categories")} + + for cat, prods in PRODUCTS.items(): + for pname in prods: + c.execute( + "INSERT OR IGNORE INTO products(name,category_id,price,stock_quantity) VALUES (?,?,?,?)", + (pname, cat_ids[cat], round(RNG.uniform(5, 500), 2), RNG.randint(0, 200)), + ) + conn.commit() + + for i in range(200): + name = f"{RNG.choice(FIRST)} {RNG.choice(LAST)}" + email = f"user{i}@example.com" + c.execute( + "INSERT OR IGNORE INTO customers(name,email,country,tier,created_at) VALUES (?,?,?,?,?)", + (name, email, RNG.choice(COUNTRIES), RNG.choice(TIERS), _date()), + ) + conn.commit() + + cust_ids = [r[0] for r in conn.execute("SELECT id FROM customers")] + prod_ids = [r[0] for r in conn.execute("SELECT id FROM products")] + + for _ in range(600): + cid = RNG.choice(cust_ids) + amt = round(RNG.uniform(10, 1000), 2) + status = RNG.choice(STATUSES) + d = _date() + c.execute( + "INSERT INTO orders(customer_id,status,created_at,total_amount) VALUES (?,?,?,?)", + (cid, status, d, amt), + ) + conn.commit() + + ord_ids = [r[0] for r in conn.execute("SELECT id FROM orders")] + for oid in ord_ids: + for _ in range(RNG.randint(1, 4)): + pid = RNG.choice(prod_ids) + qty = RNG.randint(1, 5) + price = round(RNG.uniform(5, 500), 2) + c.execute( + "INSERT INTO order_items(order_id,product_id,quantity,unit_price) VALUES (?,?,?,?)", + (oid, pid, qty, price), + ) + conn.commit() + + for _ in range(400): + pid = RNG.choice(prod_ids) + cid = RNG.choice(cust_ids) + rating = RNG.randint(1, 5) + c.execute( + "INSERT INTO reviews(product_id,customer_id,rating,created_at) VALUES (?,?,?,?)", + (pid, cid, rating, _date()), + ) + conn.commit() +""" + + +def build_db() -> sqlite3.Connection: + """Build an in-memory SQLite DB with schema + seed data.""" + conn = sqlite3.connect(":memory:") + conn.executescript(SCHEMA_SQL) + exec(SEED_SCRIPT, {"conn": conn}) # run the seeder inline + conn.row_factory = sqlite3.Row + log.info("In-memory DB built and seeded.") + return conn + + +class SQLiteValidator: + """Execute SQL against the seeded DB; return (rows, error).""" + + def __init__(self, conn: sqlite3.Connection): + self.conn = conn + + def validate(self, sql: str) -> Tuple[bool, Optional[str]]: + sql = sql.strip().rstrip(";") + if not sql: + return False, "Empty SQL" + first = sql.split()[0].lower() + if first != "select": + return False, f"Non-SELECT statement: {first}" + try: + cur = self.conn.execute(sql) + cur.fetchmany(500) + return True, None + except sqlite3.Error as exc: + return False, str(exc) + + +# ───────────────────────────────────────────────────────────────────────────── +# SQL Template Library (ground-truth, hand-written, execution-validated) +# ───────────────────────────────────────────────────────────────────────────── + +@dataclass +class SQLTemplate: + id: str + difficulty: str # easy | medium | hard + description: str # plain-English description fed to the LLM + sql: str + order_sensitive: bool = False + + +# NOTE: Every SQL here uses only the 6 tables in the schema and valid SQLite syntax. +# They are intentionally grouped by the SQL pattern they teach, not just by difficulty. + +EASY_TEMPLATES: List[SQLTemplate] = [ + # ── Equality filter ────────────────────────────────────────────────────── + SQLTemplate( + id="easy_001", + difficulty="easy", + description=( + "List all gold-tier customers, ordered alphabetically by name. " + "Return id, name, email, country." + ), + sql=( + "SELECT id, name, email, country " + "FROM customers " + "WHERE tier = 'gold' " + "ORDER BY name ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="easy_002", + difficulty="easy", + description=( + "Show all products priced above $100, sorted by price descending. " + "Return id, name, price." + ), + sql=( + "SELECT id, name, price " + "FROM products " + "WHERE price > 100 " + "ORDER BY price DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="easy_003", + difficulty="easy", + description=( + "Find all delivered orders with a total_amount greater than $200, " + "sorted by total_amount descending. " + "Return id, customer_id, total_amount, created_at." + ), + sql=( + "SELECT id, customer_id, total_amount, created_at " + "FROM orders " + "WHERE status = 'delivered' AND total_amount > 200 " + "ORDER BY total_amount DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="easy_004", + difficulty="easy", + description=( + "Return the top 5 most expensive products. Return id, name, price." + ), + sql=( + "SELECT id, name, price " + "FROM products " + "ORDER BY price DESC " + "LIMIT 5" + ), + order_sensitive=True, + ), + SQLTemplate( + id="easy_005", + difficulty="easy", + description=( + "List all distinct countries where customers come from, sorted alphabetically. " + "Return a single column: country." + ), + sql=( + "SELECT DISTINCT country " + "FROM customers " + "ORDER BY country ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="easy_006", + difficulty="easy", + description=( + "Show all pending orders, ordered by created_at descending. " + "Return id, customer_id, total_amount, created_at." + ), + sql=( + "SELECT id, customer_id, total_amount, created_at " + "FROM orders " + "WHERE status = 'pending' " + "ORDER BY created_at DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="easy_007", + difficulty="easy", + description=( + "Find all products with zero stock (stock_quantity = 0). " + "Return id, name, price, category_id." + ), + sql=( + "SELECT id, name, price, category_id " + "FROM products " + "WHERE stock_quantity = 0" + ), + ), + SQLTemplate( + id="easy_008", + difficulty="easy", + description=( + "How many customers are there in total? Return a single value: total_customers." + ), + sql="SELECT COUNT(*) AS total_customers FROM customers", + ), + SQLTemplate( + id="easy_009", + difficulty="easy", + description=( + "What is the most expensive product price in the store? " + "Return a single value: max_price." + ), + sql="SELECT MAX(price) AS max_price FROM products", + ), + SQLTemplate( + id="easy_010", + difficulty="easy", + description=( + "What is the cheapest product price in the store? " + "Return a single value: min_price." + ), + sql="SELECT MIN(price) AS min_price FROM products", + ), + SQLTemplate( + id="easy_011", + difficulty="easy", + description=( + "What is the average price of all products? " + "Round to 2 decimal places. Return: avg_price." + ), + sql="SELECT ROUND(AVG(price), 2) AS avg_price FROM products", + ), + SQLTemplate( + id="easy_012", + difficulty="easy", + description=( + "Show all customers from India, sorted by name ascending. " + "Return id, name, email, tier." + ), + sql=( + "SELECT id, name, email, tier " + "FROM customers " + "WHERE country = 'India' " + "ORDER BY name ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="easy_013", + difficulty="easy", + description=( + "List the 10 most recently placed orders. " + "Return id, customer_id, status, created_at, total_amount." + ), + sql=( + "SELECT id, customer_id, status, created_at, total_amount " + "FROM orders " + "ORDER BY created_at DESC " + "LIMIT 10" + ), + order_sensitive=True, + ), + SQLTemplate( + id="easy_014", + difficulty="easy", + description=( + "Find all reviews with a rating of 5 stars. " + "Return id, product_id, customer_id, created_at." + ), + sql=( + "SELECT id, product_id, customer_id, created_at " + "FROM reviews " + "WHERE rating = 5" + ), + ), + SQLTemplate( + id="easy_015", + difficulty="easy", + description=( + "Find all reviews with a rating of 1 star (lowest possible). " + "Return id, product_id, customer_id, created_at." + ), + sql=( + "SELECT id, product_id, customer_id, created_at " + "FROM reviews " + "WHERE rating = 1" + ), + ), + SQLTemplate( + id="easy_016", + difficulty="easy", + description=( + "Count the number of cancelled orders. Return: cancelled_count." + ), + sql=( + "SELECT COUNT(*) AS cancelled_count " + "FROM orders " + "WHERE status = 'cancelled'" + ), + ), + SQLTemplate( + id="easy_017", + difficulty="easy", + description=( + "List all products with stock_quantity greater than 100, " + "sorted by stock_quantity descending. Return id, name, stock_quantity." + ), + sql=( + "SELECT id, name, stock_quantity " + "FROM products " + "WHERE stock_quantity > 100 " + "ORDER BY stock_quantity DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="easy_018", + difficulty="easy", + description=( + "Find all silver-tier customers from the USA. " + "Return id, name, email." + ), + sql=( + "SELECT id, name, email " + "FROM customers " + "WHERE tier = 'silver' AND country = 'USA'" + ), + ), + SQLTemplate( + id="easy_019", + difficulty="easy", + description=( + "What is the total revenue from all delivered orders? " + "Round to 2 decimal places. Return: total_revenue." + ), + sql=( + "SELECT ROUND(SUM(total_amount), 2) AS total_revenue " + "FROM orders " + "WHERE status = 'delivered'" + ), + ), + SQLTemplate( + id="easy_020", + difficulty="easy", + description=( + "List all orders placed in 2024, sorted by created_at ascending. " + "Return id, customer_id, status, total_amount, created_at." + ), + sql=( + "SELECT id, customer_id, status, total_amount, created_at " + "FROM orders " + "WHERE created_at >= '2024-01-01' AND created_at < '2025-01-01' " + "ORDER BY created_at ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="easy_021", + difficulty="easy", + description=( + "Show the bottom 5 cheapest products. Return id, name, price." + ), + sql=( + "SELECT id, name, price " + "FROM products " + "ORDER BY price ASC " + "LIMIT 5" + ), + order_sensitive=True, + ), + SQLTemplate( + id="easy_022", + difficulty="easy", + description=( + "Count how many products exist in the catalogue. Return: product_count." + ), + sql="SELECT COUNT(*) AS product_count FROM products", + ), + SQLTemplate( + id="easy_023", + difficulty="easy", + description=( + "List all distinct order statuses that exist in the orders table. " + "Return a single column: status." + ), + sql="SELECT DISTINCT status FROM orders ORDER BY status ASC", + order_sensitive=True, + ), + SQLTemplate( + id="easy_024", + difficulty="easy", + description=( + "Find customers who joined (created_at) in 2023. " + "Return id, name, country, tier, created_at, sorted by created_at ascending." + ), + sql=( + "SELECT id, name, country, tier, created_at " + "FROM customers " + "WHERE created_at >= '2023-01-01' AND created_at < '2024-01-01' " + "ORDER BY created_at ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="easy_025", + difficulty="easy", + description=( + "Show all orders with total_amount between $50 and $150 inclusive. " + "Return id, customer_id, total_amount, status." + ), + sql=( + "SELECT id, customer_id, total_amount, status " + "FROM orders " + "WHERE total_amount BETWEEN 50 AND 150" + ), + ), + SQLTemplate( + id="easy_026", + difficulty="easy", + description=( + "How many distinct customers have placed at least one order? " + "Return a single value: customers_with_orders." + ), + sql=( + "SELECT COUNT(DISTINCT customer_id) AS customers_with_orders " + "FROM orders" + ), + ), + SQLTemplate( + id="easy_027", + difficulty="easy", + description=( + "What is the total number of order line items across all orders? " + "Return: total_line_items." + ), + sql="SELECT COUNT(*) AS total_line_items FROM order_items", + ), + SQLTemplate( + id="easy_028", + difficulty="easy", + description=( + "List all products priced between $20 and $80 inclusive, sorted by price ascending. " + "Return id, name, price." + ), + sql=( + "SELECT id, name, price " + "FROM products " + "WHERE price BETWEEN 20 AND 80 " + "ORDER BY price ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="easy_029", + difficulty="easy", + description=( + "Show all gold-tier customers from Germany. " + "Return id, name, email, created_at." + ), + sql=( + "SELECT id, name, email, created_at " + "FROM customers " + "WHERE tier = 'gold' AND country = 'Germany'" + ), + ), + SQLTemplate( + id="easy_030", + difficulty="easy", + description=( + "What is the average rating across all reviews in the system? " + "Round to 2 decimal places. Return: avg_rating." + ), + sql="SELECT ROUND(AVG(rating), 2) AS avg_rating FROM reviews", + ), +] + +MEDIUM_TEMPLATES: List[SQLTemplate] = [ + # ── JOIN + COUNT ───────────────────────────────────────────────────────── + SQLTemplate( + id="med_001", + difficulty="medium", + description=( + "How many orders has each customer placed? Include customers with zero orders. " + "Return customer_name and order_count. Sort by order_count descending, " + "then customer_name ascending." + ), + sql=( + "SELECT c.name AS customer_name, COUNT(o.id) AS order_count " + "FROM customers c " + "LEFT JOIN orders o ON c.id = o.customer_id " + "GROUP BY c.id, c.name " + "ORDER BY order_count DESC, customer_name ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_002", + difficulty="medium", + description=( + "Average product rating per category, only for categories that have at least one review. " + "Return category_name and avg_rating (rounded to 2 dp). Sort by avg_rating descending." + ), + sql=( + "SELECT c.name AS category_name, ROUND(AVG(r.rating), 2) AS avg_rating " + "FROM categories c " + "JOIN products p ON p.category_id = c.id " + "JOIN reviews r ON r.product_id = p.id " + "GROUP BY c.id, c.name " + "ORDER BY avg_rating DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_003", + difficulty="medium", + description=( + "Which categories have more than 5 in-stock products (stock_quantity > 0)? " + "Return category_name and in_stock_count. Sort by in_stock_count descending." + ), + sql=( + "SELECT c.name AS category_name, COUNT(p.id) AS in_stock_count " + "FROM categories c " + "JOIN products p ON p.category_id = c.id " + "WHERE p.stock_quantity > 0 " + "GROUP BY c.id, c.name " + "HAVING COUNT(p.id) > 5 " + "ORDER BY in_stock_count DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_004", + difficulty="medium", + description=( + "Which customers have spent more than $500 on delivered orders? " + "Return customer_name and total_spent (rounded to 2 dp). Sort by total_spent descending." + ), + sql=( + "SELECT c.name AS customer_name, ROUND(SUM(o.total_amount), 2) AS total_spent " + "FROM customers c " + "JOIN orders o ON o.customer_id = c.id " + "WHERE o.status = 'delivered' " + "GROUP BY c.id, c.name " + "HAVING SUM(o.total_amount) > 500 " + "ORDER BY total_spent DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_005", + difficulty="medium", + description=( + "Total quantity sold for each product that appears in at least one order. " + "Return product_name and total_quantity_sold. Sort by total_quantity_sold descending." + ), + sql=( + "SELECT p.name AS product_name, SUM(oi.quantity) AS total_quantity_sold " + "FROM products p " + "JOIN order_items oi ON oi.product_id = p.id " + "GROUP BY p.id, p.name " + "ORDER BY total_quantity_sold DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_006", + difficulty="medium", + description=( + "Number of reviews per product, only for products with at least 3 reviews. " + "Return product_name and review_count. Sort by review_count descending." + ), + sql=( + "SELECT p.name AS product_name, COUNT(r.id) AS review_count " + "FROM products p " + "JOIN reviews r ON r.product_id = p.id " + "GROUP BY p.id, p.name " + "HAVING COUNT(r.id) >= 3 " + "ORDER BY review_count DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_007", + difficulty="medium", + description=( + "Show the total revenue (sum of total_amount) per country from all orders, " + "regardless of status. Return country and total_revenue (rounded to 2 dp). " + "Sort by total_revenue descending." + ), + sql=( + "SELECT c.country, ROUND(SUM(o.total_amount), 2) AS total_revenue " + "FROM customers c " + "JOIN orders o ON o.customer_id = c.id " + "GROUP BY c.country " + "ORDER BY total_revenue DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_008", + difficulty="medium", + description=( + "For each customer tier (bronze, silver, gold) show the average order value " + "from delivered orders. Return tier and avg_order_value (rounded to 2 dp). " + "Sort by avg_order_value descending." + ), + sql=( + "SELECT c.tier, ROUND(AVG(o.total_amount), 2) AS avg_order_value " + "FROM customers c " + "JOIN orders o ON o.customer_id = c.id " + "WHERE o.status = 'delivered' " + "GROUP BY c.tier " + "ORDER BY avg_order_value DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_009", + difficulty="medium", + description=( + "Which products have never been ordered? " + "Return id and name, sorted by name ascending." + ), + sql=( + "SELECT p.id, p.name " + "FROM products p " + "LEFT JOIN order_items oi ON oi.product_id = p.id " + "WHERE oi.id IS NULL " + "ORDER BY p.name ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_010", + difficulty="medium", + description=( + "Number of orders per status. " + "Return status and order_count. Sort by order_count descending." + ), + sql=( + "SELECT status, COUNT(*) AS order_count " + "FROM orders " + "GROUP BY status " + "ORDER BY order_count DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_011", + difficulty="medium", + description=( + "Show the total number of products per category. " + "Return category_name and product_count. Sort by product_count descending." + ), + sql=( + "SELECT c.name AS category_name, COUNT(p.id) AS product_count " + "FROM categories c " + "LEFT JOIN products p ON p.category_id = c.id " + "GROUP BY c.id, c.name " + "ORDER BY product_count DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_012", + difficulty="medium", + description=( + "Average rating per product for products with at least one review. " + "Return product_name and avg_rating (rounded to 2 dp). Sort by avg_rating descending." + ), + sql=( + "SELECT p.name AS product_name, ROUND(AVG(r.rating), 2) AS avg_rating " + "FROM products p " + "JOIN reviews r ON r.product_id = p.id " + "GROUP BY p.id, p.name " + "ORDER BY avg_rating DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_013", + difficulty="medium", + description=( + "Which gold-tier customers have placed more than 3 orders? " + "Return customer_name and order_count. Sort by order_count descending." + ), + sql=( + "SELECT c.name AS customer_name, COUNT(o.id) AS order_count " + "FROM customers c " + "JOIN orders o ON o.customer_id = c.id " + "WHERE c.tier = 'gold' " + "GROUP BY c.id, c.name " + "HAVING COUNT(o.id) > 3 " + "ORDER BY order_count DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_014", + difficulty="medium", + description=( + "Total quantity of each product ordered via order_items. " + "Return product_name and total_units. Sort by total_units descending." + ), + sql=( + "SELECT p.name AS product_name, SUM(oi.quantity) AS total_units " + "FROM products p " + "JOIN order_items oi ON oi.product_id = p.id " + "GROUP BY p.id, p.name " + "ORDER BY total_units DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_015", + difficulty="medium", + description=( + "For each country, count the number of gold-tier customers. " + "Only show countries with at least one gold-tier customer. " + "Return country and gold_count. Sort by gold_count descending." + ), + sql=( + "SELECT country, COUNT(*) AS gold_count " + "FROM customers " + "WHERE tier = 'gold' " + "GROUP BY country " + "HAVING COUNT(*) >= 1 " + "ORDER BY gold_count DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_016", + difficulty="medium", + description=( + "Show how many reviews each customer has submitted. Only include customers " + "who have submitted at least one review. Return customer_name and review_count. " + "Sort by review_count descending." + ), + sql=( + "SELECT c.name AS customer_name, COUNT(r.id) AS review_count " + "FROM customers c " + "JOIN reviews r ON r.customer_id = c.id " + "GROUP BY c.id, c.name " + "ORDER BY review_count DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_017", + difficulty="medium", + description=( + "Total revenue generated from order_items (quantity * unit_price) per category. " + "Return category_name and category_revenue (rounded to 2 dp). " + "Sort by category_revenue descending." + ), + sql=( + "SELECT c.name AS category_name, " + " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS category_revenue " + "FROM categories c " + "JOIN products p ON p.category_id = c.id " + "JOIN order_items oi ON oi.product_id = p.id " + "GROUP BY c.id, c.name " + "ORDER BY category_revenue DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_018", + difficulty="medium", + description=( + "Which products have an average rating strictly below 3? " + "Return product_name and avg_rating (rounded to 2 dp). Sort by avg_rating ascending." + ), + sql=( + "SELECT p.name AS product_name, ROUND(AVG(r.rating), 2) AS avg_rating " + "FROM products p " + "JOIN reviews r ON r.product_id = p.id " + "GROUP BY p.id, p.name " + "HAVING AVG(r.rating) < 3 " + "ORDER BY avg_rating ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_019", + difficulty="medium", + description=( + "Find the maximum order value for each customer tier. " + "Return tier and max_order_value (rounded to 2 dp). Sort by max_order_value descending." + ), + sql=( + "SELECT c.tier, ROUND(MAX(o.total_amount), 2) AS max_order_value " + "FROM customers c " + "JOIN orders o ON o.customer_id = c.id " + "GROUP BY c.tier " + "ORDER BY max_order_value DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_020", + difficulty="medium", + description=( + "How many customers per country have placed at least one delivered order? " + "Return country and customer_count. Sort by customer_count descending." + ), + sql=( + "SELECT c.country, COUNT(DISTINCT c.id) AS customer_count " + "FROM customers c " + "JOIN orders o ON o.customer_id = c.id " + "WHERE o.status = 'delivered' " + "GROUP BY c.country " + "ORDER BY customer_count DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_021", + difficulty="medium", + description=( + "List all products together with their category name. " + "Return product_name, category_name, price. Sort by category_name, then price ascending." + ), + sql=( + "SELECT p.name AS product_name, c.name AS category_name, p.price " + "FROM products p " + "JOIN categories c ON c.id = p.category_id " + "ORDER BY category_name ASC, p.price ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_022", + difficulty="medium", + description=( + "For each order, show the total number of line items it contains. " + "Return order_id and line_item_count. Sort by line_item_count descending." + ), + sql=( + "SELECT order_id, COUNT(*) AS line_item_count " + "FROM order_items " + "GROUP BY order_id " + "ORDER BY line_item_count DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_023", + difficulty="medium", + description=( + "Show the minimum and maximum product price per category. " + "Return category_name, min_price, max_price. Sort by category_name ascending." + ), + sql=( + "SELECT c.name AS category_name, " + " ROUND(MIN(p.price), 2) AS min_price, " + " ROUND(MAX(p.price), 2) AS max_price " + "FROM categories c " + "JOIN products p ON p.category_id = c.id " + "GROUP BY c.id, c.name " + "ORDER BY category_name ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_024", + difficulty="medium", + description=( + "Find customers who have given a rating of 5 to at least one product. " + "Return customer_name and five_star_count. Sort by five_star_count descending." + ), + sql=( + "SELECT c.name AS customer_name, COUNT(r.id) AS five_star_count " + "FROM customers c " + "JOIN reviews r ON r.customer_id = c.id " + "WHERE r.rating = 5 " + "GROUP BY c.id, c.name " + "ORDER BY five_star_count DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="med_025", + difficulty="medium", + description=( + "Show the average number of items per order across all orders. " + "Round to 2 decimal places. Return: avg_items_per_order." + ), + sql=( + "SELECT ROUND(AVG(item_count), 2) AS avg_items_per_order " + "FROM ( " + " SELECT order_id, COUNT(*) AS item_count " + " FROM order_items " + " GROUP BY order_id " + ")" + ), + ), +] + +HARD_TEMPLATES: List[SQLTemplate] = [ + # ── Window functions ───────────────────────────────────────────────────── + SQLTemplate( + id="hard_001", + difficulty="hard", + description=( + "Rank customers by total spending on delivered orders using DENSE_RANK " + "(rank 1 = highest spender). " + "Return customer_name, total_spent (rounded to 2 dp), spending_rank. " + "Sort by spending_rank ascending." + ), + sql=( + "SELECT customer_name, total_spent, spending_rank " + "FROM ( " + " SELECT c.name AS customer_name, " + " ROUND(SUM(o.total_amount), 2) AS total_spent, " + " DENSE_RANK() OVER (ORDER BY SUM(o.total_amount) DESC) AS spending_rank " + " FROM customers c " + " JOIN orders o ON o.customer_id = c.id " + " WHERE o.status = 'delivered' " + " GROUP BY c.id, c.name " + ") sub " + "ORDER BY spending_rank ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_002", + difficulty="hard", + description=( + "For each reviewed product, show its own average rating and the average rating " + "of all products in its category (partition window). " + "Return product_name, product_avg_rating, category_avg_rating (both rounded to 2 dp). " + "Sort by product_avg_rating descending." + ), + sql=( + "SELECT p.name AS product_name, " + " ROUND(AVG(r.rating), 2) AS product_avg_rating, " + " ROUND(AVG(AVG(r.rating)) OVER (PARTITION BY p.category_id), 2) AS category_avg_rating " + "FROM products p " + "JOIN reviews r ON r.product_id = p.id " + "GROUP BY p.id, p.name, p.category_id " + "ORDER BY product_avg_rating DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_003", + difficulty="hard", + description=( + "Find all customers whose most recent order has status 'cancelled'. " + "Use a CTE with ROW_NUMBER partitioned by customer_id ordered by created_at DESC. " + "Return customer_name, last_order_status, last_order_date. Sort by customer_name ascending." + ), + sql=( + "WITH ranked_orders AS ( " + " SELECT customer_id, status, created_at, " + " ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY created_at DESC) AS rn " + " FROM orders " + ") " + "SELECT c.name AS customer_name, " + " ro.status AS last_order_status, " + " ro.created_at AS last_order_date " + "FROM customers c " + "JOIN ranked_orders ro ON ro.customer_id = c.id " + "WHERE ro.rn = 1 AND ro.status = 'cancelled' " + "ORDER BY customer_name ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_004", + difficulty="hard", + description=( + "Monthly revenue from delivered orders and its running total for all months in 2024. " + "Return month (YYYY-MM format), monthly_revenue, running_total (both rounded to 2 dp). " + "Sort by month ascending." + ), + sql=( + "WITH monthly AS ( " + " SELECT strftime('%Y-%m', created_at) AS month, " + " ROUND(SUM(total_amount), 2) AS monthly_revenue " + " FROM orders " + " WHERE status = 'delivered' " + " AND created_at >= '2024-01-01' AND created_at < '2025-01-01' " + " GROUP BY strftime('%Y-%m', created_at) " + ") " + "SELECT month, monthly_revenue, " + " ROUND(SUM(monthly_revenue) OVER (ORDER BY month), 2) AS running_total " + "FROM monthly " + "ORDER BY month ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_005", + difficulty="hard", + description=( + "Find products whose average rating is strictly above the average rating of all products " + "in their category. Use two CTEs: one for product-level averages and one for category-level. " + "Return product_name, category_name, product_avg_rating, category_avg_rating (both rounded to 2 dp). " + "Sort by product_avg_rating descending, then product_name ascending." + ), + sql=( + "WITH product_ratings AS ( " + " SELECT p.id AS product_id, p.name AS product_name, " + " p.category_id, c.name AS category_name, " + " ROUND(AVG(r.rating), 2) AS product_avg_rating " + " FROM products p " + " JOIN reviews r ON r.product_id = p.id " + " JOIN categories c ON c.id = p.category_id " + " GROUP BY p.id, p.name, p.category_id, c.name " + "), " + "category_ratings AS ( " + " SELECT category_id, ROUND(AVG(product_avg_rating), 2) AS category_avg_rating " + " FROM product_ratings " + " GROUP BY category_id " + ") " + "SELECT pr.product_name, pr.category_name, " + " pr.product_avg_rating, cr.category_avg_rating " + "FROM product_ratings pr " + "JOIN category_ratings cr ON cr.category_id = pr.category_id " + "WHERE pr.product_avg_rating > cr.category_avg_rating " + "ORDER BY pr.product_avg_rating DESC, pr.product_name ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_006", + difficulty="hard", + description=( + "For each customer, find their very first order date using ROW_NUMBER in a CTE. " + "Return customer_name and first_order_date. Sort by first_order_date ascending." + ), + sql=( + "WITH first_orders AS ( " + " SELECT customer_id, created_at, " + " ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY created_at ASC) AS rn " + " FROM orders " + ") " + "SELECT c.name AS customer_name, fo.created_at AS first_order_date " + "FROM customers c " + "JOIN first_orders fo ON fo.customer_id = c.id " + "WHERE fo.rn = 1 " + "ORDER BY first_order_date ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_007", + difficulty="hard", + description=( + "Rank products by total revenue generated (quantity * unit_price from order_items) " + "using RANK() window function. " + "Return product_name, total_revenue (rounded to 2 dp), revenue_rank. " + "Sort by revenue_rank ascending." + ), + sql=( + "SELECT product_name, total_revenue, revenue_rank " + "FROM ( " + " SELECT p.name AS product_name, " + " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_revenue, " + " RANK() OVER (ORDER BY SUM(oi.quantity * oi.unit_price) DESC) AS revenue_rank " + " FROM products p " + " JOIN order_items oi ON oi.product_id = p.id " + " GROUP BY p.id, p.name " + ") sub " + "ORDER BY revenue_rank ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_008", + difficulty="hard", + description=( + "For each customer, compute the running total of their order amounts ordered by " + "created_at. Return customer_name, order_date (created_at), order_amount (total_amount), " + "running_total (rounded to 2 dp). Sort by customer_name, order_date ascending." + ), + sql=( + "SELECT c.name AS customer_name, " + " o.created_at AS order_date, " + " o.total_amount AS order_amount, " + " ROUND(SUM(o.total_amount) OVER " + " (PARTITION BY c.id ORDER BY o.created_at " + " ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 2) AS running_total " + "FROM customers c " + "JOIN orders o ON o.customer_id = c.id " + "ORDER BY customer_name ASC, order_date ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_009", + difficulty="hard", + description=( + "Find customers who have placed orders in every status " + "(pending, processing, shipped, delivered, cancelled) at least once. " + "Return customer_name and status_count. Sort by customer_name ascending." + ), + sql=( + "SELECT c.name AS customer_name, COUNT(DISTINCT o.status) AS status_count " + "FROM customers c " + "JOIN orders o ON o.customer_id = c.id " + "GROUP BY c.id, c.name " + "HAVING COUNT(DISTINCT o.status) = 5 " + "ORDER BY customer_name ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_010", + difficulty="hard", + description=( + "Using a CTE, compute the total revenue per product, then rank the top 3 products " + "in each category by revenue using DENSE_RANK. Only return rows with rank <= 3. " + "Return category_name, product_name, total_revenue (rounded to 2 dp), rank_in_category. " + "Sort by category_name, rank_in_category ascending." + ), + sql=( + "WITH product_rev AS ( " + " SELECT p.id, p.name AS product_name, p.category_id, " + " c.name AS category_name, " + " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_revenue " + " FROM products p " + " JOIN categories c ON c.id = p.category_id " + " JOIN order_items oi ON oi.product_id = p.id " + " GROUP BY p.id, p.name, p.category_id, c.name " + "), " + "ranked AS ( " + " SELECT product_name, category_name, total_revenue, " + " DENSE_RANK() OVER (PARTITION BY category_id ORDER BY total_revenue DESC) AS rank_in_category " + " FROM product_rev " + ") " + "SELECT category_name, product_name, total_revenue, rank_in_category " + "FROM ranked " + "WHERE rank_in_category <= 3 " + "ORDER BY category_name ASC, rank_in_category ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_011", + difficulty="hard", + description=( + "Compute the percentage of total revenue each category contributes. " + "Use a CTE for category revenues and a window SUM for the grand total. " + "Return category_name, category_revenue, pct_of_total (rounded to 2 dp). " + "Sort by pct_of_total descending." + ), + sql=( + "WITH cat_rev AS ( " + " SELECT c.name AS category_name, " + " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS category_revenue " + " FROM categories c " + " JOIN products p ON p.category_id = c.id " + " JOIN order_items oi ON oi.product_id = p.id " + " GROUP BY c.id, c.name " + ") " + "SELECT category_name, category_revenue, " + " ROUND(100.0 * category_revenue / SUM(category_revenue) OVER (), 2) AS pct_of_total " + "FROM cat_rev " + "ORDER BY pct_of_total DESC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_012", + difficulty="hard", + description=( + "Find the customers who placed the highest number of orders in 2023. " + "Use a CTE to count per-customer orders in 2023, then apply DENSE_RANK. " + "Return customer_name, order_count_2023, rank. Sort by rank, then customer_name." + ), + sql=( + "WITH counts_2023 AS ( " + " SELECT c.name AS customer_name, COUNT(o.id) AS order_count_2023 " + " FROM customers c " + " JOIN orders o ON o.customer_id = c.id " + " WHERE o.created_at >= '2023-01-01' AND o.created_at < '2024-01-01' " + " GROUP BY c.id, c.name " + ") " + "SELECT customer_name, order_count_2023, " + " DENSE_RANK() OVER (ORDER BY order_count_2023 DESC) AS rank " + "FROM counts_2023 " + "ORDER BY rank ASC, customer_name ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_013", + difficulty="hard", + description=( + "Show a quarterly revenue breakdown for delivered orders across all years. " + "Use strftime to derive year and quarter. " + "Return year, quarter, quarterly_revenue (rounded to 2 dp), " + "and running_total_in_year (running SUM within the same year, rounded to 2 dp). " + "Sort by year, quarter ascending." + ), + sql=( + "WITH quarterly AS ( " + " SELECT strftime('%Y', created_at) AS year, " + " ((CAST(strftime('%m', created_at) AS INTEGER) - 1) / 3 + 1) AS quarter, " + " ROUND(SUM(total_amount), 2) AS quarterly_revenue " + " FROM orders " + " WHERE status = 'delivered' " + " GROUP BY year, quarter " + ") " + "SELECT year, quarter, quarterly_revenue, " + " ROUND(SUM(quarterly_revenue) OVER (PARTITION BY year ORDER BY quarter), 2) AS running_total_in_year " + "FROM quarterly " + "ORDER BY year ASC, quarter ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_014", + difficulty="hard", + description=( + "Find the top-spending customer in each country using ROW_NUMBER. " + "Return country, customer_name, total_spent (rounded to 2 dp). " + "Sort by country, total_spent descending." + ), + sql=( + "WITH customer_spend AS ( " + " SELECT c.id, c.name AS customer_name, c.country, " + " ROUND(SUM(o.total_amount), 2) AS total_spent " + " FROM customers c " + " JOIN orders o ON o.customer_id = c.id " + " GROUP BY c.id, c.name, c.country " + "), " + "ranked AS ( " + " SELECT country, customer_name, total_spent, " + " ROW_NUMBER() OVER (PARTITION BY country ORDER BY total_spent DESC) AS rn " + " FROM customer_spend " + ") " + "SELECT country, customer_name, total_spent " + "FROM ranked " + "WHERE rn = 1 " + "ORDER BY country ASC" + ), + order_sensitive=True, + ), + SQLTemplate( + id="hard_015", + difficulty="hard", + description=( + "Find products that have received both 1-star and 5-star reviews. " + "Use two CTEs: one for 1-star products, one for 5-star products, then intersect. " + "Return product_name. Sort by product_name ascending." + ), + sql=( + "WITH one_star AS ( " + " SELECT DISTINCT product_id FROM reviews WHERE rating = 1 " + "), " + "five_star AS ( " + " SELECT DISTINCT product_id FROM reviews WHERE rating = 5 " + ") " + "SELECT p.name AS product_name " + "FROM products p " + "JOIN one_star os ON os.product_id = p.id " + "JOIN five_star fs ON fs.product_id = p.id " + "ORDER BY product_name ASC" + ), + order_sensitive=True, + ), +] + +ALL_TEMPLATES: List[SQLTemplate] = EASY_TEMPLATES + MEDIUM_TEMPLATES + HARD_TEMPLATES + + +# ───────────────────────────────────────────────────────────────────────────── +# Personas +# ───────────────────────────────────────────────────────────────────────────── + +SCHEMA_CONTEXT = """ +DATABASE SCHEMA (SQLite e-commerce): + categories(id, name) + products(id, name, category_id, price, stock_quantity) + customers(id, name, email, country, tier∈{bronze|silver|gold}, created_at) + orders(id, customer_id, status∈{pending|processing|shipped|delivered|cancelled}, + created_at, total_amount) + order_items(id, order_id, product_id, quantity, unit_price) + reviews(id, product_id, customer_id, rating∈1-5, created_at) +""" + +PERSONA_SPECS = { + "ceo": ( + "You are a senior business executive. Write one SHORT, direct question in active voice, " + "as if you are asking an analyst to pull a number fast. Be terse, no fluff. " + "Use business language: 'revenue', 'customers', 'performance', not technical SQL terms." + ), + "chatty": ( + "You are a friendly but verbose non-technical employee. Write one long, conversational " + "question with filler phrases like 'Could you please tell me...', 'I was wondering if...', " + "passive voice is fine. Use everyday words like 'money' instead of 'revenue', " + "'people' instead of 'customers'." + ), + "lazy": ( + "You are typing quickly on a phone. Write an extremely short question with abbreviations, " + "lowercase letters, and minor spelling mistakes. Skip articles and punctuation where possible. " + "Example style: 'top 5 prods by sales?', 'hw many cust in usa'." + ), + "confused": ( + "You are a non-technical user who is unsure of the exact terminology. Write one question " + "using synonyms and vague language. Replace 'revenue' with 'money made', 'customers' with " + "'people' or 'users' or 'accounts', 'orders' with 'purchases' or 'transactions', " + "'tier' with 'membership level'. Include a bit of ambiguity." + ), + "analyst": ( + "You are a data analyst with technical knowledge. Write one precise, jargon-heavy question " + "using terms like 'aggregate', 'partition', 'metric', 'fiscal period', 'segmented by', " + "'cohort', 'granularity'. Be specific about column names and filters." + ), +} + + +# ───────────────────────────────────────────────────────────────────────────── +# Rule-based Augmentor +# ───────────────────────────────────────────────────────────────────────────── + +class RuleAugmentor: + """ + Applies deterministic, non-LLM transformations to a generated NL question. + Returns a list of augmented variants (may be empty if no rule applied). + """ + + SYNONYMS: Dict[str, List[str]] = { + "customers": ["clients", "users", "accounts", "shoppers", "buyers"], + "orders": ["purchases", "transactions", "sales", "bookings"], + "products": ["items", "goods", "listings", "SKUs"], + "revenue": ["sales", "income", "earnings", "money made"], + "spending": ["expenditure", "purchases", "money spent"], + "delivered": ["completed", "fulfilled", "received"], + "cancelled": ["canceled", "voided", "aborted"], + "pending": ["waiting", "unprocessed", "queued"], + "gold": ["premium", "top-tier", "VIP", "platinum"], + "silver": ["mid-tier", "standard-plus"], + "bronze": ["basic", "standard", "entry-level"], + "rating": ["score", "star rating", "review score"], + "country": ["region", "location", "geography", "nation"], + "category": ["department", "section", "type", "group"], + "price": ["cost", "value", "amount", "fee"], + "total": ["sum", "aggregate", "combined", "overall"], + "average": ["mean", "typical", "avg"], + "show": ["list", "display", "give me", "get", "fetch"], + "find": ["identify", "locate", "get", "pull", "retrieve"], + "return": ["give me", "show", "list", "provide"], + } + + def augment(self, question: str, rng: random.Random) -> Optional[str]: + words = question.split() + changed = False + result = [] + for w in words: + clean = w.lower().strip(".,?!;:") + if clean in self.SYNONYMS and rng.random() < 0.4: + replacement = rng.choice(self.SYNONYMS[clean]) + # Preserve trailing punctuation + punct = w[len(clean):] if w.lower().startswith(clean) else "" + result.append(replacement + punct) + changed = True + else: + result.append(w) + if not changed: + return None + new_q = " ".join(result) + # Capitalise first letter + return new_q[0].upper() + new_q[1:] if new_q else new_q + + +# ───────────────────────────────────────────────────────────────────────────── +# vLLM Generator +# ───────────────────────────────────────────────────────────────────────────── + +class VLLMGenerator: + """ + Async batched inference using the OpenAI-compatible vLLM endpoint. + vLLM exposes exactly the same API as OpenAI, so we reuse AsyncOpenAI. + """ + + def __init__(self, base_url: str, model: str, temperature: float = 0.8, + max_tokens: int = 256, semaphore: int = 64): + self.client = AsyncOpenAI(base_url=base_url, api_key="NONE") + self.model = model + self.temperature = temperature + self.max_tokens = max_tokens + self._sem = asyncio.Semaphore(semaphore) + + async def generate_one( + self, + system: str, + user: str, + retries: int = 3, + ) -> Optional[str]: + for attempt in range(retries): + try: + async with self._sem: + resp = await self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + text = resp.choices[0].message.content.strip() + return text if text else None + except Exception as exc: + wait = 2 ** attempt + log.warning(f"vLLM call failed (attempt {attempt+1}): {exc}. Retrying in {wait}s.") + await asyncio.sleep(wait) + return None + + async def generate_batch( + self, + requests: List[Tuple[str, str, str]], # (request_id, system, user) + ) -> Dict[str, Optional[str]]: + """ + Fire all requests concurrently (bounded by semaphore) and return a dict. + """ + async def _one(rid, sys, usr): + return rid, await self.generate_one(sys, usr) + + tasks = [_one(rid, sys, usr) for rid, sys, usr in requests] + results = await asyncio.gather(*tasks) + return {rid: text for rid, text in results} + + +# ───────────────────────────────────────────────────────────────────────────── +# Data Factory +# ───────────────────────────────────────────────────────────────────────────── + +@dataclass +class DataPoint: + id: str + difficulty: str + persona: str + question: str + sql: str + db_result_ok: bool + augmented: bool + + def to_training_prompt(self, system_prompt: str) -> Dict[str, Any]: + """ + Return the dict structure expected by train.py / SFT pipelines. + Includes both the raw fields and a formatted 'messages' list. + """ + user_content = ( + f"SCHEMA:\n{SCHEMA_CONTEXT}\n\nQUESTION: {self.question}" + ) + return { + **asdict(self), + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content}, + {"role": "assistant", "content": self.sql}, + ], + } + + +SYSTEM_PROMPT = ( + "You are an expert SQL analyst working with a SQLite e-commerce database. " + "Write a single SELECT query. Output ONLY the SQL query, nothing else. No markdown." +) + + +class DataFactory: + def __init__( + self, + generator: VLLMGenerator, + validator: SQLiteValidator, + augmentor: RuleAugmentor, + personas_per_template: int = 5, + aug_rounds: int = 2, + seed: int = 42, + ): + self.generator = generator + self.validator = validator + self.augmentor = augmentor + self.personas_per_template = personas_per_template + self.aug_rounds = aug_rounds + self.rng = random.Random(seed) + + # ── Step 1: Validate all template SQLs ─────────────────────────────────── + + def validate_templates(self) -> List[SQLTemplate]: + log.info("Validating all SQL templates against seeded DB...") + valid = [] + failed = [] + for t in ALL_TEMPLATES: + ok, err = self.validator.validate(t.sql) + if ok: + valid.append(t) + else: + failed.append((t.id, err)) + if failed: + log.error(f"FAILED templates (will be skipped): {failed}") + log.info(f"Templates validated: {len(valid)} ok, {len(failed)} failed.") + return valid + + # ── Step 2: Build generation requests ──────────────────────────────────── + + def _build_requests( + self, + templates: List[SQLTemplate], + persona_names: List[str], + ) -> List[Tuple[str, str, str]]: + """ + Returns a flat list of (request_id, system_prompt, user_prompt) tuples. + """ + requests = [] + for t in templates: + chosen_personas = ( + persona_names + if self.personas_per_template >= len(PERSONA_SPECS) + else self.rng.sample(persona_names, self.personas_per_template) + ) + for persona in chosen_personas: + rid = f"{t.id}__{persona}" + system = ( + f"{PERSONA_SPECS[persona]}\n\n" + "Output ONLY the natural language question. " + "No explanation, no SQL, no preamble, no quotes around the question." + ) + user = ( + f"{SCHEMA_CONTEXT}\n" + f"The SQL query that answers this question is:\n{t.sql}\n\n" + f"Write ONE natural-language question that a {persona.upper()} user " + f"would ask to get this exact result." + ) + requests.append((rid, system, user)) + return requests + + # ── Step 3: Post-process a generated question ───────────────────────────── + + @staticmethod + def _clean(text: str) -> str: + """Strip quotes, markdown, leading numbers, trailing newlines.""" + text = text.strip() + # Remove leading numbering like "1. " or "Q: " + text = re.sub(r'^[\d]+[\.\)]\s+', '', text) + text = re.sub(r'^[Qq]:\s*', '', text) + # Strip surrounding quotes + if (text.startswith('"') and text.endswith('"')) or \ + (text.startswith("'") and text.endswith("'")): + text = text[1:-1].strip() + # Collapse multiple whitespace + text = re.sub(r'\s+', ' ', text) + return text + + # ── Main pipeline ───────────────────────────────────────────────────────── + + async def run( + self, + output_path: str, + checkpoint_path: str, + batch_size: int = 64, + ) -> None: + # -- Validate templates + templates = self.validate_templates() + + # -- Load checkpoint + done_ids: set = set() + if os.path.exists(checkpoint_path): + with open(checkpoint_path) as f: + done_ids = set(json.loads(line)["id"] for line in f if line.strip()) + log.info(f"Resuming: {len(done_ids)} examples already generated.") + + persona_names = list(PERSONA_SPECS.keys())[: self.personas_per_template] + + all_requests = self._build_requests(templates, persona_names) + # Filter already done + pending = [r for r in all_requests if r[0] not in done_ids] + log.info(f"Total requests to generate: {len(pending)}") + + # -- Build template lookup + tmpl_lookup: Dict[str, SQLTemplate] = {t.id: t for t in templates} + + stats = {"generated": 0, "invalid_llm": 0, "augmented": 0} + + out_f = open(output_path, "a") + ckpt_f = open(checkpoint_path, "a") + + try: + for i in tqdm(range(0, len(pending), batch_size), desc="Batches"): + batch = pending[i: i + batch_size] + results = await self.generator.generate_batch(batch) + + for rid, raw_text in results.items(): + tmpl_id, persona = rid.split("__", 1) + tmpl = tmpl_lookup[tmpl_id] + + if not raw_text: + stats["invalid_llm"] += 1 + continue + + question = self._clean(raw_text) + if len(question) < 8: + stats["invalid_llm"] += 1 + continue + + # SQL already validated; no need to re-run for NL variants + dp = DataPoint( + id=rid, + difficulty=tmpl.difficulty, + persona=persona, + question=question, + sql=tmpl.sql, + db_result_ok=True, + augmented=False, + ) + record = dp.to_training_prompt(SYSTEM_PROMPT) + line = json.dumps(record, ensure_ascii=False) + out_f.write(line + "\n") + ckpt_f.write(line + "\n") + stats["generated"] += 1 + + # -- Rule augmentation rounds + for aug_i in range(self.aug_rounds): + aug_q = self.augmentor.augment(question, self.rng) + if aug_q and aug_q != question: + aug_dp = DataPoint( + id=f"{rid}__aug{aug_i}", + difficulty=tmpl.difficulty, + persona=persona, + question=aug_q, + sql=tmpl.sql, + db_result_ok=True, + augmented=True, + ) + aug_record = aug_dp.to_training_prompt(SYSTEM_PROMPT) + aug_line = json.dumps(aug_record, ensure_ascii=False) + out_f.write(aug_line + "\n") + ckpt_f.write(aug_line + "\n") + stats["augmented"] += 1 + + out_f.flush() + ckpt_f.flush() + + finally: + out_f.close() + ckpt_f.close() + + log.info( + f"Done. Generated={stats['generated']} " + f"Augmented={stats['augmented']} " + f"LLM failures={stats['invalid_llm']}" + ) + log.info(f"Output: {output_path}") + + +# ───────────────────────────────────────────────────────────────────────────── +# CLI +# ───────────────────────────────────────────────────────────────────────────── + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="NL2SQL Synthetic Data Factory — H100 + vLLM", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p.add_argument("--vllm-url", default="http://localhost:8001/v1", + help="Base URL of the running vLLM server.") + p.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct", + help="Model name as registered in the vLLM server.") + p.add_argument("--output", default="nl2sql_train.jsonl", + help="Path to write the final JSONL dataset.") + p.add_argument("--checkpoint",default="nl2sql_checkpoint.jsonl", + help="Path for the checkpoint file (enables resume on crash).") + p.add_argument("--personas-per-template", type=int, default=5, + help="Number of persona variants to generate per SQL template (max 5).") + p.add_argument("--aug-rounds", type=int, default=2, + help="Number of rule-based augmentation rounds per generated question.") + p.add_argument("--batch-size", type=int, default=64, + help="Concurrent vLLM requests per batch (tune based on GPU memory).") + p.add_argument("--temperature", type=float, default=0.85, + help="Sampling temperature for vLLM (higher = more diverse).") + p.add_argument("--max-tokens", type=int, default=200, + help="Max tokens for each generated question.") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--validate-only", action="store_true", + help="Only validate SQL templates, do not generate data.") + return p.parse_args() + + +async def main() -> None: + args = parse_args() + + # Build DB + validator + conn = build_db() + validator = SQLiteValidator(conn) + + if args.validate_only: + valid = [t for t in ALL_TEMPLATES if validator.validate(t.sql)[0]] + invalid = [t for t in ALL_TEMPLATES if not validator.validate(t.sql)[0]] + print(f"\n✅ Valid: {len(valid)}") + print(f"❌ Invalid: {len(invalid)}") + for t in invalid: + _, err = validator.validate(t.sql) + print(f" {t.id}: {err}") + return + + # Build pipeline components + generator = VLLMGenerator( + base_url=args.vllm_url, + model=args.model, + temperature=args.temperature, + max_tokens=args.max_tokens, + semaphore=args.batch_size, + ) + augmentor = RuleAugmentor() + + factory = DataFactory( + generator=generator, + validator=validator, + augmentor=augmentor, + personas_per_template=min(args.personas_per_template, len(PERSONA_SPECS)), + aug_rounds=args.aug_rounds, + seed=args.seed, + ) + + await factory.run( + output_path=args.output, + checkpoint_path=args.checkpoint, + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/data_factory/generator.py b/data_factory/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..43e16ef53c46a5da4f3939df69acdbac5304b3a8 --- /dev/null +++ b/data_factory/generator.py @@ -0,0 +1,410 @@ +""" +data_factory/generator.py +========================== +vLLM-based Natural Language question generator for H100. + +This module uses a large LLM (Llama-3-70B or Qwen-72B) served via vLLM +to generate diverse, persona-based natural language paraphrases of the +canonical NL questions in our template library. + +KEY DESIGN: The LLM generates ONLY natural language questions. + SQL is NEVER touched by the LLM. + This guarantees zero SQL errors in the final dataset. + +Persona descriptions: + ceo - Direct, short, active voice. Business executive style. + chatty - Conversational, verbose, passive voice. + lazy_typist - Short, abbreviations, possible informal grammar. + non_techie - Plain English, avoids SQL/tech jargon, uses synonyms. + analyst - Technical, precise, jargon-heavy. + +Usage (on H100 cluster): + python -m data_factory.generator --templates-per-chunk 20 --n-variants 10 +""" + +from __future__ import annotations + +import json +import logging +import time +from typing import Iterator, Optional + +logger = logging.getLogger(__name__) + + +# ───────────────────────────────────────────────────────────────────────────── +# PERSONA SYSTEM PROMPTS +# ───────────────────────────────────────────────────────────────────────────── + +PERSONA_SYSTEM_PROMPTS: dict[str, str] = { + + "ceo": ( + "You are a busy C-level executive who communicates in short, punchy, " + "direct sentences. You use active voice, skip filler words, and get " + "straight to the point. You are asking a data analyst for information." + ), + + "chatty": ( + "You are a friendly, conversational person who likes to be thorough " + "and explain things fully. You use passive voice sometimes, add context, " + "and ask questions in a relaxed, detailed way. You are not technical." + ), + + "lazy_typist": ( + "You type quickly and informally. You use abbreviations (e.g. 'pls', " + "'lmk', 'asap'), lowercase, minimal punctuation, and sometimes omit " + "words. You get your meaning across without perfect grammar." + ), + + "non_techie": ( + "You have no database or SQL knowledge. You use everyday English words " + "instead of technical terms. For example, you say 'customers' not 'rows', " + "'most expensive' not 'highest price', 'total money' not 'sum'. " + "You describe what you want to see, not how to get it." + ), + + "analyst": ( + "You are a data scientist or BI analyst who is precise and technical. " + "You use terms like 'aggregate', 'partition', 'granularity', 'distinct', " + "'filter predicate', 'ranked by metric'. Your questions are precise and unambiguous." + ), +} + + +# ───────────────────────────────────────────────────────────────────────────── +# PROMPT BUILDER +# ───────────────────────────────────────────────────────────────────────────── + +def build_generation_prompt( + canonical_nl: str, + description: str, + persona: str, + schema_context: str, + n_variants: int = 5, +) -> list[dict[str, str]]: + """ + Build a chat-format prompt asking the LLM to rephrase the canonical NL + question in the style of the given persona. + + Parameters + ---------- + canonical_nl : The base NL question from the template. + description : One-line SQL description (gives the LLM additional context). + persona : One of the 5 persona keys. + schema_context : The compact schema string for the domain. + n_variants : How many rephrased questions to generate. + + Returns + ------- + list[dict] Chat messages in [{"role": ..., "content": ...}] format. + """ + persona_desc = PERSONA_SYSTEM_PROMPTS[persona] + + system = ( + "You are a data labelling specialist. Your task is to rephrase a database " + "question in a specific communication style (persona). The rephrased questions " + "must preserve the EXACT same intent and required information as the original — " + "do not change what data is being asked for, only how it is expressed.\n\n" + f"PERSONA: {persona_desc}\n\n" + "OUTPUT FORMAT: Return ONLY a valid JSON array of strings. " + "No preamble, no markdown, no extra keys. Example: " + '["question 1", "question 2", "question 3"]' + ) + + user = ( + f"DATABASE CONTEXT:\n{schema_context}\n\n" + f"WHAT THE QUERY DOES: {description}\n\n" + f"CANONICAL QUESTION: {canonical_nl}\n\n" + f"Generate {n_variants} different ways a person with the persona described " + f"above would ask this same question. The meaning must stay identical." + ) + + return [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ] + + +# ───────────────────────────────────────────────────────────────────────────── +# RESPONSE PARSER +# ───────────────────────────────────────────────────────────────────────────── + +def parse_llm_response(raw_text: str) -> list[str]: + """ + Extract a list of strings from the LLM's JSON response. + Handles common failures: markdown fences, trailing commas, extra text. + + Returns an empty list if parsing fails completely. + """ + text = raw_text.strip() + + # Strip markdown fences if present + if text.startswith("```"): + lines = text.split("\n") + text = "\n".join(l for l in lines if not l.strip().startswith("```")).strip() + + # Find the JSON array boundaries + start = text.find("[") + end = text.rfind("]") + if start == -1 or end == -1 or end <= start: + logger.warning("LLM response missing JSON array brackets: %s", text[:100]) + return [] + + json_str = text[start:end + 1] + + # Fix trailing commas before ] (common LLM mistake) + json_str = json_str.rstrip() + json_str = json_str.replace(",]", "]").replace(", ]", "]") + + try: + parsed = json.loads(json_str) + if not isinstance(parsed, list): + return [] + # Filter to only non-empty strings + return [s.strip() for s in parsed if isinstance(s, str) and s.strip()] + except json.JSONDecodeError as exc: + logger.warning("JSON parse error: %s | text: %s", exc, json_str[:200]) + return [] + + +# ───────────────────────────────────────────────────────────────────────────── +# VLLM INTERFACE +# ───────────────────────────────────────────────────────────────────────────── + +class VLLMGenerator: + """ + Wrapper around a running vLLM server for high-throughput NL generation. + + Supports two modes: + online : Calls a running vLLM OpenAI-compatible API server. + offline : Uses vllm.LLM directly (loads model in-process, H100 recommended). + + For H100 cluster usage, prefer 'offline' mode with tensor_parallel_size=4 + to saturate all 4 H100s for maximum throughput. + """ + + def __init__( + self, + model_name: str, + mode: str = "offline", + tensor_parallel_size: int = 4, + gpu_memory_utilization: float = 0.90, + max_model_len: int = 4096, + # Online mode only + api_base: str = "http://localhost:8000/v1", + api_key: str = "EMPTY", + ) -> None: + self.model_name = model_name + self.mode = mode + self._llm = None + self._client = None + + if mode == "offline": + self._init_offline(tensor_parallel_size, gpu_memory_utilization, max_model_len) + elif mode == "online": + self._init_online(api_base, api_key) + else: + raise ValueError(f"Unknown mode: {mode!r}. Use 'offline' or 'online'.") + + def _init_offline( + self, + tensor_parallel_size: int, + gpu_memory_utilization: float, + max_model_len: int, + ) -> None: + """Load vLLM engine in-process (best for H100 cluster).""" + try: + from vllm import LLM, SamplingParams + self._LLM = LLM + self._SamplingParams = SamplingParams + except ImportError: + raise ImportError( + "vLLM not installed. Run: pip install vllm\n" + "For H100: pip install vllm --extra-index-url https://download.pytorch.org/whl/cu124" + ) + + logger.info("Loading model %s with %d GPUs (offline mode)...", self.model_name, tensor_parallel_size) + t0 = time.time() + self._llm = self._LLM( + model=self.model_name, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + max_model_len=max_model_len, + dtype="bfloat16", + trust_remote_code=True, + ) + logger.info("Model loaded in %.1f seconds.", time.time() - t0) + + def _init_online(self, api_base: str, api_key: str) -> None: + """Use OpenAI-compatible vLLM server (for distributed setups).""" + try: + from openai import OpenAI + self._client = OpenAI(base_url=api_base, api_key=api_key) + except ImportError: + raise ImportError("pip install openai") + logger.info("Connected to vLLM server at %s", api_base) + + def generate_batch( + self, + prompts: list[list[dict[str, str]]], + temperature: float = 0.85, + max_new_tokens: int = 300, + ) -> list[str]: + """ + Generate responses for a batch of chat prompts. + + Parameters + ---------- + prompts : List of chat message lists (one per item in batch). + temperature : Sampling temperature. Higher = more diverse. + max_new_tokens : Max tokens per response. + + Returns + ------- + list[str] Raw text response per prompt (same length as input). + """ + if self.mode == "offline": + return self._generate_offline(prompts, temperature, max_new_tokens) + else: + return self._generate_online(prompts, temperature, max_new_tokens) + + def _generate_offline( + self, + prompts: list[list[dict]], + temperature: float, + max_new_tokens: int, + ) -> list[str]: + """vLLM offline batched generation — maximises H100 throughput.""" + from vllm import SamplingParams + + sampling = SamplingParams( + temperature=temperature, + max_tokens=max_new_tokens, + stop=["", "<|eot_id|>"], # Llama-3 stop tokens + ) + + # Convert chat messages to tokenised prompt strings using the model's template + tokenizer = self._llm.get_tokenizer() + formatted_prompts: list[str] = [] + for msgs in prompts: + if hasattr(tokenizer, "apply_chat_template"): + text = tokenizer.apply_chat_template( + msgs, tokenize=False, add_generation_prompt=True + ) + else: + # Fallback: simple concatenation + text = "\n".join( + f"<|{m['role']}|>\n{m['content']}" for m in msgs + ) + formatted_prompts.append(text) + + outputs = self._llm.generate(formatted_prompts, sampling) + return [o.outputs[0].text for o in outputs] + + def _generate_online( + self, + prompts: list[list[dict]], + temperature: float, + max_new_tokens: int, + ) -> list[str]: + """Sequential generation via OpenAI-compatible API (fallback / debugging).""" + results = [] + for msgs in prompts: + try: + resp = self._client.chat.completions.create( + model=self.model_name, + messages=msgs, + temperature=temperature, + max_tokens=max_new_tokens, + ) + results.append(resp.choices[0].message.content or "") + except Exception as exc: + logger.warning("API call failed: %s", exc) + results.append("") + return results + + +# ───────────────────────────────────────────────────────────────────────────── +# HIGH-LEVEL GENERATION LOOP +# ───────────────────────────────────────────────────────────────────────────── + +def generate_persona_variants_batch( + templates_subset: list[dict], + generator: VLLMGenerator, + personas: list[str], + n_variants_per_persona: int = 5, + batch_size: int = 64, + temperature: float = 0.85, + max_new_tokens: int = 300, +) -> Iterator[dict]: + """ + For each template × persona combination, generate `n_variants_per_persona` + NL question variants using the LLM. + + Yields dicts: + { + "template_idx": int, + "persona": str, + "nl_variants": list[str], # successfully parsed NL questions + } + + Parameters + ---------- + templates_subset : List of template dicts (from templates.py). + generator : VLLMGenerator instance. + personas : List of persona keys to use. + n_variants_per_persona : How many NL variants per (template, persona) pair. + batch_size : How many LLM calls to batch together. + temperature : Sampling temperature. + max_new_tokens : Max tokens for LLM response (should be ~300 for JSON array). + """ + from data_factory.schemas import SCHEMA_CONTEXT + + # Build all (template_idx, persona) prompt pairs + all_jobs: list[tuple[int, str, list[dict]]] = [] + + for t_idx, template in enumerate(templates_subset): + schema_ctx = SCHEMA_CONTEXT[template["domain"]] + for persona in personas: + prompt = build_generation_prompt( + canonical_nl=template["base_nl"], + description=template["description"], + persona=persona, + schema_context=schema_ctx, + n_variants=n_variants_per_persona, + ) + all_jobs.append((t_idx, persona, prompt)) + + total_jobs = len(all_jobs) + logger.info("Starting LLM generation: %d jobs (templates × personas).", total_jobs) + + # Process in batches + for batch_start in range(0, total_jobs, batch_size): + batch = all_jobs[batch_start: batch_start + batch_size] + prompts = [job[2] for job in batch] + + t0 = time.time() + raw_responses = generator.generate_batch( + prompts, temperature=temperature, max_new_tokens=max_new_tokens + ) + elapsed = time.time() - t0 + logger.info( + "Batch %d-%d completed in %.1fs (%.1f jobs/s).", + batch_start, batch_start + len(batch), elapsed, len(batch) / max(elapsed, 0.001) + ) + + for (t_idx, persona, _), raw in zip(batch, raw_responses): + nl_variants = parse_llm_response(raw) + if not nl_variants: + logger.debug( + "Empty parse for template_idx=%d persona=%s. raw=%s", + t_idx, persona, raw[:100] + ) + # Fall back to the canonical NL rather than losing this entry + nl_variants = [templates_subset[t_idx]["base_nl"]] + + yield { + "template_idx": t_idx, + "persona": persona, + "nl_variants": nl_variants, + } diff --git a/data_factory/pipeline.py b/data_factory/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..b921f24ce2727fd77d4e377fbbdf60ff591dc22a --- /dev/null +++ b/data_factory/pipeline.py @@ -0,0 +1,443 @@ +""" +data_factory/pipeline.py +========================= +Master orchestration pipeline for the NL2SQL Synthetic Data Factory. + +This module ties together: + 1. Template library (66 verified SQL templates across 4 domains) + 2. Rule-based NL augmentation (augmentor.py) + 3. vLLM persona-based NL generation (generator.py) + 4. SQL execution validation (validator.py) + 5. Output serialisation (JSONL + Parquet) + +Run modes: + --mode base : Only uses template base_nl + rule augmentation (no GPU required) + --mode full : base + vLLM persona generation (requires H100) + +Output dataset format (JSONL, one record per line): + { + "prompt": [{"role": "system", ...}, {"role": "user", ...}], + "sql": "SELECT ...", + "metadata": { "domain", "difficulty", "persona", ... } + } + +This format is directly loadable by: + datasets.load_dataset("json", data_files="output/train.jsonl") +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import random +import time +from pathlib import Path +from typing import Any, Iterator, Optional + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger("pipeline") + + +# ───────────────────────────────────────────────────────────────────────────── +# HELPERS +# ───────────────────────────────────────────────────────────────────────────── + +def _ensure_dirs(*dirs: Path) -> None: + for d in dirs: + d.mkdir(parents=True, exist_ok=True) + + +def _write_jsonl(records: list[dict], path: Path) -> None: + with open(path, "w", encoding="utf-8") as f: + for rec in records: + f.write(json.dumps(rec, ensure_ascii=False) + "\n") + logger.info("Wrote %d records to %s", len(records), path) + + +def _write_parquet(records: list[dict], path: Path) -> None: + try: + import pandas as pd + df = pd.DataFrame(records) + df.to_parquet(path, index=False, engine="pyarrow", compression="snappy") + logger.info("Wrote %d records to %s (Parquet)", len(records), path) + except ImportError: + logger.warning("pandas/pyarrow not installed — skipping Parquet output.") + + +def _train_val_test_split( + records: list[dict], + train_frac: float = 0.90, + val_frac: float = 0.05, + seed: int = 42, +) -> tuple[list[dict], list[dict], list[dict]]: + """ + Stratified split by (domain, difficulty) to ensure all combinations + are represented in every split. + """ + rng = random.Random(seed) + from collections import defaultdict + + buckets: dict[str, list[dict]] = defaultdict(list) + for rec in records: + key = f"{rec['metadata']['domain']}_{rec['metadata']['difficulty']}" + buckets[key].append(rec) + + train, val, test = [], [], [] + for key, bucket in buckets.items(): + rng.shuffle(bucket) + n = len(bucket) + n_train = max(1, int(n * train_frac)) + n_val = max(1, int(n * val_frac)) + train.extend(bucket[:n_train]) + val.extend(bucket[n_train:n_train + n_val]) + test.extend(bucket[n_train + n_val:]) + + rng.shuffle(train) + rng.shuffle(val) + rng.shuffle(test) + return train, val, test + + +# ───────────────────────────────────────────────────────────────────────────── +# PHASE 1: BASE + RULE AUGMENTATION (no GPU required) +# ───────────────────────────────────────────────────────────────────────────── + +def run_base_pipeline( + templates: list, + n_augmentations: int = 5, + seed: int = 42, +) -> list[dict]: + """ + Generate training records from: + (a) the canonical base_nl of each template + (b) rule-based augmented NL variants + + Returns a list of training dicts (ready to write to JSONL). + """ + from data_factory.augmentor import augment_nl + from data_factory.validator import SQLValidator, build_record + from data_factory.schemas import SCHEMA_MAP + + # Build one validator per domain (reuse connection across templates) + validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP} + records: list[dict] = [] + + for t_idx, template in enumerate(templates): + v = validators[template["domain"]] + + # (a) Canonical base_nl + rec = build_record( + template=template, + template_idx=t_idx, + nl_question=template["base_nl"], + persona="canonical", + source="template_base", + validator=v, + ) + if rec: + records.append(rec.to_training_dict()) + + # (b) Rule-augmented variants + augmented = augment_nl( + nl_question=template["base_nl"], + n=n_augmentations, + seed=seed + t_idx, + ) + for nl_variant in augmented: + rec = build_record( + template=template, + template_idx=t_idx, + nl_question=nl_variant, + persona="rule_augmented", + source="rule_augmented", + validator=v, + ) + if rec: + records.append(rec.to_training_dict()) + + for v in validators.values(): + v.close() + + logger.info("Base pipeline: %d records generated from %d templates.", len(records), len(templates)) + return records + + +# ───────────────────────────────────────────────────────────────────────────── +# PHASE 2: vLLM PERSONA GENERATION (H100 required) +# ───────────────────────────────────────────────────────────────────────────── + +def run_vllm_pipeline( + templates: list, + generator, # VLLMGenerator instance + personas: list[str], + n_variants_per_persona: int = 10, + batch_size: int = 64, + temperature: float = 0.85, + max_new_tokens: int = 350, + seed: int = 42, +) -> list[dict]: + """ + Generate additional NL variants using the LLM, then validate SQL. + + Returns a list of training dicts. + """ + from data_factory.generator import generate_persona_variants_batch + from data_factory.validator import SQLValidator, build_record + from data_factory.schemas import SCHEMA_MAP + + validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP} + records: list[dict] = [] + + gen_iter = generate_persona_variants_batch( + templates_subset=templates, + generator=generator, + personas=personas, + n_variants_per_persona=n_variants_per_persona, + batch_size=batch_size, + temperature=temperature, + max_new_tokens=max_new_tokens, + ) + + for job_result in gen_iter: + t_idx = job_result["template_idx"] + persona = job_result["persona"] + template = templates[t_idx] + v = validators[template["domain"]] + + for nl_variant in job_result["nl_variants"]: + rec = build_record( + template=template, + template_idx=t_idx, + nl_question=nl_variant, + persona=persona, + source="vllm_persona", + validator=v, + ) + if rec: + records.append(rec.to_training_dict()) + + for v in validators.values(): + v.close() + + logger.info("vLLM pipeline: %d records generated.", len(records)) + return records + + +# ───────────────────────────────────────────────────────────────────────────── +# CHECKPOINT UTILITIES +# ───────────────────────────────────────────────────────────────────────────── + +def save_checkpoint(records: list[dict], checkpoint_dir: Path, name: str) -> Path: + path = checkpoint_dir / f"{name}.jsonl" + _write_jsonl(records, path) + return path + + +def load_checkpoint(checkpoint_dir: Path, name: str) -> Optional[list[dict]]: + path = checkpoint_dir / f"{name}.jsonl" + if not path.exists(): + return None + records = [] + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + records.append(json.loads(line)) + logger.info("Loaded %d records from checkpoint %s", len(records), path) + return records + + +# ───────────────────────────────────────────────────────────────────────────── +# DATASET STATISTICS +# ───────────────────────────────────────────────────────────────────────────── + +def print_dataset_stats(records: list[dict]) -> None: + from collections import Counter + domains = Counter(r["metadata"]["domain"] for r in records) + diffs = Counter(r["metadata"]["difficulty"] for r in records) + personas = Counter(r["metadata"]["persona"] for r in records) + sources = Counter(r["metadata"]["source"] for r in records) + + print("\n" + "=" * 55) + print(f" DATASET STATISTICS ({len(records):,} total records)") + print("=" * 55) + print("\nBy Domain:") + for k, v in sorted(domains.items()): + print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)") + print("\nBy Difficulty:") + for k, v in sorted(diffs.items()): + print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)") + print("\nBy Persona/Source:") + for k, v in sorted(personas.items()): + print(f" {k:20s}: {v:6,}") + print("\nBy Source:") + for k, v in sorted(sources.items()): + print(f" {k:20s}: {v:6,}") + print("=" * 55 + "\n") + + +# ───────────────────────────────────────────────────────────────────────────── +# MAIN ENTRY POINT +# ───────────────────────────────────────────────────────────────────────────── + +def main() -> None: + parser = argparse.ArgumentParser( + description="NL2SQL Synthetic Data Factory — generates verified training data." + ) + parser.add_argument( + "--mode", choices=["base", "full"], default="base", + help="base = rule augmentation only (no GPU). full = + vLLM on H100.", + ) + parser.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct", + help="HuggingFace model name for vLLM (full mode only).") + parser.add_argument("--tensor-parallel", type=int, default=4, + help="Tensor parallel size for vLLM (number of H100s).") + parser.add_argument("--n-rule-augments", type=int, default=5, + help="Number of rule-based NL augmentations per template.") + parser.add_argument("--n-persona-variants", type=int, default=10, + help="Number of vLLM NL variants per (template, persona) pair.") + parser.add_argument("--batch-size", type=int, default=64, + help="vLLM batch size (larger = faster on H100).") + parser.add_argument("--temperature", type=float, default=0.85, + help="Sampling temperature for vLLM generation.") + parser.add_argument("--output-dir", type=str, default="generated_data/output", + help="Directory to write final dataset files.") + parser.add_argument("--checkpoint-dir", type=str, default="generated_data/checkpoints", + help="Directory for intermediate checkpoints.") + parser.add_argument("--seed", type=int, default=42, help="Global random seed.") + parser.add_argument("--no-parquet", action="store_true", + help="Skip Parquet output (write only JSONL).") + parser.add_argument("--resume", action="store_true", + help="Resume from latest checkpoint if available.") + parser.add_argument("--domains", nargs="+", + choices=["ecommerce","healthcare","finance","hr"], + default=["ecommerce","healthcare","finance","hr"], + help="Domains to include (default: all 4).") + parser.add_argument("--difficulties", nargs="+", + choices=["easy","medium","hard"], + default=["easy","medium","hard"], + help="Difficulty levels to include (default: all 3).") + args = parser.parse_args() + + output_dir = Path(args.output_dir) + checkpoint_dir = Path(args.checkpoint_dir) + _ensure_dirs(output_dir, checkpoint_dir) + + # ── Load templates ───────────────────────────────────────────────────── + from data_factory.templates import ALL_TEMPLATES + + templates = [ + t for t in ALL_TEMPLATES + if t["domain"] in args.domains and t["difficulty"] in args.difficulties + ] + logger.info("Loaded %d templates (domains=%s, difficulties=%s).", + len(templates), args.domains, args.difficulties) + + # ── Phase 1: Base + rule augmentation ───────────────────────────────── + all_records: list[dict] = [] + + ckpt_base = load_checkpoint(checkpoint_dir, "phase1_base") if args.resume else None + if ckpt_base is not None: + all_records.extend(ckpt_base) + logger.info("Resumed Phase 1 from checkpoint (%d records).", len(ckpt_base)) + else: + logger.info("=== Phase 1: Base + Rule Augmentation ===") + base_records = run_base_pipeline( + templates=templates, + n_augmentations=args.n_rule_augments, + seed=args.seed, + ) + all_records.extend(base_records) + save_checkpoint(base_records, checkpoint_dir, "phase1_base") + + # ── Phase 2: vLLM persona generation (full mode only) ───────────────── + if args.mode == "full": + ckpt_vllm = load_checkpoint(checkpoint_dir, "phase2_vllm") if args.resume else None + if ckpt_vllm is not None: + all_records.extend(ckpt_vllm) + logger.info("Resumed Phase 2 from checkpoint (%d records).", len(ckpt_vllm)) + else: + logger.info("=== Phase 2: vLLM Persona Generation ===") + + from data_factory.generator import VLLMGenerator + from data_factory.config import PERSONAS + + generator = VLLMGenerator( + model_name=args.model, + mode="offline", + tensor_parallel_size=args.tensor_parallel, + gpu_memory_utilization=0.90, + ) + + vllm_records = run_vllm_pipeline( + templates=templates, + generator=generator, + personas=PERSONAS, + n_variants_per_persona=args.n_persona_variants, + batch_size=args.batch_size, + temperature=args.temperature, + max_new_tokens=350, + seed=args.seed, + ) + all_records.extend(vllm_records) + save_checkpoint(vllm_records, checkpoint_dir, "phase2_vllm") + + # ── Deduplication ────────────────────────────────────────────────────── + logger.info("Deduplicating %d records...", len(all_records)) + seen_nl: set[str] = set() + deduped: list[dict] = [] + for rec in all_records: + nl = rec["prompt"][1]["content"] # user message contains the NL question + if nl not in seen_nl: + seen_nl.add(nl) + deduped.append(rec) + logger.info("After dedup: %d unique records (removed %d duplicates).", + len(deduped), len(all_records) - len(deduped)) + + # ── Statistics ───────────────────────────────────────────────────────── + print_dataset_stats(deduped) + + # ── Train / Val / Test split ─────────────────────────────────────────── + train, val, test = _train_val_test_split(deduped, seed=args.seed) + logger.info("Split: train=%d | val=%d | test=%d", len(train), len(val), len(test)) + + # ── Write outputs ───────────────────────────────────────────────────── + _write_jsonl(train, output_dir / "train.jsonl") + _write_jsonl(val, output_dir / "val.jsonl") + _write_jsonl(test, output_dir / "test.jsonl") + + if not args.no_parquet: + _write_parquet(train, output_dir / "train.parquet") + _write_parquet(val, output_dir / "val.parquet") + _write_parquet(test, output_dir / "test.parquet") + + # ── Write dataset card ───────────────────────────────────────────────── + card = { + "name": "NL2SQL-Bench Synthetic Training Dataset", + "version": "1.0", + "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "total_records": len(deduped), + "splits": {"train": len(train), "val": len(val), "test": len(test)}, + "domains": args.domains, + "difficulties": args.difficulties, + "mode": args.mode, + "seed": args.seed, + "sql_guarantee": ( + "Every SQL in this dataset was human-authored and execution-validated " + "against a seeded SQLite database. Zero LLM-generated SQL." + ), + } + with open(output_dir / "dataset_card.json", "w") as f: + json.dump(card, f, indent=2) + + logger.info("=== Done! Dataset written to %s ===", output_dir) + + +if __name__ == "__main__": + main() diff --git a/data_factory/run_data_factory.py b/data_factory/run_data_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff7224cc886f20198b5d92ee3cf07b190dc1518 --- /dev/null +++ b/data_factory/run_data_factory.py @@ -0,0 +1,260 @@ +""" +run_data_factory.py +==================== +Entry point and smoke-test runner for the NL2SQL Data Factory. + +Run this FIRST before running the full pipeline to verify: + 1. All 66 SQL templates execute without errors + 2. Rule augmentation produces diverse NL variants + 3. Validators correctly accept/reject queries + 4. Base pipeline generates well-formed JSONL records + +Usage: + # Smoke test only (fast, ~10 seconds) + python run_data_factory.py --smoke-test + + # Base mode (no GPU, generates all rule-augmented records) + python run_data_factory.py --mode base + + # Full mode (H100 required) + python run_data_factory.py --mode full --model meta-llama/Meta-Llama-3-70B-Instruct --tensor-parallel 4 + + # Preview what the dataset looks like + python run_data_factory.py --smoke-test --show-samples 3 +""" + +from __future__ import annotations + +import argparse +import json +import sys +import textwrap +from pathlib import Path + +# Allow running from project root +sys.path.insert(0, str(Path(__file__).parent)) + + +# ───────────────────────────────────────────────────────────────────────────── +# SMOKE TEST +# ───────────────────────────────────────────────────────────────────────────── + +def run_smoke_test(show_samples: int = 0) -> bool: + print("\n" + "=" * 60) + print(" NL2SQL DATA FACTORY — SMOKE TEST") + print("=" * 60) + + all_passed = True + + # 1. Template validation + print("\n[1/4] Validating all SQL templates against seeded data...") + from data_factory.templates import ALL_TEMPLATES, template_stats + from data_factory.validator import validate_all_templates + + stats = template_stats() + result = validate_all_templates(ALL_TEMPLATES) + + print(f" Templates: {stats}") + print(f" Validation: {result['passed']}/{result['total']} passed", end="") + + if result["failed"]: + print(f" ← {result['failed']} FAILURES:") + for f in result["failures"]: + print(f" [{f['domain']}] {f['sql']}... → {f['error']}") + all_passed = False + else: + print(" ✓") + + # 2. Rule augmentation + print("\n[2/4] Testing rule-based augmentation...") + from data_factory.augmentor import augment_nl + + test_nls = [ + "List all gold-tier customers ordered by name alphabetically. Return id, name, email, country.", + "Which medications are prescribed most often? Return medication_name, category, times_prescribed.", + "Rank active employees by salary within their department. Return salary_rank.", + ] + for nl in test_nls: + variants = augment_nl(nl, n=3, seed=42) + if not variants: + print(f" FAIL: No variants generated for: {nl[:50]}") + all_passed = False + else: + print(f" ✓ {len(variants)} variants from: '{nl[:45]}...'") + if show_samples > 0: + for i, v in enumerate(variants[:show_samples]): + print(f" [{i+1}] {v}") + + # 3. Validator accept/reject + print("\n[3/4] Testing SQL validator accept/reject logic...") + from data_factory.validator import SQLValidator + + v = SQLValidator("ecommerce") + tests = [ + ("SELECT id, name FROM customers WHERE tier = 'gold'", True, "valid SELECT"), + ("INSERT INTO customers VALUES (1,'x','x@x.com','IN','gold','2024-01-01')", False, "rejected INSERT"), + ("SELECT nonexistent_col FROM customers", False, "bad column name"), + ("", False, "empty string"), + ] + for sql, expect_pass, label in tests: + vr = v.validate(sql) + status = "✓" if vr.passed == expect_pass else "✗" + print(f" {status} {label}: passed={vr.passed}", end="") + if not vr.passed: + print(f" (error: {vr.error})", end="") + print() + if vr.passed != expect_pass: + all_passed = False + v.close() + + # 4. Mini base pipeline (first 5 templates only) + print("\n[4/4] Running mini base pipeline (first 5 templates)...") + from data_factory.pipeline import run_base_pipeline + + mini_templates = ALL_TEMPLATES[:5] + records = run_base_pipeline(mini_templates, n_augmentations=2, seed=42) + expected_min = 5 # at least canonical NLs + if len(records) < expected_min: + print(f" FAIL: Only {len(records)} records (expected ≥{expected_min})") + all_passed = False + else: + print(f" ✓ Generated {len(records)} records from 5 templates") + + # Validate structure + required_keys = {"prompt", "sql", "metadata"} + for rec in records[:3]: + missing = required_keys - rec.keys() + if missing: + print(f" FAIL: Record missing keys: {missing}") + all_passed = False + break + else: + print(" ✓ Record structure validated") + + if show_samples > 0 and records: + print(f"\n --- Sample Record ---") + sample = records[0] + print(f" Domain: {sample['metadata']['domain']}") + print(f" Difficulty: {sample['metadata']['difficulty']}") + print(f" Persona: {sample['metadata']['persona']}") + print(f" NL: {sample['prompt'][1]['content'].split('QUESTION: ')[-1][:100]}") + print(f" SQL: {sample['sql'][:80]}...") + + # Summary + print("\n" + "=" * 60) + if all_passed: + print(" ALL SMOKE TESTS PASSED ✓") + print(" Safe to run: python run_data_factory.py --mode base") + else: + print(" SOME TESTS FAILED ✗ — fix errors before running pipeline") + print("=" * 60 + "\n") + + return all_passed + + +# ───────────────────────────────────────────────────────────────────────────── +# INSPECT DATASET +# ───────────────────────────────────────────────────────────────────────────── + +def inspect_dataset(jsonl_path: str, n: int = 5) -> None: + """Pretty-print N records from an output JSONL file.""" + path = Path(jsonl_path) + if not path.exists(): + print(f"File not found: {path}") + return + + records = [] + with open(path, encoding="utf-8") as f: + for i, line in enumerate(f): + if i >= n: + break + records.append(json.loads(line)) + + print(f"\n{'='*65}") + print(f" Showing {len(records)} records from {path.name}") + print(f"{'='*65}") + + for i, rec in enumerate(records): + nl = rec["prompt"][1]["content"].split("QUESTION:")[-1].strip() + sql = rec["sql"] + meta = rec["metadata"] + print(f"\n[{i+1}] Domain={meta['domain']} | Difficulty={meta['difficulty']} | " + f"Persona={meta['persona']} | Source={meta['source']}") + print(f" NL: {textwrap.shorten(nl, 90)}") + print(f" SQL: {textwrap.shorten(sql, 90)}") + + print() + + +# ───────────────────────────────────────────────────────────────────────────── +# MAIN +# ───────────────────────────────────────────────────────────────────────────── + +def main() -> None: + parser = argparse.ArgumentParser( + description="NL2SQL Data Factory — entry point.", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--smoke-test", action="store_true", + help="Run smoke test only (validates all templates, no output written).", + ) + parser.add_argument( + "--show-samples", type=int, default=0, + help="During smoke test, show N sample NL variants and records.", + ) + parser.add_argument( + "--inspect", type=str, default=None, + help="Path to a JSONL output file to inspect.", + ) + parser.add_argument( + "--inspect-n", type=int, default=5, + help="Number of records to show when inspecting.", + ) + parser.add_argument( + "--mode", choices=["base", "full"], default="base", + help=( + "base: rule augmentation only, ~450 records, no GPU needed.\n" + "full: + vLLM persona variants, 500K+ records, H100 required." + ), + ) + parser.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct") + parser.add_argument("--tensor-parallel", type=int, default=4) + parser.add_argument("--n-rule-augments", type=int, default=5) + parser.add_argument("--n-persona-variants", type=int, default=10) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--temperature", type=float, default=0.85) + parser.add_argument("--output-dir", default="generated_data/output") + parser.add_argument("--checkpoint-dir", default="generated_data/checkpoints") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--no-parquet", action="store_true") + parser.add_argument("--resume", action="store_true") + parser.add_argument( + "--domains", nargs="+", + choices=["ecommerce","healthcare","finance","hr"], + default=["ecommerce","healthcare","finance","hr"], + ) + parser.add_argument( + "--difficulties", nargs="+", + choices=["easy","medium","hard"], + default=["easy","medium","hard"], + ) + + args = parser.parse_args() + + if args.smoke_test: + ok = run_smoke_test(show_samples=args.show_samples) + sys.exit(0 if ok else 1) + + if args.inspect: + inspect_dataset(args.inspect, n=args.inspect_n) + sys.exit(0) + + # Forward to pipeline + from data_factory.pipeline import main as pipeline_main + # Re-parse with pipeline's own parser by forwarding sys.argv + pipeline_main() + + +if __name__ == "__main__": + main() diff --git a/data_factory/schemas.py b/data_factory/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..db8e12dfe58c4ce9e0c6f473807cd0b857595021 --- /dev/null +++ b/data_factory/schemas.py @@ -0,0 +1,564 @@ +""" +data_factory/schemas.py +======================== +SQLite CREATE TABLE statements for all four domains. +Each schema is fully self-contained and has been verified to create +without errors in SQLite 3.x. +""" + +from __future__ import annotations +import sqlite3 +import random +from datetime import date, timedelta +from typing import Callable + +# ───────────────────────────────────────────────────────────────────────────── +# SQL SCHEMAS +# ───────────────────────────────────────────────────────────────────────────── + +ECOMMERCE_SCHEMA = """ +CREATE TABLE IF NOT EXISTS categories ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL UNIQUE +); + +CREATE TABLE IF NOT EXISTS products ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + category_id INTEGER NOT NULL REFERENCES categories(id), + price REAL NOT NULL CHECK(price >= 0), + stock_quantity INTEGER NOT NULL DEFAULT 0 +); + +CREATE TABLE IF NOT EXISTS customers ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT NOT NULL UNIQUE, + country TEXT NOT NULL, + tier TEXT NOT NULL DEFAULT 'bronze' + CHECK(tier IN ('bronze', 'silver', 'gold')), + created_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS orders ( + id INTEGER PRIMARY KEY, + customer_id INTEGER NOT NULL REFERENCES customers(id), + status TEXT NOT NULL DEFAULT 'pending' + CHECK(status IN ('pending','processing','shipped','delivered','cancelled')), + created_at TEXT NOT NULL, + total_amount REAL NOT NULL CHECK(total_amount >= 0) +); + +CREATE TABLE IF NOT EXISTS order_items ( + id INTEGER PRIMARY KEY, + order_id INTEGER NOT NULL REFERENCES orders(id), + product_id INTEGER NOT NULL REFERENCES products(id), + quantity INTEGER NOT NULL CHECK(quantity > 0), + unit_price REAL NOT NULL CHECK(unit_price >= 0) +); + +CREATE TABLE IF NOT EXISTS reviews ( + id INTEGER PRIMARY KEY, + product_id INTEGER NOT NULL REFERENCES products(id), + customer_id INTEGER NOT NULL REFERENCES customers(id), + rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5), + created_at TEXT NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_products_category ON products(category_id); +CREATE INDEX IF NOT EXISTS idx_orders_customer ON orders(customer_id); +CREATE INDEX IF NOT EXISTS idx_orders_status ON orders(status); +CREATE INDEX IF NOT EXISTS idx_orders_created ON orders(created_at); +CREATE INDEX IF NOT EXISTS idx_order_items_order ON order_items(order_id); +CREATE INDEX IF NOT EXISTS idx_order_items_product ON order_items(product_id); +CREATE INDEX IF NOT EXISTS idx_reviews_product ON reviews(product_id); +CREATE INDEX IF NOT EXISTS idx_customers_tier ON customers(tier); +""" + +HEALTHCARE_SCHEMA = """ +CREATE TABLE IF NOT EXISTS patients ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + date_of_birth TEXT NOT NULL, + gender TEXT NOT NULL CHECK(gender IN ('M','F','Other')), + blood_type TEXT NOT NULL, + country TEXT NOT NULL, + registered_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS doctors ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + specialization TEXT NOT NULL, + department TEXT NOT NULL, + experience_years INTEGER NOT NULL CHECK(experience_years >= 0), + consultation_fee REAL NOT NULL CHECK(consultation_fee >= 0) +); + +CREATE TABLE IF NOT EXISTS appointments ( + id INTEGER PRIMARY KEY, + patient_id INTEGER NOT NULL REFERENCES patients(id), + doctor_id INTEGER NOT NULL REFERENCES doctors(id), + scheduled_at TEXT NOT NULL, + status TEXT NOT NULL + CHECK(status IN ('scheduled','completed','cancelled','no_show')), + notes TEXT +); + +CREATE TABLE IF NOT EXISTS diagnoses ( + id INTEGER PRIMARY KEY, + appointment_id INTEGER NOT NULL REFERENCES appointments(id), + icd_code TEXT NOT NULL, + description TEXT NOT NULL, + severity TEXT NOT NULL CHECK(severity IN ('mild','moderate','severe')) +); + +CREATE TABLE IF NOT EXISTS medications ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + category TEXT NOT NULL, + unit_price REAL NOT NULL CHECK(unit_price >= 0) +); + +CREATE TABLE IF NOT EXISTS prescriptions ( + id INTEGER PRIMARY KEY, + appointment_id INTEGER NOT NULL REFERENCES appointments(id), + medication_id INTEGER NOT NULL REFERENCES medications(id), + dosage TEXT NOT NULL, + duration_days INTEGER NOT NULL CHECK(duration_days > 0), + quantity INTEGER NOT NULL CHECK(quantity > 0) +); + +CREATE INDEX IF NOT EXISTS idx_appt_patient ON appointments(patient_id); +CREATE INDEX IF NOT EXISTS idx_appt_doctor ON appointments(doctor_id); +CREATE INDEX IF NOT EXISTS idx_appt_status ON appointments(status); +CREATE INDEX IF NOT EXISTS idx_diag_appt ON diagnoses(appointment_id); +CREATE INDEX IF NOT EXISTS idx_presc_appt ON prescriptions(appointment_id); +CREATE INDEX IF NOT EXISTS idx_presc_med ON prescriptions(medication_id); +""" + +FINANCE_SCHEMA = """ +CREATE TABLE IF NOT EXISTS fin_customers ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT NOT NULL UNIQUE, + country TEXT NOT NULL, + kyc_status TEXT NOT NULL CHECK(kyc_status IN ('pending','verified','rejected')), + created_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS accounts ( + id INTEGER PRIMARY KEY, + customer_id INTEGER NOT NULL REFERENCES fin_customers(id), + account_type TEXT NOT NULL + CHECK(account_type IN ('savings','current','fixed_deposit','loan')), + balance REAL NOT NULL DEFAULT 0, + currency TEXT NOT NULL DEFAULT 'USD', + status TEXT NOT NULL CHECK(status IN ('active','dormant','closed')), + opened_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS transactions ( + id INTEGER PRIMARY KEY, + account_id INTEGER NOT NULL REFERENCES accounts(id), + txn_type TEXT NOT NULL CHECK(txn_type IN ('credit','debit')), + amount REAL NOT NULL CHECK(amount > 0), + currency TEXT NOT NULL DEFAULT 'USD', + merchant TEXT, + created_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS loans ( + id INTEGER PRIMARY KEY, + customer_id INTEGER NOT NULL REFERENCES fin_customers(id), + loan_type TEXT NOT NULL + CHECK(loan_type IN ('personal','home','auto','business')), + principal_amount REAL NOT NULL, + interest_rate REAL NOT NULL, + tenure_months INTEGER NOT NULL, + status TEXT NOT NULL CHECK(status IN ('active','closed','defaulted')), + disbursed_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS loan_payments ( + id INTEGER PRIMARY KEY, + loan_id INTEGER NOT NULL REFERENCES loans(id), + amount_paid REAL NOT NULL CHECK(amount_paid > 0), + payment_date TEXT NOT NULL, + is_late INTEGER NOT NULL DEFAULT 0 CHECK(is_late IN (0,1)) +); + +CREATE INDEX IF NOT EXISTS idx_acct_customer ON accounts(customer_id); +CREATE INDEX IF NOT EXISTS idx_txn_account ON transactions(account_id); +CREATE INDEX IF NOT EXISTS idx_txn_type ON transactions(txn_type); +CREATE INDEX IF NOT EXISTS idx_loan_customer ON loans(customer_id); +CREATE INDEX IF NOT EXISTS idx_lp_loan ON loan_payments(loan_id); +""" + +HR_SCHEMA = """ +CREATE TABLE IF NOT EXISTS departments ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + location TEXT NOT NULL, + budget REAL NOT NULL CHECK(budget >= 0) +); + +CREATE TABLE IF NOT EXISTS employees ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT NOT NULL UNIQUE, + department_id INTEGER NOT NULL REFERENCES departments(id), + job_title TEXT NOT NULL, + hire_date TEXT NOT NULL, + salary REAL NOT NULL CHECK(salary >= 0), + status TEXT NOT NULL CHECK(status IN ('active','resigned','terminated')) +); + +CREATE TABLE IF NOT EXISTS performance_reviews ( + id INTEGER PRIMARY KEY, + employee_id INTEGER NOT NULL REFERENCES employees(id), + review_year INTEGER NOT NULL, + rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5), + reviewer_id INTEGER NOT NULL REFERENCES employees(id), + comments TEXT +); + +CREATE TABLE IF NOT EXISTS projects ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + department_id INTEGER NOT NULL REFERENCES departments(id), + start_date TEXT NOT NULL, + end_date TEXT, + budget REAL NOT NULL, + status TEXT NOT NULL + CHECK(status IN ('planned','active','completed','cancelled')) +); + +CREATE TABLE IF NOT EXISTS project_assignments ( + id INTEGER PRIMARY KEY, + employee_id INTEGER NOT NULL REFERENCES employees(id), + project_id INTEGER NOT NULL REFERENCES projects(id), + role TEXT NOT NULL, + hours_allocated INTEGER NOT NULL CHECK(hours_allocated > 0) +); + +CREATE INDEX IF NOT EXISTS idx_emp_dept ON employees(department_id); +CREATE INDEX IF NOT EXISTS idx_emp_status ON employees(status); +CREATE INDEX IF NOT EXISTS idx_pr_employee ON performance_reviews(employee_id); +CREATE INDEX IF NOT EXISTS idx_proj_dept ON projects(department_id); +CREATE INDEX IF NOT EXISTS idx_pa_employee ON project_assignments(employee_id); +CREATE INDEX IF NOT EXISTS idx_pa_project ON project_assignments(project_id); +""" + +# ───────────────────────────────────────────────────────────────────────────── +# SCHEMA REGISTRY +# ───────────────────────────────────────────────────────────────────────────── + +SCHEMA_MAP: dict[str, str] = { + "ecommerce": ECOMMERCE_SCHEMA, + "healthcare": HEALTHCARE_SCHEMA, + "finance": FINANCE_SCHEMA, + "hr": HR_SCHEMA, +} + +# ───────────────────────────────────────────────────────────────────────────── +# COMPACT SCHEMA CONTEXT (injected into every training prompt) +# ───────────────────────────────────────────────────────────────────────────── + +SCHEMA_CONTEXT: dict[str, str] = { + "ecommerce": """\ +Database: ecommerce (SQLite, read-only) + +TABLES +------ +categories(id INTEGER PK, name TEXT) +products(id INTEGER PK, name TEXT, category_id INTEGER FK→categories.id, price REAL, stock_quantity INTEGER) +customers(id INTEGER PK, name TEXT, email TEXT, country TEXT, tier TEXT ∈ {bronze|silver|gold}, created_at TEXT ISO-8601) +orders(id INTEGER PK, customer_id INTEGER FK→customers.id, status TEXT ∈ {pending|processing|shipped|delivered|cancelled}, created_at TEXT ISO-8601, total_amount REAL) +order_items(id INTEGER PK, order_id INTEGER FK→orders.id, product_id INTEGER FK→products.id, quantity INTEGER, unit_price REAL) +reviews(id INTEGER PK, product_id INTEGER FK→products.id, customer_id INTEGER FK→customers.id, rating INTEGER 1-5, created_at TEXT ISO-8601) + +NOTES +----- +- Use created_at >= '2024-01-01' for date filtering (ISO text sort works) +- SQLite window functions: RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD +- strftime('%Y-%m', created_at) returns 'YYYY-MM' +- All monetary values in USD +""", + + "healthcare": """\ +Database: healthcare (SQLite, read-only) + +TABLES +------ +patients(id INTEGER PK, name TEXT, date_of_birth TEXT ISO-8601, gender TEXT ∈ {M|F|Other}, blood_type TEXT, country TEXT, registered_at TEXT ISO-8601) +doctors(id INTEGER PK, name TEXT, specialization TEXT, department TEXT, experience_years INTEGER, consultation_fee REAL) +appointments(id INTEGER PK, patient_id INTEGER FK→patients.id, doctor_id INTEGER FK→doctors.id, scheduled_at TEXT ISO-8601, status TEXT ∈ {scheduled|completed|cancelled|no_show}, notes TEXT nullable) +diagnoses(id INTEGER PK, appointment_id INTEGER FK→appointments.id, icd_code TEXT, description TEXT, severity TEXT ∈ {mild|moderate|severe}) +medications(id INTEGER PK, name TEXT, category TEXT, unit_price REAL) +prescriptions(id INTEGER PK, appointment_id INTEGER FK→appointments.id, medication_id INTEGER FK→medications.id, dosage TEXT, duration_days INTEGER, quantity INTEGER) + +NOTES +----- +- consultation_fee is in USD per visit +- ICD codes follow WHO ICD-10 format (e.g. 'I10', 'E11') +- SQLite window functions available +""", + + "finance": """\ +Database: finance (SQLite, read-only) + +TABLES +------ +fin_customers(id INTEGER PK, name TEXT, email TEXT, country TEXT, kyc_status TEXT ∈ {pending|verified|rejected}, created_at TEXT ISO-8601) +accounts(id INTEGER PK, customer_id INTEGER FK→fin_customers.id, account_type TEXT ∈ {savings|current|fixed_deposit|loan}, balance REAL, currency TEXT, status TEXT ∈ {active|dormant|closed}, opened_at TEXT ISO-8601) +transactions(id INTEGER PK, account_id INTEGER FK→accounts.id, txn_type TEXT ∈ {credit|debit}, amount REAL, currency TEXT, merchant TEXT nullable, created_at TEXT ISO-8601) +loans(id INTEGER PK, customer_id INTEGER FK→fin_customers.id, loan_type TEXT ∈ {personal|home|auto|business}, principal_amount REAL, interest_rate REAL, tenure_months INTEGER, status TEXT ∈ {active|closed|defaulted}, disbursed_at TEXT ISO-8601) +loan_payments(id INTEGER PK, loan_id INTEGER FK→loans.id, amount_paid REAL, payment_date TEXT ISO-8601, is_late INTEGER ∈ {0|1}) + +NOTES +----- +- All monetary values in USD unless currency column specifies otherwise +- is_late = 1 means the payment was overdue +- SQLite window functions available +""", + + "hr": """\ +Database: hr (SQLite, read-only) + +TABLES +------ +departments(id INTEGER PK, name TEXT, location TEXT, budget REAL) +employees(id INTEGER PK, name TEXT, email TEXT, department_id INTEGER FK→departments.id, job_title TEXT, hire_date TEXT ISO-8601, salary REAL, status TEXT ∈ {active|resigned|terminated}) +performance_reviews(id INTEGER PK, employee_id INTEGER FK→employees.id, review_year INTEGER, rating INTEGER 1-5, reviewer_id INTEGER FK→employees.id, comments TEXT nullable) +projects(id INTEGER PK, name TEXT, department_id INTEGER FK→departments.id, start_date TEXT ISO-8601, end_date TEXT nullable, budget REAL, status TEXT ∈ {planned|active|completed|cancelled}) +project_assignments(id INTEGER PK, employee_id INTEGER FK→employees.id, project_id INTEGER FK→projects.id, role TEXT, hours_allocated INTEGER) + +NOTES +----- +- salary is annual in USD +- performance rating: 1 (lowest) to 5 (highest) +- end_date is NULL for ongoing projects +- SQLite window functions available +""", +} + + +# ───────────────────────────────────────────────────────────────────────────── +# SEED FUNCTIONS (deterministic, SEED=42) +# ───────────────────────────────────────────────────────────────────────────── + +def _rdate(rng: random.Random, start: str = "2022-01-01", end: str = "2024-12-31") -> str: + s = date.fromisoformat(start) + e = date.fromisoformat(end) + return (s + timedelta(days=rng.randint(0, (e - s).days))).isoformat() + + +def seed_ecommerce(conn: sqlite3.Connection, seed: int = 42) -> None: + rng = random.Random(seed) + cats = ["Electronics", "Clothing", "Books", "Home & Garden", + "Sports & Outdoors", "Toys & Games", "Beauty", "Automotive"] + conn.executemany("INSERT INTO categories(id,name) VALUES(?,?)", enumerate(cats, 1)) + + products = [ + (1,"Wireless Headphones",1,149.99,50),(2,"Laptop Stand",1,59.99,120), + (3,"USB-C Hub",1,49.99,90),(4,"Webcam 4K",1,89.99,30), + (5,"Cotton T-Shirt",2,19.99,200),(6,"Winter Jacket",2,129.99,60), + (7,"Running Shorts",2,34.99,150),(8,"Clean Code",3,39.99,80), + (9,"Deep Learning Book",3,59.99,45),(10,"Coffee Maker",4,89.99,40), + (11,"Air Purifier",4,199.99,25),(12,"Yoga Mat",5,29.99,150), + (13,"Resistance Bands",5,14.99,200),(14,"Lego City Set",6,79.99,60), + (15,"Face Serum",7,34.99,100),(16,"Dash Cam",8,119.99,35), + ] + conn.executemany("INSERT INTO products VALUES(?,?,?,?,?)", products) + + countries = ["India","USA","Germany","UK","Canada","Australia","France","Brazil"] + tiers = ["bronze","silver","gold"] + customers = [] + for i in range(1, 51): + customers.append((i, f"Customer {i}", f"cust{i}@shop.com", + rng.choice(countries), rng.choice(tiers), _rdate(rng))) + conn.executemany("INSERT INTO customers VALUES(?,?,?,?,?,?)", customers) + + statuses = ["pending","processing","shipped","delivered","cancelled"] + orders = [] + for i in range(1, 201): + orders.append((i, rng.randint(1, 50), rng.choice(statuses), + _rdate(rng), round(rng.uniform(20, 800), 2))) + conn.executemany("INSERT INTO orders VALUES(?,?,?,?,?)", orders) + + items = [] + for i in range(1, 301): + items.append((i, rng.randint(1, 200), rng.randint(1, 16), + rng.randint(1, 5), round(rng.uniform(10, 200), 2))) + conn.executemany("INSERT INTO order_items VALUES(?,?,?,?,?)", items) + + reviews = [] + for i in range(1, 151): + reviews.append((i, rng.randint(1, 16), rng.randint(1, 50), + rng.randint(1, 5), _rdate(rng))) + conn.executemany("INSERT INTO reviews VALUES(?,?,?,?,?)", reviews) + conn.commit() + + +def seed_healthcare(conn: sqlite3.Connection, seed: int = 42) -> None: + rng = random.Random(seed) + specs = [("Cardiology","Cardiology"), ("Neurology","Neurology"), + ("Orthopedics","Orthopedics"), ("Dermatology","Dermatology"), + ("Pediatrics","Pediatrics"), ("Oncology","Oncology"), + ("Endocrinology","Endocrinology"), ("Gastroenterology","Gastroenterology")] + for i, (spec, dept) in enumerate(specs, 1): + conn.execute("INSERT INTO doctors VALUES(?,?,?,?,?,?)", + (i, f"Dr. {['Smith','Patel','Kim','Müller','Okafor','Chen','Lopez','Roy'][i-1]}", + spec, dept, rng.randint(2, 25), round(rng.uniform(50, 350), 2))) + + genders = ["M", "F", "Other"] + blood_types = ["A+","A-","B+","B-","O+","O-","AB+","AB-"] + countries = ["India","USA","Germany","UK","Canada","Australia"] + for i in range(1, 101): + conn.execute("INSERT INTO patients VALUES(?,?,?,?,?,?,?)", + (i, f"Patient {i}", _rdate(rng, "1950-01-01", "2010-01-01"), + rng.choice(genders), rng.choice(blood_types), + rng.choice(countries), _rdate(rng, "2020-01-01", "2024-12-31"))) + + appt_statuses = ["scheduled", "completed", "cancelled", "no_show"] + weights = [0.15, 0.60, 0.15, 0.10] + for i in range(1, 301): + conn.execute("INSERT INTO appointments VALUES(?,?,?,?,?,?)", + (i, rng.randint(1, 100), rng.randint(1, 8), + _rdate(rng, "2022-01-01", "2024-12-31"), + rng.choices(appt_statuses, weights)[0], None)) + + icd_codes = ["I10","E11","J45","M54","K21","F32","G43","L30","N39","R05", + "C50","Z87","I25","E78","J18"] + descs = ["Hypertension","Type 2 Diabetes","Asthma","Back Pain","GERD", + "Depression","Migraine","Dermatitis","UTI","Cough", + "Breast Cancer","Family History","Coronary Artery Disease", + "Hyperlipidemia","Pneumonia"] + severities = ["mild","moderate","severe"] + for i in range(1, 201): + conn.execute("INSERT INTO diagnoses VALUES(?,?,?,?,?)", + (i, rng.randint(1, 300), rng.choice(icd_codes), + rng.choice(descs), rng.choice(severities))) + + meds = [("Metformin","Antidiabetic",0.15),("Lisinopril","Antihypertensive",0.20), + ("Atorvastatin","Statin",0.25),("Amoxicillin","Antibiotic",0.30), + ("Ibuprofen","NSAID",0.10),("Omeprazole","PPI",0.18), + ("Sertraline","Antidepressant",0.35),("Cetirizine","Antihistamine",0.08), + ("Paracetamol","Analgesic",0.05),("Aspirin","Antiplatelet",0.07)] + for i, (name, cat, price) in enumerate(meds, 1): + conn.execute("INSERT INTO medications VALUES(?,?,?,?)", (i, name, cat, price)) + + dosages = ["1x daily","2x daily","3x daily","once at night","as needed"] + for i in range(1, 251): + conn.execute("INSERT INTO prescriptions VALUES(?,?,?,?,?,?)", + (i, rng.randint(1, 300), rng.randint(1, 10), + rng.choice(dosages), rng.randint(5, 60), rng.randint(10, 90))) + conn.commit() + + +def seed_finance(conn: sqlite3.Connection, seed: int = 42) -> None: + rng = random.Random(seed) + countries = ["India","USA","Germany","UK","Singapore","UAE","Canada"] + kyc = ["pending","verified","verified","verified","rejected"] + for i in range(1, 51): + conn.execute("INSERT INTO fin_customers VALUES(?,?,?,?,?,?)", + (i, f"FinClient {i}", f"fincli{i}@bank.com", + rng.choice(countries), rng.choice(kyc), _rdate(rng))) + + acct_types = ["savings","savings","current","fixed_deposit"] + statuses = ["active","active","active","dormant","closed"] + for i in range(1, 101): + conn.execute("INSERT INTO accounts VALUES(?,?,?,?,?,?,?)", + (i, rng.randint(1, 50), rng.choice(acct_types), + round(rng.uniform(100, 100000), 2), "USD", + rng.choice(statuses), _rdate(rng))) + + merchants = [None, "Amazon", "Walmart", "Netflix", "Uber", "Apple", + "Google Pay", "Zomato", "Flipkart", "Airbnb"] + for i in range(1, 501): + conn.execute("INSERT INTO transactions VALUES(?,?,?,?,?,?,?)", + (i, rng.randint(1, 100), rng.choice(["credit","debit"]), + round(rng.uniform(5, 10000), 2), "USD", + rng.choice(merchants), _rdate(rng))) + + loan_types = ["personal","home","auto","business"] + loan_statuses = ["active","active","closed","defaulted"] + for i in range(1, 51): + conn.execute("INSERT INTO loans VALUES(?,?,?,?,?,?,?,?)", + (i, rng.randint(1, 50), rng.choice(loan_types), + round(rng.uniform(5000, 500000), 2), + round(rng.uniform(5, 18), 2), rng.randint(12, 360), + rng.choice(loan_statuses), _rdate(rng))) + + for i in range(1, 201): + conn.execute("INSERT INTO loan_payments VALUES(?,?,?,?,?)", + (i, rng.randint(1, 50), round(rng.uniform(500, 10000), 2), + _rdate(rng), rng.randint(0, 1))) + conn.commit() + + +def seed_hr(conn: sqlite3.Connection, seed: int = 42) -> None: + rng = random.Random(seed) + depts = [("Engineering","Bangalore",8000000),("Marketing","Mumbai",3000000), + ("Finance","Delhi",2000000),("HR","Chennai",1500000), + ("Sales","Hyderabad",5000000),("Product","Pune",4000000), + ("Legal","Delhi",1000000),("Operations","Kolkata",2500000)] + for i, (name, loc, bud) in enumerate(depts, 1): + conn.execute("INSERT INTO departments VALUES(?,?,?,?)", (i, name, loc, bud)) + + titles = ["Software Engineer","Senior Engineer","Staff Engineer","Principal Engineer", + "Engineering Manager","Product Manager","Data Analyst","Data Scientist", + "Marketing Specialist","Sales Executive","HR Specialist","Finance Analyst", + "Director","VP","Legal Counsel"] + statuses = ["active","active","active","active","resigned","terminated"] + for i in range(1, 101): + conn.execute("INSERT INTO employees VALUES(?,?,?,?,?,?,?,?)", + (i, f"Employee {i}", f"emp{i}@corp.com", + rng.randint(1, 8), rng.choice(titles), + _rdate(rng, "2015-01-01", "2024-01-01"), + round(rng.uniform(25000, 200000), 2), rng.choice(statuses))) + + for i in range(1, 201): + conn.execute("INSERT INTO performance_reviews VALUES(?,?,?,?,?,?)", + (i, rng.randint(1, 100), rng.randint(2019, 2024), + rng.randint(1, 5), rng.randint(1, 100), + rng.choice(["Excellent work","Good performance","Needs improvement", + "Outstanding","Meeting expectations"]))) + + proj_statuses = ["planned","active","active","completed","cancelled"] + for i in range(1, 51): + sd = _rdate(rng, "2021-01-01", "2024-01-01") + conn.execute("INSERT INTO projects VALUES(?,?,?,?,?,?,?)", + (i, f"Project {i}", rng.randint(1, 8), sd, + _rdate(rng, sd, "2025-06-01") if rng.random() > 0.25 else None, + round(rng.uniform(50000, 2000000), 2), rng.choice(proj_statuses))) + + roles = ["Lead","Senior Developer","Developer","Tester","Analyst","DevOps"] + for i in range(1, 251): + conn.execute("INSERT INTO project_assignments VALUES(?,?,?,?,?)", + (i, rng.randint(1, 100), rng.randint(1, 50), + rng.choice(roles), rng.randint(20, 400))) + conn.commit() + + +# ───────────────────────────────────────────────────────────────────────────── +# REGISTRY +# ───────────────────────────────────────────────────────────────────────────── + +SEED_MAP: dict[str, Callable] = { + "ecommerce": seed_ecommerce, + "healthcare": seed_healthcare, + "finance": seed_finance, + "hr": seed_hr, +} + + +def build_connection(domain: str, seed: int = 42) -> sqlite3.Connection: + """Return a seeded in-memory SQLite connection for the given domain.""" + conn = sqlite3.connect(":memory:", check_same_thread=False) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys = ON") + conn.executescript(SCHEMA_MAP[domain]) + SEED_MAP[domain](conn, seed=seed) + return conn diff --git a/data_factory/templates.py b/data_factory/templates.py new file mode 100644 index 0000000000000000000000000000000000000000..d6a9f227cff245a38d7ec5c0e56fafd4fa46077d --- /dev/null +++ b/data_factory/templates.py @@ -0,0 +1,993 @@ +""" +data_factory/templates.py +========================== +Human-authored, execution-verified SQL templates across 4 domains × 3 difficulty tiers. + +CRITICAL DESIGN PRINCIPLE: + SQL is NEVER generated by an LLM in this pipeline. + Every SQL here was written by hand and verified by running it against + seeded SQLite data. Zero errors guaranteed. + +Structure per entry: + { + "domain": str, # ecommerce | healthcare | finance | hr + "difficulty": str, # easy | medium | hard + "sql": str, # verified ground-truth SQL + "description": str, # one-line English summary (seed for NL generation) + "base_nl": str, # canonical natural-language question + "has_order": bool, # True → comparison is order-sensitive + } +""" + +from __future__ import annotations +from typing import TypedDict + + +class Template(TypedDict): + domain: str + difficulty: str + sql: str + description: str + base_nl: str + has_order: bool + + +# ───────────────────────────────────────────────────────────────────────────── +# DOMAIN: ECOMMERCE +# ───────────────────────────────────────────────────────────────────────────── + +ECOMMERCE_TEMPLATES: list[Template] = [ + + # ── EASY ──────────────────────────────────────────────────────────────── + + { + "domain": "ecommerce", "difficulty": "easy", "has_order": True, + "description": "List gold-tier customers sorted alphabetically with id, name, email, country", + "base_nl": "List all gold-tier customers ordered by name alphabetically. Return id, name, email, country.", + "sql": ( + "SELECT id, name, email, country " + "FROM customers " + "WHERE tier = 'gold' " + "ORDER BY name ASC" + ), + }, + { + "domain": "ecommerce", "difficulty": "easy", "has_order": True, + "description": "Products priced above $100, sorted by price descending", + "base_nl": "Show all products with a price above $100, sorted from highest to lowest price. Return id, name, price.", + "sql": ( + "SELECT id, name, price " + "FROM products " + "WHERE price > 100 " + "ORDER BY price DESC" + ), + }, + { + "domain": "ecommerce", "difficulty": "easy", "has_order": True, + "description": "Delivered orders with total_amount > 200, sorted by amount descending", + "base_nl": "Find all delivered orders with a total amount greater than $200, sorted by total amount descending. Return id, customer_id, total_amount, created_at.", + "sql": ( + "SELECT id, customer_id, total_amount, created_at " + "FROM orders " + "WHERE status = 'delivered' " + " AND total_amount > 200 " + "ORDER BY total_amount DESC" + ), + }, + { + "domain": "ecommerce", "difficulty": "easy", "has_order": True, + "description": "Top 5 most expensive products", + "base_nl": "Return the top 5 most expensive products. Return id, name, price.", + "sql": ( + "SELECT id, name, price " + "FROM products " + "ORDER BY price DESC " + "LIMIT 5" + ), + }, + { + "domain": "ecommerce", "difficulty": "easy", "has_order": True, + "description": "Distinct countries where customers come from, sorted alphabetically", + "base_nl": "List all distinct countries our customers come from, sorted alphabetically. Return country.", + "sql": ( + "SELECT DISTINCT country " + "FROM customers " + "ORDER BY country ASC" + ), + }, + { + "domain": "ecommerce", "difficulty": "easy", "has_order": False, + "description": "Count total number of customers", + "base_nl": "How many customers do we have in total? Return a single column total_customers.", + "sql": "SELECT COUNT(*) AS total_customers FROM customers", + }, + { + "domain": "ecommerce", "difficulty": "easy", "has_order": True, + "description": "Products with zero stock", + "base_nl": "List all out-of-stock products. Return id, name, stock_quantity.", + "sql": ( + "SELECT id, name, stock_quantity " + "FROM products " + "WHERE stock_quantity = 0 " + "ORDER BY name ASC" + ), + }, + { + "domain": "ecommerce", "difficulty": "easy", "has_order": True, + "description": "Customers from India sorted by name", + "base_nl": "Show all customers from India, sorted by name. Return id, name, email.", + "sql": ( + "SELECT id, name, email " + "FROM customers " + "WHERE country = 'India' " + "ORDER BY name ASC" + ), + }, + { + "domain": "ecommerce", "difficulty": "easy", "has_order": True, + "description": "Products in a price range of $20 to $100 sorted by price ascending", + "base_nl": "Which products are priced between $20 and $100? Sort by price ascending. Return id, name, price.", + "sql": ( + "SELECT id, name, price " + "FROM products " + "WHERE price BETWEEN 20 AND 100 " + "ORDER BY price ASC" + ), + }, + { + "domain": "ecommerce", "difficulty": "easy", "has_order": False, + "description": "Count orders by status", + "base_nl": "How many orders are there for each status? Return status, order_count.", + "sql": ( + "SELECT status, COUNT(*) AS order_count " + "FROM orders " + "GROUP BY status" + ), + }, + + # ── MEDIUM ─────────────────────────────────────────────────────────────── + + { + "domain": "ecommerce", "difficulty": "medium", "has_order": True, + "description": "Order count per customer including those with zero orders, sorted by count desc", + "base_nl": "How many orders has each customer placed? Include customers with zero orders. Return customer_name, order_count, sorted by order_count descending then customer_name ascending.", + "sql": ( + "SELECT c.name AS customer_name, COUNT(o.id) AS order_count " + "FROM customers c " + "LEFT JOIN orders o ON c.id = o.customer_id " + "GROUP BY c.id, c.name " + "ORDER BY order_count DESC, customer_name ASC" + ), + }, + { + "domain": "ecommerce", "difficulty": "medium", "has_order": True, + "description": "Average product rating per category sorted descending", + "base_nl": "What is the average product rating per category? Only include categories with at least one review. Return category_name, avg_rating (rounded to 2 decimal places), sorted by avg_rating descending.", + "sql": ( + "SELECT c.name AS category_name, " + " ROUND(AVG(r.rating), 2) AS avg_rating " + "FROM categories c " + "JOIN products p ON p.category_id = c.id " + "JOIN reviews r ON r.product_id = p.id " + "GROUP BY c.id, c.name " + "ORDER BY avg_rating DESC" + ), + }, + { + "domain": "ecommerce", "difficulty": "medium", "has_order": True, + "description": "Customers who spent more than $500 on delivered orders", + "base_nl": "Which customers have spent more than $500 total on delivered orders? Return customer_name, total_spent (rounded to 2 decimal places), sorted by total_spent descending.", + "sql": ( + "SELECT c.name AS customer_name, " + " ROUND(SUM(o.total_amount), 2) AS total_spent " + "FROM customers c " + "JOIN orders o ON o.customer_id = c.id " + "WHERE o.status = 'delivered' " + "GROUP BY c.id, c.name " + "HAVING SUM(o.total_amount) > 500 " + "ORDER BY total_spent DESC" + ), + }, + { + "domain": "ecommerce", "difficulty": "medium", "has_order": True, + "description": "Total quantity sold per product sorted descending", + "base_nl": "Show the total quantity sold for each product that appears in at least one order. Return product_name, total_quantity_sold, sorted by total_quantity_sold descending.", + "sql": ( + "SELECT p.name AS product_name, " + " SUM(oi.quantity) AS total_quantity_sold " + "FROM products p " + "JOIN order_items oi ON oi.product_id = p.id " + "GROUP BY p.id, p.name " + "ORDER BY total_quantity_sold DESC" + ), + }, + { + "domain": "ecommerce", "difficulty": "medium", "has_order": True, + "description": "Product count and average price per category sorted by count desc", + "base_nl": "For each category, show the number of products and their average price. Return category_name, product_count, avg_price (rounded to 2 decimal places), sorted by product_count descending.", + "sql": ( + "SELECT cat.name AS category_name, " + " COUNT(p.id) AS product_count, " + " ROUND(AVG(p.price), 2) AS avg_price " + "FROM categories cat " + "JOIN products p ON p.category_id = cat.id " + "GROUP BY cat.id, cat.name " + "ORDER BY product_count DESC" + ), + }, + { + "domain": "ecommerce", "difficulty": "medium", "has_order": True, + "description": "Categories with more than 5 in-stock products sorted by count desc", + "base_nl": "Which categories have more than 5 products in stock (stock_quantity > 0)? Return category_name, in_stock_count, sorted by in_stock_count descending.", + "sql": ( + "SELECT c.name AS category_name, " + " COUNT(p.id) AS in_stock_count " + "FROM categories c " + "JOIN products p ON p.category_id = c.id " + "WHERE p.stock_quantity > 0 " + "GROUP BY c.id, c.name " + "HAVING COUNT(p.id) > 5 " + "ORDER BY in_stock_count DESC" + ), + }, + { + "domain": "ecommerce", "difficulty": "medium", "has_order": True, + "description": "Total revenue per product from order items, sorted descending", + "base_nl": "What is the total revenue generated by each product from order items? Return product_name, total_revenue (rounded to 2 decimal places), sorted by total_revenue descending.", + "sql": ( + "SELECT p.name AS product_name, " + " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_revenue " + "FROM products p " + "JOIN order_items oi ON oi.product_id = p.id " + "GROUP BY p.id, p.name " + "ORDER BY total_revenue DESC" + ), + }, + + # ── HARD ───────────────────────────────────────────────────────────────── + + { + "domain": "ecommerce", "difficulty": "hard", "has_order": True, + "description": "Customer spending rank using DENSE_RANK on delivered orders", + "base_nl": "Rank customers by total spending on delivered orders using DENSE_RANK (rank 1 = highest spender). Return customer_name, total_spent (rounded to 2 decimal places), spending_rank, sorted by spending_rank ascending.", + "sql": ( + "SELECT customer_name, total_spent, spending_rank " + "FROM ( " + " SELECT c.name AS customer_name, " + " ROUND(SUM(o.total_amount), 2) AS total_spent, " + " DENSE_RANK() OVER (ORDER BY SUM(o.total_amount) DESC) AS spending_rank " + " FROM customers c " + " JOIN orders o ON o.customer_id = c.id " + " WHERE o.status = 'delivered' " + " GROUP BY c.id, c.name " + ") sub " + "ORDER BY spending_rank ASC" + ), + }, + { + "domain": "ecommerce", "difficulty": "hard", "has_order": True, + "description": "Monthly delivered revenue with running total using window SUM", + "base_nl": "Show the monthly revenue from delivered orders and its running cumulative total. Return month (YYYY-MM), monthly_revenue, running_total (both rounded to 2 decimal places), sorted by month ascending.", + "sql": ( + "WITH monthly AS ( " + " SELECT strftime('%Y-%m', created_at) AS month, " + " ROUND(SUM(total_amount), 2) AS monthly_revenue " + " FROM orders " + " WHERE status = 'delivered' " + " GROUP BY strftime('%Y-%m', created_at) " + ") " + "SELECT month, " + " monthly_revenue, " + " ROUND(SUM(monthly_revenue) OVER (ORDER BY month), 2) AS running_total " + "FROM monthly " + "ORDER BY month ASC" + ), + }, + { + "domain": "ecommerce", "difficulty": "hard", "has_order": True, + "description": "Customers whose most recent order was cancelled, using ROW_NUMBER CTE", + "base_nl": "Find all customers whose most recent order has status 'cancelled'. Use ROW_NUMBER to identify the latest order per customer. Return customer_name, last_order_status, last_order_date, sorted by customer_name ascending.", + "sql": ( + "WITH ranked_orders AS ( " + " SELECT customer_id, status, created_at, " + " ROW_NUMBER() OVER (PARTITION BY customer_id " + " ORDER BY created_at DESC) AS rn " + " FROM orders " + ") " + "SELECT c.name AS customer_name, " + " ro.status AS last_order_status, " + " ro.created_at AS last_order_date " + "FROM customers c " + "JOIN ranked_orders ro ON ro.customer_id = c.id " + "WHERE ro.rn = 1 " + " AND ro.status = 'cancelled' " + "ORDER BY customer_name ASC" + ), + }, + { + "domain": "ecommerce", "difficulty": "hard", "has_order": True, + "description": "Products above their category average rating, using two CTEs", + "base_nl": "Find products whose average rating is strictly above the average rating of all products in their category. Return product_name, category_name, product_avg_rating, category_avg_rating (both rounded to 2 decimal places), sorted by product_avg_rating descending then product_name ascending.", + "sql": ( + "WITH product_ratings AS ( " + " SELECT p.id AS product_id, p.name AS product_name, " + " p.category_id, c.name AS category_name, " + " ROUND(AVG(r.rating), 2) AS product_avg_rating " + " FROM products p " + " JOIN reviews r ON r.product_id = p.id " + " JOIN categories c ON c.id = p.category_id " + " GROUP BY p.id, p.name, p.category_id, c.name " + "), " + "category_ratings AS ( " + " SELECT category_id, " + " ROUND(AVG(product_avg_rating), 2) AS category_avg_rating " + " FROM product_ratings " + " GROUP BY category_id " + ") " + "SELECT pr.product_name, pr.category_name, " + " pr.product_avg_rating, cr.category_avg_rating " + "FROM product_ratings pr " + "JOIN category_ratings cr ON cr.category_id = pr.category_id " + "WHERE pr.product_avg_rating > cr.category_avg_rating " + "ORDER BY pr.product_avg_rating DESC, pr.product_name ASC" + ), + }, +] + + +# ───────────────────────────────────────────────────────────────────────────── +# DOMAIN: HEALTHCARE +# ───────────────────────────────────────────────────────────────────────────── + +HEALTHCARE_TEMPLATES: list[Template] = [ + + # ── EASY ──────────────────────────────────────────────────────────────── + + { + "domain": "healthcare", "difficulty": "easy", "has_order": True, + "description": "Doctors sorted by consultation fee descending", + "base_nl": "List all doctors sorted by consultation fee from highest to lowest. Return id, name, specialization, consultation_fee.", + "sql": ( + "SELECT id, name, specialization, consultation_fee " + "FROM doctors " + "ORDER BY consultation_fee DESC" + ), + }, + { + "domain": "healthcare", "difficulty": "easy", "has_order": True, + "description": "Doctors with more than 10 years experience sorted desc", + "base_nl": "Show doctors with more than 10 years of experience, sorted by experience descending. Return id, name, specialization, experience_years.", + "sql": ( + "SELECT id, name, specialization, experience_years " + "FROM doctors " + "WHERE experience_years > 10 " + "ORDER BY experience_years DESC" + ), + }, + { + "domain": "healthcare", "difficulty": "easy", "has_order": True, + "description": "Patients from India sorted by name", + "base_nl": "List all patients from India sorted alphabetically by name. Return id, name, country, blood_type.", + "sql": ( + "SELECT id, name, country, blood_type " + "FROM patients " + "WHERE country = 'India' " + "ORDER BY name ASC" + ), + }, + { + "domain": "healthcare", "difficulty": "easy", "has_order": True, + "description": "Medications with unit price under $0.20 sorted ascending", + "base_nl": "Which medications cost less than $0.20 per unit? Sort by price ascending. Return id, name, category, unit_price.", + "sql": ( + "SELECT id, name, category, unit_price " + "FROM medications " + "WHERE unit_price < 0.20 " + "ORDER BY unit_price ASC" + ), + }, + { + "domain": "healthcare", "difficulty": "easy", "has_order": True, + "description": "Top 5 most expensive medications", + "base_nl": "What are the top 5 most expensive medications? Return id, name, unit_price.", + "sql": ( + "SELECT id, name, unit_price " + "FROM medications " + "ORDER BY unit_price DESC " + "LIMIT 5" + ), + }, + { + "domain": "healthcare", "difficulty": "easy", "has_order": False, + "description": "Count of completed appointments", + "base_nl": "How many appointments have been completed? Return a single value total_completed.", + "sql": ( + "SELECT COUNT(*) AS total_completed " + "FROM appointments " + "WHERE status = 'completed'" + ), + }, + { + "domain": "healthcare", "difficulty": "easy", "has_order": True, + "description": "Severe diagnoses sorted by ICD code", + "base_nl": "List all severe diagnoses sorted by ICD code. Return id, icd_code, description, severity.", + "sql": ( + "SELECT id, icd_code, description, severity " + "FROM diagnoses " + "WHERE severity = 'severe' " + "ORDER BY icd_code ASC" + ), + }, + { + "domain": "healthcare", "difficulty": "easy", "has_order": False, + "description": "Count patients by gender", + "base_nl": "How many patients are there by gender? Return gender, patient_count.", + "sql": ( + "SELECT gender, COUNT(*) AS patient_count " + "FROM patients " + "GROUP BY gender" + ), + }, + + # ── MEDIUM ─────────────────────────────────────────────────────────────── + + { + "domain": "healthcare", "difficulty": "medium", "has_order": True, + "description": "Appointment count per doctor including those with no appointments", + "base_nl": "How many appointments has each doctor had (including those with none)? Return doctor_name, appointment_count, sorted by appointment_count descending.", + "sql": ( + "SELECT d.name AS doctor_name, COUNT(a.id) AS appointment_count " + "FROM doctors d " + "LEFT JOIN appointments a ON a.doctor_id = d.id " + "GROUP BY d.id, d.name " + "ORDER BY appointment_count DESC" + ), + }, + { + "domain": "healthcare", "difficulty": "medium", "has_order": True, + "description": "Most prescribed medications by count", + "base_nl": "Which medications are prescribed most often? Return medication_name, category, times_prescribed, sorted by times_prescribed descending.", + "sql": ( + "SELECT m.name AS medication_name, m.category, COUNT(p.id) AS times_prescribed " + "FROM medications m " + "JOIN prescriptions p ON p.medication_id = m.id " + "GROUP BY m.id, m.name, m.category " + "ORDER BY times_prescribed DESC" + ), + }, + { + "domain": "healthcare", "difficulty": "medium", "has_order": True, + "description": "Patients with more than one completed visit", + "base_nl": "Which patients have had more than one completed appointment? Return patient_name, visit_count, sorted by visit_count descending.", + "sql": ( + "SELECT pat.name AS patient_name, COUNT(DISTINCT a.id) AS visit_count " + "FROM patients pat " + "JOIN appointments a ON a.patient_id = pat.id " + "WHERE a.status = 'completed' " + "GROUP BY pat.id, pat.name " + "HAVING COUNT(DISTINCT a.id) > 1 " + "ORDER BY visit_count DESC" + ), + }, + { + "domain": "healthcare", "difficulty": "medium", "has_order": True, + "description": "Estimated revenue per doctor from completed appointments", + "base_nl": "What is the estimated total revenue per doctor from completed appointments (based on consultation fee)? Return doctor_name, specialization, estimated_revenue (rounded to 2 decimal places), sorted by estimated_revenue descending.", + "sql": ( + "SELECT d.name AS doctor_name, d.specialization, " + " ROUND(SUM(d.consultation_fee), 2) AS estimated_revenue " + "FROM doctors d " + "JOIN appointments a ON a.doctor_id = d.id " + "WHERE a.status = 'completed' " + "GROUP BY d.id, d.name, d.specialization " + "ORDER BY estimated_revenue DESC" + ), + }, + { + "domain": "healthcare", "difficulty": "medium", "has_order": True, + "description": "Diagnosis count per severity level", + "base_nl": "How many diagnoses are there per severity level? Return severity, diagnosis_count, sorted by diagnosis_count descending.", + "sql": ( + "SELECT severity, COUNT(*) AS diagnosis_count " + "FROM diagnoses " + "GROUP BY severity " + "ORDER BY diagnosis_count DESC" + ), + }, + + # ── HARD ───────────────────────────────────────────────────────────────── + + { + "domain": "healthcare", "difficulty": "hard", "has_order": True, + "description": "Doctors ranked by appointment count within specialization using RANK", + "base_nl": "Rank doctors by appointment count within their specialization (rank 1 = most appointments). Return doctor_name, specialization, appointment_count, rank_in_spec, sorted by specialization then rank_in_spec ascending.", + "sql": ( + "SELECT doctor_name, specialization, appointment_count, " + " RANK() OVER (PARTITION BY specialization ORDER BY appointment_count DESC) AS rank_in_spec " + "FROM ( " + " SELECT d.name AS doctor_name, d.specialization, COUNT(a.id) AS appointment_count " + " FROM doctors d " + " JOIN appointments a ON a.doctor_id = d.id " + " GROUP BY d.id, d.name, d.specialization " + ") sub " + "ORDER BY specialization, rank_in_spec" + ), + }, + { + "domain": "healthcare", "difficulty": "hard", "has_order": True, + "description": "Top 10 patients by total completed visits using CTE", + "base_nl": "Find the top 10 patients by number of completed appointments. Return patient_name, total_visits, last_visit, sorted by total_visits descending.", + "sql": ( + "WITH patient_visits AS ( " + " SELECT a.patient_id, COUNT(a.id) AS total_visits, " + " MAX(a.scheduled_at) AS last_visit " + " FROM appointments a " + " WHERE a.status = 'completed' " + " GROUP BY a.patient_id " + ") " + "SELECT p.name AS patient_name, pv.total_visits, pv.last_visit " + "FROM patients p " + "JOIN patient_visits pv ON pv.patient_id = p.id " + "ORDER BY pv.total_visits DESC " + "LIMIT 10" + ), + }, + { + "domain": "healthcare", "difficulty": "hard", "has_order": True, + "description": "Medications total prescription cost per category using window SUM", + "base_nl": "For each medication, show its total prescription cost (unit_price × quantity) and the running total of cost within its category. Return medication_name, category, total_cost, category_running_cost (both rounded to 2 decimal places), sorted by category then total_cost descending.", + "sql": ( + "WITH med_costs AS ( " + " SELECT m.name AS medication_name, m.category, " + " ROUND(SUM(m.unit_price * pr.quantity), 2) AS total_cost " + " FROM medications m " + " JOIN prescriptions pr ON pr.medication_id = m.id " + " GROUP BY m.id, m.name, m.category " + ") " + "SELECT medication_name, category, total_cost, " + " ROUND(SUM(total_cost) OVER (PARTITION BY category ORDER BY total_cost DESC), 2) " + " AS category_running_cost " + "FROM med_costs " + "ORDER BY category, total_cost DESC" + ), + }, +] + + +# ───────────────────────────────────────────────────────────────────────────── +# DOMAIN: FINANCE +# ───────────────────────────────────────────────────────────────────────────── + +FINANCE_TEMPLATES: list[Template] = [ + + # ── EASY ──────────────────────────────────────────────────────────────── + + { + "domain": "finance", "difficulty": "easy", "has_order": True, + "description": "Verified KYC customers sorted by name", + "base_nl": "List all customers with verified KYC status, sorted alphabetically. Return id, name, country, kyc_status.", + "sql": ( + "SELECT id, name, country, kyc_status " + "FROM fin_customers " + "WHERE kyc_status = 'verified' " + "ORDER BY name ASC" + ), + }, + { + "domain": "finance", "difficulty": "easy", "has_order": True, + "description": "Accounts with balance over $10,000 sorted by balance descending", + "base_nl": "Which accounts have a balance greater than $10,000? Return id, customer_id, account_type, balance, sorted by balance descending.", + "sql": ( + "SELECT id, customer_id, account_type, balance " + "FROM accounts " + "WHERE balance > 10000 " + "ORDER BY balance DESC" + ), + }, + { + "domain": "finance", "difficulty": "easy", "has_order": True, + "description": "Large credit transactions above $1,000 sorted by amount descending", + "base_nl": "Show all credit transactions with an amount greater than $1,000. Return id, account_id, txn_type, amount, created_at, sorted by amount descending.", + "sql": ( + "SELECT id, account_id, txn_type, amount, created_at " + "FROM transactions " + "WHERE txn_type = 'credit' AND amount > 1000 " + "ORDER BY amount DESC" + ), + }, + { + "domain": "finance", "difficulty": "easy", "has_order": True, + "description": "Defaulted loans sorted by principal amount descending", + "base_nl": "List all defaulted loans, sorted by principal amount descending. Return id, loan_type, principal_amount, interest_rate, status.", + "sql": ( + "SELECT id, loan_type, principal_amount, interest_rate, status " + "FROM loans " + "WHERE status = 'defaulted' " + "ORDER BY principal_amount DESC" + ), + }, + { + "domain": "finance", "difficulty": "easy", "has_order": False, + "description": "Count of late loan payments", + "base_nl": "How many loan payments were made late? Return a single value late_payments.", + "sql": "SELECT COUNT(*) AS late_payments FROM loan_payments WHERE is_late = 1", + }, + { + "domain": "finance", "difficulty": "easy", "has_order": True, + "description": "Top 5 highest principal loans", + "base_nl": "What are the top 5 loans by principal amount? Return id, customer_id, loan_type, principal_amount.", + "sql": ( + "SELECT id, customer_id, loan_type, principal_amount " + "FROM loans " + "ORDER BY principal_amount DESC " + "LIMIT 5" + ), + }, + { + "domain": "finance", "difficulty": "easy", "has_order": False, + "description": "Count of accounts by account type", + "base_nl": "How many accounts exist for each account type? Return account_type, account_count.", + "sql": ( + "SELECT account_type, COUNT(*) AS account_count " + "FROM accounts " + "GROUP BY account_type" + ), + }, + + # ── MEDIUM ─────────────────────────────────────────────────────────────── + + { + "domain": "finance", "difficulty": "medium", "has_order": True, + "description": "Total active account balance per customer sorted by balance descending", + "base_nl": "What is the total active account balance per customer? Return customer_name, account_count, total_balance (rounded to 2 decimal places), sorted by total_balance descending.", + "sql": ( + "SELECT fc.name AS customer_name, COUNT(a.id) AS account_count, " + " ROUND(SUM(a.balance), 2) AS total_balance " + "FROM fin_customers fc " + "JOIN accounts a ON a.customer_id = fc.id " + "WHERE a.status = 'active' " + "GROUP BY fc.id, fc.name " + "ORDER BY total_balance DESC" + ), + }, + { + "domain": "finance", "difficulty": "medium", "has_order": True, + "description": "Total credit transaction amount by account type", + "base_nl": "What is the total credit amount per account type? Return account_type, total_credits (rounded to 2 decimal places), sorted by total_credits descending.", + "sql": ( + "SELECT a.account_type, ROUND(SUM(t.amount), 2) AS total_credits " + "FROM accounts a " + "JOIN transactions t ON t.account_id = a.id " + "WHERE t.txn_type = 'credit' " + "GROUP BY a.account_type " + "ORDER BY total_credits DESC" + ), + }, + { + "domain": "finance", "difficulty": "medium", "has_order": True, + "description": "Total loan borrowing per customer sorted descending", + "base_nl": "How much has each customer borrowed in total across all loans? Return customer_name, loan_count, total_borrowed (rounded to 2 decimal places), sorted by total_borrowed descending.", + "sql": ( + "SELECT fc.name AS customer_name, COUNT(l.id) AS loan_count, " + " ROUND(SUM(l.principal_amount), 2) AS total_borrowed " + "FROM fin_customers fc " + "JOIN loans l ON l.customer_id = fc.id " + "GROUP BY fc.id, fc.name " + "ORDER BY total_borrowed DESC" + ), + }, + { + "domain": "finance", "difficulty": "medium", "has_order": True, + "description": "Late payment count and total amount by loan type", + "base_nl": "For each loan type, how many late payments were there and what was the total amount paid late? Return loan_type, late_payments, total_late_paid (rounded to 2 decimal places), sorted by late_payments descending.", + "sql": ( + "SELECT l.loan_type, COUNT(lp.id) AS late_payments, " + " ROUND(SUM(lp.amount_paid), 2) AS total_late_paid " + "FROM loans l " + "JOIN loan_payments lp ON lp.loan_id = l.id " + "WHERE lp.is_late = 1 " + "GROUP BY l.loan_type " + "ORDER BY late_payments DESC" + ), + }, + + # ── HARD ───────────────────────────────────────────────────────────────── + + { + "domain": "finance", "difficulty": "hard", "has_order": True, + "description": "Customer balance rank using DENSE_RANK on active accounts", + "base_nl": "Rank customers by their total active account balance using DENSE_RANK. Return customer_name, total_balance, balance_rank, sorted by balance_rank ascending.", + "sql": ( + "SELECT customer_name, total_balance, " + " DENSE_RANK() OVER (ORDER BY total_balance DESC) AS balance_rank " + "FROM ( " + " SELECT fc.name AS customer_name, " + " ROUND(SUM(a.balance), 2) AS total_balance " + " FROM fin_customers fc " + " JOIN accounts a ON a.customer_id = fc.id " + " WHERE a.status = 'active' " + " GROUP BY fc.id, fc.name " + ") sub " + "ORDER BY balance_rank" + ), + }, + { + "domain": "finance", "difficulty": "hard", "has_order": True, + "description": "Monthly transaction totals by type with running total using window SUM", + "base_nl": "Show monthly transaction totals per type (credit/debit) with a running cumulative total. Return month (YYYY-MM), txn_type, total, running_total (rounded to 2 decimal places), sorted by month then txn_type.", + "sql": ( + "WITH monthly_txn AS ( " + " SELECT strftime('%Y-%m', created_at) AS month, " + " txn_type, " + " ROUND(SUM(amount), 2) AS total " + " FROM transactions " + " GROUP BY strftime('%Y-%m', created_at), txn_type " + ") " + "SELECT month, txn_type, total, " + " ROUND(SUM(total) OVER (PARTITION BY txn_type ORDER BY month), 2) AS running_total " + "FROM monthly_txn " + "ORDER BY month, txn_type" + ), + }, + { + "domain": "finance", "difficulty": "hard", "has_order": True, + "description": "Customers with only defaulted loans using NOT EXISTS", + "base_nl": "Find customers who have at least one loan and ALL their loans are defaulted. Return customer_name, loan_count, sorted by customer_name ascending.", + "sql": ( + "SELECT fc.name AS customer_name, COUNT(l.id) AS loan_count " + "FROM fin_customers fc " + "JOIN loans l ON l.customer_id = fc.id " + "GROUP BY fc.id, fc.name " + "HAVING COUNT(l.id) > 0 " + " AND SUM(CASE WHEN l.status != 'defaulted' THEN 1 ELSE 0 END) = 0 " + "ORDER BY customer_name ASC" + ), + }, +] + + +# ───────────────────────────────────────────────────────────────────────────── +# DOMAIN: HR +# ───────────────────────────────────────────────────────────────────────────── + +HR_TEMPLATES: list[Template] = [ + + # ── EASY ──────────────────────────────────────────────────────────────── + + { + "domain": "hr", "difficulty": "easy", "has_order": True, + "description": "Active employees sorted by salary descending", + "base_nl": "List all active employees sorted by salary from highest to lowest. Return id, name, job_title, salary.", + "sql": ( + "SELECT id, name, job_title, salary " + "FROM employees " + "WHERE status = 'active' " + "ORDER BY salary DESC" + ), + }, + { + "domain": "hr", "difficulty": "easy", "has_order": True, + "description": "Departments sorted by budget descending", + "base_nl": "Show all departments sorted by budget from largest to smallest. Return id, name, location, budget.", + "sql": ( + "SELECT id, name, location, budget " + "FROM departments " + "ORDER BY budget DESC" + ), + }, + { + "domain": "hr", "difficulty": "easy", "has_order": True, + "description": "Employees hired in 2023 or later sorted by hire date descending", + "base_nl": "Which employees were hired on or after January 1st 2023? Sort by hire date descending. Return id, name, job_title, hire_date.", + "sql": ( + "SELECT id, name, job_title, hire_date " + "FROM employees " + "WHERE hire_date >= '2023-01-01' " + "ORDER BY hire_date DESC" + ), + }, + { + "domain": "hr", "difficulty": "easy", "has_order": True, + "description": "Active projects sorted by budget descending", + "base_nl": "Show all currently active projects sorted by budget descending. Return id, name, status, budget.", + "sql": ( + "SELECT id, name, status, budget " + "FROM projects " + "WHERE status = 'active' " + "ORDER BY budget DESC" + ), + }, + { + "domain": "hr", "difficulty": "easy", "has_order": True, + "description": "Active employees earning above $100,000 sorted by salary descending", + "base_nl": "Which active employees earn more than $100,000? Return id, name, email, job_title, sorted by salary descending.", + "sql": ( + "SELECT id, name, email, job_title " + "FROM employees " + "WHERE status = 'active' AND salary > 100000 " + "ORDER BY salary DESC" + ), + }, + { + "domain": "hr", "difficulty": "easy", "has_order": False, + "description": "Count of active employees", + "base_nl": "How many active employees do we currently have? Return active_employees.", + "sql": "SELECT COUNT(*) AS active_employees FROM employees WHERE status = 'active'", + }, + { + "domain": "hr", "difficulty": "easy", "has_order": True, + "description": "Projects with no end date (ongoing) sorted by budget descending", + "base_nl": "List all ongoing projects that have no end date set. Return id, name, start_date, budget, sorted by budget descending.", + "sql": ( + "SELECT id, name, start_date, budget " + "FROM projects " + "WHERE end_date IS NULL " + "ORDER BY budget DESC" + ), + }, + + # ── MEDIUM ─────────────────────────────────────────────────────────────── + + { + "domain": "hr", "difficulty": "medium", "has_order": True, + "description": "Headcount and average salary per department for active employees", + "base_nl": "For each department, what is the headcount and average salary of active employees? Return department_name, headcount, avg_salary (rounded to 2 decimal places), sorted by headcount descending.", + "sql": ( + "SELECT d.name AS department_name, COUNT(e.id) AS headcount, " + " ROUND(AVG(e.salary), 2) AS avg_salary " + "FROM departments d " + "LEFT JOIN employees e ON e.department_id = d.id AND e.status = 'active' " + "GROUP BY d.id, d.name " + "ORDER BY headcount DESC" + ), + }, + { + "domain": "hr", "difficulty": "medium", "has_order": True, + "description": "Average performance rating per employee sorted descending", + "base_nl": "What is the average performance review rating per active employee? Return employee_name, job_title, avg_rating (rounded to 2 decimal places), sorted by avg_rating descending.", + "sql": ( + "SELECT e.name AS employee_name, e.job_title, " + " ROUND(AVG(pr.rating), 2) AS avg_rating " + "FROM employees e " + "JOIN performance_reviews pr ON pr.employee_id = e.id " + "WHERE e.status = 'active' " + "GROUP BY e.id, e.name, e.job_title " + "ORDER BY avg_rating DESC" + ), + }, + { + "domain": "hr", "difficulty": "medium", "has_order": True, + "description": "Employees with the most total allocated project hours", + "base_nl": "Which employees have the most total hours allocated across projects? Return employee_name, total_hours, sorted by total_hours descending, top 10.", + "sql": ( + "SELECT e.name AS employee_name, SUM(pa.hours_allocated) AS total_hours " + "FROM employees e " + "JOIN project_assignments pa ON pa.employee_id = e.id " + "GROUP BY e.id, e.name " + "ORDER BY total_hours DESC " + "LIMIT 10" + ), + }, + { + "domain": "hr", "difficulty": "medium", "has_order": True, + "description": "Departments with distinct employees assigned to active projects", + "base_nl": "For each department, how many distinct employees are assigned to active projects? Return department_name, assigned_employees, sorted by assigned_employees descending.", + "sql": ( + "SELECT d.name AS department_name, " + " COUNT(DISTINCT pa.employee_id) AS assigned_employees " + "FROM departments d " + "JOIN projects p ON p.department_id = d.id " + "JOIN project_assignments pa ON pa.project_id = p.id " + "WHERE p.status = 'active' " + "GROUP BY d.id, d.name " + "ORDER BY assigned_employees DESC" + ), + }, + { + "domain": "hr", "difficulty": "medium", "has_order": True, + "description": "Total project budget per department sorted descending", + "base_nl": "What is the total project budget per department? Return department_name, total_project_budget (rounded to 2 decimal places), sorted by total_project_budget descending.", + "sql": ( + "SELECT d.name AS department_name, " + " ROUND(SUM(p.budget), 2) AS total_project_budget " + "FROM departments d " + "JOIN projects p ON p.department_id = d.id " + "GROUP BY d.id, d.name " + "ORDER BY total_project_budget DESC" + ), + }, + + # ── HARD ───────────────────────────────────────────────────────────────── + + { + "domain": "hr", "difficulty": "hard", "has_order": True, + "description": "Salary rank within department using DENSE_RANK", + "base_nl": "Rank active employees by salary within their department using DENSE_RANK (rank 1 = highest paid). Return employee_name, salary, department_name, salary_rank, sorted by department_name then salary_rank ascending.", + "sql": ( + "SELECT employee_name, salary, department_name, " + " DENSE_RANK() OVER (PARTITION BY department_name ORDER BY salary DESC) AS salary_rank " + "FROM ( " + " SELECT e.name AS employee_name, e.salary, d.name AS department_name " + " FROM employees e " + " JOIN departments d ON d.id = e.department_id " + " WHERE e.status = 'active' " + ") sub " + "ORDER BY department_name, salary_rank" + ), + }, + { + "domain": "hr", "difficulty": "hard", "has_order": True, + "description": "Employee performance band classification using CASE with avg rating CTE", + "base_nl": "Classify active employees into performance bands (High Performer: avg rating >= 4, Average: >= 3, Needs Improvement: < 3) based on their average review rating. Return employee_name, salary, avg_rating, performance_band, sorted by avg_rating descending.", + "sql": ( + "WITH avg_ratings AS ( " + " SELECT employee_id, ROUND(AVG(rating), 2) AS avg_rating " + " FROM performance_reviews " + " GROUP BY employee_id " + ") " + "SELECT e.name AS employee_name, e.salary, ar.avg_rating, " + " CASE WHEN ar.avg_rating >= 4 THEN 'High Performer' " + " WHEN ar.avg_rating >= 3 THEN 'Average' " + " ELSE 'Needs Improvement' " + " END AS performance_band " + "FROM employees e " + "JOIN avg_ratings ar ON ar.employee_id = e.id " + "WHERE e.status = 'active' " + "ORDER BY ar.avg_rating DESC" + ), + }, + { + "domain": "hr", "difficulty": "hard", "has_order": True, + "description": "Employees above their department average salary using CTE", + "base_nl": "Find active employees whose salary is above their department's average. Return employee_name, department_name, salary, dept_avg_salary (rounded to 2 decimal places), sorted by salary descending.", + "sql": ( + "WITH dept_avg AS ( " + " SELECT department_id, ROUND(AVG(salary), 2) AS dept_avg_salary " + " FROM employees " + " WHERE status = 'active' " + " GROUP BY department_id " + ") " + "SELECT e.name AS employee_name, d.name AS department_name, " + " e.salary, da.dept_avg_salary " + "FROM employees e " + "JOIN departments d ON d.id = e.department_id " + "JOIN dept_avg da ON da.department_id = e.department_id " + "WHERE e.status = 'active' AND e.salary > da.dept_avg_salary " + "ORDER BY e.salary DESC" + ), + }, +] + + +# ───────────────────────────────────────────────────────────────────────────── +# MASTER TEMPLATE REGISTRY +# ───────────────────────────────────────────────────────────────────────────── + +ALL_TEMPLATES: list[Template] = ( + ECOMMERCE_TEMPLATES + + HEALTHCARE_TEMPLATES + + FINANCE_TEMPLATES + + HR_TEMPLATES +) + +TEMPLATES_BY_DOMAIN: dict[str, list[Template]] = { + "ecommerce": ECOMMERCE_TEMPLATES, + "healthcare": HEALTHCARE_TEMPLATES, + "finance": FINANCE_TEMPLATES, + "hr": HR_TEMPLATES, +} + +TEMPLATES_BY_DIFFICULTY: dict[str, list[Template]] = { + "easy": [t for t in ALL_TEMPLATES if t["difficulty"] == "easy"], + "medium": [t for t in ALL_TEMPLATES if t["difficulty"] == "medium"], + "hard": [t for t in ALL_TEMPLATES if t["difficulty"] == "hard"], +} + + +def template_stats() -> dict: + stats: dict = {"total": len(ALL_TEMPLATES), "by_domain": {}, "by_difficulty": {}} + for d in ["ecommerce","healthcare","finance","hr"]: + stats["by_domain"][d] = len(TEMPLATES_BY_DOMAIN[d]) + for diff in ["easy","medium","hard"]: + stats["by_difficulty"][diff] = len(TEMPLATES_BY_DIFFICULTY[diff]) + return stats diff --git a/data_factory/validator.py b/data_factory/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..6ef25dba97daddae264b463e147bd1de1e9d1fba --- /dev/null +++ b/data_factory/validator.py @@ -0,0 +1,221 @@ +""" +data_factory/validator.py +========================== +SQL execution validation layer. + +GUARANTEE: Every record that passes this validator has a SQL that: + 1. Runs without error against the actual seeded SQLite schema + 2. Returns at least one row (non-empty result) + 3. Returns the expected column names + +No LLM-generated SQL ever reaches this validator — SQL always comes from +the human-verified template library. This validator is an extra safety net +to catch any copy-paste or formatting regressions. +""" + +from __future__ import annotations + +import sqlite3 +from dataclasses import dataclass, field +from typing import Any, Optional + +from data_factory.schemas import build_connection, SCHEMA_CONTEXT +from data_factory.templates import Template + + +# ───────────────────────────────────────────────────────────────────────────── +# DATA CLASSES +# ───────────────────────────────────────────────────────────────────────────── + +@dataclass +class ValidationResult: + passed: bool + sql: str + error: Optional[str] = None + row_count: int = 0 + columns: list[str] = field(default_factory=list) + + +@dataclass +class DataRecord: + """One training example ready to be written to JSONL/Parquet.""" + domain: str + difficulty: str + sql: str + nl_question: str # The NL paraphrase used as prompt + persona: str # ceo | chatty | lazy_typist | non_techie | analyst | augmented + has_order: bool + schema_context: str + row_count: int # From validation run + columns: list[str] # From validation run + source: str # "template_base" | "vllm_persona" | "rule_augmented" + template_id: int # Index into ALL_TEMPLATES + + def to_training_dict(self) -> dict[str, Any]: + """ + Returns the dictionary that will be written to the output dataset. + + Format is compatible with TRL / HuggingFace `datasets`: + prompt : chat-format messages list (system + user) + sql : ground-truth SQL (label / reward reference) + metadata: auxiliary fields for curriculum or filtering + """ + system_msg = ( + "You are an expert SQL analyst. " + "Write a single SELECT query that answers the question. " + "Output ONLY the SQL query — no markdown, no explanation, no backticks." + ) + user_msg = ( + f"DATABASE SCHEMA\n" + f"---------------\n" + f"{self.schema_context}\n\n" + f"QUESTION: {self.nl_question}" + ) + return { + "prompt": [ + {"role": "system", "content": system_msg}, + {"role": "user", "content": user_msg}, + ], + "sql": self.sql, + "metadata": { + "domain": self.domain, + "difficulty": self.difficulty, + "persona": self.persona, + "has_order": self.has_order, + "row_count": self.row_count, + "columns": self.columns, + "source": self.source, + "template_id": self.template_id, + }, + } + + +# ───────────────────────────────────────────────────────────────────────────── +# VALIDATOR +# ───────────────────────────────────────────────────────────────────────────── + +class SQLValidator: + """ + Validates SQL against a seeded in-memory SQLite connection. + + One validator per domain to reuse the same connection for all templates + in that domain (performance optimization). + """ + + def __init__(self, domain: str, seed: int = 42) -> None: + self.domain = domain + self._conn = build_connection(domain, seed=seed) + + def validate(self, sql: str) -> ValidationResult: + """ + Execute SQL and return a ValidationResult. + Never raises — always returns a result object. + """ + sql = sql.strip().rstrip(";") + if not sql: + return ValidationResult(passed=False, sql=sql, error="Empty SQL string.") + + # Block any write operations + first_word = sql.split()[0].lower() if sql.split() else "" + forbidden = {"insert","update","delete","drop","alter","create","replace","truncate","pragma"} + if first_word in forbidden: + return ValidationResult( + passed=False, sql=sql, + error=f"Write operation '{first_word.upper()}' is not permitted." + ) + + try: + cur = self._conn.execute(sql) + cols = [d[0] for d in cur.description] if cur.description else [] + rows = cur.fetchall() + return ValidationResult( + passed=True, + sql=sql, + row_count=len(rows), + columns=cols, + ) + except sqlite3.Error as exc: + return ValidationResult(passed=False, sql=sql, error=str(exc)) + + def close(self) -> None: + self._conn.close() + + +def validate_template(template: Template, seed: int = 42) -> ValidationResult: + """Convenience function: validate a single template.""" + v = SQLValidator(template["domain"], seed=seed) + result = v.validate(template["sql"]) + v.close() + return result + + +def validate_all_templates(templates: list[Template], seed: int = 42) -> dict[str, Any]: + """ + Run validation across all templates. Returns a summary dict. + Used during CI / smoke testing. + """ + from data_factory.schemas import SCHEMA_MAP + + validators = {domain: SQLValidator(domain, seed) for domain in SCHEMA_MAP} + passed = [] + failed = [] + + for i, t in enumerate(templates): + v = validators[t["domain"]] + result = v.validate(t["sql"]) + if result.passed: + passed.append(i) + else: + failed.append({"index": i, "domain": t["domain"], + "sql": t["sql"][:80], "error": result.error}) + + for v in validators.values(): + v.close() + + return { + "total": len(templates), + "passed": len(passed), + "failed": len(failed), + "failures": failed, + } + + +def build_record( + template: Template, + template_idx: int, + nl_question: str, + persona: str, + source: str, + validator: SQLValidator, +) -> Optional[DataRecord]: + """ + Validate the template SQL and, if it passes, build a DataRecord. + + Parameters + ---------- + template : The source template (contains SQL, domain, difficulty). + template_idx : Index of template in ALL_TEMPLATES (for deduplication). + nl_question : The NL paraphrase to use as the prompt. + persona : Which persona/strategy generated this NL. + source : 'template_base' | 'vllm_persona' | 'rule_augmented' + validator : Pre-built SQLValidator for this domain. + + Returns None if validation fails. + """ + vr = validator.validate(template["sql"]) + if not vr.passed: + return None + + return DataRecord( + domain=template["domain"], + difficulty=template["difficulty"], + sql=template["sql"], + nl_question=nl_question, + persona=persona, + has_order=template["has_order"], + schema_context=SCHEMA_CONTEXT[template["domain"]], + row_count=vr.row_count, + columns=vr.columns, + source=source, + template_id=template_idx, + ) diff --git a/env_server b/env_server new file mode 100644 index 0000000000000000000000000000000000000000..a5c9b5a9e669562b52e45750f4fd69c1b5107ef4 --- /dev/null +++ b/env_server @@ -0,0 +1,33 @@ +import os +import sys +from fastapi import FastAPI, Request +import uvicorn + +sys.path.insert(0, "./server") +from environment import NL2SQLEnvironment +from models import NL2SQLAction + +app = FastAPI() +env = NL2SQLEnvironment() + +@app.post("/reset") +async def reset(request: Request): + data = await request.json() + # Now we take task_name directly from the API call + task_name = data.get("task_name", "simple-filter") + print(f"🔄 Environment Resetting for Task: {task_name}") + obs = env.reset(task_name=task_name) + return {"observation": obs.__dict__} + +@app.post("/step") +async def step(request: Request): + data = await request.json() + query = data.get("query", "") + print(f"⏩ Executing SQL: {query[:60]}...") + + action = NL2SQLAction(query=query) + obs = env.step(action) + return {"observation": obs.__dict__} + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/folder.txt b/folder.txt new file mode 100644 index 0000000000000000000000000000000000000000..90a2f5905a94548ed2777d878e2262b39a2269ee --- /dev/null +++ b/folder.txt @@ -0,0 +1,95 @@ +. +├── check_quality.py +├── clean_dataset.py +├── client.py +├── custom_train.py +├── data_expander.py +├── data_factory +│   ├── augmentor.py +│   ├── config.py +│   ├── generate_data.py +│   ├── generator.py +│   ├── __init__.py +│   ├── pipeline.py +│   ├── run_data_factory.py +│   ├── schemas.py +│   ├── templates.py +│   └── validator.py +├── Dockerfile +├── edge_cases.jsonl +├── env_server +├── folder.txt +├── generate_data.py +├── generate_edge_cases.py +├── inference.py +├── __init__.py +├── llm_hybrid_templates.json +├── local_test.py +├── merge_model.py +├── mini_server.py +├── models.py +├── nl2sql_50k_elite_dataset_1.jsonl +├── nl2sql_50k_elite_dataset.jsonl +├── nl2sql_cleaned_ready_to_train.jsonl +├── nl2sql_merged_final.jsonl +├── openenv.yaml +├── pyproject.toml +├── qwen-7b-coder-nl2sql-grpo +│   ├── checkpoint-70 +│   │   ├── adapter_config.json +│   │   ├── adapter_model.safetensors +│   │   ├── chat_template.jinja +│   │   ├── optimizer.pt +│   │   ├── README.md +│   │   ├── rng_state_0.pth +│   │   ├── rng_state_1.pth +│   │   ├── scheduler.pt +│   │   ├── tokenizer_config.json +│   │   ├── tokenizer.json +│   │   ├── trainer_state.json +│   │   └── training_args.bin +│   ├── final +│   │   ├── adapter_config.json +│   │   ├── adapter_model.safetensors +│   │   ├── chat_template.jinja +│   │   ├── README.md +│   │   ├── tokenizer_config.json +│   │   └── tokenizer.json +│   └── README.md +├── qwen-7b-coder-nl2sql-grpo-v2 +├── qwen-7b-nl2sql-merged +│   ├── chat_template.jinja +│   ├── config.json +│   ├── generation_config.json +│   ├── model.safetensors +│   ├── tokenizer_config.json +│   └── tokenizer.json +├── README.md +├── scripts +│   ├── run_local.sh +│   └── smoke_test.sh +├── server +│   ├── app.py +│   ├── db +│   │   ├── __init__.py +│   │   ├── schema.sql +│   │   └── seed.py +│   ├── environment.py +│   ├── grader.py +│   ├── __init__.py +│   ├── requirements.txt +│   └── tasks +│   ├── base.py +│   ├── easy.py +│   ├── hard.py +│   ├── __init__.py +│   └── medium.py +├── swapped_templates.json +├── tests +│   ├── conftest.py +│   ├── __init__.py +│   └── test_all.py +├── train.py +└── value_swapper.py + +11 directories, 81 files diff --git a/generate_data.py b/generate_data.py new file mode 100644 index 0000000000000000000000000000000000000000..c532f4738011dfc6bd358b66fe783cc569fbfb76 --- /dev/null +++ b/generate_data.py @@ -0,0 +1,263 @@ +import os +import sys +import json +import torch +import hashlib +from pathlib import Path +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +# GPU CONFIG - All 4 H100s engaged +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,7" + +PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + +from data_factory.schemas import SCHEMA_CONTEXT +from data_factory.validator import SQLValidator + +# CONFIG +MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct" +TARGET_TEMPLATES = 10000 +OUTPUT_FILE = "llm_10k_base_templates.json" +BATCH_SIZE = 64 + +PROMPT_TEMPLATE = """ +You are a senior expert in SQLite schema design and NL2SQL dataset generation. + +TASK +Generate exactly 10 UNIQUE, COMPLEX, and FULLY VALID SQLite SQL SELECT queries for the given schema. +For each query, also write a natural language question that a real user might ask. + +HARD RULES +- Output ONLY a valid JSON array. +- Do NOT wrap output in markdown, code fences, or explanations. +- Every item must be a JSON object with exactly these keys: + - "sql" + - "base_nl" + - "difficulty" + - "has_order" +- All SQL must be a single SELECT statement. +- Do NOT use INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, PRAGMA, ATTACH, DETACH, or any DDL/DML. +- Every table and column used in SQL must exist in the provided schema. +- Do NOT invent columns, tables, aliases, or constraints. +- SQL must be valid for SQLite. +- Prefer queries that are meaningfully different from each other. +- Avoid repetitive templates. +- Each SQL should test a different reasoning pattern. +- Each base_nl should sound natural and distinct from the others. +- Use advanced SQL patterns where appropriate: + - multiple JOINs + - CTEs + - subqueries + - window functions such as ROW_NUMBER, RANK, DENSE_RANK, LAG, LEAD + - GROUP BY and HAVING + - conditional aggregation + - anti-joins / exclusion logic + - top-N per group + - time-based filtering +- Exactly 3 of the 10 queries must be "easy" (basic filtering, simple lookups, 1-2 tables). +- Exactly 3 of the 10 queries must be "medium" (moderate complexity, standard JOINs, basic aggregation). +- Exactly 4 of the 10 queries must be genuinely "hard" (advanced patterns, CTEs, subqueries, window functions). +- Ensure the "difficulty" key strictly contains one of these exact string values: "easy", "medium", or "hard". + +QUALITY TARGETS +- The SQL should be executable as written. +- The question should be answerable from the schema alone. +- Prefer business-like, realistic analytics questions. +- Prefer queries that require combining 2 to 4 tables. +- If a query uses aggregation, ensure the NL clearly implies aggregation. +- If a query uses ordering, include "has_order": true. +- If a query does not require ordering, set "has_order": false. +- Make the 10 queries cover diverse intent types: + 1. ranking + 2. comparison against average or median + 3. top/bottom-N + 4. grouped aggregation + 5. time filtering + 6. multi-join analysis + 7. exclusion / NOT EXISTS + 8. window-function based analysis + 9. conditional counting + 10. trend or interval-based logic + +SCHEMA +{schema} + +OUTPUT FORMAT +Return ONLY a valid JSON array of 10 objects. + +Example structure: +[ + {{ + "sql": "SELECT ...", + "base_nl": "Show ...", + "difficulty": "hard", + "has_order": true + }} +] + +FINAL SELF-CHECK BEFORE RESPONDING +- Confirm the output is valid JSON. +- Confirm there are exactly 10 objects. +- Confirm every SQL is a single SELECT. +- Confirm no hallucinated schema elements exist. +- Confirm the 10 questions are not paraphrases of each other. +""" + +def extract_json(raw_text): + text = raw_text.strip() + if text.startswith("```json"): + text = text[7:-3].strip() + elif text.startswith("```"): + text = text[3:-3].strip() + start = text.find("[") + end = text.rfind("]") + if start != -1 and end != -1: + return text[start:end+1] + return None + +def main(): + print("Loading Model Qwen-72B (SDPA) for 10K Mining...") + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + custom_max_memory = { + 0: "60GiB", # System GPU 0 (Has 13GB used, ~67GB free) + 1: "75GiB", # System GPU 1 (Fully free) + 2: "75GiB", # System GPU 2 (Fully free) + 3: "75GiB", # System GPU 3 (Fully free) + 4: "75GiB", # System GPU 4 (Fully free) + 5: "45GiB" # System GPU 7 (Has 25GB used, ~55GB free) + } + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + device_map="auto", + max_memory = custom_max_memory, + torch_dtype=torch.bfloat16, + attn_implementation="sdpa" + ) + + domains = list(SCHEMA_CONTEXT.keys()) + valid_templates = [] + seen_sql_hashes = set() + + # Resume support: Load existing templates to prevent duplicates + if os.path.exists(OUTPUT_FILE): + with open(OUTPUT_FILE, "r") as f: + valid_templates = json.load(f) + for t in valid_templates: + seen_sql_hashes.add(hashlib.md5(t["sql"].lower().encode()).hexdigest()) + + pbar = tqdm(total=TARGET_TEMPLATES, initial=len(valid_templates), desc="Mining 10K Base Templates") + + validators = {} + domain_idx = 0 + + while len(valid_templates) < TARGET_TEMPLATES: + batch_prompts = [] + batch_domains = [] + + # Prepare Batch + for _ in range(BATCH_SIZE): + domain = domains[domain_idx % len(domains)] + schema_string = SCHEMA_CONTEXT[domain] + domain_idx += 1 + + messages = [ + {"role": "system", "content": "You output only valid JSON arrays. Do not include markdown."}, + {"role": "user", "content": PROMPT_TEMPLATE.format(schema=schema_string)} + ] + chat_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + batch_prompts.append(chat_text) + batch_domains.append(domain) + + inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True).to(model.device) + + try: + tqdm.write(f"\n[DEBUG] Sending batch of {BATCH_SIZE} to model.generate(). Please wait...") + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=5000, + do_sample=True, + temperature=0.55, + top_p=0.9, + pad_token_id=tokenizer.eos_token_id + ) + tqdm.write("[DEBUG] Model generation finished. Decoding responses...") + + # Output Slicing + input_length = inputs.input_ids.shape[1] + generated_tokens = outputs[:, input_length:] + responses = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + + batch_added = 0 + for i, (response, domain) in enumerate(zip(responses, batch_domains)): + tqdm.write(f"\n[DEBUG] Processing Response {i+1}/{BATCH_SIZE} for domain: {domain}") + + json_text = extract_json(response) + if not json_text: + tqdm.write(f"[DEBUG] extract_json failed. Raw text snippet: {response[:200]}...") + continue + + try: + generated_data = json.loads(json_text) + tqdm.write(f"[DEBUG] JSON loaded successfully. Found {len(generated_data)} items.") + except Exception as e: + tqdm.write(f"[DEBUG] json.loads failed. Error: {e}") + tqdm.write(f"[DEBUG] Bad JSON snippet: {json_text[:200]}...") + continue + + if domain not in validators: + validators[domain] = SQLValidator(domain, seed=42) + validator = validators[domain] + + for item in generated_data: + if not isinstance(item, dict): continue + + sql = item.get("sql", "").strip() + if not sql: continue + + # Check for duplicates using hash + sql_hash = hashlib.md5(sql.lower().encode()).hexdigest() + if sql_hash in seen_sql_hashes: + tqdm.write("[DEBUG] Duplicate query skipped.") + continue + + val_result = validator.validate(sql) + + # Hard validation rule: SQL must execute AND return rows + if val_result.passed and val_result.row_count > 0: + tqdm.write(f"[DEBUG] SQL Passed (Rows: {val_result.row_count}): {sql[:50]}...") + item["domain"] = domain + item["id"] = f"base_{len(valid_templates)}" + valid_templates.append(item) + seen_sql_hashes.add(sql_hash) + batch_added += 1 + else: + tqdm.write(f"[DEBUG] SQL Failed Validation or 0 Rows (Passed: {val_result.passed}, Rows: {val_result.row_count}): {sql[:50]}...") + + if batch_added > 0: + pbar.update(batch_added) + tqdm.write(f"[DEBUG] Auto-saving {batch_added} new templates to JSON...") + # Auto-save after every successful batch + with open(OUTPUT_FILE, "w") as f: + json.dump(valid_templates, f, indent=2) + + if len(valid_templates) >= TARGET_TEMPLATES: + break + + except Exception as e: + tqdm.write(f"\n[DEBUG] CRITICAL EXCEPTION CAUGHT: {e}") + continue + + # Close validators + for v in validators.values(): + v.close() + + pbar.close() + print(f"\nBoom! Generated {len(valid_templates)} Elite Base Templates!") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/generate_edge_cases.py b/generate_edge_cases.py new file mode 100644 index 0000000000000000000000000000000000000000..9345125b54864c2b253b754c23043014cfac50c6 --- /dev/null +++ b/generate_edge_cases.py @@ -0,0 +1,319 @@ +""" +generate_edge_cases.py +====================== +Targeted edge-case data generator for the 4 failure patterns found in eval: + 1. ROW_NUMBER vs RANK vs DENSE_RANK (tie-breaking semantics) + 2. strftime month as INTEGER (not '%Y-%m' string) + 3. SELECT column discipline (no unrequested extras) + 4. LAG/LEAD period-over-period + 5. HAVING vs WHERE placement + 6. COUNT(DISTINCT) vs COUNT + +Produces: edge_cases.jsonl (same chat format as nl2sql_cleaned_ready_to_train.jsonl) +Run: python generate_edge_cases.py +""" + +import os, sys, json, re, hashlib +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig, BitsAndBytesConfig +import torch +import transformers.activations +# Yeh line AutoAWQ ko bewakoof banayegi taaki wo crash na ho +if not hasattr(transformers.activations, 'PytorchGELUTanh'): + transformers.activations.PytorchGELUTanh = transformers.activations.NewGELUActivation +os.environ["CUDA_VISIBLE_DEVICES"] = "3,1,6,7" + +PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + +from data_factory.schemas import SCHEMA_CONTEXT + +quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4" +) + +MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct" +OUTPUT_FILE = "edge_cases.jsonl" +BATCH_SIZE = 8 # smaller — edge prompts are long +SAMPLES_PER_PATTERN = 715 # ~6 batches per pattern → 5005 total edge samples + +SYSTEM_PROMPT = ( + "You are a Senior SQL Architect. " + "Output ONLY the SQL query. Use SQLite syntax." +) + +# ── Edge-case prompt templates ────────────────────────────────────────────── +# Each entry: (pattern_tag, user_prompt_template) +# {schema} is filled at runtime with a random domain schema. + +EDGE_PATTERNS = [ + + # 1. ROW_NUMBER tie-breaking — the #1 failure + ("row_number_tiebreak", """SCHEMA: +{schema} + +Generate exactly 8 NL2SQL pairs that REQUIRE ROW_NUMBER() (not RANK or DENSE_RANK) \ +because the question explicitly says "pick one winner when there is a tie" \ +using a tiebreaker column (e.g. lower id, earlier date). + +Output ONLY a valid JSON array: +[ + {{"nl": "...", "sql": "SELECT ..."}}, + ... +] + +Rules: +- Every SQL must use ROW_NUMBER() OVER (...) not RANK(). +- The OVER clause ORDER BY must include the tiebreaker column. +- WHERE rn = 1 must appear in an outer query or CTE. +- No markdown. No explanation. Just the JSON array."""), + + # 2. RANK / DENSE_RANK — when ties SHOULD persist + ("rank_dense_rank", """SCHEMA: +{schema} + +Generate exactly 8 NL2SQL pairs where RANK() or DENSE_RANK() is the CORRECT choice \ +because the question says "show all tied records at the same rank". + +Output ONLY a valid JSON array: +[ + {{"nl": "...", "sql": "SELECT ..."}}, + ... +] + +Rules: +- Use RANK() when question implies gaps after ties, DENSE_RANK() when no gaps. +- NL must make the tie-semantics explicit ("same rank", "tied positions"). +- No markdown. No explanation. Just the JSON array."""), + + # 3. strftime integer month output + ("strftime_integer_month", """SCHEMA: +{schema} + +Generate exactly 8 NL2SQL pairs where the question asks for a numeric month number \ +(1–12), NOT a 'YYYY-MM' string. + +Output ONLY a valid JSON array: +[ + {{"nl": "...", "sql": "SELECT ..."}}, + ... +] + +Rules: +- SQL must use CAST(strftime('%m', ) AS INTEGER) to produce integer month. +- Do NOT use strftime('%Y-%m', ...) when the question asks for month number. +- NL questions must say "month number", "which month (1–12)", or similar. +- No markdown. No explanation. Just the JSON array."""), + + # 4. SELECT column discipline + ("select_column_discipline", """SCHEMA: +{schema} + +Generate exactly 8 NL2SQL pairs where the question explicitly names ONLY the columns \ +to return. The SQL must select EXACTLY those columns — no extras like avg_salary, \ +row counts, or intermediate aggregates. + +Output ONLY a valid JSON array: +[ + {{"nl": "...", "sql": "SELECT ..."}}, + ... +] + +Rules: +- NL must say "return only X, Y, Z" or "show me only the name and total". +- SQL SELECT list must contain only those columns. +- If aggregation is needed internally (e.g. for HAVING), do NOT expose it in SELECT. +- No markdown. No explanation. Just the JSON array."""), + + # 5. LAG / LEAD period-over-period + ("lag_lead_period", """SCHEMA: +{schema} + +Generate exactly 8 NL2SQL pairs that require LAG() or LEAD() window functions \ +for period-over-period comparison (e.g. month-over-month revenue change, \ +previous order amount, next appointment date). + +Output ONLY a valid JSON array: +[ + {{"nl": "...", "sql": "SELECT ..."}}, + ... +] + +Rules: +- Use LAG(, 1) OVER (ORDER BY ...) or LEAD(...) correctly. +- NL must imply comparison with previous or next row/period. +- No markdown. No explanation. Just the JSON array."""), + + # 6. HAVING vs WHERE + ("having_vs_where", """SCHEMA: +{schema} + +Generate exactly 8 NL2SQL pairs that test correct placement of filter conditions: +- Conditions on raw columns → WHERE +- Conditions on aggregates → HAVING +Include 4 pairs where a wrong model might put an aggregate condition in WHERE (trap). + +Output ONLY a valid JSON array: +[ + {{"nl": "...", "sql": "SELECT ..."}}, + ... +] + +Rules: +- SQL must never filter an aggregate (COUNT, SUM, AVG) inside WHERE. +- SQL must never put a raw column filter inside HAVING. +- No markdown. No explanation. Just the JSON array."""), + + # 7. COUNT(DISTINCT) vs COUNT + ("count_distinct", """SCHEMA: +{schema} + +Generate exactly 8 NL2SQL pairs where the question specifically asks for \ +"unique", "distinct", or "different" counts — requiring COUNT(DISTINCT col). +Also include 2 pairs where COUNT(*) is correct to reinforce the contrast. + +Output ONLY a valid JSON array: +[ + {{"nl": "...", "sql": "SELECT ..."}}, + ... +] + +Rules: +- When NL says "unique/distinct", SQL must use COUNT(DISTINCT ). +- When NL says "total orders placed" (not distinct), use COUNT(*) or COUNT(id). +- No markdown. No explanation. Just the JSON array."""), +] + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +def extract_json_array(text: str) -> str: + text = text.strip() + # strip code fences if model leaks them + text = re.sub(r"```(?:json)?\n?(.*?)```", r"\1", text, flags=re.DOTALL).strip() + s, e = text.find("["), text.rfind("]") + return text[s:e+1] if s != -1 and e != -1 else "[]" + +def get_hash(text: str) -> str: + return hashlib.md5(text.lower().strip().encode()).hexdigest() + +def build_record(nl: str, sql: str, domain: str) -> dict: + return { + "prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {nl}"} + ], + "sql": sql + } + + +# ── Main ───────────────────────────────────────────────────────────────────── + +def main(): + print(f"Loading {MODEL_NAME}...") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + custom_memory = {0:"30GiB",1:"75GiB",2:"45GiB",3:"45GiB"} + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + device_map="auto", + max_memory=custom_memory, + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + attn_implementation = "sdpa" + ) + + domains = list(SCHEMA_CONTEXT.keys()) + seen = set() + total = 0 + + out = open(OUTPUT_FILE, "a", encoding="utf-8") + + for pattern_tag, prompt_tmpl in EDGE_PATTERNS: + print(f"\n[PATTERN] {pattern_tag}") + collected = 0 + domain_idx = 0 + pbar = tqdm(total=SAMPLES_PER_PATTERN, desc=pattern_tag) + + while collected < SAMPLES_PER_PATTERN: + # Build a batch of prompts, cycling through domains + batch_domains = [] + batch_prompts = [] + for _ in range(BATCH_SIZE): + domain = domains[domain_idx % len(domains)] + domain_idx += 1 + user_msg = prompt_tmpl.format(schema=SCHEMA_CONTEXT[domain]) + msgs = [ + {"role": "system", "content": "You output only valid JSON arrays. No markdown."}, + {"role": "user", "content": user_msg} + ] + batch_prompts.append( + tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) + ) + batch_domains.append(domain) + + inputs = tokenizer( + batch_prompts, return_tensors="pt", padding=True, truncation=True + ).to(model.device) + + try: + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=2048, + do_sample=True, + temperature=0.5, + top_p=0.9, + pad_token_id=tokenizer.eos_token_id + ) + responses = tokenizer.batch_decode( + outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True + ) + + for resp, domain in zip(responses, batch_domains): + raw = extract_json_array(resp) + try: + pairs = json.loads(raw) + except Exception: + continue + + for pair in pairs: + nl = pair.get("nl", "").strip() + sql = pair.get("sql", "").strip() + if not nl or not sql: + continue + # strip fences in sql just in case + sql = re.sub(r"```(?:sql)?\n?(.*?)```", r"\1", sql, flags=re.DOTALL).strip() + + h = get_hash(nl + sql) + if h in seen: + continue + seen.add(h) + + record = build_record(nl, sql, domain) + out.write(json.dumps(record, ensure_ascii=False) + "\n") + out.flush() + collected += 1 + total += 1 + pbar.update(1) + + if collected >= SAMPLES_PER_PATTERN: + break + + except Exception as e: + tqdm.write(f"[WARN] Batch failed: {e}") + continue + + pbar.close() + + out.close() + print(f"\nDone! {total} edge-case records saved to {OUTPUT_FILE}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd04c68d124ed3efa6f78352446d9319ee744e9 --- /dev/null +++ b/inference.py @@ -0,0 +1,265 @@ +""" +inference.py — NL2SQL-Bench Baseline Inference Script +======================================================== + +MANDATORY COMPLIANCE +-------------------- +- Named `inference.py`, placed in project root. +- Uses OpenAI client for all LLM calls. +- Reads: API_BASE_URL, MODEL_NAME, HF_TOKEN from environment. +- Emits [START] / [STEP] / [END] lines to stdout in the exact format below. +- Runs all 3 tasks; total runtime < 20 min on 2 vCPU / 8 GB. + +STDOUT FORMAT (exact — any deviation breaks scoring) +---------------------------------------------------- +[START] task= env=nl2sql-bench model= +[STEP] step= action= reward=<0.00> done= error= +[END] success= steps= score=<0.000> rewards= +""" + +from __future__ import annotations + +import asyncio +import os +import sys +import textwrap +from typing import List, Optional + +from openai import OpenAI + +# ── Configuration ────────────────────────────────────────────────────────── +API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") +MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") +API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "") +IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest") +SPACE_URL = os.getenv("SPACE_URL", "http://localhost:8000") + +BENCHMARK = "nl2sql-bench" +MAX_STEPS = 5 +TEMPERATURE = 0.2 # Low temp for SQL generation +MAX_TOKENS = 512 +SUCCESS_THRESHOLD = 0.7 # score >= 0.7 → success + +TASKS = ["simple-filter", "join-aggregation", "analytics-window"] + +# ── System prompt ────────────────────────────────────────────────────────── +SYSTEM_PROMPT = textwrap.dedent(""" +You are an expert SQL analyst working with a SQLite e-commerce database. + +DATABASE SCHEMA +--------------- +categories(id, name) +products(id, name, category_id, price, stock_quantity) +customers(id, name, email, country, tier∈{bronze|silver|gold}, created_at) +orders(id, customer_id, status∈{pending|processing|shipped|delivered|cancelled}, + created_at, total_amount) +order_items(id, order_id, product_id, quantity, unit_price) +reviews(id, product_id, customer_id, rating∈1-5, created_at) + +RULES +----- +1. Write a single SELECT query — no INSERT/UPDATE/DELETE. +2. Output ONLY the SQL query, nothing else. No markdown, no explanation. +3. Use SQLite syntax: strftime('%Y-%m', date_col) for month, ROUND(x, 2) for decimals. +4. Window functions (RANK, DENSE_RANK, ROW_NUMBER, running SUM) are supported. +5. CTEs (WITH ... AS (...)) are supported. +6. If you receive an error, fix it carefully in your next attempt. +7. If you receive partial results, refine your query to match the expected output. +""").strip() + + +# ── Stdout logging (mandatory format) ───────────────────────────────────── + +def log_start(task: str, model: str) -> None: + print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True) + + +def log_step( + step: int, action: str, reward: float, done: bool, error: Optional[str] +) -> None: + # Collapse multi-line SQL to single line for log compliance + action_single = " ".join(action.split()) + error_val = error.replace("\n", " ") if error else "null" + print( + f"[STEP] step={step} action={action_single!r} " + f"reward={reward:.2f} done={str(done).lower()} error={error_val}", + flush=True, + ) + + +def log_end( + success: bool, steps: int, score: float, rewards: List[float] +) -> None: + rewards_str = ",".join(f"{r:.2f}" for r in rewards) + print( + f"[END] success={str(success).lower()} steps={steps} " + f"score={score:.3f} rewards={rewards_str}", + flush=True, + ) + + +# ── LLM interaction ──────────────────────────────────────────────────────── + +def build_user_prompt( + question: str, + schema_context: str, + step: int, + last_query: str, + last_error: Optional[str], + last_result: list, + result_columns: list, +) -> str: + parts = [f"QUESTION: {question}", ""] + + if step > 1: + parts.append(f"Your previous SQL (step {step - 1}):") + parts.append(f" {' '.join(last_query.split())}") + parts.append("") + if last_error: + parts.append(f"ERROR: {last_error}") + elif last_result: + preview = str(last_result[:3]).replace("\n", " ") + parts.append(f"RESULT PREVIEW (first 3 rows): {preview}") + parts.append(f"COLUMNS: {result_columns}") + parts.append("") + parts.append("Please correct or refine your query.") + else: + parts.append("Write a SQL query to answer the question.") + + return "\n".join(parts) + + +def call_llm(client: OpenAI, user_prompt: str) -> str: + try: + resp = client.chat.completions.create( + model=MODEL_NAME, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ], + temperature=TEMPERATURE, + max_tokens=MAX_TOKENS, + stream=False, + ) + text = (resp.choices[0].message.content or "").strip() + # Strip markdown code fences if model wraps in ```sql ... ``` + if text.startswith("```"): + lines = text.split("\n") + text = "\n".join( + l for l in lines + if not l.strip().startswith("```") + ).strip() + return text if text else "SELECT 1" + except Exception as exc: + print(f"[DEBUG] LLM call failed: {exc}", file=sys.stderr, flush=True) + return "SELECT 1" + + +# ── Single-task episode ──────────────────────────────────────────────────── + +async def run_task(client: OpenAI, env, task_name: str) -> dict: + """Run one full episode for the given task. Returns result dict.""" + rewards: List[float] = [] + steps_taken = 0 + score = 0.0 + success = False + + log_start(task_name, MODEL_NAME) + + try: + # Reset — pass task_name via action payload or query param + # OpenEnv reset() may not accept task args via HTTP; we rely on + # NL2SQL_DEFAULT_TASK env-var being set before calling, OR we + # pass it as a reset parameter if the server supports it. + result = await env.reset() + obs = result.observation + + for step in range(1, MAX_STEPS + 1): + if result.done: + break + + user_prompt = build_user_prompt( + question=obs.question, + schema_context=obs.schema_context, + step=step, + last_query=obs.last_query, + last_error=obs.last_error, + last_result=obs.last_result, + result_columns=obs.result_columns, + ) + + sql = call_llm(client, user_prompt) + + from client import NL2SQLAction # local to avoid circular at module level + result = await env.step(NL2SQLAction(query=sql)) + obs = result.observation + + reward = obs.reward or 0.0 + done = obs.done + error = obs.last_error + + rewards.append(reward) + steps_taken = step + + log_step(step=step, action=sql, reward=reward, done=done, error=error) + + if done: + break + + # Compute final score + score = sum(rewards) / max(len(rewards), 1) + score = round(min(max(score, 0.0), 1.0), 4) + success = score >= SUCCESS_THRESHOLD + + except Exception as exc: + print(f"[DEBUG] Episode error for {task_name}: {exc}", file=sys.stderr, flush=True) + finally: + log_end(success=success, steps=steps_taken, score=score, rewards=rewards) + + return {"task": task_name, "success": success, "score": score, "rewards": rewards} + + +# ── Main ─────────────────────────────────────────────────────────────────── + +async def main() -> None: + client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) + + # Import here to avoid import errors if openenv not installed during lint + from client import NL2SQLEnv + + all_results = [] + + for task_name in TASKS: + # Set the default task for the server session via env-var approach. + # For the hosted Space, we rely on the task cycling implemented in + # the task registry's round-robin iterator. + os.environ["NL2SQL_DEFAULT_TASK"] = task_name + + try: + async with NL2SQLEnv(base_url=SPACE_URL) as env: + result = await run_task(client, env, task_name) + all_results.append(result) + except Exception as exc: + print( + f"[DEBUG] Failed to connect for task {task_name}: {exc}", + file=sys.stderr, + flush=True, + ) + # Emit a zero-score END to keep log format valid + log_end(success=False, steps=0, score=0.0, rewards=[]) + all_results.append({"task": task_name, "success": False, "score": 0.0}) + + # Summary to stderr (not scored, for human readability) + print("\n=== Baseline Summary ===", file=sys.stderr) + for r in all_results: + print( + f" {r['task']:20s} score={r['score']:.3f} " + f"success={r['success']}", + file=sys.stderr, + ) + avg = sum(r["score"] for r in all_results) / max(len(all_results), 1) + print(f" {'AVERAGE':20s} score={avg:.3f}", file=sys.stderr) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/local_test.py b/local_test.py new file mode 100644 index 0000000000000000000000000000000000000000..85b2c1758ba000e730c6dda1c434ff642320402f --- /dev/null +++ b/local_test.py @@ -0,0 +1,118 @@ +import asyncio +import os +import sys +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import PeftModel + +# --- Configuration --- +BASE_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" +LORA_DIR = "./qwen-nl2sql-grpo/checkpoint-50" +SPACE_URL = "http://localhost:8000" # Local server URL +TASKS = ["simple-filter", "join-aggregation", "analytics-window"] +MAX_STEPS = 5 + +print("Loading Base Model and LoRA weights...") +tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) +base_model = AutoModelForCausalLM.from_pretrained( + BASE_MODEL, + torch_dtype=torch.bfloat16, + device_map="auto" +) +model = PeftModel.from_pretrained(base_model, LORA_DIR) + +# --- System Prompt & LLM Call --- +SYSTEM_PROMPT = """You are an expert SQL analyst working with a SQLite e-commerce database. +Write a single SELECT query. Output ONLY the SQL query, nothing else. No markdown.""" + +def call_local_llm(user_prompt: str) -> str: + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt} + ] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = tokenizer([text], return_tensors="pt").to(model.device) + + with torch.no_grad(): + outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.2, do_sample=True) + + response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) + + # Strip markdown code fences if model wraps in ```sql ... ``` + if response.startswith("```"): + lines = response.split("\n") + response = "\n".join(l for l in lines if not l.strip().startswith("```")).strip() + return response if response else "SELECT 1" + +def build_user_prompt(question, schema_context, step, last_query, last_error, last_result, result_columns): + parts = [f"QUESTION: {question}", ""] + if step > 1: + parts.append(f"Your previous SQL (step {step - 1}):") + parts.append(f" {' '.join(last_query.split())}") + parts.append("") + if last_error: + parts.append(f"ERROR: {last_error}") + elif last_result: + preview = str(last_result[:3]).replace("\n", " ") + parts.append(f"RESULT PREVIEW (first 3 rows): {preview}") + parts.append(f"COLUMNS: {result_columns}") + parts.append("") + parts.append("Please correct or refine your query.") + else: + parts.append("Write a SQL query to answer the question.") + return "\n".join(parts) + +async def main(): + from client import NL2SQLEnv, NL2SQLAction + + all_results = [] + + for task_name in TASKS: + print(f"\n--- Starting Task: {task_name} ---") + os.environ["NL2SQL_DEFAULT_TASK"] = task_name + + try: + async with NL2SQLEnv(base_url=SPACE_URL) as env: + result = await env.reset() + obs = result.observation + + rewards = [] + success = False + + for step in range(1, MAX_STEPS + 1): + if obs.done: + break + + user_prompt = build_user_prompt( + obs.question, obs.schema_context, step, + obs.last_query, obs.last_error, obs.last_result, obs.result_columns + ) + + sql = call_local_llm(user_prompt) + + print(f"Step {step} Agent Output: {sql}") + + step_result = await env.step(NL2SQLAction(query=sql)) + obs = step_result.observation + + reward = obs.reward or 0.0 + rewards.append(reward) + print(f"Step {step} Reward: {reward}") + + if obs.done: + break + + score = sum(rewards) / max(len(rewards), 1) + success = score >= 0.7 + print(f"Final Score for {task_name}: {score:.3f}") + all_results.append({"task": task_name, "score": score, "success": success}) + + except Exception as e: + print(f"Error testing task {task_name}: {e}") + + print("\n=== Final Results ===") + for r in all_results: + print(f"{r['task']}: Score {r['score']:.3f} | Success: {r['success']}") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/merge_model.py b/merge_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc19c32142762c7f87862c9fd0642d2b7834ea1 --- /dev/null +++ b/merge_model.py @@ -0,0 +1,32 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import PeftModel + +BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct" +ADAPTER_DIR = "./qwen-7b-coder-nl2sql-grpo/final" +OUTPUT_DIR = "./qwen-7b-nl2sql-merged" + +def main(): + print("Loading Base Model...") + base_model = AutoModelForCausalLM.from_pretrained( + BASE_MODEL, + torch_dtype=torch.bfloat16, + device_map="auto" + ) + + print("Loading Adapters and Merging...") + # Load the LoRA adapters into the base model + model = PeftModel.from_pretrained(base_model, ADAPTER_DIR) + + # Merge weights permanently + merged_model = model.merge_and_unload() + + print("Saving Merged Model...") + merged_model.save_pretrained(OUTPUT_DIR) + + tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) + tokenizer.save_pretrained(OUTPUT_DIR) + print(f"Done! Merged model saved to {OUTPUT_DIR}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mini_server.py b/mini_server.py new file mode 100644 index 0000000000000000000000000000000000000000..5cd8d0a798b08aca5d9670a50e2a3818f02ae188 --- /dev/null +++ b/mini_server.py @@ -0,0 +1,74 @@ +import os +import torch +import uvicorn +from fastapi import FastAPI +from pydantic import BaseModel +from typing import List +from transformers import AutoModelForCausalLM, AutoTokenizer + +# CRITICAL: GPU 0 pe host karenge +os.environ["CUDA_VISIBLE_DEVICES"] = "7" + +app = FastAPI() + +# Tera Merged Model Path +MODEL_PATH = "./qwen-7b-nl2sql-merged" + +print("🚀 Loading Local Model for Inference API... (Takes a minute)") +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) +model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + device_map="auto", + torch_dtype=torch.bfloat16, + attn_implementation="sdpa" # Super stable, no vLLM crashes +) +print("✅ Server Ready! Acting as OpenAI on Port 8000.") + +# OpenAI Request Schemas +class Message(BaseModel): + role: str + content: str + +class ChatRequest(BaseModel): + model: str + messages: List[Message] + temperature: float = 0.2 + max_tokens: int = 512 + +@app.post("/v1/chat/completions") +async def chat(request: ChatRequest): + # Convert OpenAI messages to Qwen format + messages = [{"role": m.role, "content": m.content} for m in request.messages] + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + # Generate SQL + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=request.max_tokens, + temperature=request.temperature, + do_sample=True if request.temperature > 0 else False, + pad_token_id=tokenizer.eos_token_id + ) + + # Decode only the newly generated text + response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) + + # Return EXACT OpenAI JSON Structure + return { + "id": "chatcmpl-local-hackathon", + "object": "chat.completion", + "created": 1700000000, + "model": request.model, + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": response_text}, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + } + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8001) \ No newline at end of file diff --git a/models.py b/models.py new file mode 100644 index 0000000000000000000000000000000000000000..03468bd6e564861a381db881f785494e04034a6f --- /dev/null +++ b/models.py @@ -0,0 +1,79 @@ +""" +nl2sql-bench/models.py +====================== +Typed contracts for the NL2SQL-Bench OpenEnv environment. + +Action : The SQL query the agent submits. +Observation : What the agent sees after each step. +State : Episode-level metadata (for state() endpoint). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from openenv.core.env_server import Action, Observation, State + + +# --------------------------------------------------------------------------- +# Action +# --------------------------------------------------------------------------- + +class NL2SQLAction(Action): + """A single SQL query submitted by the agent.""" + query: str = "" + + +# --------------------------------------------------------------------------- +# Observation +# --------------------------------------------------------------------------- + + +class NL2SQLObservation(Observation): + """ + Everything the agent needs to reason about and iterate its SQL query. + + Fields + ------ + question : The natural-language question to answer. + schema_context : Relevant table/column descriptions as a string block. + task_name : Identifier of the current task (easy / medium / hard). + last_query : The SQL the agent submitted on the last step (empty on reset). + last_result : Up to 10 rows returned by the last query (list of dicts). + last_error : SQLite error string if the query failed, else None. + result_columns : Column names of last_result rows. + step : Current step number (1-indexed). + max_steps : Maximum steps allowed per episode. + done : True when the episode is over (success or step exhausted). + reward : Reward for the most recent action (None on reset). + score : Normalised cumulative score so far [0.0, 1.0]. + """ + question: str = "" + schema_context: str = "" + task_name: str = "" + last_query: str = "" + last_result: List[Dict[str, Any]] = field(default_factory=list) + last_error: Optional[str] = None + result_columns: List[str] = field(default_factory=list) + step: int = 0 + max_steps: int = 5 + done: bool = False + reward: Optional[float] = None + score: float = 0.0 + + +# --------------------------------------------------------------------------- +# State +# --------------------------------------------------------------------------- + +class NL2SQLState(State): + """Episode-level state (returned by the /state endpoint).""" + episode_id: Optional[str] = None + step_count: int = 0 + task_name: str = "" + task_difficulty: str = "" # easy | medium | hard + question: str = "" + best_reward: float = 0.0 # highest reward seen this episode + cumulative_reward: float = 0.0 + solved: bool = False # True if exact match was achieved diff --git a/openenv.yaml b/openenv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..923ca254a2cb2e03d80b68307f7adebdc745b469 --- /dev/null +++ b/openenv.yaml @@ -0,0 +1,135 @@ +# nl2sql-bench/openenv.yaml +# OpenEnv environment manifest — validated by `openenv validate` + +name: nl2sql-bench +version: "0.1.0" +description: > + Natural Language to SQL query generation environment for RL training. + An agent iteratively writes and refines SQLite queries against a synthetic + e-commerce database to answer business questions. Multi-turn episodes with + dense, shaped rewards. Three difficulty tasks: easy (single-table), + medium (JOIN + GROUP BY), hard (window functions + CTEs). + +author: "nl2sql-bench team" +license: MIT +tags: + - openenv + - nl2sql + - sql + - analytics + - rl-training + - deterministic + - multi-turn + +# ── Task definitions ──────────────────────────────────────────────────────── +tasks: + - name: simple-filter + difficulty: easy + description: > + Single-table SELECT with WHERE, ORDER BY, and LIMIT. + Tests basic SQL fluency. Expected solve rate: high. + max_steps: 5 + reward_range: [0.0, 1.0] + + - name: join-aggregation + difficulty: medium + description: > + Multi-table JOINs with GROUP BY, HAVING, and aggregation functions + (COUNT, SUM, AVG, ROUND). Tests relational reasoning. + max_steps: 5 + reward_range: [0.0, 1.0] + + - name: analytics-window + difficulty: hard + description: > + Advanced analytics using CTEs, window functions (DENSE_RANK, + ROW_NUMBER, running SUM), and nested subqueries. Tests multi-step + planning and SQLite-specific syntax. + max_steps: 5 + reward_range: [0.0, 1.0] + +# ── Action / Observation space ────────────────────────────────────────────── +action_space: + type: object + properties: + query: + type: string + description: "A SQLite SELECT query string." + +observation_space: + type: object + properties: + question: + type: string + description: "Natural-language question the agent must answer." + schema_context: + type: string + description: "Compact database schema description for the agent." + task_name: + type: string + description: "Active task identifier." + last_query: + type: string + description: "The SQL query submitted on the previous step." + last_result: + type: array + description: "Up to 10 rows returned by the last query (list of dicts)." + last_error: + type: string + nullable: true + description: "SQLite error string if last query failed, else null." + result_columns: + type: array + description: "Column names of last_result." + step: + type: integer + description: "Current step number (1-indexed; 0 after reset)." + max_steps: + type: integer + description: "Maximum steps per episode." + done: + type: boolean + description: "True when episode ends (exact match or step limit reached)." + reward: + type: number + nullable: true + description: "Reward for the most recent step [0.0, 1.0]." + score: + type: number + description: "Normalised cumulative episode score [0.0, 1.0]." + +# ── Reward function description ───────────────────────────────────────────── +reward: + type: shaped + range: [0.0, 1.0] + components: + - name: syntax_ok + weight: 0.10 + description: "Query executes without SQLite error." + - name: columns_match + weight: 0.20 + description: "Returned column names match ground truth exactly." + - name: row_count_match + weight: 0.20 + description: "Number of returned rows matches ground truth." + - name: exact_match + weight: 0.50 + description: "Full result set matches ground truth (order-aware for ORDER BY)." + - name: step_penalty + weight: -0.05 + description: "Deducted per step beyond the first (encourages efficiency)." + +# ── Deployment ────────────────────────────────────────────────────────────── +server: + port: 7860 + dockerfile: Dockerfile + healthcheck: /health + +# ── Baseline ──────────────────────────────────────────────────────────────── +baseline: + script: inference.py + model: Qwen/Qwen2.5-72B-Instruct + expected_scores: + simple-filter: 0.70 + join-aggregation: 0.45 + analytics-window: 0.25 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..1477d945b340e122f91561e7ef4630894422a8ac --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,39 @@ +[build-system] +requires = ["setuptools>=69", "wheel"] +build-backend = "setuptools.backends.legacy:build" + +[project] +name = "nl2sql-bench" +version = "0.1.0" +description = "NL2SQL-Bench: Natural Language to SQL Analytics OpenEnv environment for RL training" +readme = "README.md" +requires-python = ">=3.10" +license = { text = "MIT" } + +dependencies = [ + "openenv-core>=0.2.3", + "fastapi>=0.110.0", + "uvicorn[standard]>=0.29.0", + "pydantic>=2.0.0", + "openai>=1.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "pytest-asyncio>=0.23", + "httpx>=0.27", + "black", + "ruff", +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["nl2sql*", "server*"] + +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.pytest.ini_options] +asyncio_mode = "auto" diff --git a/scripts/run_local.sh b/scripts/run_local.sh new file mode 100644 index 0000000000000000000000000000000000000000..33ca17a1310dbed4f5afe83b15f73b76ada573ce --- /dev/null +++ b/scripts/run_local.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash +# nl2sql-bench/scripts/run_local.sh +# ───────────────────────────────────────────────────────────────────────────── +# Quick local development server (no Docker needed). +# Prerequisites: Python 3.10+, pip or uv +# +# Usage: +# chmod +x scripts/run_local.sh +# ./scripts/run_local.sh +# ───────────────────────────────────────────────────────────────────────────── +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$REPO_ROOT" + +echo "═══════════════════════════════════════════════" +echo " NL2SQL-Bench — Local Dev Server" +echo "═══════════════════════════════════════════════" + +# ── Check Python ──────────────────────────────────────────────────────────── +if ! command -v python3 &>/dev/null; then + echo "ERROR: python3 not found. Install Python 3.10+." && exit 1 +fi +PY_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') +echo "Python: $PY_VERSION" + +# ── Virtual environment ────────────────────────────────────────────────────── +if [ ! -d ".venv" ]; then + echo "Creating virtualenv..." + python3 -m venv .venv +fi +source .venv/bin/activate + +# ── Install deps ───────────────────────────────────────────────────────────── +echo "Installing dependencies..." +pip install -q --upgrade pip +pip install -q openenv-core fastapi "uvicorn[standard]" openai pydantic pytest pytest-asyncio + +# ── Load .env if present ───────────────────────────────────────────────────── +if [ -f ".env" ]; then + echo "Loading .env..." + set -a + source .env + set +a +fi + +# ── Export PYTHONPATH ───────────────────────────────────────────────────────── +export PYTHONPATH="$REPO_ROOT:$REPO_ROOT/server" + +echo "" +echo "Starting server at http://localhost:8000" +echo " /reset → POST (start episode)" +echo " /step → POST (submit SQL)" +echo " /state → GET (episode metadata)" +echo " /health → GET (liveness probe)" +echo " /docs → GET (Swagger UI)" +echo "" +echo "Press Ctrl+C to stop." +echo "───────────────────────────────────────────────" + +cd "$REPO_ROOT/server" +uvicorn app:app \ + --host 0.0.0.0 \ + --port 8000 \ + --reload \ + --reload-dir "$REPO_ROOT" \ + --log-level info diff --git a/scripts/smoke_test.sh b/scripts/smoke_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..c630d13e35e6afdd947accff5922d7eb167acb80 --- /dev/null +++ b/scripts/smoke_test.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +# nl2sql-bench/scripts/smoke_test.sh +# ───────────────────────────────────────────────────────────────────────────── +# Smoke tests against a running server (local or HF Space). +# Verifies all endpoints return expected HTTP codes and JSON shapes. +# +# Usage: +# ./scripts/smoke_test.sh # default localhost:8000 +# ./scripts/smoke_test.sh https://your.hf.space # HF Space URL +# ───────────────────────────────────────────────────────────────────────────── +set -euo pipefail + +BASE_URL="${1:-http://localhost:8000}" +BASE_URL="${BASE_URL%/}" +PASS=0; FAIL=0 + +GREEN='\033[0;32m'; RED='\033[0;31m'; NC='\033[0m'; BOLD='\033[1m' + +pass() { echo -e "${GREEN}✓${NC} $1"; PASS=$((PASS+1)); } +fail() { echo -e "${RED}✗${NC} $1"; FAIL=$((FAIL+1)); } + +echo "" +echo -e "${BOLD}NL2SQL-Bench Smoke Tests${NC}" +echo "Target: $BASE_URL" +echo "────────────────────────────────────────" + +# ── /health ────────────────────────────────────────────────────────────────── +CODE=$(curl -s -o /dev/null -w "%{http_code}" "$BASE_URL/health") +[ "$CODE" = "200" ] && pass "/health → 200" || fail "/health → $CODE (expected 200)" + +# ── /reset ─────────────────────────────────────────────────────────────────── +RESET_BODY=$(curl -s -X POST "$BASE_URL/reset" \ + -H "Content-Type: application/json" -d '{}') +echo "$RESET_BODY" | grep -q "question" && pass "/reset → has 'question' field" \ + || fail "/reset → missing 'question' field. Body: $RESET_BODY" + +# ── /step (valid SQL) ───────────────────────────────────────────────────────── +STEP_BODY=$(curl -s -X POST "$BASE_URL/step" \ + -H "Content-Type: application/json" \ + -d '{"query": "SELECT id, name FROM customers LIMIT 3"}') +echo "$STEP_BODY" | grep -q "reward" && pass "/step valid SQL → has 'reward'" \ + || fail "/step valid SQL → missing 'reward'. Body: $STEP_BODY" +echo "$STEP_BODY" | grep -q '"done"' && pass "/step valid SQL → has 'done'" \ + || fail "/step valid SQL → missing 'done'. Body: $STEP_BODY" + +# ── /step (syntax error SQL) ────────────────────────────────────────────────── +STEP_ERR=$(curl -s -X POST "$BASE_URL/step" \ + -H "Content-Type: application/json" \ + -d '{"query": "SELCT * FORM broken_tbl"}') +echo "$STEP_ERR" | grep -q "last_error" && pass "/step bad SQL → has 'last_error'" \ + || fail "/step bad SQL → missing 'last_error'. Body: $STEP_ERR" + +# ── /state ──────────────────────────────────────────────────────────────────── +STATE_BODY=$(curl -s "$BASE_URL/state") +echo "$STATE_BODY" | grep -q "step_count" && pass "/state → has 'step_count'" \ + || fail "/state → missing 'step_count'. Body: $STATE_BODY" + +echo "────────────────────────────────────────" +echo -e "${BOLD}Results: ${GREEN}${PASS} passed${NC}, ${RED}${FAIL} failed${NC}" +echo "" + +[ "$FAIL" -eq 0 ] && exit 0 || exit 1 diff --git a/server/__init__.py b/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..807a203fc4b92a6657c78c286059816a5784087b --- /dev/null +++ b/server/__init__.py @@ -0,0 +1 @@ +# server/__init__.py diff --git a/server/app.py b/server/app.py new file mode 100644 index 0000000000000000000000000000000000000000..8a4fdc4bc02fe773a8e12ac66f289574c95c46ff --- /dev/null +++ b/server/app.py @@ -0,0 +1,31 @@ +""" +nl2sql-bench/server/app.py +============================ +FastAPI application entry point for the NL2SQL-Bench OpenEnv server. + +create_fastapi_app() auto-creates all required OpenEnv endpoints: + POST /reset — start a new episode + POST /step — submit an action + GET /state — retrieve episode state + GET /health — health check + GET /web — interactive web UI (if ENABLE_WEB_INTERFACE=true) + GET /docs — Swagger UI +""" + +import sys +from pathlib import Path +from openenv.core.env_server import create_fastapi_app +from environment import NL2SQLEnvironment + +# Ensure models can be imported from the parent directory +_HERE = Path(__file__).parent +sys.path.insert(0, str(_HERE.parent)) + +from models import NL2SQLAction, NL2SQLObservation + +# Pass the explicitly required action and observation classes +app = create_fastapi_app( + NL2SQLEnvironment, + action_cls=NL2SQLAction, + observation_cls=NL2SQLObservation +) \ No newline at end of file diff --git a/server/app.py.bak b/server/app.py.bak new file mode 100644 index 0000000000000000000000000000000000000000..8a4fdc4bc02fe773a8e12ac66f289574c95c46ff --- /dev/null +++ b/server/app.py.bak @@ -0,0 +1,31 @@ +""" +nl2sql-bench/server/app.py +============================ +FastAPI application entry point for the NL2SQL-Bench OpenEnv server. + +create_fastapi_app() auto-creates all required OpenEnv endpoints: + POST /reset — start a new episode + POST /step — submit an action + GET /state — retrieve episode state + GET /health — health check + GET /web — interactive web UI (if ENABLE_WEB_INTERFACE=true) + GET /docs — Swagger UI +""" + +import sys +from pathlib import Path +from openenv.core.env_server import create_fastapi_app +from environment import NL2SQLEnvironment + +# Ensure models can be imported from the parent directory +_HERE = Path(__file__).parent +sys.path.insert(0, str(_HERE.parent)) + +from models import NL2SQLAction, NL2SQLObservation + +# Pass the explicitly required action and observation classes +app = create_fastapi_app( + NL2SQLEnvironment, + action_cls=NL2SQLAction, + observation_cls=NL2SQLObservation +) \ No newline at end of file diff --git a/server/db/__init__.py b/server/db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31ddf848df68f6a967cf376a35e6d0138ffc338d --- /dev/null +++ b/server/db/__init__.py @@ -0,0 +1 @@ +# server/db/__init__.py diff --git a/server/db/schema.sql b/server/db/schema.sql new file mode 100644 index 0000000000000000000000000000000000000000..f6316cc34e1a258cae56838a21d3e9fa8c740810 --- /dev/null +++ b/server/db/schema.sql @@ -0,0 +1,62 @@ +-- nl2sql-bench/server/db/schema.sql +-- E-commerce database schema for NL2SQL-Bench +-- Designed for in-memory SQLite: realistic, universally understood domain. + +CREATE TABLE IF NOT EXISTS categories ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL UNIQUE +); + +CREATE TABLE IF NOT EXISTS products ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + category_id INTEGER NOT NULL REFERENCES categories(id), + price REAL NOT NULL CHECK(price >= 0), + stock_quantity INTEGER NOT NULL DEFAULT 0 +); + +CREATE TABLE IF NOT EXISTS customers ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT NOT NULL UNIQUE, + country TEXT NOT NULL, + tier TEXT NOT NULL DEFAULT 'bronze' -- bronze | silver | gold + CHECK(tier IN ('bronze', 'silver', 'gold')), + created_at TEXT NOT NULL -- ISO-8601 date string +); + +CREATE TABLE IF NOT EXISTS orders ( + id INTEGER PRIMARY KEY, + customer_id INTEGER NOT NULL REFERENCES customers(id), + status TEXT NOT NULL DEFAULT 'pending' + CHECK(status IN ('pending','processing','shipped','delivered','cancelled')), + created_at TEXT NOT NULL, + total_amount REAL NOT NULL CHECK(total_amount >= 0) +); + +CREATE TABLE IF NOT EXISTS order_items ( + id INTEGER PRIMARY KEY, + order_id INTEGER NOT NULL REFERENCES orders(id), + product_id INTEGER NOT NULL REFERENCES products(id), + quantity INTEGER NOT NULL CHECK(quantity > 0), + unit_price REAL NOT NULL CHECK(unit_price >= 0) +); + +CREATE TABLE IF NOT EXISTS reviews ( + id INTEGER PRIMARY KEY, + product_id INTEGER NOT NULL REFERENCES products(id), + customer_id INTEGER NOT NULL REFERENCES customers(id), + rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5), + created_at TEXT NOT NULL +); + +-- Indexes for common join/filter patterns +CREATE INDEX IF NOT EXISTS idx_products_category ON products(category_id); +CREATE INDEX IF NOT EXISTS idx_orders_customer ON orders(customer_id); +CREATE INDEX IF NOT EXISTS idx_orders_status ON orders(status); +CREATE INDEX IF NOT EXISTS idx_orders_created ON orders(created_at); +CREATE INDEX IF NOT EXISTS idx_order_items_order ON order_items(order_id); +CREATE INDEX IF NOT EXISTS idx_order_items_product ON order_items(product_id); +CREATE INDEX IF NOT EXISTS idx_reviews_product ON reviews(product_id); +CREATE INDEX IF NOT EXISTS idx_customers_country ON customers(country); +CREATE INDEX IF NOT EXISTS idx_customers_tier ON customers(tier); diff --git a/server/db/seed.py b/server/db/seed.py new file mode 100644 index 0000000000000000000000000000000000000000..0f4e2159d2c4088c81035aaccf5e028eddfce65c --- /dev/null +++ b/server/db/seed.py @@ -0,0 +1,225 @@ +""" +nl2sql-bench/server/db/seed.py +============================== +Deterministic synthetic data generator for the NL2SQL-Bench SQLite database. + +Uses a fixed random seed so every fresh environment build produces +IDENTICAL data, which is essential for reproducible grader scores across +different machines, runs, and Docker containers. + +Call: seed_database(conn) once after creating tables. +""" + +from __future__ import annotations + +import random +import sqlite3 +from datetime import date, timedelta +from typing import List + +# ── Deterministic seed ──────────────────────────────────────────────────── +SEED = 42 +RNG = random.Random(SEED) + +# ── Domain constants ────────────────────────────────────────────────────── +CATEGORIES = [ + "Electronics", "Clothing", "Books", "Home & Garden", + "Sports & Outdoors", "Toys & Games", "Beauty", "Automotive", +] + +PRODUCT_NAMES = { + "Electronics": ["Wireless Headphones", "USB-C Hub", "Mechanical Keyboard", + "Webcam 4K", "Portable Charger", "Smart Speaker", + "Monitor Stand", "HDMI Cable 2.1"], + "Clothing": ["Cotton T-Shirt", "Slim Fit Jeans", "Hoodie", + "Running Shorts", "Winter Jacket", "Polo Shirt", + "Casual Sneakers", "Wool Socks"], + "Books": ["Clean Code", "Designing Data-Intensive Applications", + "The Pragmatic Programmer", "System Design Interview", + "Deep Learning Book", "Python Cookbook", + "Domain-Driven Design", "Refactoring"], + "Home & Garden": ["Coffee Maker", "Air Purifier", "LED Desk Lamp", + "Plant Pot Set", "Storage Organiser", "Cutting Board", + "Vacuum Cleaner", "Electric Kettle"], + "Sports & Outdoors":["Yoga Mat", "Resistance Bands", "Cycling Gloves", + "Trekking Poles", "Water Bottle 1L", "Jump Rope", + "Foam Roller", "Compression Socks"], + "Toys & Games": ["Lego City Set", "Card Game Pack", "Puzzle 1000pc", + "Remote Control Car", "Building Blocks", + "Board Game Strategy", "Art Set", "Toy Drone"], + "Beauty": ["Face Serum", "SPF 50 Sunscreen", "Lip Balm", + "Shampoo Pro", "Hair Mask", "Eye Cream", + "Vitamin C Cream", "Toner Mist"], + "Automotive": ["Car Phone Mount", "Dash Cam", "Tyre Inflator", + "Car Vacuum", "Seat Cushion", "Steering Wheel Cover", + "OBD Scanner", "Jump Starter"], +} + +COUNTRIES = ["India", "USA", "Germany", "UK", "Canada", + "Australia", "France", "Brazil", "Japan", "Singapore"] + +TIERS = ["bronze", "silver", "gold"] +STATUSES = ["pending", "processing", "shipped", "delivered", "cancelled"] + +FIRST_NAMES = [ + "Aarav","Priya","Rahul","Neha","Arjun","Sneha","Vikram","Pooja", + "Karthik","Divya","James","Sarah","Michael","Emily","David","Jessica", + "Hans","Lena","Oliver","Sofia","Pierre","Amelie","Carlos","Laura", + "Yuki","Hana","Wei","Mei","Aiden","Zara", +] +LAST_NAMES = [ + "Sharma","Singh","Patel","Kumar","Gupta","Verma","Nair","Reddy", + "Smith","Johnson","Brown","Williams","Jones","Davis","Wilson", + "Müller","Schmidt","Schneider","Fischer","Weber", + "Martin","Bernard","Thomas","Richard","Petit", + "Garcia","Martinez","Lopez","Sanchez","Gonzalez", +] + + +def _random_date(start_year: int = 2022, end_year: int = 2025) -> str: + start = date(start_year, 1, 1) + end = date(end_year, 12, 31) + delta = (end - start).days + return (start + timedelta(days=RNG.randint(0, delta))).isoformat() + + +def seed_database(conn: sqlite3.Connection) -> None: + """Populate the database with deterministic synthetic data.""" + conn.execute("PRAGMA foreign_keys = ON") + cur = conn.cursor() + + # ── Categories ──────────────────────────────────────────────────────── + for i, name in enumerate(CATEGORIES, 1): + cur.execute( + "INSERT OR IGNORE INTO categories(id, name) VALUES (?, ?)", + (i, name), + ) + + # ── Products (8 per category → 64 total) ───────────────────────────── + pid = 1 + for cat_id, (cat_name, names) in enumerate(PRODUCT_NAMES.items(), 1): + for pname in names: + price = round(RNG.uniform(5.0, 250.0), 2) + stock = RNG.randint(0, 500) + cur.execute( + "INSERT OR IGNORE INTO products(id, name, category_id, price, stock_quantity) " + "VALUES (?, ?, ?, ?, ?)", + (pid, pname, cat_id, price, stock), + ) + pid += 1 + + # ── Customers (150 total) ───────────────────────────────────────────── + used_emails: set = set() + for cid in range(1, 151): + fname = RNG.choice(FIRST_NAMES) + lname = RNG.choice(LAST_NAMES) + name = f"{fname} {lname}" + email_base = f"{fname.lower()}.{lname.lower()}" + email = f"{email_base}{cid}@example.com" + while email in used_emails: + email = f"{email_base}{cid}x@example.com" + used_emails.add(email) + + # Bias: 60% bronze, 30% silver, 10% gold + tier = RNG.choices(TIERS, weights=[60, 30, 10])[0] + country = RNG.choice(COUNTRIES) + created = _random_date(2021, 2023) + cur.execute( + "INSERT OR IGNORE INTO customers(id, name, email, country, tier, created_at) " + "VALUES (?, ?, ?, ?, ?, ?)", + (cid, name, email, country, tier, created), + ) + + # ── Orders + Order items ────────────────────────────────────────────── + oid = 1 + item_id = 1 + for cid in range(1, 151): + # Each customer has 0–8 orders; gold customers tend to have more + tier_row = cur.execute( + "SELECT tier FROM customers WHERE id=?", (cid,) + ).fetchone() + tier = tier_row[0] if tier_row else "bronze" + n_orders = RNG.choices( + range(9), + weights=[5, 20, 20, 15, 15, 10, 8, 5, 2] if tier == "bronze" + else ([2, 10, 15, 20, 20, 15, 10, 5, 3] if tier == "silver" + else [1, 5, 10, 15, 20, 20, 15, 10, 4]), + )[0] + + for _ in range(n_orders): + status = RNG.choices(STATUSES, weights=[5, 10, 15, 60, 10])[0] + order_date = _random_date(2022, 2025) + # Pick 1–4 products for this order + n_items = RNG.randint(1, 4) + chosen_pids = RNG.sample(range(1, 65), k=min(n_items, 64)) + total = 0.0 + + cur.execute( + "INSERT OR IGNORE INTO orders(id, customer_id, status, created_at, total_amount) " + "VALUES (?, ?, ?, ?, ?)", + (oid, cid, status, order_date, 0.0), # update total after items + ) + + for cpid in chosen_pids: + qty = RNG.randint(1, 5) + price_row = cur.execute( + "SELECT price FROM products WHERE id=?", (cpid,) + ).fetchone() + unit_price = price_row[0] if price_row else 10.0 + total += round(qty * unit_price, 2) + cur.execute( + "INSERT OR IGNORE INTO order_items(id, order_id, product_id, quantity, unit_price) " + "VALUES (?, ?, ?, ?, ?)", + (item_id, oid, cpid, qty, unit_price), + ) + item_id += 1 + + cur.execute( + "UPDATE orders SET total_amount=? WHERE id=?", + (round(total, 2), oid), + ) + oid += 1 + + # ── Reviews ─────────────────────────────────────────────────────────── + # Each customer reviews 0–6 products they (may have) ordered + rev_id = 1 + reviewed: set = set() # (customer_id, product_id) pairs + for cid in range(1, 151): + n_reviews = RNG.randint(0, 6) + for _ in range(n_reviews): + rpid = RNG.randint(1, 64) + if (cid, rpid) in reviewed: + continue + reviewed.add((cid, rpid)) + rating = RNG.choices([1, 2, 3, 4, 5], weights=[5, 10, 15, 35, 35])[0] + rev_date = _random_date(2022, 2025) + cur.execute( + "INSERT OR IGNORE INTO reviews(id, product_id, customer_id, rating, created_at) " + "VALUES (?, ?, ?, ?, ?)", + (rev_id, rpid, cid, rating, rev_date), + ) + rev_id += 1 + + conn.commit() + + +def get_db_summary(conn: sqlite3.Connection) -> dict: + """Return row counts per table for debugging / README stats.""" + tables = ["categories", "products", "customers", "orders", "order_items", "reviews"] + summary = {} + for t in tables: + row = conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone() + summary[t] = row[0] if row else 0 + return summary + + +if __name__ == "__main__": + import os + schema_path = os.path.join(os.path.dirname(__file__), "schema.sql") + conn = sqlite3.connect(":memory:") + conn.row_factory = sqlite3.Row + with open(schema_path) as f: + conn.executescript(f.read()) + seed_database(conn) + print("Seed stats:", get_db_summary(conn)) + conn.close() diff --git a/server/environment.py b/server/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..e98f69bcf7807bc28e5896d6798341269c42d4e3 --- /dev/null +++ b/server/environment.py @@ -0,0 +1,223 @@ +""" +nl2sql-bench/server/environment.py +==================================== +NL2SQL-Bench core environment — implements the OpenEnv Environment interface. + +Episode flow +------------ +1. reset(task_name?) → picks a task + question, returns initial observation +2. step(action) → executes the SQL, grades it, returns observation + reward +3. state() → returns episode metadata +4. Episode ends when: exact_match OR step count reaches max_steps + +The environment manages its own SQLite connection (in-memory, seeded +deterministically). One connection per Environment instance; the FastAPI +server creates one Environment per WebSocket session. +""" + +from __future__ import annotations + +import os +import sqlite3 +import uuid +from pathlib import Path +from typing import Optional + +from openenv.core.env_server import Environment + +# Import after openenv so path is correct regardless of working directory +_HERE = Path(__file__).parent + +# Lazy import of task registry (avoids circular imports) +from tasks import get_task, all_task_names, BaseTask +from tasks.base import TaskExample +from grader import ( + GradeResult, + compute_ground_truth, + execute_query, + grade, + has_order_by, +) + +# We import our models from one level up (models.py at project root) +import sys +sys.path.insert(0, str(_HERE.parent)) +from models import NL2SQLAction, NL2SQLObservation, NL2SQLState + +# ── Constants ────────────────────────────────────────────────────────────── +DEFAULT_TASK = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter") +MAX_STEPS = int(os.getenv("NL2SQL_MAX_STEPS", "5")) +RESULT_LIMIT = 10 # Max rows shown to agent per step + + +class NL2SQLEnvironment(Environment): + """ + OpenEnv-compliant environment for NL-to-SQL query generation. + + One instance per WebSocket session (created by create_fastapi_app). + """ + + def __init__(self) -> None: + self._conn: Optional[sqlite3.Connection] = None + self._task: Optional[BaseTask] = None + self._example: Optional[TaskExample] = None + self._ground_truth: list = [] + self._order_sensitive: bool = False + self._state = NL2SQLState( + episode_id=None, + step_count=0, + task_name="", + task_difficulty="", + question="", + best_reward=0.0, + cumulative_reward=0.0, + solved=False + ) + self._last_obs = NL2SQLObservation( + question="", + schema_context="", + task_name="", + last_query="", + last_result=[], + last_error=None, + result_columns=[], + step=0, + max_steps=5, + done=False, + reward=None, + score=0.0 + ) + self._episode_rewards: list = [] + self._setup_db() + + # ── DB lifecycle ─────────────────────────────────────────────────────── + + def _setup_db(self) -> None: + """Create in-memory SQLite DB and seed it.""" + schema_path = _HERE / "db" / "schema.sql" + from db.seed import seed_database # local import after sys.path setup + conn = sqlite3.connect(":memory:", check_same_thread=False) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys = ON") + conn.executescript(schema_path.read_text()) + seed_database(conn) + self._conn = conn + + # ── OpenEnv interface ────────────────────────────────────────────────── + + def reset(self, task_name: Optional[str] = None) -> NL2SQLObservation: + """ + Start a new episode. + + task_name: one of 'simple-filter', 'join-aggregation', 'analytics-window'. + Defaults to NL2SQL_DEFAULT_TASK env-var or 'simple-filter'. + """ + task_name = task_name or DEFAULT_TASK + if task_name not in all_task_names(): + task_name = DEFAULT_TASK + + self._task = get_task(task_name) + self._example = self._task.next_example() + self._order_sensitive = has_order_by(self._example.sql) + + # Pre-compute ground truth once per episode + self._ground_truth = compute_ground_truth(self._conn, self._example.sql) + + self._episode_rewards = [] + self._state = NL2SQLState( + episode_id=str(uuid.uuid4()), + step_count=0, + task_name=self._task.name, + task_difficulty=self._task.difficulty, + question=self._example.question, + best_reward=0.0, + cumulative_reward=0.0, + solved=False, + ) + + obs = NL2SQLObservation( + question=self._example.question, + schema_context=self._task.schema_context(), + task_name=self._task.name, + last_query="", + last_result=[], + last_error=None, + result_columns=[], + step=0, + max_steps=MAX_STEPS, + done=False, + reward=None, + score=0.0, + ) + self._last_obs = obs + return obs + + def step(self, action: NL2SQLAction) -> NL2SQLObservation: + """Execute the agent's SQL and return graded observation.""" + if self._task is None or self._example is None: + # Called before reset — auto-reset + self.reset() + + self._state.step_count += 1 + current_step = self._state.step_count + done = False + + # Execute the query + rows, error = execute_query(self._conn, action.query) + + # Grade it + result: GradeResult = grade( + actual_rows=rows, + ground_truth_rows=self._ground_truth, + error=error, + step=current_step, + order_sensitive=self._order_sensitive, + ) + + reward = result.reward + self._episode_rewards.append(reward) + self._state.cumulative_reward += reward + self._state.best_reward = max(self._state.best_reward, reward) + + if result.exact_match: + self._state.solved = True + done = True + elif current_step >= MAX_STEPS: + done = True + + # Prepare result rows for observation (truncated for agent readability) + display_rows = (rows or [])[:RESULT_LIMIT] + result_columns = list(display_rows[0].keys()) if display_rows else [] + # Convert sqlite3.Row objects if needed + display_rows = [dict(r) for r in display_rows] + + # Normalised cumulative score + n = len(self._episode_rewards) + score = self._state.cumulative_reward / max(n, 1) if n else 0.0 + score = round(min(max(score, 0.0), 1.0), 4) + + obs = NL2SQLObservation( + question=self._example.question, + schema_context=self._task.schema_context(), + task_name=self._task.name, + last_query=action.query, + last_result=display_rows, + last_error=error, + result_columns=result_columns, + step=current_step, + max_steps=MAX_STEPS, + done=done, + reward=reward, + score=score, + ) + self._last_obs = obs + return obs + + @property + def state(self) -> NL2SQLState: + return self._state + + # ── Helpers ──────────────────────────────────────────────────────────── + + def available_tasks(self) -> list: + return all_task_names() diff --git a/server/grader.py b/server/grader.py new file mode 100644 index 0000000000000000000000000000000000000000..fa1124937e856874a51f0ed1a3de355765e85220 --- /dev/null +++ b/server/grader.py @@ -0,0 +1,214 @@ +""" +nl2sql-bench/server/grader.py +============================== +Deterministic, programmatic reward grader. + +No LLM-as-judge. Every reward is computed by comparing the agent's SQL +execution results against a ground-truth result set. + +Reward decomposition (sums to 1.0 for a perfect first-attempt answer): + +0.10 syntax_ok — query runs without SQLite error + +0.20 columns_match — returned column names match ground truth exactly + +0.20 row_count_match — number of returned rows matches + +0.50 exact_match — full result set equals ground truth (order-aware + for ORDER BY queries, order-agnostic otherwise) + +Step penalty: + -0.05 per step beyond the first (encourages solving in fewer steps), + clamped so the minimum is always 0.0. + +All rewards are floats in [0.0, 1.0]. +""" + +from __future__ import annotations + +import sqlite3 +from typing import Any, Dict, List, Optional, Tuple + + +# ── Result normalisation ─────────────────────────────────────────────────── + +def _normalise_value(v: Any) -> Any: + """Round floats for comparison so 1.2300000001 == 1.23.""" + if isinstance(v, float): + return round(v, 4) + if isinstance(v, str): + return v.strip() + return v + + +def _normalise_row(row: Dict[str, Any]) -> Dict[str, Any]: + return {k: _normalise_value(v) for k, v in row.items()} + + +def _normalise_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + return [_normalise_row(r) for r in rows] + + +# ── SQL execution ────────────────────────────────────────────────────────── + +def execute_query( + conn: sqlite3.Connection, + query: str, + max_rows: int = 200, +) -> Tuple[Optional[List[Dict[str, Any]]], Optional[str]]: + """ + Execute a SQL query safely. + + Returns (rows, error_string). + rows is None on error. + """ + query = query.strip().rstrip(";") + if not query: + return None, "Empty query." + + # Block write operations — the environment is read-only from the agent's view. + forbidden = ("insert", "update", "delete", "drop", "alter", + "create", "replace", "truncate", "pragma") + first_word = query.split()[0].lower() if query.split() else "" + if first_word in forbidden: + return None, ( + f"Operation '{first_word.upper()}' is not allowed. " + "Only SELECT queries are permitted." + ) + + try: + cur = conn.execute(query) + cols = [d[0] for d in cur.description] if cur.description else [] + rows = [dict(zip(cols, row)) for row in cur.fetchmany(max_rows)] + return rows, None + except sqlite3.Error as exc: + return None, str(exc) + + +# ── Grading logic ────────────────────────────────────────────────────────── + +class GradeResult: + __slots__ = ( + "reward", "syntax_ok", "columns_match", + "row_count_match", "exact_match", "step_penalty", + "breakdown", + ) + + def __init__( + self, + reward: float, + syntax_ok: bool, + columns_match: bool, + row_count_match: bool, + exact_match: bool, + step_penalty: float, + ) -> None: + self.reward = reward + self.syntax_ok = syntax_ok + self.columns_match = columns_match + self.row_count_match = row_count_match + self.exact_match = exact_match + self.step_penalty = step_penalty + self.breakdown = { + "syntax_ok": 0.10 if syntax_ok else 0.0, + "columns_match": 0.20 if (syntax_ok and columns_match) else 0.0, + "row_count_match": 0.20 if (syntax_ok and row_count_match) else 0.0, + "exact_match": 0.50 if (syntax_ok and exact_match) else 0.0, + "step_penalty": -step_penalty, + } + + def __repr__(self) -> str: # pragma: no cover + return ( + f"GradeResult(reward={self.reward:.3f}, " + f"exact={self.exact_match}, cols={self.columns_match}, " + f"rows={self.row_count_match}, syntax={self.syntax_ok})" + ) + + +def grade( + actual_rows: Optional[List[Dict[str, Any]]], + ground_truth_rows: List[Dict[str, Any]], + error: Optional[str], + step: int, + order_sensitive: bool = False, +) -> GradeResult: + """ + Grade the agent's query result against ground truth. + + Parameters + ---------- + actual_rows : Rows returned by the agent's query (None on error). + ground_truth_rows : Expected rows (pre-computed at task load time). + error : SQLite error string (None if query ran successfully). + step : Current step number (1-indexed) for penalty calculation. + order_sensitive : If True, row order matters (queries with ORDER BY). + """ + # ── Syntax ────────────────────────────────────────────────────────── + syntax_ok = error is None and actual_rows is not None + + if not syntax_ok: + return GradeResult( + reward=0.0, + syntax_ok=False, + columns_match=False, + row_count_match=False, + exact_match=False, + step_penalty=0.0, + ) + + gt_norm = _normalise_rows(ground_truth_rows) + act_norm = _normalise_rows(actual_rows) + + gt_cols = set(gt_norm[0].keys()) if gt_norm else set() + act_cols = set(act_norm[0].keys()) if act_norm else set() + columns_match = act_cols == gt_cols + row_count_match = len(act_norm) == len(gt_norm) + + # Exact match: if order matters, compare list; otherwise compare sorted sets + if columns_match and row_count_match: + if order_sensitive: + exact_match = act_norm == gt_norm + else: + # Sort rows by their string representation for order-agnostic compare + def _sort_key(r: Dict) -> str: + return str(sorted(r.items())) + exact_match = ( + sorted(act_norm, key=_sort_key) == sorted(gt_norm, key=_sort_key) + ) + else: + exact_match = False + + # ── Score assembly ──────────────────────────────────────────────── + raw = ( + 0.10 # syntax + + (0.20 if columns_match else 0.0) + + (0.20 if row_count_match else 0.0) + + (0.50 if exact_match else 0.0) + ) + + penalty = max(0.0, step - 1) * 0.05 + reward = float(max(0.0, min(1.0, raw - penalty))) + + return GradeResult( + reward=reward, + syntax_ok=syntax_ok, + columns_match=columns_match, + row_count_match=row_count_match, + exact_match=exact_match, + step_penalty=penalty, + ) + + +# ── Convenience: pre-compute ground truth rows ───────────────────────────── + +def compute_ground_truth( + conn: sqlite3.Connection, + sql: str, +) -> List[Dict[str, Any]]: + """Execute the ground-truth SQL and return normalised rows.""" + rows, error = execute_query(conn, sql) + if error or rows is None: + raise ValueError(f"Ground-truth SQL failed: {error}\nSQL: {sql}") + return _normalise_rows(rows) + + +def has_order_by(sql: str) -> bool: + """Heuristic: does the top-level query have an ORDER BY?""" + # Simple check sufficient for our controlled task SQL + return "ORDER BY" in sql.upper() diff --git a/server/requirements.txt b/server/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9f3dc726d820d278b8b4a5aa151ef7cc01e0c338 --- /dev/null +++ b/server/requirements.txt @@ -0,0 +1,13 @@ +# nl2sql-bench/server/requirements.txt +# Minimal dependency set for 2 vCPU / 8 GB constraint. +# SQLite is part of the Python stdlib — no extra DB dependency needed. + +# OpenEnv core framework +openenv-core>=0.2.3 + +# Web server +fastapi>=0.110.0 +uvicorn[standard]>=0.29.0 + +# Typing helpers (already included in openenv-core but listed explicitly) +pydantic>=2.0.0 diff --git a/server/tasks/__init__.py b/server/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5643870c4ceb93f36a7aecd8b3f249fddd8426ac --- /dev/null +++ b/server/tasks/__init__.py @@ -0,0 +1,14 @@ +"""Auto-registers all tasks by importing them.""" +from .easy import SimpleFilterTask +from .medium import JoinAggregationTask +from .hard import AnalyticsWindowTask +from .base import get_task, all_task_names, BaseTask + +__all__ = [ + "SimpleFilterTask", + "JoinAggregationTask", + "AnalyticsWindowTask", + "get_task", + "all_task_names", + "BaseTask", +] diff --git a/server/tasks/base.py b/server/tasks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..60b963af8c6df67f1eae7a2a2bfa684e3ac67300 --- /dev/null +++ b/server/tasks/base.py @@ -0,0 +1,108 @@ +""" +nl2sql-bench/server/tasks/base.py +================================== +Abstract base for all NL2SQL tasks and the global task registry. + +Each task holds a list of (question, ground_truth_sql) pairs. +The environment picks one pair per episode via a deterministic round-robin +so that the same task always cycles through the same question sequence — +this keeps grader results reproducible across runs. +""" + +from __future__ import annotations + +import sqlite3 +from abc import ABC, abstractmethod +from typing import Dict, List, NamedTuple, Tuple, Type + + +class TaskExample(NamedTuple): + question: str + sql: str + # Human-readable description of what makes this question that difficulty + notes: str = "" + + +class BaseTask(ABC): + """Abstract base class for all tasks.""" + + name: str = "" + difficulty: str = "" # easy | medium | hard + examples: List[TaskExample] = [] + + def __init__(self) -> None: + if not self.examples: + raise ValueError(f"Task {self.name!r} has no examples defined.") + self._cursor = 0 # round-robin index + + def next_example(self) -> TaskExample: + """Return the next question in round-robin order.""" + example = self.examples[self._cursor % len(self.examples)] + self._cursor += 1 + return example + + @classmethod + def schema_context(cls) -> str: + """Return a compact schema description for the agent system prompt.""" + return _SCHEMA_CONTEXT + + @abstractmethod + def description(self) -> str: + """One-sentence description for openenv.yaml.""" + + +# ── Global schema context string (injected into every observation) ───────── + +_SCHEMA_CONTEXT = """\ +Database: ecommerce (SQLite, read-only) + +TABLES +------ +categories(id INTEGER PK, name TEXT) + +products(id INTEGER PK, name TEXT, category_id INTEGER FK→categories.id, + price REAL, stock_quantity INTEGER) + +customers(id INTEGER PK, name TEXT, email TEXT, country TEXT, + tier TEXT ∈ {bronze|silver|gold}, created_at TEXT ISO-8601) + +orders(id INTEGER PK, customer_id INTEGER FK→customers.id, + status TEXT ∈ {pending|processing|shipped|delivered|cancelled}, + created_at TEXT ISO-8601, total_amount REAL) + +order_items(id INTEGER PK, order_id INTEGER FK→orders.id, + product_id INTEGER FK→products.id, + quantity INTEGER, unit_price REAL) + +reviews(id INTEGER PK, product_id INTEGER FK→products.id, + customer_id INTEGER FK→customers.id, + rating INTEGER 1-5, created_at TEXT ISO-8601) + +NOTES +----- +- Date comparisons: use created_at >= '2024-01-01' (text ISO sort works) +- SQLite window functions (RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD) are available +- strftime('%Y-%m', created_at) returns 'YYYY-MM' month strings +- All monetary values are in USD +""" + + +# ── Task registry ────────────────────────────────────────────────────────── + +_REGISTRY: Dict[str, Type[BaseTask]] = {} + + +def register(cls: Type[BaseTask]) -> Type[BaseTask]: + """Class decorator to register a task.""" + _REGISTRY[cls.name] = cls + return cls + + +def get_task(name: str) -> BaseTask: + if name not in _REGISTRY: + raise KeyError(f"Unknown task {name!r}. Available: {list(_REGISTRY)}") + return _REGISTRY[name]() + + +def all_task_names() -> List[str]: + return list(_REGISTRY.keys()) diff --git a/server/tasks/easy.py b/server/tasks/easy.py new file mode 100644 index 0000000000000000000000000000000000000000..0c052d3abd580c4c16d239f380b277b562c4ac0e --- /dev/null +++ b/server/tasks/easy.py @@ -0,0 +1,93 @@ +""" +nl2sql-bench/server/tasks/easy.py +=================================== +Task 1 — Simple Filter (difficulty: easy) + +All questions target a SINGLE table with basic WHERE / ORDER BY / LIMIT. +A competent small model should solve these in 1–2 steps. +""" + +from __future__ import annotations + +from .base import BaseTask, TaskExample, register + + +@register +class SimpleFilterTask(BaseTask): + name = "simple-filter" + difficulty = "easy" + + examples = [ + TaskExample( + question=( + "List all gold-tier customers ordered by their name alphabetically. " + "Return columns: id, name, email, country." + ), + sql=( + "SELECT id, name, email, country " + "FROM customers " + "WHERE tier = 'gold' " + "ORDER BY name ASC" + ), + notes="Single table, equality filter, text sort.", + ), + TaskExample( + question=( + "Show all products with a price above $100, sorted by price from " + "highest to lowest. Return columns: id, name, price." + ), + sql=( + "SELECT id, name, price " + "FROM products " + "WHERE price > 100 " + "ORDER BY price DESC" + ), + notes="Numeric range filter, descending sort.", + ), + TaskExample( + question=( + "Find all delivered orders with a total_amount greater than $200, " + "ordered by total_amount descending. " + "Return columns: id, customer_id, total_amount, created_at." + ), + sql=( + "SELECT id, customer_id, total_amount, created_at " + "FROM orders " + "WHERE status = 'delivered' " + " AND total_amount > 200 " + "ORDER BY total_amount DESC" + ), + notes="Two-condition WHERE on a single table.", + ), + TaskExample( + question=( + "Return the top 5 most expensive products. " + "Return columns: id, name, price." + ), + sql=( + "SELECT id, name, price " + "FROM products " + "ORDER BY price DESC " + "LIMIT 5" + ), + notes="ORDER BY + LIMIT, no WHERE clause.", + ), + TaskExample( + question=( + "List all distinct countries where our customers come from, " + "sorted alphabetically. Return a single column: country." + ), + sql=( + "SELECT DISTINCT country " + "FROM customers " + "ORDER BY country ASC" + ), + notes="DISTINCT on a single column.", + ), + ] + + def description(self) -> str: + return ( + "Single-table SELECT queries with WHERE filters, ORDER BY, and LIMIT. " + "Tests basic SQL fluency." + ) diff --git a/server/tasks/hard.py b/server/tasks/hard.py new file mode 100644 index 0000000000000000000000000000000000000000..65a00ca311cca22945f2a9af6742e3745ffdf4ef --- /dev/null +++ b/server/tasks/hard.py @@ -0,0 +1,156 @@ +""" +nl2sql-bench/server/tasks/hard.py +=================================== +Task 3 — Analytics & Window (difficulty: hard) + +Questions require CTEs, window functions (RANK, ROW_NUMBER, running totals), +or non-trivial subqueries. Even strong frontier models often need 3–5 steps. +""" + +from __future__ import annotations + +from .base import BaseTask, TaskExample, register + + +@register +class AnalyticsWindowTask(BaseTask): + name = "analytics-window" + difficulty = "hard" + + examples = [ + TaskExample( + question=( + "Rank customers by their total spending on delivered orders " + "using DENSE_RANK (rank 1 = highest spender). " + "Return columns: customer_name, total_spent, spending_rank. " + "Round total_spent to 2 decimal places. " + "Sort by spending_rank ascending." + ), + sql=( + "SELECT customer_name, total_spent, spending_rank " + "FROM ( " + " SELECT c.name AS customer_name, " + " ROUND(SUM(o.total_amount), 2) AS total_spent, " + " DENSE_RANK() OVER (ORDER BY SUM(o.total_amount) DESC) AS spending_rank " + " FROM customers c " + " JOIN orders o ON o.customer_id = c.id " + " WHERE o.status = 'delivered' " + " GROUP BY c.id, c.name " + ") sub " + "ORDER BY spending_rank ASC" + ), + notes="Window function DENSE_RANK inside a subquery wrapping a GROUP BY.", + ), + TaskExample( + question=( + "For each product that has been reviewed, show its name, its own " + "average rating, and the average rating of all products in its category. " + "Return columns: product_name, product_avg_rating, category_avg_rating. " + "Round both averages to 2 decimal places. " + "Sort by product_avg_rating descending." + ), + sql=( + "SELECT p.name AS product_name, " + " ROUND(AVG(r.rating), 2) AS product_avg_rating, " + " ROUND(AVG(AVG(r.rating)) OVER (PARTITION BY p.category_id), 2) " + " AS category_avg_rating " + "FROM products p " + "JOIN reviews r ON r.product_id = p.id " + "GROUP BY p.id, p.name, p.category_id " + "ORDER BY product_avg_rating DESC" + ), + notes="AVG of AVG via window PARTITION BY — requires nested aggregate understanding.", + ), + TaskExample( + question=( + "Find all customers whose most recent order has status 'cancelled'. " + "Use a CTE with ROW_NUMBER to identify the latest order per customer. " + "Return columns: customer_name, last_order_status, last_order_date. " + "Sort by customer_name ascending." + ), + sql=( + "WITH ranked_orders AS ( " + " SELECT customer_id, status, created_at, " + " ROW_NUMBER() OVER (PARTITION BY customer_id " + " ORDER BY created_at DESC) AS rn " + " FROM orders " + ") " + "SELECT c.name AS customer_name, " + " ro.status AS last_order_status, " + " ro.created_at AS last_order_date " + "FROM customers c " + "JOIN ranked_orders ro ON ro.customer_id = c.id " + "WHERE ro.rn = 1 " + " AND ro.status = 'cancelled' " + "ORDER BY customer_name ASC" + ), + notes="CTE + ROW_NUMBER window partitioned by customer_id.", + ), + TaskExample( + question=( + "Show the monthly revenue from delivered orders and its running total, " + "for all months in 2024. " + "Return columns: month (format YYYY-MM), monthly_revenue, running_total. " + "Round both revenue columns to 2 decimal places. " + "Sort by month ascending." + ), + sql=( + "WITH monthly AS ( " + " SELECT strftime('%Y-%m', created_at) AS month, " + " ROUND(SUM(total_amount), 2) AS monthly_revenue " + " FROM orders " + " WHERE status = 'delivered' " + " AND created_at >= '2024-01-01' " + " AND created_at < '2025-01-01' " + " GROUP BY strftime('%Y-%m', created_at) " + ") " + "SELECT month, " + " monthly_revenue, " + " ROUND(SUM(monthly_revenue) OVER (ORDER BY month), 2) AS running_total " + "FROM monthly " + "ORDER BY month ASC" + ), + notes="CTE + cumulative SUM window ordered by month string.", + ), + TaskExample( + question=( + "Find products whose average rating is strictly above the average " + "rating of all products in their category. " + "Return columns: product_name, category_name, " + "product_avg_rating, category_avg_rating. " + "Round both averages to 2 decimal places. " + "Sort by product_avg_rating descending, then product_name ascending." + ), + sql=( + "WITH product_ratings AS ( " + " SELECT p.id AS product_id, p.name AS product_name, " + " p.category_id, c.name AS category_name, " + " ROUND(AVG(r.rating), 2) AS product_avg_rating " + " FROM products p " + " JOIN reviews r ON r.product_id = p.id " + " JOIN categories c ON c.id = p.category_id " + " GROUP BY p.id, p.name, p.category_id, c.name " + "), " + "category_ratings AS ( " + " SELECT category_id, " + " ROUND(AVG(product_avg_rating), 2) AS category_avg_rating " + " FROM product_ratings " + " GROUP BY category_id " + ") " + "SELECT pr.product_name, pr.category_name, " + " pr.product_avg_rating, cr.category_avg_rating " + "FROM product_ratings pr " + "JOIN category_ratings cr ON cr.category_id = pr.category_id " + "WHERE pr.product_avg_rating > cr.category_avg_rating " + "ORDER BY pr.product_avg_rating DESC, pr.product_name ASC" + ), + notes="Two CTEs, correlated comparison between product and category averages.", + ), + ] + + def description(self) -> str: + return ( + "Advanced analytics queries using CTEs, window functions " + "(DENSE_RANK, ROW_NUMBER, running SUM), and nested subqueries. " + "Tests multi-step reasoning and SQLite-specific syntax." + ) diff --git a/server/tasks/medium.py b/server/tasks/medium.py new file mode 100644 index 0000000000000000000000000000000000000000..0aafc16519ca591e8f92b0e414bd5306e6643afb --- /dev/null +++ b/server/tasks/medium.py @@ -0,0 +1,117 @@ +""" +nl2sql-bench/server/tasks/medium.py +===================================== +Task 2 — Join & Aggregation (difficulty: medium) + +Questions require at least one JOIN and GROUP BY / HAVING. +Expect most frontier models to succeed in 2–3 steps. +""" + +from __future__ import annotations + +from .base import BaseTask, TaskExample, register + + +@register +class JoinAggregationTask(BaseTask): + name = "join-aggregation" + difficulty = "medium" + + examples = [ + TaskExample( + question=( + "How many orders has each customer placed? " + "Return columns: customer_name, order_count. " + "Include customers with zero orders. " + "Sort by order_count descending, then customer_name ascending." + ), + sql=( + "SELECT c.name AS customer_name, COUNT(o.id) AS order_count " + "FROM customers c " + "LEFT JOIN orders o ON c.id = o.customer_id " + "GROUP BY c.id, c.name " + "ORDER BY order_count DESC, customer_name ASC" + ), + notes="LEFT JOIN to include zero-order customers, COUNT aggregate.", + ), + TaskExample( + question=( + "What is the average product rating per category? " + "Only include categories that have at least one review. " + "Return columns: category_name, avg_rating. " + "Round avg_rating to 2 decimal places. " + "Sort by avg_rating descending." + ), + sql=( + "SELECT c.name AS category_name, " + " ROUND(AVG(r.rating), 2) AS avg_rating " + "FROM categories c " + "JOIN products p ON p.category_id = c.id " + "JOIN reviews r ON r.product_id = p.id " + "GROUP BY c.id, c.name " + "ORDER BY avg_rating DESC" + ), + notes="Two JOINs, AVG aggregate, ROUND function.", + ), + TaskExample( + question=( + "Which categories have more than 5 products in stock " + "(i.e., stock_quantity > 0)? " + "Return columns: category_name, in_stock_count. " + "Sort by in_stock_count descending." + ), + sql=( + "SELECT c.name AS category_name, " + " COUNT(p.id) AS in_stock_count " + "FROM categories c " + "JOIN products p ON p.category_id = c.id " + "WHERE p.stock_quantity > 0 " + "GROUP BY c.id, c.name " + "HAVING COUNT(p.id) > 5 " + "ORDER BY in_stock_count DESC" + ), + notes="WHERE before GROUP BY, HAVING filter on aggregate.", + ), + TaskExample( + question=( + "Which customers have spent more than $500 total on delivered orders? " + "Return columns: customer_name, total_spent. " + "Round total_spent to 2 decimal places. " + "Sort by total_spent descending." + ), + sql=( + "SELECT c.name AS customer_name, " + " ROUND(SUM(o.total_amount), 2) AS total_spent " + "FROM customers c " + "JOIN orders o ON o.customer_id = c.id " + "WHERE o.status = 'delivered' " + "GROUP BY c.id, c.name " + "HAVING SUM(o.total_amount) > 500 " + "ORDER BY total_spent DESC" + ), + notes="SUM aggregate, HAVING on SUM, status filter.", + ), + TaskExample( + question=( + "Show the total quantity sold for each product. " + "Only include products that appear in at least one order item. " + "Return columns: product_name, total_quantity_sold. " + "Sort by total_quantity_sold descending." + ), + sql=( + "SELECT p.name AS product_name, " + " SUM(oi.quantity) AS total_quantity_sold " + "FROM products p " + "JOIN order_items oi ON oi.product_id = p.id " + "GROUP BY p.id, p.name " + "ORDER BY total_quantity_sold DESC" + ), + notes="JOIN on order_items, SUM aggregate.", + ), + ] + + def description(self) -> str: + return ( + "Multi-table JOIN queries with GROUP BY, HAVING, and aggregation " + "functions (COUNT, SUM, AVG, ROUND). Tests relational reasoning." + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae78246ec23f342d440481a9b77b868c14d26463 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# tests/__init__.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb5f269296c85758f2db128e17d28385f701943 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,14 @@ +# nl2sql-bench/tests/conftest.py +""" +Pytest configuration — adds project root and server/ to sys.path +so all test imports resolve without installing the package. +""" +import sys +from pathlib import Path + +ROOT = Path(__file__).parent.parent +SERVER = ROOT / "server" + +for p in [str(ROOT), str(SERVER)]: + if p not in sys.path: + sys.path.insert(0, p) diff --git a/tests/test_all.py b/tests/test_all.py new file mode 100644 index 0000000000000000000000000000000000000000..c381ac5af34902dc9ede716ed29c961cecc4ea99 --- /dev/null +++ b/tests/test_all.py @@ -0,0 +1,493 @@ +""" +nl2sql-bench/tests/test_all.py +================================ +Comprehensive test suite covering: + - Database seeder (determinism + row counts) + - Grader (all reward components, step penalty, edge cases) + - Task registry (all 3 tasks load and produce valid examples) + - Environment (reset, step, episode boundary, done logic) + - Inference log format (regex checks on START / STEP / END) + +Run with: + pytest tests/ -v +or from project root: + PYTHONPATH=.:server pytest tests/ -v +""" + +from __future__ import annotations + +import re +import sqlite3 +import sys +import os +from pathlib import Path + +import pytest + +# ── Path setup so tests can import from both project root and server/ ────── +ROOT = Path(__file__).parent.parent +SERVER = ROOT / "server" +sys.path.insert(0, str(ROOT)) +sys.path.insert(0, str(SERVER)) + +# ── Fixtures ─────────────────────────────────────────────────────────────── + +@pytest.fixture(scope="session") +def db_conn(): + """Shared in-memory SQLite connection with full schema + seed data.""" + from db.seed import seed_database + schema = (SERVER / "db" / "schema.sql").read_text() + conn = sqlite3.connect(":memory:", check_same_thread=False) + conn.row_factory = sqlite3.Row + conn.executescript(schema) + seed_database(conn) + yield conn + conn.close() + + +@pytest.fixture +def fresh_env(): + """A fresh NL2SQLEnvironment instance per test.""" + from environment import NL2SQLEnvironment + return NL2SQLEnvironment() + + +# ══════════════════════════════════════════════════════════════════════════════ +# 1. DATABASE SEEDER +# ══════════════════════════════════════════════════════════════════════════════ + +class TestSeeder: + + def test_categories_count(self, db_conn): + row = db_conn.execute("SELECT COUNT(*) FROM categories").fetchone() + assert row[0] == 8, "Should have exactly 8 categories" + + def test_products_count(self, db_conn): + row = db_conn.execute("SELECT COUNT(*) FROM products").fetchone() + assert row[0] == 64, "Should have 8 products × 8 categories = 64" + + def test_customers_count(self, db_conn): + row = db_conn.execute("SELECT COUNT(*) FROM customers").fetchone() + assert row[0] == 150 + + def test_orders_exist(self, db_conn): + row = db_conn.execute("SELECT COUNT(*) FROM orders").fetchone() + assert row[0] > 100, "Should have a meaningful number of orders" + + def test_order_items_exist(self, db_conn): + row = db_conn.execute("SELECT COUNT(*) FROM order_items").fetchone() + assert row[0] > 200 + + def test_reviews_exist(self, db_conn): + row = db_conn.execute("SELECT COUNT(*) FROM reviews").fetchone() + assert row[0] > 50 + + def test_determinism(self, db_conn): + """Seeding a second connection with the same seed gives identical counts.""" + from db.seed import seed_database + schema = (SERVER / "db" / "schema.sql").read_text() + conn2 = sqlite3.connect(":memory:") + conn2.executescript(schema) + seed_database(conn2) + + for tbl in ["categories", "products", "customers", "orders", + "order_items", "reviews"]: + c1 = db_conn.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0] + c2 = conn2.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0] + assert c1 == c2, f"Table {tbl} count mismatch: {c1} vs {c2}" + conn2.close() + + def test_tiers_valid(self, db_conn): + bad = db_conn.execute( + "SELECT COUNT(*) FROM customers WHERE tier NOT IN ('bronze','silver','gold')" + ).fetchone()[0] + assert bad == 0 + + def test_statuses_valid(self, db_conn): + bad = db_conn.execute( + "SELECT COUNT(*) FROM orders " + "WHERE status NOT IN ('pending','processing','shipped','delivered','cancelled')" + ).fetchone()[0] + assert bad == 0 + + def test_ratings_valid(self, db_conn): + bad = db_conn.execute( + "SELECT COUNT(*) FROM reviews WHERE rating < 1 OR rating > 5" + ).fetchone()[0] + assert bad == 0 + + def test_referential_integrity(self, db_conn): + """Order items should reference valid orders and products.""" + orphan_orders = db_conn.execute( + "SELECT COUNT(*) FROM order_items oi " + "LEFT JOIN orders o ON o.id = oi.order_id WHERE o.id IS NULL" + ).fetchone()[0] + assert orphan_orders == 0 + + orphan_products = db_conn.execute( + "SELECT COUNT(*) FROM order_items oi " + "LEFT JOIN products p ON p.id = oi.product_id WHERE p.id IS NULL" + ).fetchone()[0] + assert orphan_products == 0 + + +# ══════════════════════════════════════════════════════════════════════════════ +# 2. GRADER +# ══════════════════════════════════════════════════════════════════════════════ + +class TestGrader: + + def test_exact_match_first_step(self): + from grader import grade + gt = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] + result = grade( + actual_rows=gt.copy(), + ground_truth_rows=gt, + error=None, + step=1, + order_sensitive=False, + ) + assert result.reward == pytest.approx(1.0) + assert result.exact_match is True + assert result.syntax_ok is True + assert result.columns_match is True + assert result.row_count_match is True + assert result.step_penalty == 0.0 + + def test_syntax_error_gives_zero(self): + from grader import grade + result = grade( + actual_rows=None, + ground_truth_rows=[{"x": 1}], + error="near 'SELCT': syntax error", + step=1, + ) + assert result.reward == 0.0 + assert result.syntax_ok is False + + def test_step_penalty_applied(self): + from grader import grade + gt = [{"n": 1}] + result = grade( + actual_rows=gt.copy(), + ground_truth_rows=gt, + error=None, + step=3, # penalty = (3-1)*0.05 = 0.10 + ) + assert result.reward == pytest.approx(1.0 - 0.10) + assert result.step_penalty == pytest.approx(0.10) + + def test_columns_wrong_zero_higher_components(self): + from grader import grade + gt = [{"name": "Alice", "score": 10}] + actual = [{"user": "Alice", "points": 10}] # wrong column names + result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1) + assert result.columns_match is False + assert result.exact_match is False + # Only syntax score: 0.10 + assert result.reward == pytest.approx(0.10) + + def test_correct_columns_wrong_rows(self): + from grader import grade + gt = [{"name": "Alice"}, {"name": "Bob"}] + actual = [{"name": "Charlie"}, {"name": "Dave"}] + result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1) + assert result.columns_match is True + assert result.row_count_match is True + assert result.exact_match is False + # syntax(0.10) + columns(0.20) + row_count(0.20) = 0.50 + assert result.reward == pytest.approx(0.50) + + def test_order_sensitive_wrong_order_is_not_exact(self): + from grader import grade + gt = [{"id": 1}, {"id": 2}] + actual = [{"id": 2}, {"id": 1}] # reversed + result = grade( + actual_rows=actual, + ground_truth_rows=gt, + error=None, + step=1, + order_sensitive=True, + ) + assert result.exact_match is False + + def test_order_insensitive_accepts_different_row_order(self): + from grader import grade + gt = [{"id": 1}, {"id": 2}] + actual = [{"id": 2}, {"id": 1}] # different order but same content + result = grade( + actual_rows=actual, + ground_truth_rows=gt, + error=None, + step=1, + order_sensitive=False, + ) + assert result.exact_match is True + + def test_penalty_never_makes_reward_negative(self): + from grader import grade + # Step 99 with syntax error → reward must be >= 0 + result = grade( + actual_rows=None, + ground_truth_rows=[{"x": 1}], + error="some error", + step=99, + ) + assert result.reward >= 0.0 + + def test_execute_query_blocks_writes(self, db_conn): + from grader import execute_query + rows, err = execute_query(db_conn, "INSERT INTO categories(name) VALUES ('x')") + assert rows is None + assert "not allowed" in err.lower() or "INSERT" in err + + def test_execute_query_returns_rows(self, db_conn): + from grader import execute_query + rows, err = execute_query(db_conn, "SELECT id, name FROM categories ORDER BY id") + assert err is None + assert len(rows) == 8 + assert "id" in rows[0] + assert "name" in rows[0] + + def test_compute_ground_truth(self, db_conn): + from grader import compute_ground_truth + rows = compute_ground_truth(db_conn, "SELECT COUNT(*) AS n FROM customers") + assert len(rows) == 1 + assert rows[0]["n"] == 150 + + +# ══════════════════════════════════════════════════════════════════════════════ +# 3. TASK REGISTRY +# ══════════════════════════════════════════════════════════════════════════════ + +class TestTasks: + + def test_all_tasks_registered(self): + from tasks import all_task_names + names = all_task_names() + assert "simple-filter" in names + assert "join-aggregation" in names + assert "analytics-window" in names + + @pytest.mark.parametrize("task_name", [ + "simple-filter", "join-aggregation", "analytics-window" + ]) + def test_task_has_examples(self, task_name): + from tasks import get_task + task = get_task(task_name) + assert len(task.examples) >= 3, f"{task_name} needs at least 3 examples" + + @pytest.mark.parametrize("task_name", [ + "simple-filter", "join-aggregation", "analytics-window" + ]) + def test_task_sql_runs_on_real_db(self, task_name, db_conn): + """Every ground-truth SQL must execute cleanly against the seeded DB.""" + from tasks import get_task + from grader import execute_query + task = get_task(task_name) + for ex in task.examples: + rows, error = execute_query(db_conn, ex.sql) + assert error is None, ( + f"Task {task_name!r} SQL failed:\n{ex.sql}\nError: {error}" + ) + assert rows is not None + + @pytest.mark.parametrize("task_name", [ + "simple-filter", "join-aggregation", "analytics-window" + ]) + def test_task_roundrobin(self, task_name): + from tasks import get_task + task = get_task(task_name) + n = len(task.examples) + seen = [task.next_example() for _ in range(n * 2)] + # After n calls, second half should repeat first half + assert seen[:n] == seen[n:] + + def test_schema_context_non_empty(self): + from tasks import get_task + task = get_task("simple-filter") + ctx = task.schema_context() + assert "customers" in ctx + assert "orders" in ctx + assert "products" in ctx + + +# ══════════════════════════════════════════════════════════════════════════════ +# 4. ENVIRONMENT +# ══════════════════════════════════════════════════════════════════════════════ + +class TestEnvironment: + + def test_reset_returns_observation(self, fresh_env): + obs = fresh_env.reset(task_name="simple-filter") + assert obs.question != "" + assert obs.schema_context != "" + assert obs.task_name == "simple-filter" + assert obs.done is False + assert obs.step == 0 + assert obs.reward is None + + def test_reset_state(self, fresh_env): + fresh_env.reset(task_name="join-aggregation") + state = fresh_env.state + assert state.task_name == "join-aggregation" + assert state.task_difficulty == "medium" + assert state.step_count == 0 + assert state.solved is False + + def test_step_increments_step_count(self, fresh_env): + from models import NL2SQLAction + fresh_env.reset(task_name="simple-filter") + fresh_env.step(NL2SQLAction(query="SELECT 1")) + assert fresh_env.state.step_count == 1 + + def test_step_syntax_error_gives_nonzero_error(self, fresh_env): + from models import NL2SQLAction + fresh_env.reset(task_name="simple-filter") + obs = fresh_env.step(NL2SQLAction(query="SELCT * FORM broken")) + assert obs.last_error is not None + assert obs.reward == 0.0 + + def test_step_valid_query_returns_result(self, fresh_env): + from models import NL2SQLAction + fresh_env.reset(task_name="simple-filter") + obs = fresh_env.step(NL2SQLAction( + query="SELECT id, name FROM customers ORDER BY name LIMIT 5" + )) + assert obs.last_error is None + assert len(obs.last_result) <= 5 + assert obs.reward >= 0.0 + + def test_exact_match_ends_episode(self, fresh_env): + """Submitting the exact ground-truth SQL should solve the episode.""" + from models import NL2SQLAction + fresh_env.reset(task_name="simple-filter") + # Get the ground truth SQL from the internal example + gt_sql = fresh_env._example.sql + obs = fresh_env.step(NL2SQLAction(query=gt_sql)) + assert obs.done is True + assert fresh_env.state.solved is True + assert obs.reward == pytest.approx(1.0) # step 1, full score + + def test_max_steps_ends_episode(self, fresh_env): + """Exhausting all steps should end the episode even without solving.""" + from models import NL2SQLAction + from environment import MAX_STEPS + fresh_env.reset(task_name="analytics-window") + obs = None + for _ in range(MAX_STEPS): + obs = fresh_env.step(NL2SQLAction(query="SELECT 1")) + assert obs is not None + assert obs.done is True + + def test_reset_clears_previous_episode(self, fresh_env): + from models import NL2SQLAction + fresh_env.reset(task_name="simple-filter") + fresh_env.step(NL2SQLAction(query="SELECT 1")) + # Second reset should clear state + obs = fresh_env.reset(task_name="join-aggregation") + assert fresh_env.state.step_count == 0 + assert obs.step == 0 + assert obs.task_name == "join-aggregation" + + @pytest.mark.parametrize("task_name", [ + "simple-filter", "join-aggregation", "analytics-window" + ]) + def test_all_tasks_solvable(self, task_name): + """Ground-truth SQL should always produce reward == 1.0 on step 1.""" + from environment import NL2SQLEnvironment + from models import NL2SQLAction + env = NL2SQLEnvironment() + env.reset(task_name=task_name) + gt_sql = env._example.sql + obs = env.step(NL2SQLAction(query=gt_sql)) + assert obs.done is True + assert obs.reward == pytest.approx(1.0), ( + f"Task {task_name!r}: ground-truth SQL did not score 1.0.\n" + f"SQL: {gt_sql}\nError: {obs.last_error}\nReward: {obs.reward}" + ) + + def test_score_normalised_to_0_1(self, fresh_env): + from models import NL2SQLAction + fresh_env.reset(task_name="simple-filter") + for _ in range(3): + obs = fresh_env.step(NL2SQLAction(query="SELECT 1 AS x")) + assert 0.0 <= obs.score <= 1.0 + + def test_write_query_blocked(self, fresh_env): + from models import NL2SQLAction + fresh_env.reset(task_name="simple-filter") + obs = fresh_env.step(NL2SQLAction( + query="INSERT INTO categories(name) VALUES ('hack')" + )) + assert obs.last_error is not None + assert "not allowed" in obs.last_error.lower() or "INSERT" in obs.last_error + + +# ══════════════════════════════════════════════════════════════════════════════ +# 5. LOG FORMAT COMPLIANCE +# ══════════════════════════════════════════════════════════════════════════════ + +class TestLogFormat: + """Validate that the inference.py log helpers emit correct format.""" + + START_RE = re.compile( + r"^\[START\] task=\S+ env=\S+ model=\S+$" + ) + STEP_RE = re.compile( + r"^\[STEP\] step=\d+ action=.+ reward=\d+\.\d{2} " + r"done=(true|false) error=.+$" + ) + END_RE = re.compile( + r"^\[END\] success=(true|false) steps=\d+ score=\d+\.\d{3} " + r"rewards=[\d.,]+$" + ) + + def _capture(self, func, *args, **kwargs) -> str: + import io + from contextlib import redirect_stdout + buf = io.StringIO() + with redirect_stdout(buf): + func(*args, **kwargs) + return buf.getvalue().strip() + + def test_log_start_format(self): + sys.path.insert(0, str(ROOT)) + from inference import log_start + out = self._capture(log_start, "simple-filter", "Qwen/Qwen2.5-72B") + assert self.START_RE.match(out), f"Bad [START] format: {out!r}" + + def test_log_step_format_null_error(self): + from inference import log_step + out = self._capture(log_step, 1, "SELECT 1", 0.10, False, None) + assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}" + + def test_log_step_format_with_error(self): + from inference import log_step + out = self._capture(log_step, 2, "SELCT 1", 0.0, False, "syntax error") + assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}" + + def test_log_end_format_success(self): + from inference import log_end + out = self._capture(log_end, True, 3, 0.850, [0.50, 1.0, 1.0]) + assert self.END_RE.match(out), f"Bad [END] format: {out!r}" + + def test_log_end_format_failure(self): + from inference import log_end + out = self._capture(log_end, False, 5, 0.100, [0.1, 0.0, 0.0, 0.0, 0.0]) + assert self.END_RE.match(out), f"Bad [END] format: {out!r}" + + def test_reward_two_decimal_places(self): + from inference import log_step + out = self._capture(log_step, 1, "SELECT 1", 0.5, False, None) + # reward= field must have exactly 2 decimal places + match = re.search(r"reward=(\d+\.\d+)", out) + assert match, "No reward= field found" + assert len(match.group(1).split(".")[1]) == 2 + + def test_score_three_decimal_places(self): + from inference import log_end + out = self._capture(log_end, True, 1, 1.0, [1.0]) + match = re.search(r"score=(\d+\.\d+)", out) + assert match + assert len(match.group(1).split(".")[1]) == 3 diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..3323eba8cfe2d6e18dc42b2689cf98cf18f15d07 --- /dev/null +++ b/train.py @@ -0,0 +1,125 @@ +import os +# CRITICAL: Ye line sabse upar honi chahiye kisi bhi PyTorch import se pehle! +os.environ["CUDA_VISIBLE_DEVICES"] = "0,7" + +import sys +import torch +from datasets import Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import LoraConfig +from trl import GRPOConfig, GRPOTrainer + +sys.path.insert(0, "./server") +from environment import NL2SQLEnvironment +from models import NL2SQLAction +from tasks import all_task_names, get_task + +MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct" +OUTPUT_DIR = "./qwen-7b-coder-nl2sql-grpo" + +SYSTEM_PROMPT = """You are a Senior Database Architect and an expert in SQLite. +Your task is to translate natural language questions into highly optimized, correct SQLite SELECT queries. + +STRICT RULES: +1. Output EXACTLY ONE valid SQLite query. +2. DO NOT wrap the query in markdown formatting (no ```sql or ```). +3. DO NOT output any explanations, conversational text, or preambles (e.g., never say "Here is the query"). +4. ONLY use standard SQLite functions. Avoid SQL Server, MySQL, or PostgreSQL specific syntax. +5. If the question implies ordering, use the correct ORDER BY clause. + +Your output must be executable directly against the database as-is.""" + +def build_dataset(): + data = [] + for t_name in all_task_names(): + task = get_task(t_name) + schema = task.schema_context() + for ex in task.examples: + user_content = f"SCHEMA:\n{schema}\n\nQUESTION: {ex.question}" + data.append({ + "prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_content} + ], + "task_name": t_name + }) + return Dataset.from_list(data) + +def sql_reward_func(prompts, completions, task_name, **kwargs): + rewards = [] + env = NL2SQLEnvironment() + + for idx, completion in enumerate(completions): + generated_text = completion[0]['content'] if isinstance(completion, list) else completion + + if generated_text.startswith("```"): + lines = generated_text.split("\n") + generated_text = "\n".join(l for l in lines if not l.strip().startswith("```")).strip() + + current_task = task_name[idx] if isinstance(task_name, list) else task_name + + env.reset(task_name=current_task) + + try: + action = NL2SQLAction(query=generated_text) + obs = env.step(action) + rewards.append(float(obs.reward)) + except Exception: + rewards.append(0.0) + + return rewards + +def main(): + dataset = build_dataset() + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=torch.bfloat16, + attn_implementation="sdpa" # Defaulting to sdpa to avoid any flash_attn setup issues + ) + + peft_config = LoraConfig( + r=128, + lora_alpha=256, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + bias="none", + task_type="CAUSAL_LM" + ) + + training_args = GRPOConfig( + output_dir=OUTPUT_DIR, + learning_rate=2e-5, + per_device_train_batch_size=2, + gradient_accumulation_steps=4, + max_completion_length=256, + num_generations=8, + temperature=0.5, + bf16=True, + logging_steps=5, + num_train_epochs=10, + report_to="none", + remove_unused_columns=False, + ddp_find_unused_parameters=False + ) + + trainer = GRPOTrainer( + model=model, + reward_funcs=sql_reward_func, + args=training_args, + train_dataset=dataset, + peft_config=peft_config, + processing_class=tokenizer + ) + + trainer.train() + + if trainer.accelerator.is_main_process: + trainer.model.save_pretrained(f"{OUTPUT_DIR}/final") + tokenizer.save_pretrained(f"{OUTPUT_DIR}/final") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/value_swapper.py b/value_swapper.py new file mode 100644 index 0000000000000000000000000000000000000000..2e3425b8f61983ea4f86c5f017ae81d6a527ce95 --- /dev/null +++ b/value_swapper.py @@ -0,0 +1,66 @@ +import json +import re +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) +from data_factory.templates import ALL_TEMPLATES + +# Define strict categorical swaps based on the exact schemas +SWAP_RULES = { + "ecommerce": [ + (r"'gold'", r"gold", ["'silver'", "'bronze'"], ["silver", "bronze"]), + (r"'delivered'", r"delivered", ["'pending'", "'processing'", "'shipped'", "'cancelled'"], ["pending", "processing", "shipped", "cancelled"]), + (r"'India'", r"India", ["'USA'", "'Germany'", "'UK'", "'Canada'"], ["USA", "Germany", "UK", "Canada"]) + ], + "healthcare": [ + (r"'severe'", r"severe", ["'mild'", "'moderate'"], ["mild", "moderate"]), + (r"'completed'", r"completed", ["'scheduled'", "'cancelled'", "'no_show'"], ["scheduled", "cancelled", "no-show"]) + ], + "finance": [ + (r"'active'", r"active", ["'dormant'", "'closed'"], ["dormant", "closed"]), + (r"'credit'", r"credit", ["'debit'"], ["debit"]), + (r"'verified'", r"verified", ["'pending'", "'rejected'"], ["pending", "rejected"]) + ], + "hr": [ + (r"'active'", r"active", ["'resigned'", "'terminated'"], ["resigned", "terminated"]) + ] +} + +def generate_swaps(): + expanded_templates = [] + + for template in ALL_TEMPLATES: + expanded_templates.append(template) # Keep the original + domain = template["domain"] + + if domain not in SWAP_RULES: + continue + + for sql_target, nl_target, sql_replacements, nl_replacements in SWAP_RULES[domain]: + if re.search(sql_target, template["sql"], re.IGNORECASE): + for sql_repl, nl_repl in zip(sql_replacements, nl_replacements): + new_template = template.copy() + + # Swap in SQL + new_template["sql"] = re.sub(sql_target, sql_repl, template["sql"], flags=re.IGNORECASE) + + # Swap in NL and Description + new_template["base_nl"] = re.sub(nl_target, nl_repl, template["base_nl"], flags=re.IGNORECASE) + new_template["description"] = re.sub(nl_target, nl_repl, template["description"], flags=re.IGNORECASE) + + # Create a unique ID + new_template["id"] = f"{template.get('id', 'temp')}_swap_{nl_repl.replace(' ', '_')}" + + expanded_templates.append(new_template) + + return expanded_templates + +if __name__ == "__main__": + swapped = generate_swaps() + print(f"Original Templates: {len(ALL_TEMPLATES)}") + print(f"After Value Swapping: {len(swapped)}") + + with open("swapped_templates.json", "w") as f: + json.dump(swapped, f, indent=2) + print("Saved to swapped_templates.json") \ No newline at end of file