sql-error-classifier-train / src /sql_templates.py
nishu08's picture
Deploy CodeBERT training Space
9b2cded verified
"""Error injectors that transform exercise context into labeled mistakes."""
from __future__ import annotations
import random
from typing import Callable, Dict, List, Tuple
from src.exercises import Exercise
FAKE_COLUMNS = ["fullname", "studentname", "coursename", "dept_name", "totals"]
FAKE_TABLES = ["student", "course", "enrolment", "employe", "orderz"]
def _pick(rng: random.Random, items: List[str], k: int = 1) -> str | List[str]:
if k == 1:
return rng.choice(items)
return rng.sample(items, k)
def _first_table(exercise: Exercise) -> str:
return exercise.tables[0]
def _second_table(exercise: Exercise) -> str:
return exercise.tables[1] if len(exercise.tables) > 1 else exercise.tables[0]
# --- Error injectors: (exercise) -> (erroneous_sql, error_message) ---
def inject_syntax_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
sql = exercise.correct_query
mutations = [
lambda s: s.replace("SELECT", "SELEC", 1),
lambda s: s.replace("FROM", "FRO", 1),
lambda s: s[:-1],
lambda s: s.replace(")", "", 1),
lambda s: s + " WHERE",
lambda s: s.replace(",", "", 1),
lambda s: s.replace("'", '"', 1) if "'" in s else s + " 'unclosed",
]
return rng.choice(mutations)(sql), "syntax error at or near unexpected token"
def inject_join_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
t1, t2 = _first_table(exercise), _second_table(exercise)
col = _pick(rng, list(exercise.columns))
variants = [
f"SELECT {col} FROM {t1} JOIN {t2}",
f"SELECT {col} FROM {t1} INNER JOIN {t2} ON {t1}.id = {t2}.id",
(
f"SELECT {t1}.{col} FROM {t1} "
f"LEFT JOIN {t2} ON {t1}.{col} = {t2}.{col}"
),
f"SELECT * FROM {t1}, {t2} WHERE {t1}.wrong_id = {t2}.wrong_id",
]
return rng.choice(variants), "missing ON clause or invalid join condition"
def inject_aggregation_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
t = _first_table(exercise)
cols = list(exercise.columns)
group_col = cols[0]
agg_col = cols[-1]
bad = f"SELECT {group_col}, AVG({agg_col}) FROM {t}"
return bad, "column must appear in GROUP BY clause or be used in aggregate function"
def inject_having_where_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
t = _first_table(exercise)
cols = list(exercise.columns)
group_col, agg_col = cols[0], cols[-1]
bad = (
f"SELECT {group_col}, COUNT({agg_col}) FROM {t} "
f"WHERE COUNT({agg_col}) > {rng.randint(1, 5)}"
)
return bad, "aggregate functions are not allowed in WHERE"
def inject_subquery_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
t1, t2 = _first_table(exercise), _second_table(exercise)
col = _pick(rng, list(exercise.columns))
bad = f"SELECT {col} FROM {t1} WHERE {col} = (SELECT {col} FROM {t2})"
return bad, "subquery returned more than one row"
def inject_window_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
t = _first_table(exercise)
col = _pick(rng, list(exercise.columns))
variants = [
f"SELECT {col}, ROW_NUMBER() OVER () FROM {t}",
f"SELECT {col}, SUM({col}) OVER (ORDER BY {col}) FROM {t} GROUP BY {col}",
f"SELECT {col}, RANK() OVER (PARTITION {col}) FROM {t}",
]
return rng.choice(variants), "window function requires PARTITION BY or ORDER BY"
def inject_null_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
t = _first_table(exercise)
col = _pick(rng, list(exercise.columns))
bad = f"SELECT * FROM {t} WHERE {col} = NULL"
return bad, "use IS NULL or IS NOT NULL to test for null values"
def inject_date_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
t = _first_table(exercise)
variants = [
f"SELECT * FROM {t} WHERE order_date = '31/02/2023'",
f"SELECT * FROM {t} WHERE order_date = DATE '2023-13-40'",
f"SELECT * FROM {t} WHERE STR_TO_DATE('bad-date', '%Y-%m-%d')",
f"SELECT * FROM {t} WHERE hire_date > 'yesterday'",
]
return rng.choice(variants), "invalid date format or unknown date function"
def inject_column_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
t = _first_table(exercise)
col = _pick(rng, FAKE_COLUMNS)
bad = f"SELECT {col} FROM {t}"
return bad, f"column '{col}' does not exist"
def inject_table_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
tbl = _pick(rng, FAKE_TABLES)
col = _pick(rng, list(exercise.columns))
bad = f"SELECT {col} FROM {tbl}"
return bad, f"relation '{tbl}' does not exist"
def inject_datatype_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
t = _first_table(exercise)
col = _pick(rng, list(exercise.columns))
bad = f"SELECT {col} FROM {t} WHERE {col} = '{rng.choice(['abc', 'ten', 'N/A'])}'"
return bad, "operator does not exist: integer = character varying"
def inject_duplicate_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
"""Drop DISTINCT when the question asks for unique values."""
sql = exercise.correct_query
if "DISTINCT" in sql.upper():
bad = sql.upper().replace("DISTINCT ", "").replace("distinct ", "")
# restore original casing loosely
bad = sql.replace("DISTINCT ", "").replace("distinct ", "")
else:
col = _pick(rng, list(exercise.columns))
bad = f"SELECT {col} FROM {_first_table(exercise)}"
return bad, "query returns duplicate rows; DISTINCT may be required"
def inject_logical_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
"""
Produce a query that runs against the schema but answers the question incorrectly.
Variants are tied to the exercise question and correct answer.
"""
sql = exercise.correct_query
q = exercise.question.lower()
variants: List[str] = []
if "average" in q or "avg" in sql.lower():
variants.append(sql.replace("AVG(", "SUM(", 1))
variants.append(sql.replace("AVG(", "MAX(", 1))
if " and " in q and " AND " in sql:
variants.append(sql.replace(" AND ", " OR ", 1))
if "join" in sql.lower():
t1, t2 = _first_table(exercise), _second_table(exercise)
variants.append(
f"SELECT {t1}.name, {t2}.name FROM {t1} "
f"JOIN {t2} ON {t1}.id = {t2}.id"
)
variants.append(sql.replace("INNER JOIN", "LEFT JOIN", 1))
if "between" in q and "BETWEEN" in sql.upper():
upper = sql.upper()
between_part = upper.split("BETWEEN", 1)[1]
bounds = between_part.split("AND", 1)
if len(bounds) == 2:
lo = bounds[0].strip().split()[-1]
hi = bounds[1].strip().split()[0]
variants.append(
sql.split("WHERE", 1)[0]
+ f" WHERE price BETWEEN {hi} AND {lo}"
)
if "rank" in q or "over" in sql.lower():
col = _pick(rng, list(exercise.columns))
variants.append(
f"SELECT name, {col} FROM {_first_table(exercise)} ORDER BY {col} DESC"
)
if "total" in q and "WHERE" in sql.upper():
variants.append(sql.replace("active", "inactive"))
if "highest" in q or "max" in sql.lower():
col = _pick(rng, list(exercise.columns))
variants.append(
f"SELECT name FROM {_first_table(exercise)} "
f"WHERE {col} >= (SELECT AVG({col}) FROM {_second_table(exercise)})"
)
if "enrolled" in q and "INNER JOIN" in sql.upper():
variants.append(sql.replace("INNER JOIN", "LEFT JOIN", 1))
if "not provided" in q or "is null" in sql.lower():
variants.append(sql.replace("IS NULL", "= ''"))
if not variants:
col = _pick(rng, list(exercise.columns))
t = _first_table(exercise)
variants = [
f"SELECT {col} FROM {t} ORDER BY {col} DESC LIMIT 10",
f"SELECT COUNT(*) FROM {t}",
]
bad = rng.choice(variants)
return bad, "query executes but produces incorrect result set"
def inject_performance_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
t1, t2 = _first_table(exercise), _second_table(exercise)
variants = [
f"SELECT * FROM {t1}",
f"SELECT * FROM {t1} JOIN {t2} ON {t1}.id = {t2}.id",
(
f"SELECT * FROM {t1} "
f"WHERE {_pick(rng, list(exercise.columns))} "
f"LIKE '%{rng.choice(['a', 'e', 'i'])}%'"
),
f"SELECT * FROM {t1} CROSS JOIN {t2}",
]
return rng.choice(variants), "inefficient query: SELECT * or cartesian join detected"
def inject_filtering_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]:
sql = exercise.correct_query
col = _pick(rng, list(exercise.columns))
t = _first_table(exercise)
threshold = rng.randint(50, 90)
variants = [
sql.replace(">", "<", 1) if ">" in sql else sql.replace("=", "!=", 1),
f"SELECT {col} FROM {t} WHERE {col} > {threshold} AND {col} < {threshold - 20}",
f"SELECT {col} FROM {t} WHERE NOT {col} > {threshold}",
sql.replace(" AND ", " OR ", 1) if " AND " in sql else (
f"SELECT {col} FROM {t} WHERE {col} BETWEEN {threshold} AND {threshold - 10}"
),
]
return rng.choice(variants), "WHERE clause filters incorrect rows"
ERROR_INJECTORS: Dict[int, Callable[[random.Random, Exercise], Tuple[str, str]]] = {
0: inject_syntax_error,
1: inject_join_error,
2: inject_aggregation_error,
3: inject_having_where_error,
4: inject_subquery_error,
5: inject_window_error,
6: inject_null_error,
7: inject_date_error,
8: inject_column_error,
9: inject_table_error,
10: inject_datatype_error,
11: inject_duplicate_error,
12: inject_logical_error,
13: inject_performance_error,
14: inject_filtering_error,
}