Spaces:
Running
Running
| """SQL safety validation. | |
| Rejects any query that is not a pure SELECT statement. | |
| """ | |
| import re | |
| _FORBIDDEN_KEYWORDS = [ | |
| r"\bDROP\b", | |
| r"\bDELETE\b", | |
| r"\bUPDATE\b", | |
| r"\bALTER\b", | |
| r"\bTRUNCATE\b", | |
| r"\bINSERT\b", | |
| r"\bCREATE\b", | |
| r"\bGRANT\b", | |
| r"\bREVOKE\b", | |
| r"\bEXEC\b", | |
| r"\bEXECUTE\b", | |
| ] | |
| _FORBIDDEN_PATTERN = re.compile("|".join(_FORBIDDEN_KEYWORDS), re.IGNORECASE) | |
| def validate_sql(sql: str) -> tuple[bool, str]: | |
| """Check if a SQL string is safe to execute. | |
| Returns | |
| ------- | |
| (is_safe, reason) | |
| """ | |
| stripped = sql.strip().rstrip(";").strip() | |
| if not stripped: | |
| return False, "Empty query." | |
| # Must start with SELECT or WITH (CTE) | |
| if not re.match(r"^\s*(SELECT|WITH)\b", stripped, re.IGNORECASE): | |
| return False, "Only SELECT queries are allowed." | |
| # Check for forbidden keywords | |
| match = _FORBIDDEN_PATTERN.search(stripped) | |
| if match: | |
| return False, f"Forbidden keyword detected: {match.group().upper()}" | |
| return True, "" | |
| def check_sql_against_schema(sql: str, schema: dict[str, list[dict]]) -> tuple[bool, list[str]]: | |
| """Programmatically check that tables/columns in SQL exist in the schema. | |
| Returns (is_valid, list_of_issues). | |
| Much faster and more accurate than LLM-based critique. | |
| """ | |
| issues: list[str] = [] | |
| # Build lookup sets | |
| all_tables = {t.lower() for t in schema} | |
| table_columns: dict[str, set[str]] = {} | |
| for t, cols in schema.items(): | |
| table_columns[t.lower()] = {c["column_name"].lower() for c in cols} | |
| all_columns = set() | |
| for cols in table_columns.values(): | |
| all_columns |= cols | |
| sql_upper = sql.upper() | |
| # Extract table references (FROM / JOIN) | |
| table_refs = re.findall( | |
| r'(?:FROM|JOIN)\s+"?(\w+)"?', sql, re.IGNORECASE | |
| ) | |
| for tref in table_refs: | |
| if tref.lower() not in all_tables: | |
| issues.append(f"Table '{tref}' not found in schema") | |
| # Basic check: if GROUP BY is present, verify SELECT has aggregation or is in GROUP BY | |
| # (lightweight check — not full SQL parsing) | |
| if "GROUP BY" in sql_upper and "SELECT" in sql_upper: | |
| if not any(fn in sql_upper for fn in ["SUM(", "COUNT(", "AVG(", "MIN(", "MAX("]): | |
| issues.append("GROUP BY present but no aggregation function found") | |
| return (len(issues) == 0, issues) | |