Spaces:
Sleeping
Sleeping
| """ | |
| data_factory/validator.py | |
| ========================== | |
| SQL execution validation layer. | |
| GUARANTEE: Every record that passes this validator has a SQL that: | |
| 1. Runs without error against the actual seeded SQLite schema | |
| 2. Returns at least one row (non-empty result) | |
| 3. Returns the expected column names | |
| No LLM-generated SQL ever reaches this validator β SQL always comes from | |
| the human-verified template library. This validator is an extra safety net | |
| to catch any copy-paste or formatting regressions. | |
| """ | |
| from __future__ import annotations | |
| import sqlite3 | |
| from dataclasses import dataclass, field | |
| from typing import Any, Optional | |
| from data_factory.schemas import build_connection, SCHEMA_CONTEXT | |
| from data_factory.templates import Template | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DATA CLASSES | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ValidationResult: | |
| passed: bool | |
| sql: str | |
| error: Optional[str] = None | |
| row_count: int = 0 | |
| columns: list[str] = field(default_factory=list) | |
| class DataRecord: | |
| """One training example ready to be written to JSONL/Parquet.""" | |
| domain: str | |
| difficulty: str | |
| sql: str | |
| nl_question: str # The NL paraphrase used as prompt | |
| persona: str # ceo | chatty | lazy_typist | non_techie | analyst | augmented | |
| has_order: bool | |
| schema_context: str | |
| row_count: int # From validation run | |
| columns: list[str] # From validation run | |
| source: str # "template_base" | "vllm_persona" | "rule_augmented" | |
| template_id: int # Index into ALL_TEMPLATES | |
| def to_training_dict(self) -> dict[str, Any]: | |
| """ | |
| Returns the dictionary that will be written to the output dataset. | |
| Format is compatible with TRL / HuggingFace `datasets`: | |
| prompt : chat-format messages list (system + user) | |
| sql : ground-truth SQL (label / reward reference) | |
| metadata: auxiliary fields for curriculum or filtering | |
| """ | |
| system_msg = ( | |
| "You are an expert SQL analyst. " | |
| "Write a single SELECT query that answers the question. " | |
| "Output ONLY the SQL query β no markdown, no explanation, no backticks." | |
| ) | |
| user_msg = ( | |
| f"DATABASE SCHEMA\n" | |
| f"---------------\n" | |
| f"{self.schema_context}\n\n" | |
| f"QUESTION: {self.nl_question}" | |
| ) | |
| return { | |
| "prompt": [ | |
| {"role": "system", "content": system_msg}, | |
| {"role": "user", "content": user_msg}, | |
| ], | |
| "sql": self.sql, | |
| "metadata": { | |
| "domain": self.domain, | |
| "difficulty": self.difficulty, | |
| "persona": self.persona, | |
| "has_order": self.has_order, | |
| "row_count": self.row_count, | |
| "columns": self.columns, | |
| "source": self.source, | |
| "template_id": self.template_id, | |
| }, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # VALIDATOR | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SQLValidator: | |
| """ | |
| Validates SQL against a seeded in-memory SQLite connection. | |
| One validator per domain to reuse the same connection for all templates | |
| in that domain (performance optimization). | |
| """ | |
| def __init__(self, domain: str, seed: int = 42) -> None: | |
| self.domain = domain | |
| self._conn = build_connection(domain, seed=seed) | |
| def validate(self, sql: str) -> ValidationResult: | |
| """ | |
| Execute SQL and return a ValidationResult. | |
| Never raises β always returns a result object. | |
| """ | |
| sql = sql.strip().rstrip(";") | |
| if not sql: | |
| return ValidationResult(passed=False, sql=sql, error="Empty SQL string.") | |
| # Block any write operations | |
| first_word = sql.split()[0].lower() if sql.split() else "" | |
| forbidden = {"insert","update","delete","drop","alter","create","replace","truncate","pragma"} | |
| if first_word in forbidden: | |
| return ValidationResult( | |
| passed=False, sql=sql, | |
| error=f"Write operation '{first_word.upper()}' is not permitted." | |
| ) | |
| try: | |
| cur = self._conn.execute(sql) | |
| cols = [d[0] for d in cur.description] if cur.description else [] | |
| rows = cur.fetchall() | |
| return ValidationResult( | |
| passed=True, | |
| sql=sql, | |
| row_count=len(rows), | |
| columns=cols, | |
| ) | |
| except sqlite3.Error as exc: | |
| return ValidationResult(passed=False, sql=sql, error=str(exc)) | |
| def close(self) -> None: | |
| self._conn.close() | |
| def validate_template(template: Template, seed: int = 42) -> ValidationResult: | |
| """Convenience function: validate a single template.""" | |
| v = SQLValidator(template["domain"], seed=seed) | |
| result = v.validate(template["sql"]) | |
| v.close() | |
| return result | |
| def validate_all_templates(templates: list[Template], seed: int = 42) -> dict[str, Any]: | |
| """ | |
| Run validation across all templates. Returns a summary dict. | |
| Used during CI / smoke testing. | |
| """ | |
| from data_factory.schemas import SCHEMA_MAP | |
| validators = {domain: SQLValidator(domain, seed) for domain in SCHEMA_MAP} | |
| passed = [] | |
| failed = [] | |
| for i, t in enumerate(templates): | |
| v = validators[t["domain"]] | |
| result = v.validate(t["sql"]) | |
| if result.passed: | |
| passed.append(i) | |
| else: | |
| failed.append({"index": i, "domain": t["domain"], | |
| "sql": t["sql"][:80], "error": result.error}) | |
| for v in validators.values(): | |
| v.close() | |
| return { | |
| "total": len(templates), | |
| "passed": len(passed), | |
| "failed": len(failed), | |
| "failures": failed, | |
| } | |
| def build_record( | |
| template: Template, | |
| template_idx: int, | |
| nl_question: str, | |
| persona: str, | |
| source: str, | |
| validator: SQLValidator, | |
| ) -> Optional[DataRecord]: | |
| """ | |
| Validate the template SQL and, if it passes, build a DataRecord. | |
| Parameters | |
| ---------- | |
| template : The source template (contains SQL, domain, difficulty). | |
| template_idx : Index of template in ALL_TEMPLATES (for deduplication). | |
| nl_question : The NL paraphrase to use as the prompt. | |
| persona : Which persona/strategy generated this NL. | |
| source : 'template_base' | 'vllm_persona' | 'rule_augmented' | |
| validator : Pre-built SQLValidator for this domain. | |
| Returns None if validation fails. | |
| """ | |
| vr = validator.validate(template["sql"]) | |
| if not vr.passed: | |
| return None | |
| return DataRecord( | |
| domain=template["domain"], | |
| difficulty=template["difficulty"], | |
| sql=template["sql"], | |
| nl_question=nl_question, | |
| persona=persona, | |
| has_order=template["has_order"], | |
| schema_context=SCHEMA_CONTEXT[template["domain"]], | |
| row_count=vr.row_count, | |
| columns=vr.columns, | |
| source=source, | |
| template_id=template_idx, | |
| ) | |