sqlbot / ai /validator.py
sqlbot
Initial Hugging Face sqlbot setup
28035e9
"""SQL safety validation.
Rejects any query that is not a pure SELECT statement.
"""
import re
_FORBIDDEN_KEYWORDS = [
r"\bDROP\b",
r"\bDELETE\b",
r"\bUPDATE\b",
r"\bALTER\b",
r"\bTRUNCATE\b",
r"\bINSERT\b",
r"\bCREATE\b",
r"\bGRANT\b",
r"\bREVOKE\b",
r"\bEXEC\b",
r"\bEXECUTE\b",
]
_FORBIDDEN_PATTERN = re.compile("|".join(_FORBIDDEN_KEYWORDS), re.IGNORECASE)
def validate_sql(sql: str) -> tuple[bool, str]:
"""Check if a SQL string is safe to execute.
Returns
-------
(is_safe, reason)
"""
stripped = sql.strip().rstrip(";").strip()
if not stripped:
return False, "Empty query."
# Must start with SELECT or WITH (CTE)
if not re.match(r"^\s*(SELECT|WITH)\b", stripped, re.IGNORECASE):
return False, "Only SELECT queries are allowed."
# Check for forbidden keywords
match = _FORBIDDEN_PATTERN.search(stripped)
if match:
return False, f"Forbidden keyword detected: {match.group().upper()}"
return True, ""
def check_sql_against_schema(sql: str, schema: dict[str, list[dict]]) -> tuple[bool, list[str]]:
"""Programmatically check that tables/columns in SQL exist in the schema.
Returns (is_valid, list_of_issues).
Much faster and more accurate than LLM-based critique.
"""
issues: list[str] = []
# Build lookup sets
all_tables = {t.lower() for t in schema}
table_columns: dict[str, set[str]] = {}
for t, cols in schema.items():
table_columns[t.lower()] = {c["column_name"].lower() for c in cols}
all_columns = set()
for cols in table_columns.values():
all_columns |= cols
sql_upper = sql.upper()
# Extract table references (FROM / JOIN)
table_refs = re.findall(
r'(?:FROM|JOIN)\s+"?(\w+)"?', sql, re.IGNORECASE
)
for tref in table_refs:
if tref.lower() not in all_tables:
issues.append(f"Table '{tref}' not found in schema")
# Basic check: if GROUP BY is present, verify SELECT has aggregation or is in GROUP BY
# (lightweight check — not full SQL parsing)
if "GROUP BY" in sql_upper and "SELECT" in sql_upper:
if not any(fn in sql_upper for fn in ["SUM(", "COUNT(", "AVG(", "MIN(", "MAX("]):
issues.append("GROUP BY present but no aggregation function found")
return (len(issues) == 0, issues)