"""Error injectors that transform exercise context into labeled mistakes.""" from __future__ import annotations import random from typing import Callable, Dict, List, Tuple from src.exercises import Exercise FAKE_COLUMNS = ["fullname", "studentname", "coursename", "dept_name", "totals"] FAKE_TABLES = ["student", "course", "enrolment", "employe", "orderz"] def _pick(rng: random.Random, items: List[str], k: int = 1) -> str | List[str]: if k == 1: return rng.choice(items) return rng.sample(items, k) def _first_table(exercise: Exercise) -> str: return exercise.tables[0] def _second_table(exercise: Exercise) -> str: return exercise.tables[1] if len(exercise.tables) > 1 else exercise.tables[0] # --- Error injectors: (exercise) -> (erroneous_sql, error_message) --- def inject_syntax_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: sql = exercise.correct_query mutations = [ lambda s: s.replace("SELECT", "SELEC", 1), lambda s: s.replace("FROM", "FRO", 1), lambda s: s[:-1], lambda s: s.replace(")", "", 1), lambda s: s + " WHERE", lambda s: s.replace(",", "", 1), lambda s: s.replace("'", '"', 1) if "'" in s else s + " 'unclosed", ] return rng.choice(mutations)(sql), "syntax error at or near unexpected token" def inject_join_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: t1, t2 = _first_table(exercise), _second_table(exercise) col = _pick(rng, list(exercise.columns)) variants = [ f"SELECT {col} FROM {t1} JOIN {t2}", f"SELECT {col} FROM {t1} INNER JOIN {t2} ON {t1}.id = {t2}.id", ( f"SELECT {t1}.{col} FROM {t1} " f"LEFT JOIN {t2} ON {t1}.{col} = {t2}.{col}" ), f"SELECT * FROM {t1}, {t2} WHERE {t1}.wrong_id = {t2}.wrong_id", ] return rng.choice(variants), "missing ON clause or invalid join condition" def inject_aggregation_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: t = _first_table(exercise) cols = list(exercise.columns) group_col = cols[0] agg_col = cols[-1] bad = f"SELECT {group_col}, AVG({agg_col}) FROM {t}" return bad, "column must appear in GROUP BY clause or be used in aggregate function" def inject_having_where_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: t = _first_table(exercise) cols = list(exercise.columns) group_col, agg_col = cols[0], cols[-1] bad = ( f"SELECT {group_col}, COUNT({agg_col}) FROM {t} " f"WHERE COUNT({agg_col}) > {rng.randint(1, 5)}" ) return bad, "aggregate functions are not allowed in WHERE" def inject_subquery_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: t1, t2 = _first_table(exercise), _second_table(exercise) col = _pick(rng, list(exercise.columns)) bad = f"SELECT {col} FROM {t1} WHERE {col} = (SELECT {col} FROM {t2})" return bad, "subquery returned more than one row" def inject_window_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: t = _first_table(exercise) col = _pick(rng, list(exercise.columns)) variants = [ f"SELECT {col}, ROW_NUMBER() OVER () FROM {t}", f"SELECT {col}, SUM({col}) OVER (ORDER BY {col}) FROM {t} GROUP BY {col}", f"SELECT {col}, RANK() OVER (PARTITION {col}) FROM {t}", ] return rng.choice(variants), "window function requires PARTITION BY or ORDER BY" def inject_null_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: t = _first_table(exercise) col = _pick(rng, list(exercise.columns)) bad = f"SELECT * FROM {t} WHERE {col} = NULL" return bad, "use IS NULL or IS NOT NULL to test for null values" def inject_date_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: t = _first_table(exercise) variants = [ f"SELECT * FROM {t} WHERE order_date = '31/02/2023'", f"SELECT * FROM {t} WHERE order_date = DATE '2023-13-40'", f"SELECT * FROM {t} WHERE STR_TO_DATE('bad-date', '%Y-%m-%d')", f"SELECT * FROM {t} WHERE hire_date > 'yesterday'", ] return rng.choice(variants), "invalid date format or unknown date function" def inject_column_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: t = _first_table(exercise) col = _pick(rng, FAKE_COLUMNS) bad = f"SELECT {col} FROM {t}" return bad, f"column '{col}' does not exist" def inject_table_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: tbl = _pick(rng, FAKE_TABLES) col = _pick(rng, list(exercise.columns)) bad = f"SELECT {col} FROM {tbl}" return bad, f"relation '{tbl}' does not exist" def inject_datatype_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: t = _first_table(exercise) col = _pick(rng, list(exercise.columns)) bad = f"SELECT {col} FROM {t} WHERE {col} = '{rng.choice(['abc', 'ten', 'N/A'])}'" return bad, "operator does not exist: integer = character varying" def inject_duplicate_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: """Drop DISTINCT when the question asks for unique values.""" sql = exercise.correct_query if "DISTINCT" in sql.upper(): bad = sql.upper().replace("DISTINCT ", "").replace("distinct ", "") # restore original casing loosely bad = sql.replace("DISTINCT ", "").replace("distinct ", "") else: col = _pick(rng, list(exercise.columns)) bad = f"SELECT {col} FROM {_first_table(exercise)}" return bad, "query returns duplicate rows; DISTINCT may be required" def inject_logical_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: """ Produce a query that runs against the schema but answers the question incorrectly. Variants are tied to the exercise question and correct answer. """ sql = exercise.correct_query q = exercise.question.lower() variants: List[str] = [] if "average" in q or "avg" in sql.lower(): variants.append(sql.replace("AVG(", "SUM(", 1)) variants.append(sql.replace("AVG(", "MAX(", 1)) if " and " in q and " AND " in sql: variants.append(sql.replace(" AND ", " OR ", 1)) if "join" in sql.lower(): t1, t2 = _first_table(exercise), _second_table(exercise) variants.append( f"SELECT {t1}.name, {t2}.name FROM {t1} " f"JOIN {t2} ON {t1}.id = {t2}.id" ) variants.append(sql.replace("INNER JOIN", "LEFT JOIN", 1)) if "between" in q and "BETWEEN" in sql.upper(): upper = sql.upper() between_part = upper.split("BETWEEN", 1)[1] bounds = between_part.split("AND", 1) if len(bounds) == 2: lo = bounds[0].strip().split()[-1] hi = bounds[1].strip().split()[0] variants.append( sql.split("WHERE", 1)[0] + f" WHERE price BETWEEN {hi} AND {lo}" ) if "rank" in q or "over" in sql.lower(): col = _pick(rng, list(exercise.columns)) variants.append( f"SELECT name, {col} FROM {_first_table(exercise)} ORDER BY {col} DESC" ) if "total" in q and "WHERE" in sql.upper(): variants.append(sql.replace("active", "inactive")) if "highest" in q or "max" in sql.lower(): col = _pick(rng, list(exercise.columns)) variants.append( f"SELECT name FROM {_first_table(exercise)} " f"WHERE {col} >= (SELECT AVG({col}) FROM {_second_table(exercise)})" ) if "enrolled" in q and "INNER JOIN" in sql.upper(): variants.append(sql.replace("INNER JOIN", "LEFT JOIN", 1)) if "not provided" in q or "is null" in sql.lower(): variants.append(sql.replace("IS NULL", "= ''")) if not variants: col = _pick(rng, list(exercise.columns)) t = _first_table(exercise) variants = [ f"SELECT {col} FROM {t} ORDER BY {col} DESC LIMIT 10", f"SELECT COUNT(*) FROM {t}", ] bad = rng.choice(variants) return bad, "query executes but produces incorrect result set" def inject_performance_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: t1, t2 = _first_table(exercise), _second_table(exercise) variants = [ f"SELECT * FROM {t1}", f"SELECT * FROM {t1} JOIN {t2} ON {t1}.id = {t2}.id", ( f"SELECT * FROM {t1} " f"WHERE {_pick(rng, list(exercise.columns))} " f"LIKE '%{rng.choice(['a', 'e', 'i'])}%'" ), f"SELECT * FROM {t1} CROSS JOIN {t2}", ] return rng.choice(variants), "inefficient query: SELECT * or cartesian join detected" def inject_filtering_error(rng: random.Random, exercise: Exercise) -> Tuple[str, str]: sql = exercise.correct_query col = _pick(rng, list(exercise.columns)) t = _first_table(exercise) threshold = rng.randint(50, 90) variants = [ sql.replace(">", "<", 1) if ">" in sql else sql.replace("=", "!=", 1), f"SELECT {col} FROM {t} WHERE {col} > {threshold} AND {col} < {threshold - 20}", f"SELECT {col} FROM {t} WHERE NOT {col} > {threshold}", sql.replace(" AND ", " OR ", 1) if " AND " in sql else ( f"SELECT {col} FROM {t} WHERE {col} BETWEEN {threshold} AND {threshold - 10}" ), ] return rng.choice(variants), "WHERE clause filters incorrect rows" ERROR_INJECTORS: Dict[int, Callable[[random.Random, Exercise], Tuple[str, str]]] = { 0: inject_syntax_error, 1: inject_join_error, 2: inject_aggregation_error, 3: inject_having_where_error, 4: inject_subquery_error, 5: inject_window_error, 6: inject_null_error, 7: inject_date_error, 8: inject_column_error, 9: inject_table_error, 10: inject_datatype_error, 11: inject_duplicate_error, 12: inject_logical_error, 13: inject_performance_error, 14: inject_filtering_error, }