iris-at-text2sparql / src /validation.py
Alex Latipov
Harden frozen eval prompts and judge JSON handling
d745844
"""Cheap symbolic validation for the Text2SPARQL repair pipeline.
All validation is symbolic — no LLM calls. Scores candidates for selection.
"""
from __future__ import annotations
import logging
import re
from typing import Any
from .config import RuntimeConfig
from .models import (
CandidateQuery,
ContextPackage,
DatasetConfig,
QueryRequest,
ValidationResult,
)
logger = logging.getLogger(__name__)
# Threshold for "huge result" flag
_HUGE_RESULT_THRESHOLD = 10000
def parse_check(query: str) -> tuple[bool, str | None]:
"""Check whether a SPARQL query parses correctly using rdflib.
Args:
query: SPARQL query string.
Returns:
Tuple of (parse_ok, error_message).
"""
try:
from rdflib.plugins.sparql.parser import parseQuery
parseQuery(query)
return True, None
except ImportError:
# If rdflib is not installed, do a basic structural check
logger.warning("rdflib not installed — using basic parse check")
return _basic_parse_check(query)
except Exception as exc:
return False, str(exc)
def _basic_parse_check(query: str) -> tuple[bool, str | None]:
"""Basic structural SPARQL parse check without rdflib.
Checks for balanced braces and required keywords.
"""
q_upper = query.upper()
has_keyword = any(
kw in q_upper
for kw in ("SELECT", "ASK", "CONSTRUCT", "DESCRIBE")
)
if not has_keyword:
return False, "No SPARQL query keyword found (SELECT/ASK/CONSTRUCT/DESCRIBE)"
# Check balanced braces
open_count = query.count("{")
close_count = query.count("}")
if open_count != close_count:
return False, f"Unbalanced braces: {open_count} open, {close_count} close"
if "WHERE" not in q_upper and "ASK" not in q_upper:
return False, "Missing WHERE clause"
return True, None
def execute_query(
query: str, endpoint_url: str, timeout_sec: int
) -> tuple[bool, list[dict], int | None, str | None, bool]:
"""Execute a SPARQL query against an endpoint.
Args:
query: SPARQL query string.
endpoint_url: SPARQL endpoint URL.
timeout_sec: Request timeout in seconds.
Returns:
Tuple of (execute_ok, results, result_count, error_message, timed_out).
"""
try:
from SPARQLWrapper import SPARQLWrapper, JSON, POST
sparql = SPARQLWrapper(endpoint_url)
sparql.setQuery(query)
sparql.setReturnFormat(JSON)
sparql.setTimeout(timeout_sec)
sparql.setMethod(POST)
raw_results = sparql.query().convert()
# Parse results based on query type
if "boolean" in raw_results:
# ASK query
results = [{"boolean": raw_results["boolean"]}]
return True, results, 1, None, False
if "results" in raw_results and "bindings" in raw_results["results"]:
bindings = raw_results["results"]["bindings"]
result_count = len(bindings)
# Keep only first few for preview
preview = bindings[:5]
results = [
{k: v.get("value", "") for k, v in row.items()}
for row in preview
]
return True, results, result_count, None, False
return True, [], 0, None, False
except ImportError:
logger.warning("SPARQLWrapper not installed — skipping endpoint execution")
return False, [], None, "SPARQLWrapper not installed", False
except Exception as exc:
error_str = str(exc)
timed_out = any(
phrase in error_str.lower()
for phrase in ("timeout", "timed out", "time out", "deadline")
)
return False, [], None, error_str[:500], timed_out
def _detect_query_form(query: str) -> str:
"""Detect the SPARQL query form (ASK, SELECT, etc.)."""
q_stripped = re.sub(r"PREFIX\s+\S+\s+<[^>]+>", "", query, flags=re.IGNORECASE)
q_upper = q_stripped.strip().upper()
if q_upper.lstrip().startswith("ASK"):
return "ask"
if "COUNT(" in q_upper or "COUNT (" in q_upper:
return "count"
if q_upper.lstrip().startswith("SELECT"):
return "select"
if q_upper.lstrip().startswith("CONSTRUCT"):
return "construct"
if q_upper.lstrip().startswith("DESCRIBE"):
return "describe"
return "unknown"
def score_answer_type_fit(
question: str, query: str, answer_type_hint: str
) -> float:
"""Score how well the query form matches the expected answer type.
Args:
question: Natural language question.
query: SPARQL query.
answer_type_hint: Expected type ("ask", "count", "select").
Returns:
Score between 0.0 and 1.0.
"""
query_form = _detect_query_form(query)
if answer_type_hint == "ask":
if query_form == "ask":
return 1.0
return 0.0
if answer_type_hint == "count":
if query_form == "count":
return 1.0
if query_form == "select":
return 0.3 # Select could still work
return 0.0
if answer_type_hint == "select":
if query_form == "select":
return 1.0
if query_form == "count":
return 0.3
return 0.2
return 0.5 # Unknown hint
def score_schema_fit(query: str, context: ContextPackage) -> float:
"""Score how well the query uses entities/relations from the context.
Simple heuristic: checks if context URIs appear in the query.
Args:
query: SPARQL query.
context: Context package with candidates.
Returns:
Score between 0.0 and 1.0.
"""
if not context.entity_candidates and not context.relation_candidates:
return 0.5 # No context to judge against
total_candidates = 0
matched = 0
for entity in context.entity_candidates:
uri = entity.get("uri", "")
if uri:
total_candidates += 1
if uri in query:
matched += 1
for relation in context.relation_candidates:
uri = relation.get("uri", "")
if uri:
total_candidates += 1
if uri in query:
matched += 1
for cls in context.class_candidates:
uri = cls.get("uri", "")
if uri:
total_candidates += 1
if uri in query:
matched += 1
if total_candidates == 0:
return 0.5
return min(1.0, matched / max(1, min(total_candidates, 3)))
def compute_validation_score(
parse_ok: bool,
execute_ok: bool,
result_count: int | None,
answer_type_fit: float,
schema_fit: float,
suspicious_flags: list[str],
weights: dict[str, float],
) -> float:
"""Compute the validation score using the fixed scoring formula.
Formula:
score = + 5.0 if parse_ok
+ 5.0 if execute_ok
+ 2.0 * answer_type_fit
+ 2.0 * schema_fit
- 2.0 if timeout
- 1.5 if empty_result
- 1.0 if huge_result
- 0.5 * len(suspicious_flags)
Args:
parse_ok: Whether query parsed.
execute_ok: Whether query executed.
result_count: Number of results.
answer_type_fit: Answer type fit score [0,1].
schema_fit: Schema fit score [0,1].
suspicious_flags: List of suspicious flag strings.
weights: Scoring weights dict.
Returns:
Total validation score.
"""
score = 0.0
if parse_ok:
score += weights.get("parse_ok", 5.0)
if execute_ok:
score += weights.get("execute_ok", 5.0)
score += weights.get("answer_type_fit", 2.0) * answer_type_fit
score += weights.get("schema_fit", 2.0) * schema_fit
if "timeout" in suspicious_flags:
score += weights.get("timeout", -2.0)
if "empty_result" in suspicious_flags:
score += weights.get("empty_result", -1.5)
if "huge_result" in suspicious_flags:
score += weights.get("huge_result", -1.0)
score += weights.get("suspicious_flag", -0.5) * len(suspicious_flags)
return round(score, 4)
def validate_candidate(
candidate: CandidateQuery,
request: QueryRequest,
context: ContextPackage,
dataset: DatasetConfig,
runtime: RuntimeConfig,
) -> ValidationResult:
"""Validate a single candidate query.
Runs all symbolic checks:
- Parser check
- Endpoint execution
- Timeout check
- Result count check
- Answer type sanity check
- Schema plausibility check
Args:
candidate: The candidate SPARQL query.
request: The original query request.
context: Context package.
dataset: Dataset configuration.
runtime: Runtime configuration.
Returns:
ValidationResult with all check results and score.
"""
flags: list[str] = []
query = candidate.query
# 1. Parser check
parse_ok, parse_error = parse_check(query)
if not parse_ok:
flags.append("parse_fail")
return ValidationResult(
candidate_id=candidate.candidate_id,
parse_ok=False,
execute_ok=False,
timeout=False,
execution_error=parse_error,
result_count=None,
result_preview=[],
answer_type_fit=0.0,
schema_fit=0.0,
suspicious_flags=flags,
score=compute_validation_score(
False, False, None, 0.0, 0.0, flags,
runtime.selection_weights,
),
)
# 2. Endpoint execution
execute_ok, results, result_count, exec_error, timed_out = execute_query(
query, dataset.endpoint_url, runtime.request_timeout_sec
)
if timed_out:
flags.append("timeout")
if not execute_ok:
flags.append("execute_fail")
if result_count is not None:
if result_count == 0:
flags.append("empty_result")
elif result_count > _HUGE_RESULT_THRESHOLD:
flags.append("huge_result")
# 3. Answer type check
answer_type_hint = context.answer_type_hint or "select"
at_fit = score_answer_type_fit(request.question, query, answer_type_hint)
query_form = _detect_query_form(query)
if answer_type_hint != query_form and query_form != "unknown":
# Only flag if there's a clear mismatch
if not (answer_type_hint == "count" and query_form == "select"):
flags.append("form_mismatch")
# 4. Schema fit
s_fit = score_schema_fit(query, context)
# 5. Compute score
score = compute_validation_score(
parse_ok, execute_ok, result_count, at_fit, s_fit,
flags, runtime.selection_weights,
)
return ValidationResult(
candidate_id=candidate.candidate_id,
parse_ok=parse_ok,
execute_ok=execute_ok,
timeout=timed_out,
execution_error=exec_error,
result_count=result_count,
result_preview=results,
answer_type_fit=at_fit,
schema_fit=s_fit,
suspicious_flags=flags,
score=score,
)
def validate_all(
candidates: list[CandidateQuery],
request: QueryRequest,
context: ContextPackage,
dataset: DatasetConfig,
runtime: RuntimeConfig,
) -> list[ValidationResult]:
"""Validate all candidate queries.
Args:
candidates: List of candidate queries.
request: The original query request.
context: Context package.
dataset: Dataset configuration.
runtime: Runtime configuration.
Returns:
List of ValidationResult objects, one per candidate.
"""
results = []
for candidate in candidates:
logger.info("Validating candidate %s", candidate.candidate_id)
result = validate_candidate(candidate, request, context, dataset, runtime)
logger.info(
"Candidate %s: score=%.2f, flags=%s",
candidate.candidate_id, result.score, result.suspicious_flags,
)
results.append(result)
return results