"""Lightweight structural SQL features for the classification head.""" from __future__ import annotations import re from typing import List AGG_FUNCS = ("COUNT", "SUM", "AVG", "MAX", "MIN") WINDOW_FUNCS = ("ROW_NUMBER", "RANK", "DENSE_RANK", "OVER") def _upper(sql: str) -> str: return sql.upper() def extract_sql_features(student_query: str, correct_query: str = "") -> List[float]: """ Rule-based signals that complement semantic embeddings. Returns a fixed-length float vector. """ s = _upper(student_query) c = _upper(correct_query) if correct_query else "" has_agg = any(f" {f}(" in s or f"{f}(" in s for f in AGG_FUNCS) has_group = "GROUP BY" in s has_join = "JOIN" in s has_on = " ON " in s has_where = " WHERE " in s has_having = " HAVING " in s has_distinct = "DISTINCT" in s has_subquery = "(" in s and "SELECT" in s[s.find("(") :] has_window = "OVER" in s has_null_eq = "= NULL" in s or "=NULL" in s has_is_null = "IS NULL" in s or "IS NOT NULL" in s has_select_star = bool(re.search(r"SELECT\s+\*", s)) has_or = " OR " in s has_and = " AND " in s correct_has_distinct = "DISTINCT" in c correct_has_group = "GROUP BY" in c correct_has_inner = "INNER JOIN" in c student_has_left = "LEFT JOIN" in s return [ float(has_agg), float(has_agg and not has_group), float(has_join and not has_on), float(has_join), float(has_where and has_having), float(has_agg and has_where and not has_having), float(has_distinct), float(correct_has_distinct and not has_distinct), float(has_subquery), float(has_window), float(has_null_eq), float(has_is_null), float(has_select_star), float(has_or and has_and), float(correct_has_inner and student_has_left), float(len(s) / max(len(c), 1)), # length ratio vs reference ] FEATURE_NAMES = [ "has_aggregate", "agg_without_group_by", "join_without_on", "has_join", "where_and_having", "agg_in_where", "has_distinct", "missing_distinct_vs_correct", "has_subquery", "has_window", "null_equals", "is_null_check", "select_star", "and_or_mix", "left_vs_inner_join", "length_ratio", ]