Ru2SQL / src /models /postprocess.py
Tyycha's picture
fix bugs
cc2ed2f
"""Постобработка SQL: чистка вывода модели, валидация и нормализация.
Соответствует разделу 2.5 пояснительной записки. Pipeline:
raw_output ──► strip_model_artifacts ──► is_valid_sql ──► sql | ""
Дополнительно модуль предоставляет:
is_select_only(sql) — AST-уровневый гвардейл против DDL/DML
перед выполнением сгенерированного запроса;
normalize_sql(sql) — каноническая форма для расчёта Exact Match
(совместима с evaluate_pauq.py).
"""
from __future__ import annotations
import logging
import re
import sqlglot
from sqlglot import exp
from sqlglot.errors import ParseError
logger = logging.getLogger(__name__)
# Ключевые слова, с которых может начинаться корректный SQL-запрос.
_SQL_START_KEYWORDS = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE")
_SQL_START_REGEX = re.compile(
r"\b(" + "|".join(_SQL_START_KEYWORDS) + r")\b",
flags=re.IGNORECASE,
)
_FENCE_REGEX = re.compile(r"```(?:sql)?\s*(.*?)```", flags=re.DOTALL | re.IGNORECASE)
_PREFIX_REGEX = re.compile(r"^\s*(?:SQL|Ответ|Answer)\s*:\s*", flags=re.IGNORECASE)
# Типы AST-узлов, которые мы считаем «осмысленными» SQL-запросами.
# sqlglot — лояльный парсер: 'garbage text' он распарсит как Column/Table.
# Без проверки isinstance такие случаи будут проходить is_valid_sql.
_VALID_ROOT_TYPES: tuple[type[exp.Expression], ...] = (
exp.Select,
exp.With,
exp.Insert,
exp.Update,
exp.Delete,
exp.Union,
exp.Intersect,
exp.Except,
)
def strip_model_artifacts(text: str) -> str:
"""Очищает вывод модели от markdown и пояснений до начала SQL-запроса.
Шаги:
1. Если ответ обёрнут в ```sql ... ``` — извлекается содержимое.
2. Удаляются префиксы вида «SQL:», «Ответ:», «Answer:».
3. Ищется первое вхождение SQL-ключевого слова, всё до него отбрасывается.
4. Берётся первый statement до первой точки с запятой включительно.
"""
fence = _FENCE_REGEX.search(text)
if fence:
text = fence.group(1)
text = _PREFIX_REGEX.sub("", text)
keyword_match = _SQL_START_REGEX.search(text)
if keyword_match:
text = text[keyword_match.start():]
text = text.strip()
if ";" in text:
head, _, _ = text.partition(";")
text = head.strip() + ";"
return text.strip()
def is_valid_sql(sql: str, dialect: str = "sqlite") -> bool:
"""Проверяет, что строка — это валидный SQL-запрос.
Парсится через sqlglot и дополнительно проверяется, что корень AST —
это один из «осмысленных» типов запроса (SELECT/WITH/INSERT/UPDATE/
DELETE/UNION). Без проверки типа sqlglot принимает за SQL даже
случайные идентификаторы, потому что он лояльный парсер.
"""
if not sql or not sql.strip():
return False
try:
parsed = sqlglot.parse_one(sql, dialect=dialect)
except (ParseError, ValueError, TypeError) as e:
logger.debug("sqlglot не смог разобрать SQL: %s", e)
return False
if parsed is None:
return False
return isinstance(parsed, _VALID_ROOT_TYPES)
def is_select_only(sql: str, dialect: str = "sqlite") -> bool:
"""Возвращает True, если SQL — это SELECT (в т. ч. внутри WITH-CTE).
Используется как guardrail перед выполнением сгенерированного запроса
на реальной базе данных: модель не должна получить возможность вызвать
DROP/UPDATE/DELETE/INSERT, даже если такие конструкции синтаксически
корректны.
"""
if not sql or not sql.strip():
return False
try:
parsed = sqlglot.parse_one(sql, dialect=dialect)
except (ParseError, ValueError, TypeError):
return False
if parsed is None:
return False
if isinstance(parsed, exp.Select):
return True
if isinstance(parsed, exp.With):
return isinstance(parsed.this, exp.Select)
if isinstance(parsed, exp.Subquery):
return isinstance(parsed.this, exp.Select)
return False
def normalize_sql(sql: str, dialect: str = "sqlite") -> str:
"""Каноническая форма для расчёта Exact Match.
Использует sqlglot с флагом ``normalize=True`` — это нормализует регистр
ключевых слов и идентификаторов. Результат приводится к верхнему регистру,
чтобы EM считался идентично эталонной реализации в ``evaluate_pauq.py``.
"""
try:
parsed = sqlglot.parse_one(sql, dialect=dialect)
return parsed.sql(dialect=dialect, normalize=True).upper()
except (ParseError, ValueError, TypeError):
return re.sub(r"\s+", " ", sql.upper()).strip().rstrip(";")
def postprocess(raw_output: str) -> str:
"""Полный pipeline постобработки вывода модели.
1. Чистка артефактов через :func:`strip_model_artifacts`.
2. Валидация через :func:`is_valid_sql`.
3. Возврат пустой строки при провале валидации.
Соответствует разделу 2.5 пояснительной записки.
"""
sql = strip_model_artifacts(raw_output)
if not is_valid_sql(sql):
logger.warning("postprocess отбросил невалидный SQL: %r", sql[:120])
return ""
return sql