Spaces:
Running
Running
Melika Kheirieh
commited on
Commit
·
570f7bd
1
Parent(s):
5eeca35
init: NL2SQL Copilot base with API and Dockerfile
Browse files- .github/workflows/ci.yml +59 -0
- Dockerfile +40 -0
- adapters/db/base.py +13 -0
- adapters/db/postgres_adapter.py +44 -0
- adapters/db/sqlite_adapter.py +32 -0
- adapters/llm/base.py +16 -0
- adapters/llm/openai_provider.py +113 -0
- app.py +0 -235
- app/__init__.py +0 -0
- app/main.py +29 -0
- app/routers/__init__.py +0 -0
- app/routers/nl2sql.py +83 -0
- app/schemas.py +29 -0
- benchmarks/results/demo.jsonl +3 -0
- benchmarks/run.py +143 -0
- docker-compose.yml +40 -0
- infra/migrate.sql +8 -0
- logs/spider_eval/dev_gold_1760430884.txt +0 -10
- logs/spider_eval/dev_metrics_1760430884.json +0 -15
- logs/spider_eval/dev_pred_1760430884.txt +0 -10
- logs/spider_eval/dev_results_1760430884.jsonl +0 -11
- nl2sql/__init__.py +0 -0
- nl2sql/ambiguity_detector.py +16 -0
- nl2sql/executor.py +21 -0
- nl2sql/generator.py +49 -0
- nl2sql/pipeline.py +165 -0
- nl2sql/planner.py +16 -0
- nl2sql/repair.py +29 -0
- nl2sql/safety.py +75 -0
- nl2sql/stubs.py +31 -0
- nl2sql/types.py +19 -0
- nl2sql/verifier.py +33 -0
- requirements.txt +11 -8
- tests/conftest.py +7 -0
- tests/test_ambiguity.py +19 -0
- tests/test_executor.py +16 -0
- tests/test_generator.py +96 -0
- tests/test_nl2sql_router.py +95 -0
- tests/test_openai_provider.py +87 -0
- tests/test_pipeline_integration.py +142 -0
- tests/test_safety.py +103 -0
- tests/test_stage_types.py +18 -0
- ui/benchmark_app.py +45 -0
.github/workflows/ci.yml
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: CI
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [ main, develop ]
|
| 6 |
+
pull_request:
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
build-test:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
|
| 12 |
+
env:
|
| 13 |
+
PIP_NO_CACHE_DIR: 1
|
| 14 |
+
|
| 15 |
+
steps:
|
| 16 |
+
- name: Checkout repository
|
| 17 |
+
uses: actions/checkout@v4
|
| 18 |
+
|
| 19 |
+
- name: Set up Python
|
| 20 |
+
uses: actions/setup-python@v5
|
| 21 |
+
with:
|
| 22 |
+
python-version: "3.12"
|
| 23 |
+
|
| 24 |
+
- name: Install dependencies
|
| 25 |
+
run: |
|
| 26 |
+
python -m pip install --upgrade pip
|
| 27 |
+
pip install -r requirements.txt
|
| 28 |
+
|
| 29 |
+
- name: Lint (ruff)
|
| 30 |
+
run: ruff check .
|
| 31 |
+
|
| 32 |
+
- name: Type check (mypy)
|
| 33 |
+
run: mypy .
|
| 34 |
+
|
| 35 |
+
- name: Run tests
|
| 36 |
+
run: pytest -q
|
| 37 |
+
|
| 38 |
+
docker-build:
|
| 39 |
+
needs: build-test
|
| 40 |
+
runs-on: ubuntu-latest
|
| 41 |
+
if: github.ref == 'refs/heads/main'
|
| 42 |
+
|
| 43 |
+
steps:
|
| 44 |
+
- name: Checkout code
|
| 45 |
+
uses: actions/checkout@v4
|
| 46 |
+
|
| 47 |
+
- name: Login to GHCR
|
| 48 |
+
if: secrets.GHCR_TOKEN != ''
|
| 49 |
+
run: echo "${{ secrets.GHCR_TOKEN }}" | docker login ghcr.io -u ${{ github.actor }} --password-stdin
|
| 50 |
+
|
| 51 |
+
- name: Build Docker image
|
| 52 |
+
run: |
|
| 53 |
+
IMAGE=ghcr.io/${{ github.repository_owner }}/nl2sql-copilot:${{ github.sha }}
|
| 54 |
+
docker build -t $IMAGE .
|
| 55 |
+
echo "IMAGE=$IMAGE" >> $GITHUB_ENV
|
| 56 |
+
|
| 57 |
+
- name: Push image
|
| 58 |
+
if: secrets.GHCR_TOKEN != ''
|
| 59 |
+
run: docker push $IMAGE
|
Dockerfile
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------- Stage 1: Build wheels ----------
|
| 2 |
+
FROM python:3.12-slim AS builder
|
| 3 |
+
|
| 4 |
+
# Set working directory for the build stage
|
| 5 |
+
WORKDIR /build
|
| 6 |
+
|
| 7 |
+
# Install system dependencies required to compile some Python packages
|
| 8 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 9 |
+
build-essential libpq-dev && \
|
| 10 |
+
rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# Copy only requirements first (so Docker caching works efficiently)
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
|
| 15 |
+
# Build all dependencies as wheel files inside /wheels
|
| 16 |
+
RUN pip install --upgrade pip && \
|
| 17 |
+
pip wheel --wheel-dir /wheels -r requirements.txt
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ---------- Stage 2: Runtime image ----------
|
| 21 |
+
FROM python:3.12-slim AS runtime
|
| 22 |
+
|
| 23 |
+
# Set working directory for the application
|
| 24 |
+
WORKDIR /app
|
| 25 |
+
|
| 26 |
+
# Copy prebuilt wheels from the builder stage
|
| 27 |
+
COPY --from=builder /wheels /wheels
|
| 28 |
+
|
| 29 |
+
# Install dependencies from prebuilt wheels (no need to compile again)
|
| 30 |
+
COPY requirements.txt .
|
| 31 |
+
RUN pip install --no-cache-dir --find-links=/wheels -r requirements.txt
|
| 32 |
+
|
| 33 |
+
# Copy the actual application code
|
| 34 |
+
COPY . .
|
| 35 |
+
|
| 36 |
+
# Expose the FastAPI port
|
| 37 |
+
EXPOSE 8000
|
| 38 |
+
|
| 39 |
+
# Start FastAPI with Uvicorn
|
| 40 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--proxy-headers"]
|
adapters/db/base.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List, Dict, Any, Protocol
|
| 2 |
+
from typing import List, Tuple, Any
|
| 3 |
+
|
| 4 |
+
class DBAdapter(Protocol):
|
| 5 |
+
"""Abstract database adapter for read-only queries."""
|
| 6 |
+
name: str
|
| 7 |
+
dialect: str
|
| 8 |
+
|
| 9 |
+
def preview_schema(self, limit_per_table: int = 0) -> str:
|
| 10 |
+
"""Generate a readable summary of the database schema with optional sample rows per table."""
|
| 11 |
+
|
| 12 |
+
def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
|
| 13 |
+
"""Execute a SELECT query and return (rows, columns)."""
|
adapters/db/postgres_adapter.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import psycopg
|
| 2 |
+
from typing import Any, List, Tuple
|
| 3 |
+
from adapters.db.base import DBAdapter
|
| 4 |
+
|
| 5 |
+
class PostgresAdapter(DBAdapter):
|
| 6 |
+
name = "postgres"
|
| 7 |
+
dialect = "postgres"
|
| 8 |
+
|
| 9 |
+
def __init__(self, dsn: str):
|
| 10 |
+
"""
|
| 11 |
+
DSN example:
|
| 12 |
+
"dbname=demo user=postgres password=postgres host=localhost port=5432"
|
| 13 |
+
"""
|
| 14 |
+
self.dsn = dsn
|
| 15 |
+
|
| 16 |
+
def preview_schema(self, limit_per_table: int = 0) -> str:
|
| 17 |
+
with psycopg.connect(self.dsn) as conn:
|
| 18 |
+
cur = conn.cursor()
|
| 19 |
+
cur.execute("""
|
| 20 |
+
SELECT table_name
|
| 21 |
+
FROM information_schema.tables
|
| 22 |
+
WHERE table_schema = 'public';
|
| 23 |
+
""")
|
| 24 |
+
tables = [t[0] for t in cur.fetchall()]
|
| 25 |
+
lines = []
|
| 26 |
+
for t in tables:
|
| 27 |
+
cur.execute(f"""
|
| 28 |
+
SELECT column_name, data_type
|
| 29 |
+
FROM information_schema.columns
|
| 30 |
+
WHERE table_name = %s;
|
| 31 |
+
""", (t,))
|
| 32 |
+
cols = [f"{c[0]}:{c[1]}" for c in cur.fetchall()]
|
| 33 |
+
lines.append(f"- {t} ({', '.join(cols)})")
|
| 34 |
+
return "\n".join(lines)
|
| 35 |
+
|
| 36 |
+
def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
|
| 37 |
+
if not sql.strip().lower().startswith("select"):
|
| 38 |
+
raise ValueError("Only SELECT statements are allowed.")
|
| 39 |
+
with psycopg.connect(self.dsn) as conn:
|
| 40 |
+
cur = conn.cursor()
|
| 41 |
+
cur.execute(sql)
|
| 42 |
+
rows = cur.fetchall()
|
| 43 |
+
cols = [desc[0] for desc in cur.description]
|
| 44 |
+
return rows, cols
|
adapters/db/sqlite_adapter.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
from typing import List, Tuple, Any
|
| 3 |
+
from adapters.db.base import DBAdapter
|
| 4 |
+
|
| 5 |
+
class SQLiteAdapter(DBAdapter):
|
| 6 |
+
name = "sqlite"
|
| 7 |
+
dialect = "sqlite"
|
| 8 |
+
|
| 9 |
+
def __init__(self, path: str):
|
| 10 |
+
self.path = path
|
| 11 |
+
|
| 12 |
+
def preview_schema(self, limit_per_table: int = 0) -> str:
|
| 13 |
+
with sqlite3.connect(self.path, uri=True) as conn:
|
| 14 |
+
cur = conn.cursor()
|
| 15 |
+
cur.execute("PRAGMA foreign_keys = ON")
|
| 16 |
+
tables = [t[0] for t in cur.fetchall()]
|
| 17 |
+
lines = []
|
| 18 |
+
for t in tables:
|
| 19 |
+
cur.execute(f"PRAGMA table_info({t});")
|
| 20 |
+
cols = [f"{c[1]}:{c[2]}" for c in cur.fetchall()]
|
| 21 |
+
lines.append(f"- {t} ({', '.join(cols)})")
|
| 22 |
+
return "\n".join(lines)
|
| 23 |
+
|
| 24 |
+
def execute(self, sql: str) -> Tuple[List[Tuple[Any, ...]], List[str]]:
|
| 25 |
+
# enforce read-only connection
|
| 26 |
+
uri = f"file:{self.path}?mode=ro&uri=true"
|
| 27 |
+
with sqlite3.connect(uri, uri=True, timeout=3) as conn:
|
| 28 |
+
cur = conn.cursor()
|
| 29 |
+
cur.execute(sql)
|
| 30 |
+
rows = cur.fetchall()
|
| 31 |
+
cols = [desc[0] for desc in cur.description]
|
| 32 |
+
return rows, cols
|
adapters/llm/base.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# adapters/llm/base.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
from typing import Tuple, List, Dict, Any, Protocol
|
| 4 |
+
|
| 5 |
+
class LLMProvider(Protocol):
|
| 6 |
+
provider_id: str
|
| 7 |
+
|
| 8 |
+
def plan(self, *, user_query: str, schema_preview: str) -> Tuple[str, int, int, float]:
|
| 9 |
+
"""Return (plan_text, token_in, token_out, cost_usd)."""
|
| 10 |
+
|
| 11 |
+
def generate_sql(self, *, user_query: str, schema_preview: str, plan_text: str,
|
| 12 |
+
clarify_answers: Dict[str, Any] | None = None) -> Tuple[str, str, int, int, float]:
|
| 13 |
+
"""Return (sql, rationale, token_in, token_out, cost_usd)."""
|
| 14 |
+
|
| 15 |
+
def repair(self, *, sql: str, error_msg: str, schema_preview: str) -> Tuple[str, int, int, float]:
|
| 16 |
+
"""Return (patched_sql, token_in, token_out, cost_usd)."""
|
adapters/llm/openai_provider.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import os
|
| 3 |
+
from typing import Tuple, Dict, Any, List
|
| 4 |
+
import json
|
| 5 |
+
from adapters.llm.base import LLMProvider
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
|
| 8 |
+
# NOTE: Read keys/base URL from env. Do NOT pass base_url in constructors.
|
| 9 |
+
# - OPENAI_API_KEY (required)
|
| 10 |
+
# - OPENAI_BASE_URL (optional; defaults to OpenAI public)
|
| 11 |
+
# - OPENAI_MODEL_ID (e.g., "gpt-4o-mini")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class OpenAIProvider(LLMProvider):
|
| 16 |
+
provider_id = "openai"
|
| 17 |
+
|
| 18 |
+
def __init__(self) -> None:
|
| 19 |
+
self.client = OpenAI(
|
| 20 |
+
api_key=os.environ["OPENAI_API_KEY"],
|
| 21 |
+
base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
| 22 |
+
)
|
| 23 |
+
self.model = os.getenv("OPENAI_MODEL_ID", "gpt-4o-mini")
|
| 24 |
+
|
| 25 |
+
def plan(self, *, user_query, schema_preview):
|
| 26 |
+
completion = self.client.chat.completions.create(
|
| 27 |
+
model=self.model,
|
| 28 |
+
messages=[
|
| 29 |
+
{"role": "system", "content": "You create SQL query plans."},
|
| 30 |
+
{"role": "user", "content": f"Query: {user_query}\nSchema:\n{schema_preview}"}
|
| 31 |
+
],
|
| 32 |
+
temperature=0
|
| 33 |
+
)
|
| 34 |
+
msg = completion.choices[0].message.content
|
| 35 |
+
usage = completion.usage
|
| 36 |
+
return msg, usage.prompt_tokens, usage.completion_tokens, self._estimate_cost(usage)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def generate_sql(self, *, user_query, schema_preview, plan_text, clarify_answers=None):
|
| 40 |
+
prompt = f"""
|
| 41 |
+
You are a precise SQL generator.
|
| 42 |
+
Return ONLY valid JSON with two keys: "sql" and "rationale".
|
| 43 |
+
Do not include any markdown, backticks, or extra text.
|
| 44 |
+
|
| 45 |
+
Example:
|
| 46 |
+
{{
|
| 47 |
+
"sql": "SELECT * FROM singer;",
|
| 48 |
+
"rationale": "The user requested to list all singers."
|
| 49 |
+
}}
|
| 50 |
+
|
| 51 |
+
Now generate JSON for this input:
|
| 52 |
+
|
| 53 |
+
User query: {user_query}
|
| 54 |
+
Schema preview:
|
| 55 |
+
{schema_preview}
|
| 56 |
+
Plan: {plan_text}
|
| 57 |
+
Clarifications: {clarify_answers}
|
| 58 |
+
"""
|
| 59 |
+
completion = self.client.chat.completions.create(
|
| 60 |
+
model=self.model,
|
| 61 |
+
messages=[
|
| 62 |
+
{"role": "system", "content": "You convert natural language to SQL."},
|
| 63 |
+
{"role": "user", "content": prompt}
|
| 64 |
+
],
|
| 65 |
+
temperature=0
|
| 66 |
+
)
|
| 67 |
+
content = completion.choices[0].message.content.strip()
|
| 68 |
+
usage = completion.usage # ← لازم داریم
|
| 69 |
+
t_in = usage.prompt_tokens if usage else None
|
| 70 |
+
t_out = usage.completion_tokens if usage else None
|
| 71 |
+
cost = self._estimate_cost(usage) if usage else None
|
| 72 |
+
|
| 73 |
+
# Robust JSON parse (with fallback to substring)
|
| 74 |
+
try:
|
| 75 |
+
parsed = json.loads(content)
|
| 76 |
+
except json.JSONDecodeError:
|
| 77 |
+
start = content.find("{")
|
| 78 |
+
end = content.rfind("}")
|
| 79 |
+
if start != -1 and end != -1:
|
| 80 |
+
try:
|
| 81 |
+
parsed = json.loads(content[start:end + 1])
|
| 82 |
+
except Exception:
|
| 83 |
+
raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"Invalid LLM JSON output: {content[:200]}")
|
| 86 |
+
|
| 87 |
+
sql = (parsed.get("sql") or "").strip()
|
| 88 |
+
rationale = parsed.get("rationale") or ""
|
| 89 |
+
|
| 90 |
+
if not sql:
|
| 91 |
+
raise ValueError("LLM returned empty 'sql'")
|
| 92 |
+
|
| 93 |
+
# IMPORTANT: return the expected 5-tuple
|
| 94 |
+
return sql, rationale, t_in, t_out, cost
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def repair(self, *, sql, error_msg, schema_preview):
|
| 98 |
+
completion = self.client.chat.completions.create(
|
| 99 |
+
model=self.model,
|
| 100 |
+
messages=[
|
| 101 |
+
{"role": "system", "content": "You fix SQL queries keeping them SELECT-only."},
|
| 102 |
+
{"role": "user", "content": f"SQL:\n{sql}\nError:\n{error_msg}\nSchema:\n{schema_preview}"}
|
| 103 |
+
],
|
| 104 |
+
temperature=0
|
| 105 |
+
)
|
| 106 |
+
msg = completion.choices[0].message.content
|
| 107 |
+
usage = completion.usage
|
| 108 |
+
return msg, usage.prompt_tokens, usage.completion_tokens, self._estimate_cost(usage)
|
| 109 |
+
|
| 110 |
+
def _estimate_cost(self, usage):
|
| 111 |
+
# Rough estimation example — can be refined with official token pricing
|
| 112 |
+
total = usage.prompt_tokens + usage.completion_tokens
|
| 113 |
+
return total * 0.000001
|
app.py
DELETED
|
@@ -1,235 +0,0 @@
|
|
| 1 |
-
from config import (
|
| 2 |
-
LLM_MODEL,
|
| 3 |
-
LLM_TEMPERATURE,
|
| 4 |
-
FORBIDDEN_KEYWORDS,
|
| 5 |
-
FORBIDDEN_TABLES
|
| 6 |
-
)
|
| 7 |
-
import os
|
| 8 |
-
import sqlite3
|
| 9 |
-
import json
|
| 10 |
-
import re
|
| 11 |
-
from typing import Optional, Tuple, List
|
| 12 |
-
|
| 13 |
-
import gradio as gr
|
| 14 |
-
import sqlglot
|
| 15 |
-
from sqlglot import exp
|
| 16 |
-
|
| 17 |
-
from langchain_openai import ChatOpenAI
|
| 18 |
-
from langchain_community.utilities import SQLDatabase
|
| 19 |
-
from langchain.chains import create_sql_query_chain
|
| 20 |
-
from langchain.prompts import ChatPromptTemplate
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def get_readonly_sqlite_url(db_path: str) -> str:
|
| 24 |
-
return f"file:{db_path}?mode=ro&uri=true"
|
| 25 |
-
|
| 26 |
-
def get_schema_preview(db_path: str, limit_per_table: int = 0) -> str:
|
| 27 |
-
uri = get_readonly_sqlite_url(db_path)
|
| 28 |
-
with sqlite3.connect(uri, uri=True, timeout=3) as conn:
|
| 29 |
-
conn.row_factory = sqlite3.Row
|
| 30 |
-
cur = conn.cursor()
|
| 31 |
-
cur.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;")
|
| 32 |
-
tables = [r["name"] for r in cur.fetchall()]
|
| 33 |
-
lines = []
|
| 34 |
-
for t in tables:
|
| 35 |
-
# skip SQLite internals
|
| 36 |
-
if t in FORBIDDEN_TABLES:
|
| 37 |
-
continue
|
| 38 |
-
cur.execute(f"PRAGMA table_info({t});")
|
| 39 |
-
cols = cur.fetchall()
|
| 40 |
-
col_line = ", ".join([f"{c['name']}:{c['type']}" for c in cols])
|
| 41 |
-
lines.append(f"- {t} ({col_line})")
|
| 42 |
-
if limit_per_table > 0:
|
| 43 |
-
try:
|
| 44 |
-
cur.execute(f"SELECT * FROM {t} LIMIT {limit_per_table};")
|
| 45 |
-
sample = cur.fetchall()
|
| 46 |
-
if sample:
|
| 47 |
-
lines.append(f" sample rows: {len(sample)}")
|
| 48 |
-
except Exception:
|
| 49 |
-
pass
|
| 50 |
-
if not lines:
|
| 51 |
-
return "(no user tables found)"
|
| 52 |
-
return "\n".join(lines)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def validate_sql_safe(sql: str) -> Tuple[bool, str]:
|
| 56 |
-
if sql.count(";") > 0:
|
| 57 |
-
if sql.strip().endswith(";"):
|
| 58 |
-
if sql.strip()[:-1].count(";") > 0:
|
| 59 |
-
return False, "Multiple statements are not allowed."
|
| 60 |
-
else:
|
| 61 |
-
return False, "Multiple statements are not allowed."
|
| 62 |
-
|
| 63 |
-
upper = re.sub(r"\s+", " ", sql).strip()
|
| 64 |
-
for kw in FORBIDDEN_KEYWORDS:
|
| 65 |
-
if re.search(rf"\b{kw}\b", upper):
|
| 66 |
-
return False, f"Keyword '{kw}' is not allowed."
|
| 67 |
-
|
| 68 |
-
try:
|
| 69 |
-
parsed = sqlglot.parse(sql, read='sqlite')
|
| 70 |
-
except Exception as e:
|
| 71 |
-
return False, f"SQL parse error: {e}"
|
| 72 |
-
|
| 73 |
-
if not parsed or len(parsed) != 1:
|
| 74 |
-
return False, "Exactly one SQL statement is allowed."
|
| 75 |
-
|
| 76 |
-
stmt = parsed[0]
|
| 77 |
-
if not isinstance(stmt, exp.Select):
|
| 78 |
-
return False, "Only SELECT statements are allowed."
|
| 79 |
-
|
| 80 |
-
for table in stmt.find_all(exp.Table):
|
| 81 |
-
table_name = table.name.lower() if table.name else ""
|
| 82 |
-
if table_name in FORBIDDEN_TABLES:
|
| 83 |
-
return False, f"Access to {table_name} is not allowed."
|
| 84 |
-
|
| 85 |
-
return True, "OK"
|
| 86 |
-
|
| 87 |
-
def execute_select(db_path: str, sql: str, max_rows: int = 1000, timeout: float = 5.0) -> Tuple[list[str], List[List]]:
|
| 88 |
-
uri = get_readonly_sqlite_url(db_path)
|
| 89 |
-
if not re.search(r"\bLIMIT\b", sql, re.IGNORECASE):
|
| 90 |
-
sql = f"{sql.rstrip(';')} LIMIT {max_rows}"
|
| 91 |
-
|
| 92 |
-
with sqlite3.connect(uri, uri=True, timeout=timeout) as conn:
|
| 93 |
-
conn.row_factory = sqlite3.Row
|
| 94 |
-
cur = conn.cursor()
|
| 95 |
-
cur.execute(sql)
|
| 96 |
-
rows = cur.fetchall()
|
| 97 |
-
if rows:
|
| 98 |
-
cols = rows[0].keys()
|
| 99 |
-
data = [list(r) for r in rows]
|
| 100 |
-
return list(cols), data
|
| 101 |
-
else:
|
| 102 |
-
return [], []
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
custom_prompt = ChatPromptTemplate.from_template("""
|
| 107 |
-
Given the following question, return ONLY a valid SQL query in JSON form.
|
| 108 |
-
|
| 109 |
-
Question: {input}
|
| 110 |
-
Database schema: {table_info}
|
| 111 |
-
|
| 112 |
-
You may sample/preview at most {top_k} rows if you need examples.
|
| 113 |
-
|
| 114 |
-
Respond in this exact JSON format:
|
| 115 |
-
{{
|
| 116 |
-
"sql": "<SQL_QUERY_HERE>"
|
| 117 |
-
}}
|
| 118 |
-
""")
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def make_sql_chain(sql_db: SQLDatabase):
|
| 122 |
-
assert hasattr(sql_db, "get_table_info"), "Expected LangChain SQLDatabase"
|
| 123 |
-
llm = ChatOpenAI(model=LLM_MODEL, temperature=LLM_TEMPERATURE)
|
| 124 |
-
chain = create_sql_query_chain(llm, sql_db, prompt=custom_prompt, k=20)
|
| 125 |
-
return chain
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def on_upload_database(db_file, state):
|
| 129 |
-
if db_file is None:
|
| 130 |
-
return state, "No file provided.", "(no schema)"
|
| 131 |
-
path = db_file.name
|
| 132 |
-
|
| 133 |
-
sql_db = SQLDatabase.from_uri(f"sqlite:///{path}")
|
| 134 |
-
|
| 135 |
-
schema_text = get_schema_preview(path, limit_per_table=0)
|
| 136 |
-
|
| 137 |
-
chain = make_sql_chain(sql_db)
|
| 138 |
-
|
| 139 |
-
new_state = {
|
| 140 |
-
"db_path": path,
|
| 141 |
-
"sql_db": sql_db,
|
| 142 |
-
"schema_text": schema_text,
|
| 143 |
-
"chain": chain,
|
| 144 |
-
}
|
| 145 |
-
return new_state, f"Database '{os.path.basename(path)}' uploaded successfully.", schema_text
|
| 146 |
-
|
| 147 |
-
def extract_sql_safe(output_text: str) -> str:
|
| 148 |
-
try:
|
| 149 |
-
obj = json.loads(output_text)
|
| 150 |
-
if isinstance(obj, dict) and "sql" in obj:
|
| 151 |
-
return obj["sql"].strip()
|
| 152 |
-
except Exception:
|
| 153 |
-
pass
|
| 154 |
-
m = re.search(r"```sql\s*(.*?)\s*```", output_text, re.DOTALL | re.IGNORECASE)
|
| 155 |
-
if m:
|
| 156 |
-
return m.group(1).strip()
|
| 157 |
-
return output_text.strip()
|
| 158 |
-
|
| 159 |
-
def on_generate_query(question , max_rows, state):
|
| 160 |
-
if not state or not state.get("db_path") or not state.get("chain"):
|
| 161 |
-
return "Please upload a database first.", "", ""
|
| 162 |
-
if not question or not question.strip():
|
| 163 |
-
return "Please enter a question.", "", ""
|
| 164 |
-
|
| 165 |
-
try:
|
| 166 |
-
generated_sql = state["chain"].invoke({"question": question})
|
| 167 |
-
|
| 168 |
-
sql = extract_sql_safe(str(generated_sql))
|
| 169 |
-
|
| 170 |
-
ok, msg = validate_sql_safe(sql)
|
| 171 |
-
if not ok:
|
| 172 |
-
return f"Blocked SQL: {msg}", sql, ""
|
| 173 |
-
|
| 174 |
-
cols, rows = execute_select(state["db_path"], sql, max_rows=max_rows)
|
| 175 |
-
if not cols:
|
| 176 |
-
return f"No rows returned.", sql, "[]"
|
| 177 |
-
|
| 178 |
-
sample = [dict(zip(cols, r)) for r in rows[:50]]
|
| 179 |
-
return f"Returned {len(rows)} row(s). Showing up to 50.", sql, json.dumps(sample, indent=2)
|
| 180 |
-
|
| 181 |
-
except Exception as e:
|
| 182 |
-
return f"Error: {e}", "", ""
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
with gr.Blocks(title="nl2sql-copilot-prototype (safe)") as demo:
|
| 186 |
-
gr.Markdown("# nl2sql-copilot-prototype (Sqlite, safe)")
|
| 187 |
-
gr.Markdown(
|
| 188 |
-
"Upload a **SQLite** file, ask a question in natural language, "
|
| 189 |
-
"and I will: (1) generate SQL, (2) validate it (SELECT-only), (3) execute read-only, "
|
| 190 |
-
"and (4) show you the results."
|
| 191 |
-
)
|
| 192 |
-
|
| 193 |
-
state = gr.State({"db_path": None, "sql_db": None, "schema_text": "", "chain": None})
|
| 194 |
-
|
| 195 |
-
with gr.Row():
|
| 196 |
-
db_file = gr.File(label="Upload SQlite Database", file_types=[".sqlite", ".db"])
|
| 197 |
-
upload_status = gr.Textbox(label="upload Status", interactive=False)
|
| 198 |
-
|
| 199 |
-
schema_box = gr.Accordion("Database schema (preview)", open=False)
|
| 200 |
-
with schema_box:
|
| 201 |
-
schema_md = gr.Markdown("(no schema)")
|
| 202 |
-
|
| 203 |
-
gr.Markdown("---")
|
| 204 |
-
|
| 205 |
-
with gr.Row():
|
| 206 |
-
question = gr.Textbox(label="Your question", placeholder="e.g., Top 10 tracks by total sales")
|
| 207 |
-
with gr.Row():
|
| 208 |
-
max_row= gr.Slider(10, 5000, value=1000, step=10, label="Max rows")
|
| 209 |
-
|
| 210 |
-
with gr.Row():
|
| 211 |
-
run_btn = gr.Button("Generate & Run SQL", variant="primary")
|
| 212 |
-
|
| 213 |
-
with gr.Row():
|
| 214 |
-
status_out = gr.Textbox(label="Status")
|
| 215 |
-
with gr.Row():
|
| 216 |
-
sql_out = gr.Code(label="Generated SQL (validated)")
|
| 217 |
-
with gr.Row():
|
| 218 |
-
result_out = gr.Code(label="Result (JSON sample)")
|
| 219 |
-
|
| 220 |
-
db_file.change(
|
| 221 |
-
fn=on_upload_database,
|
| 222 |
-
inputs=[db_file, state],
|
| 223 |
-
outputs=[state, upload_status, schema_md],
|
| 224 |
-
)
|
| 225 |
-
|
| 226 |
-
run_btn.click(
|
| 227 |
-
fn=on_generate_query,
|
| 228 |
-
inputs=[question, max_row, state],
|
| 229 |
-
outputs=[status_out, sql_out, result_out],
|
| 230 |
-
)
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
if __name__ == "__main__":
|
| 235 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/__init__.py
ADDED
|
File without changes
|
app/main.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
load_dotenv()
|
| 3 |
+
|
| 4 |
+
from fastapi import FastAPI
|
| 5 |
+
from app.routers import nl2sql
|
| 6 |
+
app = FastAPI(
|
| 7 |
+
title="NL2SQL Copilot Prototype",
|
| 8 |
+
version="0.1.0",
|
| 9 |
+
description="Natural Language -> SQL Copilot API"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
app.include_router(nl2sql.router, prefix="/api/v1")
|
| 13 |
+
|
| 14 |
+
@app.get("/healthz")
|
| 15 |
+
def health_check():
|
| 16 |
+
return {"status": "ok"}
|
| 17 |
+
|
| 18 |
+
@app.get("/")
|
| 19 |
+
def root():
|
| 20 |
+
return {"status": "ok", "message": "NL2SQL Copilot API is running"}
|
| 21 |
+
|
| 22 |
+
@app.get("/health")
|
| 23 |
+
def health():
|
| 24 |
+
return {
|
| 25 |
+
"status": "ok",
|
| 26 |
+
"db": "connected",
|
| 27 |
+
"llm": "reachable",
|
| 28 |
+
"uptime_sec": 123.4
|
| 29 |
+
}
|
app/routers/__init__.py
ADDED
|
File without changes
|
app/routers/nl2sql.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import asdict, is_dataclass
|
| 2 |
+
from fastapi import APIRouter, HTTPException
|
| 3 |
+
from app.schemas import NL2SQLRequest, NL2SQLResponse, ClarifyResponse
|
| 4 |
+
from nl2sql.pipeline import Pipeline
|
| 5 |
+
from nl2sql.ambiguity_detector import AmbiguityDetector
|
| 6 |
+
from nl2sql.safety import Safety
|
| 7 |
+
from nl2sql.planner import Planner
|
| 8 |
+
from nl2sql.generator import Generator
|
| 9 |
+
from adapters.llm.openai_provider import OpenAIProvider
|
| 10 |
+
from nl2sql.types import StageResult
|
| 11 |
+
from nl2sql.executor import Executor
|
| 12 |
+
from nl2sql.verifier import Verifier
|
| 13 |
+
from nl2sql.repair import Repair
|
| 14 |
+
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 15 |
+
from adapters.db.postgres_adapter import PostgresAdapter
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
router = APIRouter(prefix="/nl2sql")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
if os.getenv("DB_MODE", "sqlite") == "postgres":
|
| 24 |
+
_db = PostgresAdapter(os.environ["POSTGRES_DSN"])
|
| 25 |
+
else:
|
| 26 |
+
_db = SQLiteAdapter("data/chinook.db")
|
| 27 |
+
|
| 28 |
+
# --- Composition Root ---
|
| 29 |
+
_llm = OpenAIProvider()
|
| 30 |
+
# _db = SQLiteAdapter("data/chinook.db")
|
| 31 |
+
_executor = Executor(_db)
|
| 32 |
+
_verifier = Verifier()
|
| 33 |
+
_repair = Repair(_llm)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
_pipeline = Pipeline(
|
| 37 |
+
detector=AmbiguityDetector(),
|
| 38 |
+
planner=Planner(_llm),
|
| 39 |
+
generator=Generator(_llm),
|
| 40 |
+
safety=Safety(),
|
| 41 |
+
executor=_executor,
|
| 42 |
+
verifier=_verifier,
|
| 43 |
+
repair=_repair
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _to_dict(obj):
|
| 48 |
+
"""Helper: safely convert dataclass → dict."""
|
| 49 |
+
return asdict(obj) if is_dataclass(obj) else obj
|
| 50 |
+
|
| 51 |
+
def _round_trace(t: dict) -> dict:
|
| 52 |
+
if t.get("cost_usd") is not None:
|
| 53 |
+
t["cost_usd"] = round(t["cost_usd"], 6)
|
| 54 |
+
if t.get("duration_ms") is not None:
|
| 55 |
+
t["duration_ms"] = round(t["duration_ms"], 2)
|
| 56 |
+
return t
|
| 57 |
+
|
| 58 |
+
@router.post("", name="nl2sql_handler")
|
| 59 |
+
def nl2sql_handler(request: NL2SQLRequest):
|
| 60 |
+
result = _pipeline.run(user_query=request.query, schema_preview=request.schema_preview)
|
| 61 |
+
|
| 62 |
+
# --- Ensure result type ---
|
| 63 |
+
if not isinstance(result, StageResult):
|
| 64 |
+
raise HTTPException(status_code=500, detail="Pipeline returned unexpected type")
|
| 65 |
+
|
| 66 |
+
data = result.data or {}
|
| 67 |
+
|
| 68 |
+
# --- Handle ambiguity ---
|
| 69 |
+
if isinstance(data, dict) and data.get("ambiguous") and data.get("questions"):
|
| 70 |
+
return ClarifyResponse(ambiguous=True, questions=data["questions"])
|
| 71 |
+
|
| 72 |
+
# --- Handle error ---
|
| 73 |
+
if not result.ok:
|
| 74 |
+
detail = "; ".join(result.error) if result.error else "Unknown error"
|
| 75 |
+
raise HTTPException(status_code=400, detail=detail)
|
| 76 |
+
|
| 77 |
+
# --- Success case ---
|
| 78 |
+
return NL2SQLResponse(
|
| 79 |
+
ambiguous=False,
|
| 80 |
+
sql=data.get("sql"),
|
| 81 |
+
rationale=data.get("rationale"),
|
| 82 |
+
traces=[_to_dict(t) for t in data.get("traces", [])],
|
| 83 |
+
)
|
app/schemas.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List, Optional, Any, Dict
|
| 3 |
+
|
| 4 |
+
class NL2SQLRequest(BaseModel):
|
| 5 |
+
query: str
|
| 6 |
+
schema_preview: str
|
| 7 |
+
db_name: Optional[str] = "default"
|
| 8 |
+
|
| 9 |
+
class TraceModel(BaseModel):
|
| 10 |
+
stage: str
|
| 11 |
+
duration_ms: float
|
| 12 |
+
token_in: int | None = 0
|
| 13 |
+
token_out: int | None = 0
|
| 14 |
+
cost_usd: float | None = 0
|
| 15 |
+
notes: Dict[str, Any] | None = None
|
| 16 |
+
|
| 17 |
+
class NL2SQLResponse(BaseModel):
|
| 18 |
+
ambiguous: bool = False
|
| 19 |
+
sql: str
|
| 20 |
+
rationale: Optional[str] = None
|
| 21 |
+
traces: List[TraceModel] = []
|
| 22 |
+
|
| 23 |
+
class ClarifyResponse(BaseModel):
|
| 24 |
+
ambiguous: bool = True
|
| 25 |
+
questions: List[str]
|
| 26 |
+
|
| 27 |
+
class ErrorResponse(BaseModel):
|
| 28 |
+
error: str
|
| 29 |
+
details: List[str] | None = None
|
benchmarks/results/demo.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"query": "show all users", "exec_acc": 0.0, "safe_fail": 0.0, "latency_ms": 0.610041999607347, "cost_usd": 0.0, "repair_attempts": 0, "provider": "dummy-llm"}
|
| 2 |
+
{"query": "top spenders", "exec_acc": 0.0, "safe_fail": 0.0, "latency_ms": 0.005625000085274223, "cost_usd": 0.0, "repair_attempts": 0, "provider": "dummy-llm"}
|
| 3 |
+
{"query": "sum of spend", "exec_acc": 0.0, "safe_fail": 0.0, "latency_ms": 0.20833300004596822, "cost_usd": 0.0, "repair_attempts": 0, "provider": "dummy-llm"}
|
benchmarks/run.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# benchmarks/run.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
# ---- app imports
|
| 10 |
+
from nl2sql.pipeline import Pipeline
|
| 11 |
+
from nl2sql.ambiguity_detector import AmbiguityDetector
|
| 12 |
+
from nl2sql.planner import Planner
|
| 13 |
+
from nl2sql.generator import Generator
|
| 14 |
+
from nl2sql.safety import Safety
|
| 15 |
+
from nl2sql.executor import Executor
|
| 16 |
+
from nl2sql.verifier import Verifier
|
| 17 |
+
from nl2sql.repair import Repair
|
| 18 |
+
|
| 19 |
+
# ---- adapters
|
| 20 |
+
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 21 |
+
from adapters.llm.openai_provider import OpenAIProvider
|
| 22 |
+
|
| 23 |
+
# ---- fallbacks: Dummy LLM (so it runs without API keys)
|
| 24 |
+
class DummyLLM:
|
| 25 |
+
provider_id = "dummy-llm"
|
| 26 |
+
|
| 27 |
+
def plan(self, *, user_query: str, schema_preview: str):
|
| 28 |
+
text = f"- understand question: {user_query}\n- identify tables\n- join if needed\n- filter\n- order/limit"
|
| 29 |
+
return text, 0, 0, 0.0
|
| 30 |
+
|
| 31 |
+
def generate_sql(self, *, user_query: str, schema_preview: str, plan_text: str, clarify_answers=None):
|
| 32 |
+
# naive demo SQL (so pipeline flows end-to-end)
|
| 33 |
+
sql = "SELECT 1 AS one;"
|
| 34 |
+
rationale = "Demo SQL from DummyLLM"
|
| 35 |
+
return sql, rationale, 0, 0, 0.0
|
| 36 |
+
|
| 37 |
+
def repair(self, *, sql: str, error_msg: str, schema_preview: str):
|
| 38 |
+
return sql, 0, 0, 0.0
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def ensure_demo_db(path: Path) -> None:
|
| 42 |
+
"""Create a tiny SQLite db if missing, so executor has something to run."""
|
| 43 |
+
if path.exists():
|
| 44 |
+
return
|
| 45 |
+
import sqlite3
|
| 46 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 47 |
+
con = sqlite3.connect(path)
|
| 48 |
+
cur = con.cursor()
|
| 49 |
+
cur.execute("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT, spend REAL);")
|
| 50 |
+
cur.executemany("INSERT INTO users(id,name,spend) VALUES(?,?,?)",
|
| 51 |
+
[(1,"Alice",120.5),(2,"Bob",80.0),(3,"Carol",155.0)])
|
| 52 |
+
con.commit()
|
| 53 |
+
con.close()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def build_pipeline(db_path: Path, use_openai: bool) -> Pipeline:
|
| 57 |
+
# DB adapter
|
| 58 |
+
db = SQLiteAdapter(str(db_path))
|
| 59 |
+
executor = Executor(db)
|
| 60 |
+
# LLM provider
|
| 61 |
+
if use_openai and os.getenv("OPENAI_API_KEY"):
|
| 62 |
+
llm = OpenAIProvider()
|
| 63 |
+
else:
|
| 64 |
+
llm = DummyLLM()
|
| 65 |
+
# stages
|
| 66 |
+
detector = AmbiguityDetector()
|
| 67 |
+
planner = Planner(llm)
|
| 68 |
+
generator = Generator(llm)
|
| 69 |
+
safety = Safety()
|
| 70 |
+
verifier = Verifier()
|
| 71 |
+
repair = Repair(llm)
|
| 72 |
+
# pipeline
|
| 73 |
+
return Pipeline(
|
| 74 |
+
detector=detector,
|
| 75 |
+
planner=planner,
|
| 76 |
+
generator=generator,
|
| 77 |
+
safety=safety,
|
| 78 |
+
executor=executor,
|
| 79 |
+
verifier=verifier,
|
| 80 |
+
repair=repair,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def run_benchmark(queries, schema_preview, pipeline: Pipeline, outfile: Path):
|
| 85 |
+
results = []
|
| 86 |
+
for q in queries:
|
| 87 |
+
t0 = time.perf_counter()
|
| 88 |
+
r = pipeline.run(user_query=q, schema_preview=schema_preview)
|
| 89 |
+
latency_ms = (time.perf_counter()-t0)*1000
|
| 90 |
+
ok = (not r.get("ambiguous")) and ("error" not in r)
|
| 91 |
+
|
| 92 |
+
traces = r.get("traces", [])
|
| 93 |
+
cost_sum = 0.0
|
| 94 |
+
for t in traces:
|
| 95 |
+
try:
|
| 96 |
+
cost_sum += float(t.get("cost_usd", 0.0))
|
| 97 |
+
except Exception:
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
results.append({
|
| 101 |
+
"query": q,
|
| 102 |
+
"exec_acc": 1.0 if ok else 0.0,
|
| 103 |
+
"safe_fail": 0.0 if ok else 1.0 if "unsafe" in str(r).lower() else 0.0,
|
| 104 |
+
"latency_ms": latency_ms,
|
| 105 |
+
"cost_usd": cost_sum,
|
| 106 |
+
"repair_attempts": sum(1 for t in traces if t.get("stage") == "repair"),
|
| 107 |
+
"provider": pipeline.generator.llm.provider_id if hasattr(pipeline.generator, "llm") else "unknown",
|
| 108 |
+
})
|
| 109 |
+
|
| 110 |
+
outfile.parent.mkdir(parents=True, exist_ok=True)
|
| 111 |
+
with open(outfile, "w") as f:
|
| 112 |
+
for row in results:
|
| 113 |
+
f.write(json.dumps(row) + "\n")
|
| 114 |
+
print(f"[OK] wrote {len(results)} rows → {outfile}")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def main():
|
| 118 |
+
parser = argparse.ArgumentParser()
|
| 119 |
+
parser.add_argument("--outfile", default="benchmarks/results/demo.jsonl")
|
| 120 |
+
parser.add_argument("--db", default="data/bench_demo.db")
|
| 121 |
+
parser.add_argument("--use-openai", action="store_true", help="Use OpenAI provider if API key present")
|
| 122 |
+
args = parser.parse_args()
|
| 123 |
+
|
| 124 |
+
ROOT = Path(__file__).resolve().parents[1] # project root
|
| 125 |
+
outfile = (ROOT / args.outfile).resolve()
|
| 126 |
+
db_path = (ROOT / args.db).resolve()
|
| 127 |
+
|
| 128 |
+
ensure_demo_db(db_path)
|
| 129 |
+
pipe = build_pipeline(db_path, use_openai=args.use_openai)
|
| 130 |
+
|
| 131 |
+
# a small demo set; replace with Spider when ready
|
| 132 |
+
queries = [
|
| 133 |
+
"show all users",
|
| 134 |
+
"top spenders",
|
| 135 |
+
"sum of spend",
|
| 136 |
+
]
|
| 137 |
+
schema_preview = "CREATE TABLE users(id INT, name TEXT, spend REAL);"
|
| 138 |
+
|
| 139 |
+
run_benchmark(queries, schema_preview, pipe, outfile)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
main()
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: "3.9"
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
postgres:
|
| 5 |
+
image: postgres:16
|
| 6 |
+
container_name: nl2sql_pg
|
| 7 |
+
environment:
|
| 8 |
+
POSTGRES_USER: postgres
|
| 9 |
+
POSTGRES_PASSWORD: postgres
|
| 10 |
+
POSTGRES_DB: demo
|
| 11 |
+
volumes:
|
| 12 |
+
- pgdata:/var/lib/postgresql/data
|
| 13 |
+
- ./infra/migrate.sql:/docker-entrypoint-initdb.d/00_init.sql:ro
|
| 14 |
+
ports:
|
| 15 |
+
- "5432:5432"
|
| 16 |
+
healthcheck:
|
| 17 |
+
test: ["CMD-SHELL", "pg_isready -U postgres -d demo"]
|
| 18 |
+
interval: 5s
|
| 19 |
+
timeout: 3s
|
| 20 |
+
retries: 10
|
| 21 |
+
|
| 22 |
+
api:
|
| 23 |
+
build:
|
| 24 |
+
context: .
|
| 25 |
+
dockerfile: Dockerfile
|
| 26 |
+
container_name: nl2sql_api
|
| 27 |
+
depends_on:
|
| 28 |
+
postgres:
|
| 29 |
+
condition: service_healthy
|
| 30 |
+
environment:
|
| 31 |
+
DB_MODE: postgres
|
| 32 |
+
POSTGRES_DSN: dbname=demo user=postgres password=postgres host=postgres port=5432
|
| 33 |
+
OPENAI_MODEL_ID: gpt-4o-mini
|
| 34 |
+
OPENAI_API_KEY: ${OPENAI_API_KEY}
|
| 35 |
+
ports:
|
| 36 |
+
- "8000:8000"
|
| 37 |
+
command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--proxy-headers"]
|
| 38 |
+
|
| 39 |
+
volumes:
|
| 40 |
+
pgdata:
|
infra/migrate.sql
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CREATE TABLE IF NOT EXISTS users (
|
| 2 |
+
id SERIAL PRIMARY KEY,
|
| 3 |
+
name TEXT NOT NULL,
|
| 4 |
+
city TEXT
|
| 5 |
+
);
|
| 6 |
+
|
| 7 |
+
INSERT INTO users (name, city)
|
| 8 |
+
VALUES ('Alice', 'Tehran'), ('Bob', 'Karaj'), ('Caro', 'Isfahan');
|
logs/spider_eval/dev_gold_1760430884.txt
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
SELECT count(*) FROM singer concert_singer
|
| 2 |
-
SELECT count(*) FROM singer concert_singer
|
| 3 |
-
SELECT name , country , age FROM singer ORDER BY age DESC concert_singer
|
| 4 |
-
SELECT name , country , age FROM singer ORDER BY age DESC concert_singer
|
| 5 |
-
SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France' concert_singer
|
| 6 |
-
SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France' concert_singer
|
| 7 |
-
SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1 concert_singer
|
| 8 |
-
SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1 concert_singer
|
| 9 |
-
SELECT DISTINCT country FROM singer WHERE age > 20 concert_singer
|
| 10 |
-
SELECT DISTINCT country FROM singer WHERE age > 20 concert_singer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logs/spider_eval/dev_metrics_1760430884.json
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"commit_hash": "e207f417ac5923220817e3c3f61c72e51a98c63b",
|
| 3 |
-
"split": "dev",
|
| 4 |
-
"limit": 10,
|
| 5 |
-
"total_examples": 10,
|
| 6 |
-
"valid_examples": 10,
|
| 7 |
-
"exact_match_rate": 0.2,
|
| 8 |
-
"exact_match_structural_rate": 0.0,
|
| 9 |
-
"execution_accuracy_rate": 0.8,
|
| 10 |
-
"error_rate": 0.0,
|
| 11 |
-
"safe_check_fail_rate": 0.0,
|
| 12 |
-
"avg_gen_time": 1.4374850749969483,
|
| 13 |
-
"avg_exec_time": 0.0007865667343139648,
|
| 14 |
-
"run_id": 1760430884
|
| 15 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logs/spider_eval/dev_pred_1760430884.txt
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
SELECT COUNT(*) AS total_singers FROM singer; concert_singer
|
| 2 |
-
SELECT COUNT(*) AS total_singers FROM singer; concert_singer
|
| 3 |
-
SELECT Name, Country, Age FROM singer ORDER BY Age DESC concert_singer
|
| 4 |
-
SELECT Name, Country, Age FROM singer ORDER BY Age DESC concert_singer
|
| 5 |
-
SELECT AVG(Age) AS average_age, MIN(Age) AS minimum_age, MAX(Age) AS maximum_age FROM singer WHERE Country = 'France' concert_singer
|
| 6 |
-
SELECT AVG(Age) AS average_age, MIN(Age) AS minimum_age, MAX(Age) AS maximum_age FROM singer WHERE Country = 'France'; concert_singer
|
| 7 |
-
SELECT Name, Song_Name, Song_release_year FROM singer WHERE Age = (SELECT MAX(Age) FROM singer) concert_singer
|
| 8 |
-
SELECT Song_Name, Song_release_year FROM singer WHERE Age = (SELECT MAX(Age) FROM singer) concert_singer
|
| 9 |
-
SELECT DISTINCT Country FROM singer WHERE Age > 20 concert_singer
|
| 10 |
-
SELECT DISTINCT Country FROM singer WHERE Age > 20 concert_singer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logs/spider_eval/dev_results_1760430884.jsonl
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
# {"commit_hash": "e207f417ac5923220817e3c3f61c72e51a98c63b", "split": "dev", "limit": 10, "start_time": 1760430884}
|
| 2 |
-
{"db_id": "concert_singer", "question": "How many singers do we have?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "SELECT COUNT(*) AS total_singers FROM singer;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"total_singers\": 6\n }\n]", "gen_time": 1.2182981967926025, "exec_time": 0.0008916854858398438, "error": null, "gold_error": null, "pred_rows": "[(6,)]", "gold_rows": "[(6,)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
|
| 3 |
-
{"db_id": "concert_singer", "question": "What is the total number of singers?", "gold_sql": "SELECT count(*) FROM singer", "pred_sql": "SELECT COUNT(*) AS total_singers FROM singer;", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"total_singers\": 6\n }\n]", "gen_time": 1.261944055557251, "exec_time": 0.00044798851013183594, "error": null, "gold_error": null, "pred_rows": "[(6,)]", "gold_rows": "[(6,)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
|
| 4 |
-
{"db_id": "concert_singer", "question": "Show name, country, age for all singers ordered by age from the oldest to the youngest.", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "SELECT Name, Country, Age FROM singer ORDER BY Age DESC", "status": "Returned 6 row(s). Showing up to 50.", "output": "[\n {\n \"Name\": \"Joe Sharp\",\n \"Country\": \"Netherlands\",\n \"Age\": 52\n },\n {\n \"Name\": \"John Nizinik\",\n \"Country\": \"France\",\n \"Age\": 43\n },\n {\n \"Name\": \"Rose White\",\n \"Country\": \"France\",\n \"Age\": 41\n },\n {\n \"Name\": \"Timbaland\",\n \"Country\": \"United States\",\n \"Age\": 32\n },\n {\n \"Name\": \"Justin Brown\",\n \"Country\": \"France\",\n \"Age\": 29\n },\n {\n \"Name\": \"Tribal King\",\n \"Country\": \"France\",\n \"Age\": 25\n }\n]", "gen_time": 1.0276496410369873, "exec_time": 0.0006437301635742188, "error": null, "gold_error": null, "pred_rows": "[('Joe Sharp', 'Netherlands', 52), ('John Nizinik', 'France', 43), ('Rose White', 'France', 41), ('Timbaland', 'United States', 32), ('Justin Brown', 'France', 29), ('Tribal King', 'France', 25)]", "gold_rows": "[('Joe Sharp', 'Netherlands', 52), ('John Nizinik', 'France', 43), ('Rose White', 'France', 41), ('Timbaland', 'United States', 32), ('Justin Brown', 'France', 29), ('Tribal King', 'France', 25)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
|
| 5 |
-
{"db_id": "concert_singer", "question": "What are the names, countries, and ages for every singer in descending order of age?", "gold_sql": "SELECT name , country , age FROM singer ORDER BY age DESC", "pred_sql": "SELECT Name, Country, Age FROM singer ORDER BY Age DESC", "status": "Returned 6 row(s). Showing up to 50.", "output": "[\n {\n \"Name\": \"Joe Sharp\",\n \"Country\": \"Netherlands\",\n \"Age\": 52\n },\n {\n \"Name\": \"John Nizinik\",\n \"Country\": \"France\",\n \"Age\": 43\n },\n {\n \"Name\": \"Rose White\",\n \"Country\": \"France\",\n \"Age\": 41\n },\n {\n \"Name\": \"Timbaland\",\n \"Country\": \"United States\",\n \"Age\": 32\n },\n {\n \"Name\": \"Justin Brown\",\n \"Country\": \"France\",\n \"Age\": 29\n },\n {\n \"Name\": \"Tribal King\",\n \"Country\": \"France\",\n \"Age\": 25\n }\n]", "gen_time": 1.4854280948638916, "exec_time": 0.0007121562957763672, "error": null, "gold_error": null, "pred_rows": "[('Joe Sharp', 'Netherlands', 52), ('John Nizinik', 'France', 43), ('Rose White', 'France', 41), ('Timbaland', 'United States', 32), ('Justin Brown', 'France', 29), ('Tribal King', 'France', 25)]", "gold_rows": "[('Joe Sharp', 'Netherlands', 52), ('John Nizinik', 'France', 43), ('Rose White', 'France', 41), ('Timbaland', 'United States', 32), ('Justin Brown', 'France', 29), ('Tribal King', 'France', 25)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
|
| 6 |
-
{"db_id": "concert_singer", "question": "What is the average, minimum, and maximum age of all singers from France?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "SELECT AVG(Age) AS average_age, MIN(Age) AS minimum_age, MAX(Age) AS maximum_age FROM singer WHERE Country = 'France'", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"average_age\": 34.5,\n \"minimum_age\": 25,\n \"maximum_age\": 43\n }\n]", "gen_time": 1.8432340621948242, "exec_time": 0.0009491443634033203, "error": null, "gold_error": null, "pred_rows": "[(34.5, 25, 43)]", "gold_rows": "[(34.5, 25, 43)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
|
| 7 |
-
{"db_id": "concert_singer", "question": "What is the average, minimum, and maximum age for all French singers?", "gold_sql": "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "pred_sql": "SELECT AVG(Age) AS average_age, MIN(Age) AS minimum_age, MAX(Age) AS maximum_age FROM singer WHERE Country = 'France';", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"average_age\": 34.5,\n \"minimum_age\": 25,\n \"maximum_age\": 43\n }\n]", "gen_time": 2.5529589653015137, "exec_time": 0.0012962818145751953, "error": null, "gold_error": null, "pred_rows": "[(34.5, 25, 43)]", "gold_rows": "[(34.5, 25, 43)]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
|
| 8 |
-
{"db_id": "concert_singer", "question": "Show the name and the release year of the song by the youngest singer.", "gold_sql": "SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1", "pred_sql": "SELECT Name, Song_Name, Song_release_year FROM singer WHERE Age = (SELECT MAX(Age) FROM singer)", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Name\": \"Joe Sharp\",\n \"Song_Name\": \"You\",\n \"Song_release_year\": \"1992\"\n }\n]", "gen_time": 1.2428169250488281, "exec_time": 0.0007181167602539062, "error": null, "gold_error": null, "pred_rows": "[('Joe Sharp', 'You', '1992')]", "gold_rows": "[('Love', '2016')]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": false, "safe_check_failed": false}
|
| 9 |
-
{"db_id": "concert_singer", "question": "What are the names and release years for all the songs of the youngest singer?", "gold_sql": "SELECT song_name , song_release_year FROM singer ORDER BY age LIMIT 1", "pred_sql": "SELECT Song_Name, Song_release_year FROM singer WHERE Age = (SELECT MAX(Age) FROM singer)", "status": "Returned 1 row(s). Showing up to 50.", "output": "[\n {\n \"Song_Name\": \"You\",\n \"Song_release_year\": \"1992\"\n }\n]", "gen_time": 1.4568238258361816, "exec_time": 0.0009098052978515625, "error": null, "gold_error": null, "pred_rows": "[('You', '1992')]", "gold_rows": "[('Love', '2016')]", "exact_match": false, "exact_match_structural": false, "execution_accuracy": false, "safe_check_failed": false}
|
| 10 |
-
{"db_id": "concert_singer", "question": "What are all distinct countries where singers above age 20 are from?", "gold_sql": "SELECT DISTINCT country FROM singer WHERE age > 20", "pred_sql": "SELECT DISTINCT Country FROM singer WHERE Age > 20", "status": "Returned 3 row(s). Showing up to 50.", "output": "[\n {\n \"Country\": \"Netherlands\"\n },\n {\n \"Country\": \"United States\"\n },\n {\n \"Country\": \"France\"\n }\n]", "gen_time": 0.9801719188690186, "exec_time": 0.0007050037384033203, "error": null, "gold_error": null, "pred_rows": "[('Netherlands',), ('United States',), ('France',)]", "gold_rows": "[('Netherlands',), ('United States',), ('France',)]", "exact_match": true, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
|
| 11 |
-
{"db_id": "concert_singer", "question": "What are the different countries with singers above age 20?", "gold_sql": "SELECT DISTINCT country FROM singer WHERE age > 20", "pred_sql": "SELECT DISTINCT Country FROM singer WHERE Age > 20", "status": "Returned 3 row(s). Showing up to 50.", "output": "[\n {\n \"Country\": \"Netherlands\"\n },\n {\n \"Country\": \"United States\"\n },\n {\n \"Country\": \"France\"\n }\n]", "gen_time": 1.3055250644683838, "exec_time": 0.0005917549133300781, "error": null, "gold_error": null, "pred_rows": "[('Netherlands',), ('United States',), ('France',)]", "gold_rows": "[('Netherlands',), ('United States',), ('France',)]", "exact_match": true, "exact_match_structural": false, "execution_accuracy": true, "safe_check_failed": false}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nl2sql/__init__.py
ADDED
|
File without changes
|
nl2sql/ambiguity_detector.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
class AmbiguityDetector:
|
| 5 |
+
"""Lightweight AmbiSQL-style ambiguity detection."""
|
| 6 |
+
|
| 7 |
+
AMBIGUOUS_TERMS = ["recent", "top", "name", "rank", "latest"]
|
| 8 |
+
|
| 9 |
+
def detect(self, query:str, schema_preview: str) -> list[str]:
|
| 10 |
+
hits = []
|
| 11 |
+
q_lower = query.lower()
|
| 12 |
+
for term in self.AMBIGUOUS_TERMS:
|
| 13 |
+
if re.search(rf"\b{term}\b", q_lower):
|
| 14 |
+
hits.append(f"The term '{term}' is ambiguous in this query.'")
|
| 15 |
+
|
| 16 |
+
return hits
|
nl2sql/executor.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from nl2sql.types import StageResult, StageTrace
|
| 3 |
+
from adapters.db.base import DBAdapter
|
| 4 |
+
|
| 5 |
+
class Executor:
|
| 6 |
+
name = "executor"
|
| 7 |
+
|
| 8 |
+
def __init__(self, db: DBAdapter):
|
| 9 |
+
self.db = db
|
| 10 |
+
|
| 11 |
+
def run(self, sql: str) -> StageResult:
|
| 12 |
+
t0 = time.perf_counter()
|
| 13 |
+
try:
|
| 14 |
+
rows, cols = self.db.execute(sql)
|
| 15 |
+
trace = StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000,
|
| 16 |
+
notes={"row_count": len(rows), "col_count": len(cols)})
|
| 17 |
+
return StageResult(ok=True, data={"rows": rows, "columns": cols}, trace=trace)
|
| 18 |
+
except Exception as e:
|
| 19 |
+
trace = StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000,
|
| 20 |
+
notes={"error": str(e)})
|
| 21 |
+
return StageResult(ok=False, data=None, trace=trace, error=[str(e)])
|
nl2sql/generator.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import time
|
| 3 |
+
from typing import Optional, Dict, Any
|
| 4 |
+
from nl2sql.types import StageResult, StageTrace
|
| 5 |
+
from adapters.llm.base import LLMProvider
|
| 6 |
+
|
| 7 |
+
class Generator:
|
| 8 |
+
name = "generator"
|
| 9 |
+
|
| 10 |
+
def __init__(self, llm: LLMProvider) -> None:
|
| 11 |
+
self.llm = llm
|
| 12 |
+
|
| 13 |
+
def run(self, *, user_query: str, schema_preview: str, plan_text: str,
|
| 14 |
+
clarify_answers: Optional[Dict[str, Any]] = None) -> StageResult:
|
| 15 |
+
t0 = time.perf_counter()
|
| 16 |
+
try:
|
| 17 |
+
res = self.llm.generate_sql(
|
| 18 |
+
user_query=user_query,
|
| 19 |
+
schema_preview=schema_preview,
|
| 20 |
+
plan_text=plan_text,
|
| 21 |
+
clarify_answers=clarify_answers or {}
|
| 22 |
+
)
|
| 23 |
+
except Exception as e:
|
| 24 |
+
return StageResult(ok=False, error=[f"Generator failed: {e}"])
|
| 25 |
+
|
| 26 |
+
# Expect a 5-tuple
|
| 27 |
+
if not isinstance(res, tuple) or len(res) != 5:
|
| 28 |
+
return StageResult(ok=False, error=["Generator contract violation: expected 5-tuple (sql, rationale, t_in, t_out, cost)"])
|
| 29 |
+
|
| 30 |
+
sql, rationale, t_in, t_out, cost = res
|
| 31 |
+
|
| 32 |
+
# Type/shape checks
|
| 33 |
+
if not isinstance(sql, str) or not sql.strip():
|
| 34 |
+
return StageResult(ok=False, error=["Generator produced empty or non-string SQL"])
|
| 35 |
+
if not sql.lower().lstrip().startswith("select"):
|
| 36 |
+
return StageResult(ok=False, error=[f"Generated non-SELECT SQL: {sql}"])
|
| 37 |
+
|
| 38 |
+
rationale = rationale or "" # safe length
|
| 39 |
+
trace = StageTrace(
|
| 40 |
+
stage=self.name,
|
| 41 |
+
duration_ms=(time.perf_counter() - t0) * 1000.0,
|
| 42 |
+
token_in=t_in,
|
| 43 |
+
token_out=t_out,
|
| 44 |
+
cost_usd=cost,
|
| 45 |
+
notes={"rationale_len": len(rationale)},
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
return StageResult(ok=True, data={"sql": sql, "rationale": rationale}, trace=trace)
|
| 49 |
+
|
nl2sql/pipeline.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import traceback
|
| 3 |
+
from typing import Dict, Any, Optional, List
|
| 4 |
+
from nl2sql.types import StageResult
|
| 5 |
+
from nl2sql.ambiguity_detector import AmbiguityDetector
|
| 6 |
+
from nl2sql.planner import Planner
|
| 7 |
+
from nl2sql.generator import Generator
|
| 8 |
+
from nl2sql.safety import Safety
|
| 9 |
+
from nl2sql.executor import Executor
|
| 10 |
+
from nl2sql.verifier import Verifier
|
| 11 |
+
from nl2sql.repair import Repair
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Pipeline:
|
| 15 |
+
"""
|
| 16 |
+
NL2SQL Copilot pipeline with guaranteed dict output.
|
| 17 |
+
All stages return structured traces and errors but final result is JSON-safe dict.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, *,
|
| 21 |
+
detector: AmbiguityDetector,
|
| 22 |
+
planner: Planner,
|
| 23 |
+
generator: Generator,
|
| 24 |
+
safety: Safety,
|
| 25 |
+
executor: Executor,
|
| 26 |
+
verifier: Verifier,
|
| 27 |
+
repair: Repair):
|
| 28 |
+
self.detector = detector
|
| 29 |
+
self.planner = planner
|
| 30 |
+
self.generator = generator
|
| 31 |
+
self.safety = safety
|
| 32 |
+
self.executor = executor
|
| 33 |
+
self.verifier = verifier
|
| 34 |
+
self.repair = repair
|
| 35 |
+
|
| 36 |
+
# ------------------------------------------------------------
|
| 37 |
+
def _trace_list(self, *stages: StageResult) -> List[dict]:
|
| 38 |
+
traces = []
|
| 39 |
+
for s in stages:
|
| 40 |
+
if not s:
|
| 41 |
+
continue
|
| 42 |
+
t = getattr(s, "trace", None)
|
| 43 |
+
if t:
|
| 44 |
+
traces.append(t.__dict__)
|
| 45 |
+
return traces
|
| 46 |
+
|
| 47 |
+
# ------------------------------------------------------------
|
| 48 |
+
def _safe_stage(self, fn, **kwargs) -> StageResult:
|
| 49 |
+
"""Run a stage safely; if it throws, catch and convert to StageResult."""
|
| 50 |
+
try:
|
| 51 |
+
r = fn(**kwargs)
|
| 52 |
+
if isinstance(r, StageResult):
|
| 53 |
+
return r
|
| 54 |
+
else:
|
| 55 |
+
# not ideal, but wrap it
|
| 56 |
+
return StageResult(ok=True, data=r, trace=None)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
tb = traceback.format_exc()
|
| 59 |
+
return StageResult(ok=False, data=None, trace=None, errors=[f"{e}", tb])
|
| 60 |
+
|
| 61 |
+
# ------------------------------------------------------------
|
| 62 |
+
def run(self, *, user_query: str, schema_preview: str,
|
| 63 |
+
clarify_answers: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
| 64 |
+
"""
|
| 65 |
+
Always returns:
|
| 66 |
+
{
|
| 67 |
+
"ambiguous": bool,
|
| 68 |
+
"error": bool,
|
| 69 |
+
"details": list[str] | None,
|
| 70 |
+
"sql": str | None,
|
| 71 |
+
"rationale": str | None,
|
| 72 |
+
"verified": bool | None,
|
| 73 |
+
"traces": list[dict]
|
| 74 |
+
}
|
| 75 |
+
"""
|
| 76 |
+
traces: List[dict] = []
|
| 77 |
+
details: List[str] = []
|
| 78 |
+
sql, rationale, verified = None, None, None
|
| 79 |
+
|
| 80 |
+
# --- 1) ambiguity detection
|
| 81 |
+
try:
|
| 82 |
+
questions = self.detector.detect(user_query, schema_preview)
|
| 83 |
+
if questions:
|
| 84 |
+
return {
|
| 85 |
+
"ambiguous": True,
|
| 86 |
+
"error": False,
|
| 87 |
+
"details": [f"Ambiguities found: {len(questions)}"],
|
| 88 |
+
"questions": questions,
|
| 89 |
+
"traces": []
|
| 90 |
+
}
|
| 91 |
+
except Exception as e:
|
| 92 |
+
return {"ambiguous": True, "error": True, "details": [f"Detector failed: {e}"], "traces": []}
|
| 93 |
+
|
| 94 |
+
# --- 2) planner
|
| 95 |
+
r_plan = self._safe_stage(self.planner.run, user_query=user_query, schema_preview=schema_preview)
|
| 96 |
+
traces.extend(self._trace_list(r_plan))
|
| 97 |
+
if not r_plan.ok:
|
| 98 |
+
return {"ambiguous": False, "error": True, "details": r_plan.errors, "traces": traces}
|
| 99 |
+
|
| 100 |
+
# --- 3) generator
|
| 101 |
+
r_gen = self._safe_stage(self.generator.run,
|
| 102 |
+
user_query=user_query,
|
| 103 |
+
schema_preview=schema_preview,
|
| 104 |
+
plan_text=r_plan.data.get("plan"),
|
| 105 |
+
clarify_answers=clarify_answers or {})
|
| 106 |
+
traces.extend(self._trace_list(r_gen))
|
| 107 |
+
if not r_gen.ok:
|
| 108 |
+
return {"ambiguous": False, "error": True, "details": r_gen.errors, "traces": traces}
|
| 109 |
+
sql = r_gen.data.get("sql")
|
| 110 |
+
rationale = r_gen.data.get("rationale")
|
| 111 |
+
|
| 112 |
+
# --- 4) safety
|
| 113 |
+
r_safe = self._safe_stage(self.safety.check, sql=sql)
|
| 114 |
+
traces.extend(self._trace_list(r_safe))
|
| 115 |
+
if not r_safe.ok:
|
| 116 |
+
return {"ambiguous": False, "error": True, "details": r_safe.errors, "traces": traces}
|
| 117 |
+
|
| 118 |
+
# --- 5) executor
|
| 119 |
+
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
|
| 120 |
+
traces.extend(self._trace_list(r_exec))
|
| 121 |
+
if not r_exec.ok:
|
| 122 |
+
details.extend(r_exec.errors or [])
|
| 123 |
+
|
| 124 |
+
# --- 6) verifier
|
| 125 |
+
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
|
| 126 |
+
traces.extend(self._trace_list(r_ver))
|
| 127 |
+
verified = bool(r_ver.ok)
|
| 128 |
+
|
| 129 |
+
# --- 7) repair loop if verification failed
|
| 130 |
+
if not verified:
|
| 131 |
+
for attempt in range(2):
|
| 132 |
+
r_fix = self._safe_stage(self.repair.run,
|
| 133 |
+
sql=sql,
|
| 134 |
+
error_msg="; ".join(details or ["unknown"]),
|
| 135 |
+
schema_preview=schema_preview)
|
| 136 |
+
traces.extend(self._trace_list(r_fix))
|
| 137 |
+
if not r_fix.ok:
|
| 138 |
+
break
|
| 139 |
+
sql = r_fix.data.get("sql")
|
| 140 |
+
r_safe = self._safe_stage(self.safety.check, sql=sql)
|
| 141 |
+
traces.extend(self._trace_list(r_safe))
|
| 142 |
+
if not r_safe.ok:
|
| 143 |
+
details.extend(r_safe.errors or [])
|
| 144 |
+
continue
|
| 145 |
+
r_exec = self._safe_stage(self.executor.run, sql=r_safe.data["sql"])
|
| 146 |
+
traces.extend(self._trace_list(r_exec))
|
| 147 |
+
if not r_exec.ok:
|
| 148 |
+
details.extend(r_exec.errors or [])
|
| 149 |
+
continue
|
| 150 |
+
r_ver = self._safe_stage(self.verifier.run, sql=sql, exec_result=r_exec)
|
| 151 |
+
traces.extend(self._trace_list(r_ver))
|
| 152 |
+
verified = bool(r_ver.ok)
|
| 153 |
+
if verified:
|
| 154 |
+
break
|
| 155 |
+
|
| 156 |
+
# --- Final result dict
|
| 157 |
+
return {
|
| 158 |
+
"ambiguous": False,
|
| 159 |
+
"error": len(details) > 0 and not verified,
|
| 160 |
+
"details": details or None,
|
| 161 |
+
"sql": sql,
|
| 162 |
+
"rationale": rationale,
|
| 163 |
+
"verified": verified,
|
| 164 |
+
"traces": traces,
|
| 165 |
+
}
|
nl2sql/planner.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import time
|
| 3 |
+
from nl2sql.types import StageResult, StageTrace
|
| 4 |
+
from adapters.llm.base import LLMProvider
|
| 5 |
+
|
| 6 |
+
class Planner:
|
| 7 |
+
name = "planner"
|
| 8 |
+
def __init__(self, llm: LLMProvider) -> None:
|
| 9 |
+
self.llm = llm
|
| 10 |
+
|
| 11 |
+
def run(self, *, user_query: str, schema_preview: str) -> StageResult:
|
| 12 |
+
t0 = time.perf_counter()
|
| 13 |
+
plan_text, t_in, t_out, cost = self.llm.plan(user_query=user_query, schema_preview=schema_preview)
|
| 14 |
+
trace = StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000,
|
| 15 |
+
token_in=t_in, token_out=t_out, cost_usd=cost, notes={"len_plan": len(plan_text)})
|
| 16 |
+
return StageResult(ok=True, data={"plan": plan_text}, trace=trace)
|
nl2sql/repair.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
from nl2sql.types import StageTrace, StageResult
|
| 5 |
+
from adapters.llm.base import LLMProvider
|
| 6 |
+
|
| 7 |
+
GUIDELINES = """
|
| 8 |
+
When repairing:
|
| 9 |
+
1. Keep query SELECT-only.
|
| 10 |
+
2. Explicitly qualify ambiguous columns with table names.
|
| 11 |
+
3. Match GROUP BY fields with aggregations.
|
| 12 |
+
4. Use known foreign keys for JOIN.
|
| 13 |
+
5. Add a reasonable LIMIT if missing.
|
| 14 |
+
Return only the corrected SQL.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
class Repair:
|
| 18 |
+
name = "repair"
|
| 19 |
+
def __init__(self, llm: LLMProvider):
|
| 20 |
+
self.llm = llm
|
| 21 |
+
|
| 22 |
+
def run(self, sql:str, error_msg: str, schema_preview: str) -> StageResult:
|
| 23 |
+
t0 = time.perf_counter()
|
| 24 |
+
fixed_sql, t_in, t_out, cost = self.llm.repair(sql=sql, error_msg=f"{GUIDELINES}\n\n{error_msg}",
|
| 25 |
+
schema_preview=schema_preview)
|
| 26 |
+
trace = StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000,
|
| 27 |
+
token_in=t_in, token_out=t_out, cost_usd=cost,
|
| 28 |
+
notes={"old_sql_len": len(sql), "new_sql_len": len(fixed_sql)})
|
| 29 |
+
return StageResult(ok=True, data={"sql": fixed_sql}, trace=trace)
|
nl2sql/safety.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import re, time
|
| 3 |
+
from nl2sql.types import StageResult, StageTrace
|
| 4 |
+
|
| 5 |
+
# --- Regex utils ---
|
| 6 |
+
_COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL)
|
| 7 |
+
_COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE)
|
| 8 |
+
# string literals (single & double quotes), allow escaped quotes
|
| 9 |
+
_STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
|
| 10 |
+
_STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)
|
| 11 |
+
|
| 12 |
+
# case-insensitive, word-boundary forbidden keywords
|
| 13 |
+
_FORBIDDEN = re.compile(
|
| 14 |
+
r"\b(delete|update|insert|drop|create|alter|attach|pragma|reindex|vacuum|replace|grant|revoke|execute)\b",
|
| 15 |
+
re.IGNORECASE,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# allow: SELECT ... or WITH <cte...> SELECT ...
|
| 19 |
+
_ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL)
|
| 20 |
+
|
| 21 |
+
def _strip_comments(s: str) -> str:
|
| 22 |
+
s = _COMMENT_BLOCK.sub(" ", s)
|
| 23 |
+
s = _COMMENT_LINE.sub(" ", s)
|
| 24 |
+
return s
|
| 25 |
+
|
| 26 |
+
def _mask_strings(s: str) -> str:
|
| 27 |
+
s = _STRING_SINGLE.sub("'X'", s)
|
| 28 |
+
s = _STRING_DOUBLE.sub('"X"', s)
|
| 29 |
+
return s
|
| 30 |
+
|
| 31 |
+
def _split_statements(s: str) -> list[str]:
|
| 32 |
+
parts = [p.strip() for p in s.split(";")]
|
| 33 |
+
return [p for p in parts if p]
|
| 34 |
+
|
| 35 |
+
class Safety:
|
| 36 |
+
name = "safety"
|
| 37 |
+
|
| 38 |
+
def check(self, sql: str) -> StageResult:
|
| 39 |
+
t0 = time.perf_counter()
|
| 40 |
+
print("🧩 SQL candidate:", sql)
|
| 41 |
+
s = _strip_comments(sql)
|
| 42 |
+
s = _mask_strings(s).strip()
|
| 43 |
+
|
| 44 |
+
stmts = _split_statements(s)
|
| 45 |
+
if len(stmts) != 1:
|
| 46 |
+
return StageResult(
|
| 47 |
+
ok=False,
|
| 48 |
+
error=["Multiple statements detected"],
|
| 49 |
+
trace=StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
body = stmts[0]
|
| 53 |
+
|
| 54 |
+
if _FORBIDDEN.search(body):
|
| 55 |
+
return StageResult(
|
| 56 |
+
ok=False,
|
| 57 |
+
error=["Forbidden keyword detected"],
|
| 58 |
+
trace=StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if not _ALLOW_SELECT.match(body):
|
| 62 |
+
return StageResult(
|
| 63 |
+
ok=False,
|
| 64 |
+
error=["Non-SELECT statement"],
|
| 65 |
+
trace=StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
return StageResult(
|
| 69 |
+
ok=True,
|
| 70 |
+
data={
|
| 71 |
+
"sql": sql.strip(),
|
| 72 |
+
"rationale": "Statement validated as SELECT-only (strings/comments ignored).",
|
| 73 |
+
},
|
| 74 |
+
trace=StageTrace(stage=self.name, duration_ms=(time.perf_counter()-t0)*1000),
|
| 75 |
+
)
|
nl2sql/stubs.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nl2sql.types import StageResult, StageTrace
|
| 2 |
+
|
| 3 |
+
class NoOpExecutor:
|
| 4 |
+
name = "executor"
|
| 5 |
+
def run(self, sql: str) -> StageResult:
|
| 6 |
+
# pretend success, return empty result set
|
| 7 |
+
return StageResult(
|
| 8 |
+
ok=True,
|
| 9 |
+
data={"rows": [], "columns": []},
|
| 10 |
+
trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True})
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
class NoOpVerifier:
|
| 14 |
+
name = "verifier"
|
| 15 |
+
def run(self, sql: str, exec_result: StageResult) -> StageResult:
|
| 16 |
+
# always verified for legacy tests
|
| 17 |
+
return StageResult(
|
| 18 |
+
ok=True,
|
| 19 |
+
data={"verified": True},
|
| 20 |
+
trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True})
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
class NoOpRepair:
|
| 24 |
+
name = "repair"
|
| 25 |
+
def run(self, sql: str, error_msg: str, schema_preview: str) -> StageResult:
|
| 26 |
+
# return original SQL unchanged
|
| 27 |
+
return StageResult(
|
| 28 |
+
ok=True,
|
| 29 |
+
data={"sql": sql},
|
| 30 |
+
trace=StageTrace(stage=self.name, duration_ms=0.0, notes={"noop": True})
|
| 31 |
+
)
|
nl2sql/types.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any, Dict, Optional, List
|
| 3 |
+
|
| 4 |
+
@dataclass(frozen=True)
|
| 5 |
+
class StageTrace:
|
| 6 |
+
stage: str
|
| 7 |
+
duration_ms: float
|
| 8 |
+
notes: Optional[Dict[str, Any]] = None
|
| 9 |
+
token_in: Optional[int] = None
|
| 10 |
+
token_out: Optional[int] = None
|
| 11 |
+
cost_usd: Optional[float] = None
|
| 12 |
+
|
| 13 |
+
@dataclass(frozen=True)
|
| 14 |
+
class StageResult:
|
| 15 |
+
ok: bool
|
| 16 |
+
data: Optional[Any] = None
|
| 17 |
+
trace: Optional[StageTrace] = None
|
| 18 |
+
error: Optional[List[str]] = None
|
| 19 |
+
notes: Optional[Dict[str, Any]] = None
|
nl2sql/verifier.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlglot
|
| 2 |
+
from sqlglot import expressions as exp
|
| 3 |
+
from nl2sql.types import StageResult, StageTrace
|
| 4 |
+
|
| 5 |
+
class Verifier:
|
| 6 |
+
name = "verifier"
|
| 7 |
+
|
| 8 |
+
def run(self, sql: str, exec_result: StageResult) -> StageResult:
|
| 9 |
+
if not exec_result.ok:
|
| 10 |
+
return StageResult(ok=False, data=None,
|
| 11 |
+
trace=StageTrace(stage=self.name, duration_ms=0,
|
| 12 |
+
notes={"reason": "execution_error"}),
|
| 13 |
+
error=exec_result.errors)
|
| 14 |
+
|
| 15 |
+
# Rule 1: check SELECT / GROUP consistency
|
| 16 |
+
issues = []
|
| 17 |
+
try:
|
| 18 |
+
tree = sqlglot.parse_one(sql)
|
| 19 |
+
if isinstance(tree, exp.Select):
|
| 20 |
+
group = tree.args.get("group")
|
| 21 |
+
aggs = [a for a in tree.find_all(exp.AggFunc)]
|
| 22 |
+
if aggs and not group:
|
| 23 |
+
issues.append("Aggregation without GROUP BY.")
|
| 24 |
+
except Exception as e:
|
| 25 |
+
issues.append(f"Parse error during verification: {e}")
|
| 26 |
+
|
| 27 |
+
if issues:
|
| 28 |
+
return StageResult(ok=False, data=None,
|
| 29 |
+
trace=StageTrace(stage=self.name, duration_ms=0,
|
| 30 |
+
notes={"issues": issues}),
|
| 31 |
+
error=issues)
|
| 32 |
+
return StageResult(ok=True, data={"verified": True},
|
| 33 |
+
trace=StageTrace(stage=self.name, duration_ms=0))
|
requirements.txt
CHANGED
|
@@ -1,8 +1,11 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.2
|
| 2 |
+
uvicorn[standard]==0.30.6
|
| 3 |
+
pydantic==2.9.2
|
| 4 |
+
sqlglot==27.26.0
|
| 5 |
+
requests==2.32.3
|
| 6 |
+
streamlit==1.39.0
|
| 7 |
+
plotly==5.24.1
|
| 8 |
+
pytest==8.3.3
|
| 9 |
+
python-dotenv==1.1.1
|
| 10 |
+
openai==2.6.1
|
| 11 |
+
psycopg[binary]~=3.2
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 5 |
+
ENV_PATH = os.path.join(ROOT_DIR, ".env")
|
| 6 |
+
|
| 7 |
+
load_dotenv(dotenv_path=ENV_PATH)
|
tests/test_ambiguity.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nl2sql.ambiguity_detector import AmbiguityDetector
|
| 2 |
+
from nl2sql.types import StageResult
|
| 3 |
+
from app.routers import nl2sql
|
| 4 |
+
|
| 5 |
+
def test_detects_ambiguous_terms():
|
| 6 |
+
det = AmbiguityDetector()
|
| 7 |
+
res = det.detect("Show me recent top singers", "table: singer(id,name,age)")
|
| 8 |
+
assert len(res) >= 1
|
| 9 |
+
assert "recent" in res[0].lower()
|
| 10 |
+
|
| 11 |
+
def test_not_false_positive():
|
| 12 |
+
det = AmbiguityDetector()
|
| 13 |
+
res = det.detect("List all singers older than 30", "table: singer(id, name, age)")
|
| 14 |
+
assert res == []
|
| 15 |
+
|
| 16 |
+
def test_ambiguity_response():
|
| 17 |
+
fake_result = StageResult(ok=True, data={"ambiguous": True, "questions": ["Clarify column?"]})
|
| 18 |
+
response = nl2sql._to_dict(fake_result.data)
|
| 19 |
+
assert response["ambiguous"] is True
|
tests/test_executor.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nl2sql.executor import Executor
|
| 2 |
+
from adapters.db.sqlite_adapter import SQLiteAdapter
|
| 3 |
+
|
| 4 |
+
def test_executor_runs_select(tmp_path):
|
| 5 |
+
db_path = tmp_path / "test.db"
|
| 6 |
+
import sqlite3
|
| 7 |
+
conn = sqlite3.connect(db_path)
|
| 8 |
+
conn.execute("CREATE TABLE users(id INT, name TEXT);")
|
| 9 |
+
conn.execute("INSERT INTO users VALUES (1, 'Alice');")
|
| 10 |
+
conn.commit()
|
| 11 |
+
conn.close()
|
| 12 |
+
|
| 13 |
+
ex = Executor(SQLiteAdapter(str(db_path)))
|
| 14 |
+
res = ex.run("SELECT * FROM users;")
|
| 15 |
+
assert res.ok
|
| 16 |
+
assert res.data["rows"][0][1] == "Alice"
|
tests/test_generator.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from nl2sql.generator import Generator
|
| 3 |
+
from nl2sql.types import StageResult
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# --- Dummy LLMs (respect the 5-tuple contract) --------------------------------
|
| 7 |
+
|
| 8 |
+
class LLM_OK:
|
| 9 |
+
def generate_sql(self, **kwargs):
|
| 10 |
+
# contract: (sql, rationale, t_in, t_out, cost)
|
| 11 |
+
return "SELECT * FROM singer;", "list all", 10, 5, 0.00001
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LLM_EMPTY_SQL:
|
| 15 |
+
def generate_sql(self, **kwargs):
|
| 16 |
+
# empty SQL → should be error
|
| 17 |
+
return "", "reason", 10, 5, 0.0
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LLM_NON_SELECT:
|
| 21 |
+
def generate_sql(self, **kwargs):
|
| 22 |
+
# non-SELECT SQL → should be error
|
| 23 |
+
return "UPDATE users SET name='x' WHERE id=1;", "bad", 8, 3, 0.0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class LLM_CONTRACT_NONE:
|
| 27 |
+
def generate_sql(self, **kwargs):
|
| 28 |
+
# contract violation: None instead of 5-tuple
|
| 29 |
+
return None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class LLM_CONTRACT_SHORT:
|
| 33 |
+
def generate_sql(self, **kwargs):
|
| 34 |
+
# contract violation: too few items
|
| 35 |
+
return ("SELECT * FROM singer;", "list all") # only 2
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# --- Parametrized negative cases ----------------------------------------------
|
| 39 |
+
|
| 40 |
+
@pytest.mark.parametrize(
|
| 41 |
+
"llm, err_keyword",
|
| 42 |
+
[
|
| 43 |
+
(LLM_EMPTY_SQL(), "empty"), # empty or non-string sql
|
| 44 |
+
(LLM_NON_SELECT(), "non-select"), # generated non-SELECT
|
| 45 |
+
(LLM_CONTRACT_NONE(), "contract violation"),
|
| 46 |
+
(LLM_CONTRACT_SHORT(), "contract violation"),
|
| 47 |
+
],
|
| 48 |
+
)
|
| 49 |
+
def test_generator_errors_do_not_create_trace(llm, err_keyword):
|
| 50 |
+
gen = Generator(llm=llm)
|
| 51 |
+
r = gen.run(
|
| 52 |
+
user_query="show all singers",
|
| 53 |
+
schema_preview="CREATE TABLE singer(id int, name text);",
|
| 54 |
+
plan_text="-- plan --",
|
| 55 |
+
clarify_answers={}
|
| 56 |
+
)
|
| 57 |
+
assert isinstance(r, StageResult)
|
| 58 |
+
assert r.ok is False
|
| 59 |
+
# Error message is flexible; just check a keyword
|
| 60 |
+
joined = " ".join(r.error or []).lower()
|
| 61 |
+
assert err_keyword in joined
|
| 62 |
+
# On errors, Generator should not attach a trace (we measure only successful stage)
|
| 63 |
+
assert r.trace is None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# --- Positive case (success) ---------------------------------------------------
|
| 67 |
+
|
| 68 |
+
def test_generator_success_has_valid_trace_and_data():
|
| 69 |
+
gen = Generator(llm=LLM_OK())
|
| 70 |
+
r = gen.run(
|
| 71 |
+
user_query="show all singers",
|
| 72 |
+
schema_preview="CREATE TABLE singer(id int, name text);",
|
| 73 |
+
plan_text="-- plan --",
|
| 74 |
+
clarify_answers={}
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Basic success checks
|
| 78 |
+
assert isinstance(r, StageResult)
|
| 79 |
+
assert r.ok is True
|
| 80 |
+
assert r.data and r.data["sql"].lower().startswith("select")
|
| 81 |
+
assert "rationale" in r.data
|
| 82 |
+
|
| 83 |
+
# Trace should exist and be coherent
|
| 84 |
+
assert r.trace is not None
|
| 85 |
+
assert r.trace.stage == "generator"
|
| 86 |
+
assert isinstance(r.trace.duration_ms, float)
|
| 87 |
+
assert r.trace.token_in == 10
|
| 88 |
+
assert r.trace.token_out == 5
|
| 89 |
+
# cost can be float or None depending on provider; if present must be numeric
|
| 90 |
+
if r.trace.cost_usd is not None:
|
| 91 |
+
assert isinstance(r.trace.cost_usd, float)
|
| 92 |
+
|
| 93 |
+
# Optional notes check – rationale_len should match length of rationale
|
| 94 |
+
notes = r.trace.notes or {}
|
| 95 |
+
if "rationale_len" in notes:
|
| 96 |
+
assert notes["rationale_len"] == len(r.data.get("rationale", ""))
|
tests/test_nl2sql_router.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from fastapi.testclient import TestClient
|
| 3 |
+
from app.main import app
|
| 4 |
+
from nl2sql.types import StageResult, StageTrace
|
| 5 |
+
|
| 6 |
+
client = TestClient(app)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def fake_trace(stage: str):
|
| 10 |
+
return StageTrace(stage=stage, duration_ms=10.0)
|
| 11 |
+
|
| 12 |
+
path = app.url_path_for("nl2sql_handler")
|
| 13 |
+
|
| 14 |
+
# --- 1) Clarify / ambiguity case ---------------------------------------------
|
| 15 |
+
def test_ambiguity_route(monkeypatch):
|
| 16 |
+
from app.routers import nl2sql
|
| 17 |
+
|
| 18 |
+
# mock pipeline to return StageResult with ambiguous=True
|
| 19 |
+
def fake_run(*args, **kwargs):
|
| 20 |
+
return StageResult(
|
| 21 |
+
ok=True,
|
| 22 |
+
data={
|
| 23 |
+
"ambiguous": True,
|
| 24 |
+
"questions": ["Which table do you mean?"],
|
| 25 |
+
"traces": [fake_trace("detector")],
|
| 26 |
+
},
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
| 30 |
+
|
| 31 |
+
resp = client.post(
|
| 32 |
+
path,
|
| 33 |
+
json={
|
| 34 |
+
"query": "show all records",
|
| 35 |
+
"schema_preview": "CREATE TABLE ...",
|
| 36 |
+
},
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
assert resp.status_code == 200
|
| 40 |
+
data = resp.json()
|
| 41 |
+
assert data["ambiguous"] is True
|
| 42 |
+
assert "questions" in data
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# --- 2) Error / failure case -------------------------------------------------
|
| 46 |
+
def test_error_route(monkeypatch):
|
| 47 |
+
from app.routers import nl2sql
|
| 48 |
+
|
| 49 |
+
def fake_run(*args, **kwargs):
|
| 50 |
+
return StageResult(ok=False, error=["Bad SQL"], data={"traces": [fake_trace("safety")]})
|
| 51 |
+
|
| 52 |
+
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
| 53 |
+
|
| 54 |
+
resp = client.post(
|
| 55 |
+
path,
|
| 56 |
+
json={
|
| 57 |
+
"query": "drop table users;",
|
| 58 |
+
"schema_preview": "CREATE TABLE users(id int);",
|
| 59 |
+
},
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
assert resp.status_code == 400
|
| 63 |
+
assert "Bad SQL" in resp.json()["detail"]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# --- 3) Success / happy path -------------------------------------------------
|
| 67 |
+
def test_success_route(monkeypatch):
|
| 68 |
+
from app.routers import nl2sql
|
| 69 |
+
|
| 70 |
+
def fake_run(*args, **kwargs):
|
| 71 |
+
return StageResult(
|
| 72 |
+
ok=True,
|
| 73 |
+
data={
|
| 74 |
+
"ambiguous": False,
|
| 75 |
+
"sql": "SELECT * FROM users;",
|
| 76 |
+
"rationale": "Simple listing",
|
| 77 |
+
"traces": [fake_trace("planner"), fake_trace("generator")],
|
| 78 |
+
},
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
monkeypatch.setattr(nl2sql._pipeline, "run", fake_run)
|
| 82 |
+
|
| 83 |
+
resp = client.post(
|
| 84 |
+
path,
|
| 85 |
+
json={
|
| 86 |
+
"query": "show all users",
|
| 87 |
+
"schema_preview": "CREATE TABLE users(id int, name text);",
|
| 88 |
+
},
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
assert resp.status_code == 200
|
| 92 |
+
data = resp.json()
|
| 93 |
+
assert data["sql"].lower().startswith("select")
|
| 94 |
+
assert isinstance(data["traces"], list)
|
| 95 |
+
assert any(t["stage"] == "planner" for t in data["traces"])
|
tests/test_openai_provider.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pytest
|
| 3 |
+
from adapters.llm.openai_provider import OpenAIProvider
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# Helper class to fake the completion object returned by OpenAI SDK
|
| 7 |
+
class FakeCompletion:
|
| 8 |
+
def __init__(self, content: str, prompt_tokens=5, completion_tokens=7):
|
| 9 |
+
self.choices = [type("Choice", (), {"message": type("Msg", (), {"content": content})})]
|
| 10 |
+
self.usage = type("Usage", (), {
|
| 11 |
+
"prompt_tokens": prompt_tokens,
|
| 12 |
+
"completion_tokens": completion_tokens
|
| 13 |
+
})
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# --- Case 1: clean valid JSON --------------------------------------------------
|
| 17 |
+
def test_generate_sql_valid_json(monkeypatch):
|
| 18 |
+
provider = OpenAIProvider()
|
| 19 |
+
|
| 20 |
+
fake_content = json.dumps({
|
| 21 |
+
"sql": "SELECT * FROM singer;",
|
| 22 |
+
"rationale": "List all singers."
|
| 23 |
+
})
|
| 24 |
+
fake_completion = FakeCompletion(fake_content)
|
| 25 |
+
|
| 26 |
+
# Monkeypatch client.chat.completions.create
|
| 27 |
+
def fake_create(*args, **kwargs):
|
| 28 |
+
return fake_completion
|
| 29 |
+
|
| 30 |
+
monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)
|
| 31 |
+
|
| 32 |
+
sql, rationale, t_in, t_out, cost = provider.generate_sql(
|
| 33 |
+
user_query="show all singers",
|
| 34 |
+
schema_preview="CREATE TABLE singer(id int, name text);",
|
| 35 |
+
plan_text="-- plan --",
|
| 36 |
+
clarify_answers={}
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
assert sql.strip().lower().startswith("select")
|
| 40 |
+
assert "singer" in sql.lower()
|
| 41 |
+
assert "list" in rationale.lower()
|
| 42 |
+
assert t_in == 5 and t_out == 7
|
| 43 |
+
assert isinstance(cost, float)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# --- Case 2: malformed JSON with extra text (should still recover) ------------
|
| 47 |
+
def test_generate_sql_recover_from_partial_json(monkeypatch):
|
| 48 |
+
provider = OpenAIProvider()
|
| 49 |
+
|
| 50 |
+
# invalid JSON with text around it
|
| 51 |
+
fake_content = "Here is the result:\n{ \"sql\": \"SELECT * FROM users;\", \"rationale\": \"list users\" }\nThanks!"
|
| 52 |
+
fake_completion = FakeCompletion(fake_content)
|
| 53 |
+
|
| 54 |
+
def fake_create(*args, **kwargs):
|
| 55 |
+
return fake_completion
|
| 56 |
+
|
| 57 |
+
monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)
|
| 58 |
+
|
| 59 |
+
sql, rationale, *_ = provider.generate_sql(
|
| 60 |
+
user_query="show all users",
|
| 61 |
+
schema_preview="CREATE TABLE users(id int, name text);",
|
| 62 |
+
plan_text="-- plan --"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
assert sql.lower().startswith("select")
|
| 66 |
+
assert "user" in sql.lower()
|
| 67 |
+
assert "list" in rationale.lower()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# --- Case 3: completely invalid JSON (should raise ValueError) ----------------
|
| 71 |
+
def test_generate_sql_invalid_json(monkeypatch):
|
| 72 |
+
provider = OpenAIProvider()
|
| 73 |
+
|
| 74 |
+
fake_content = "This is nonsense output without braces"
|
| 75 |
+
fake_completion = FakeCompletion(fake_content)
|
| 76 |
+
|
| 77 |
+
def fake_create(*args, **kwargs):
|
| 78 |
+
return fake_completion
|
| 79 |
+
|
| 80 |
+
monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)
|
| 81 |
+
|
| 82 |
+
with pytest.raises(ValueError):
|
| 83 |
+
provider.generate_sql(
|
| 84 |
+
user_query="show X",
|
| 85 |
+
schema_preview="CREATE TABLE t(id int);",
|
| 86 |
+
plan_text="-- plan --"
|
| 87 |
+
)
|
tests/test_pipeline_integration.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from nl2sql.pipeline import Pipeline
|
| 3 |
+
from nl2sql.types import StageResult, StageTrace
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# --- Dummy stages to isolate pipeline -----------------------------------------
|
| 7 |
+
|
| 8 |
+
class DummyDetector:
|
| 9 |
+
"""Simulates ambiguity detector stage."""
|
| 10 |
+
def __init__(self, ambiguous=False):
|
| 11 |
+
self.ambiguous = ambiguous
|
| 12 |
+
|
| 13 |
+
def detect(self, user_query, schema_preview):
|
| 14 |
+
# If ambiguous=True, return clarification questions
|
| 15 |
+
return ["Which column?"] if self.ambiguous else []
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DummyPlanner:
|
| 19 |
+
"""Simulates planner stage."""
|
| 20 |
+
def run(self, *, user_query, schema_preview):
|
| 21 |
+
trace = StageTrace(stage="planner", duration_ms=1.0)
|
| 22 |
+
if "fail_plan" in user_query:
|
| 23 |
+
return StageResult(ok=False, error=["Planner failed"], trace=trace)
|
| 24 |
+
return StageResult(ok=True, data={"plan": "plan text"}, trace=trace)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DummyGenerator:
|
| 28 |
+
"""Simulates generator stage."""
|
| 29 |
+
def run(self, *, user_query, schema_preview, plan_text, clarify_answers):
|
| 30 |
+
trace = StageTrace(stage="generator", duration_ms=1.0)
|
| 31 |
+
if "fail_gen" in user_query:
|
| 32 |
+
return StageResult(ok=False, error=["Generator failed"], trace=trace)
|
| 33 |
+
sql = "SELECT * FROM singer;"
|
| 34 |
+
rationale = "List all singers."
|
| 35 |
+
return StageResult(ok=True, data={"sql": sql, "rationale": rationale}, trace=trace)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class DummySafety:
|
| 39 |
+
"""Simulates safety stage."""
|
| 40 |
+
def check(self, sql):
|
| 41 |
+
trace = StageTrace(stage="safety", duration_ms=1.0)
|
| 42 |
+
if "DROP" in sql.upper():
|
| 43 |
+
return StageResult(ok=False, error=["Unsafe SQL"], trace=trace)
|
| 44 |
+
return StageResult(ok=True, data={"sql": sql, "rationale": "safe"}, trace=trace)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# --- 1) Success path ----------------------------------------------------------
|
| 48 |
+
def test_pipeline_success():
|
| 49 |
+
pipeline = Pipeline(
|
| 50 |
+
detector=DummyDetector(ambiguous=False),
|
| 51 |
+
planner=DummyPlanner(),
|
| 52 |
+
generator=DummyGenerator(),
|
| 53 |
+
safety=DummySafety()
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
r = pipeline.run(
|
| 57 |
+
user_query="show all singers",
|
| 58 |
+
schema_preview="CREATE TABLE singer(id int, name text);"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
assert isinstance(r, StageResult)
|
| 62 |
+
assert r.ok is True
|
| 63 |
+
data = r.data or {}
|
| 64 |
+
assert data["sql"].lower().startswith("select")
|
| 65 |
+
assert any(t.stage == "planner" for t in data["traces"])
|
| 66 |
+
assert any(t.stage == "generator" for t in data["traces"])
|
| 67 |
+
assert any(t.stage == "safety" for t in data["traces"])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# --- 2) Ambiguity case --------------------------------------------------------
|
| 71 |
+
def test_pipeline_ambiguity():
|
| 72 |
+
pipeline = Pipeline(
|
| 73 |
+
detector=DummyDetector(ambiguous=True),
|
| 74 |
+
planner=DummyPlanner(),
|
| 75 |
+
generator=DummyGenerator(),
|
| 76 |
+
safety=DummySafety()
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
r = pipeline.run(
|
| 80 |
+
user_query="show data",
|
| 81 |
+
schema_preview="CREATE TABLE x(id int);"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
assert isinstance(r, StageResult)
|
| 85 |
+
assert r.ok is True
|
| 86 |
+
assert r.data["ambiguous"] is True
|
| 87 |
+
assert isinstance(r.data["questions"], list)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# --- 3) Planner failure -------------------------------------------------------
|
| 91 |
+
def test_pipeline_plan_fail():
|
| 92 |
+
pipeline = Pipeline(
|
| 93 |
+
detector=DummyDetector(),
|
| 94 |
+
planner=DummyPlanner(),
|
| 95 |
+
generator=DummyGenerator(),
|
| 96 |
+
safety=DummySafety()
|
| 97 |
+
)
|
| 98 |
+
r = pipeline.run(
|
| 99 |
+
user_query="fail_plan",
|
| 100 |
+
schema_preview="CREATE TABLE singer(id int);"
|
| 101 |
+
)
|
| 102 |
+
assert isinstance(r, StageResult)
|
| 103 |
+
assert r.ok is False
|
| 104 |
+
assert "Planner failed" in " ".join(r.error or [])
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# --- 4) Generator failure -----------------------------------------------------
|
| 108 |
+
def test_pipeline_gen_fail():
|
| 109 |
+
pipeline = Pipeline(
|
| 110 |
+
detector=DummyDetector(),
|
| 111 |
+
planner=DummyPlanner(),
|
| 112 |
+
generator=DummyGenerator(),
|
| 113 |
+
safety=DummySafety()
|
| 114 |
+
)
|
| 115 |
+
r = pipeline.run(
|
| 116 |
+
user_query="fail_gen",
|
| 117 |
+
schema_preview="CREATE TABLE singer(id int);"
|
| 118 |
+
)
|
| 119 |
+
assert r.ok is False
|
| 120 |
+
assert "Generator failed" in " ".join(r.error or [])
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# --- 5) Safety failure --------------------------------------------------------
|
| 124 |
+
def test_pipeline_safety_fail():
|
| 125 |
+
class UnsafeGen(DummyGenerator):
|
| 126 |
+
def run(self, **kw):
|
| 127 |
+
trace = StageTrace(stage="generator", duration_ms=1.0)
|
| 128 |
+
# Generate a DROP TABLE → unsafe
|
| 129 |
+
return StageResult(ok=True, data={"sql": "DROP TABLE x;", "rationale": "oops"}, trace=trace)
|
| 130 |
+
|
| 131 |
+
pipeline = Pipeline(
|
| 132 |
+
detector=DummyDetector(),
|
| 133 |
+
planner=DummyPlanner(),
|
| 134 |
+
generator=UnsafeGen(),
|
| 135 |
+
safety=DummySafety()
|
| 136 |
+
)
|
| 137 |
+
r = pipeline.run(
|
| 138 |
+
user_query="drop something",
|
| 139 |
+
schema_preview="CREATE TABLE x(id int);"
|
| 140 |
+
)
|
| 141 |
+
assert r.ok is False
|
| 142 |
+
assert "unsafe" in " ".join(r.error or []).lower()
|
tests/test_safety.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nl2sql.safety import Safety
|
| 2 |
+
import pytest
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_safety_allows_select():
|
| 7 |
+
s = Safety()
|
| 8 |
+
result = s.check("SELECT * FROM users;")
|
| 9 |
+
assert result.ok
|
| 10 |
+
assert "sql" in result.data
|
| 11 |
+
assert result.trace.stage == "safety"
|
| 12 |
+
|
| 13 |
+
def test_safety_allows_with_select_cte():
|
| 14 |
+
s = Safety()
|
| 15 |
+
sql = """
|
| 16 |
+
WITH recent AS (
|
| 17 |
+
SELECT id FROM users WHERE created_at > '2024-01-01'
|
| 18 |
+
)
|
| 19 |
+
SELECT * FROM users u JOIN recent r ON u.id = r.id;
|
| 20 |
+
"""
|
| 21 |
+
r = s.check(sql)
|
| 22 |
+
assert r.ok
|
| 23 |
+
|
| 24 |
+
def test_safety_allows_select_with_comments_and_newlines():
|
| 25 |
+
s = Safety()
|
| 26 |
+
sql = "/* head */ \n -- inline\n SELECT 1; -- tail"
|
| 27 |
+
r = s.check(sql)
|
| 28 |
+
assert r.ok
|
| 29 |
+
|
| 30 |
+
def test_safety_allows_keywords_inside_string_literals():
|
| 31 |
+
s = Safety()
|
| 32 |
+
sql = "SELECT 'DROP TABLE x' as note, 'delete from y' as text;"
|
| 33 |
+
r = s.check(sql)
|
| 34 |
+
assert r.ok, r.error
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_safety_blocks_delete():
|
| 38 |
+
s = Safety()
|
| 39 |
+
result = s.check("DELETE FROM users;")
|
| 40 |
+
assert not result.ok
|
| 41 |
+
assert any("Forbidden" in e or "Non-SELECT" in e for e in (result.error or []))
|
| 42 |
+
|
| 43 |
+
@pytest.mark.parametrize("sql", [
|
| 44 |
+
"UPDATE users SET name='X' WHERE id=1;",
|
| 45 |
+
"INSERT INTO users(id) VALUES (1);",
|
| 46 |
+
"DROP TABLE users;",
|
| 47 |
+
"CREATE TABLE x(id INT);",
|
| 48 |
+
"ALTER TABLE users ADD COLUMN x INT;",
|
| 49 |
+
"ATTACH DATABASE 'hack.db' AS h;",
|
| 50 |
+
"PRAGMA journal_mode=WAL;",
|
| 51 |
+
])
|
| 52 |
+
def test_safety_blocks_forbidden_statements(sql):
|
| 53 |
+
s = Safety()
|
| 54 |
+
res = s.check(sql)
|
| 55 |
+
assert not res.ok
|
| 56 |
+
|
| 57 |
+
def test_safety_blocks_stacked_delete_after_select():
|
| 58 |
+
s = Safety()
|
| 59 |
+
sql = "SELECT * FROM users; DELETE FROM users;"
|
| 60 |
+
r = s.check(sql)
|
| 61 |
+
assert not r.ok
|
| 62 |
+
|
| 63 |
+
def test_safety_blocks_stacked_delete_with_spaces():
|
| 64 |
+
s = Safety()
|
| 65 |
+
sql = "SELECT * FROM users ; \n DELETE users;"
|
| 66 |
+
r = s.check(sql)
|
| 67 |
+
assert not r.ok
|
| 68 |
+
|
| 69 |
+
def test_safety_blocks_delete_inside_cte():
|
| 70 |
+
s = Safety()
|
| 71 |
+
sql = """
|
| 72 |
+
WITH bad AS (DELETE FROM users)
|
| 73 |
+
SELECT * FROM users;
|
| 74 |
+
"""
|
| 75 |
+
r = s.check(sql)
|
| 76 |
+
assert not r.ok
|
| 77 |
+
|
| 78 |
+
@pytest.mark.parametrize("sql", [
|
| 79 |
+
"/*D*/ROP TABLE users;",
|
| 80 |
+
"PR/*x*/AGMA journal_mode=WAL;",
|
| 81 |
+
"AL/* comment */TER TABLE x ADD COLUMN y INT;",
|
| 82 |
+
])
|
| 83 |
+
def test_safety_blocks_comment_obfuscation(sql):
|
| 84 |
+
s = Safety()
|
| 85 |
+
r = s.check(sql)
|
| 86 |
+
assert not r.ok
|
| 87 |
+
|
| 88 |
+
@pytest.mark.parametrize("sql", [
|
| 89 |
+
"pragma journal_mode=WAL;", # lower-case
|
| 90 |
+
" PRAGMA user_version = 5 ; ",
|
| 91 |
+
"\nATTACH DATABASE 'hack.db' AS h;",
|
| 92 |
+
])
|
| 93 |
+
def test_safety_blocks_forbidden_case_and_spacing(sql):
|
| 94 |
+
s = Safety()
|
| 95 |
+
r = s.check(sql)
|
| 96 |
+
assert not r.ok
|
| 97 |
+
|
| 98 |
+
def test_safety_blocks_multiple_nonempty_statements_even_if_second_is_comment():
|
| 99 |
+
s = Safety()
|
| 100 |
+
sql = "SELECT 1; -- now do something bad\n"
|
| 101 |
+
sql_bad = "SELECT 1; /* spacer */ DROP TABLE x;"
|
| 102 |
+
assert s.check(sql).ok
|
| 103 |
+
assert not s.check(sql_bad).ok
|
tests/test_stage_types.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nl2sql.types import StageResult, StageTrace
|
| 2 |
+
|
| 3 |
+
def test_error_response():
|
| 4 |
+
r = StageResult(ok=False, error=["Syntax error"])
|
| 5 |
+
assert not r.ok
|
| 6 |
+
assert r.error == ["Syntax error"]
|
| 7 |
+
|
| 8 |
+
def test_trace_dataclass_structure():
|
| 9 |
+
t = StageTrace(stage="planner", duration_ms=12.5, token_in=10, token_out=20)
|
| 10 |
+
assert t.stage == "planner"
|
| 11 |
+
assert isinstance(t.duration_ms, float)
|
| 12 |
+
assert t.token_out == 20
|
| 13 |
+
|
| 14 |
+
def test_stage_result_defaults():
|
| 15 |
+
r = StageResult(ok=True)
|
| 16 |
+
assert r.ok
|
| 17 |
+
assert r.data is None
|
| 18 |
+
assert r.error is None
|
ui/benchmark_app.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import plotly.express as px
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
st.set_page_config(page_title="NL2SQL Benchmark Dashboard", layout="wide")
|
| 8 |
+
|
| 9 |
+
st.title("📊 NL2SQL Copilot – Benchmark Dashboard")
|
| 10 |
+
|
| 11 |
+
# 1. Load results
|
| 12 |
+
result_files = list(Path("benchmarks/results").glob("*.jsonl"))
|
| 13 |
+
if not result_files:
|
| 14 |
+
st.warning("No benchmark result files found in benchmarks/results/")
|
| 15 |
+
st.stop()
|
| 16 |
+
|
| 17 |
+
file = st.selectbox("Select benchmark file", result_files)
|
| 18 |
+
rows = [json.loads(l) for l in open(file)]
|
| 19 |
+
df = pd.DataFrame(rows)
|
| 20 |
+
|
| 21 |
+
# 2. Summary metrics
|
| 22 |
+
st.subheader("Aggregate Metrics")
|
| 23 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 24 |
+
col1.metric("Total Queries", len(df))
|
| 25 |
+
col2.metric("Execution Accuracy", f"{df['exec_acc'].mean()*100:.1f}%")
|
| 26 |
+
col3.metric("Safety Violations", f"{df['safe_fail'].mean()*100:.1f}%")
|
| 27 |
+
col4.metric("Average Latency (ms)", f"{df['latency_ms'].mean():.0f}")
|
| 28 |
+
|
| 29 |
+
# 3. Latency Distribution
|
| 30 |
+
st.subheader("Latency Distribution")
|
| 31 |
+
fig1 = px.histogram(df, x="latency_ms", nbins=30, title="Latency Histogram")
|
| 32 |
+
st.plotly_chart(fig1, use_container_width=True)
|
| 33 |
+
|
| 34 |
+
# 4. Cost vs Accuracy
|
| 35 |
+
st.subheader("Cost vs Execution Accuracy")
|
| 36 |
+
fig2 = px.scatter(df, x="cost_usd", y="exec_acc", color="provider",
|
| 37 |
+
title="Trade-off: Cost vs Accuracy", hover_data=["query"])
|
| 38 |
+
st.plotly_chart(fig2, use_container_width=True)
|
| 39 |
+
|
| 40 |
+
# 5. Repair Stats
|
| 41 |
+
if "repair_attempts" in df.columns:
|
| 42 |
+
st.subheader("Repair Attempts")
|
| 43 |
+
fig3 = px.bar(df.groupby("repair_attempts").size().reset_index(name="count"),
|
| 44 |
+
x="repair_attempts", y="count", title="Number of Repair Attempts per Query")
|
| 45 |
+
st.plotly_chart(fig3, use_container_width=True)
|