Your Full Name
Clamp scores to strictly (0, 1) range for validator compliance
3cbb05e
Raw
History Blame Contribute Delete
14.9 kB
"""SQL Query Debugging Tasks with Graders.
Three difficulty levels:
- Easy: Simple syntax errors (typos, missing keywords)
- Medium: Logic errors (wrong joins, incorrect conditions)
- Hard: Complex issues (subquery errors, aggregation bugs, performance issues)
"""
import sqlite3
from dataclasses import dataclass
from typing import Optional
@dataclass
class TaskDefinition:
"""Definition of a SQL debugging task."""
task_id: str
difficulty: str # easy, medium, hard
description: str
schema_ddl: str
sample_data_sql: str
broken_query: str
correct_query: str
expected_output: list[tuple]
expected_output_hint: str
hints: list[str]
max_steps: int
error_types: list[str] # Types of errors in the broken query
# =============================================================================
# EASY TASK: Simple Syntax Errors
# =============================================================================
EASY_TASK = TaskDefinition(
task_id="easy_syntax_fix",
difficulty="easy",
description="Fix a query with simple syntax errors to retrieve all customers from California",
schema_ddl="""
CREATE TABLE customers (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
email TEXT NOT NULL,
state TEXT NOT NULL,
created_at DATE NOT NULL
);
""",
sample_data_sql="""
INSERT INTO customers VALUES (1, 'Alice Johnson', 'alice@email.com', 'CA', '2024-01-15');
INSERT INTO customers VALUES (2, 'Bob Smith', 'bob@email.com', 'NY', '2024-02-20');
INSERT INTO customers VALUES (3, 'Carol Davis', 'carol@email.com', 'CA', '2024-03-10');
INSERT INTO customers VALUES (4, 'David Wilson', 'david@email.com', 'TX', '2024-01-25');
INSERT INTO customers VALUES (5, 'Eve Brown', 'eve@email.com', 'CA', '2024-04-05');
""",
broken_query="""
SELCT name, email FORM customers WERE state = 'CA' ORDERY BY name;
""",
correct_query="""
SELECT name, email FROM customers WHERE state = 'CA' ORDER BY name;
""",
expected_output=[
("Alice Johnson", "alice@email.com"),
("Carol Davis", "carol@email.com"),
("Eve Brown", "eve@email.com"),
],
expected_output_hint="Should return 3 rows with name and email of California customers, sorted alphabetically",
hints=[
"Check the spelling of SQL keywords like SELECT, FROM, WHERE, ORDER BY",
"The query has 4 misspelled keywords",
"SELCT→SELECT, FORM→FROM, WERE→WHERE, ORDERY→ORDER",
],
max_steps=10,
error_types=["syntax_typo"],
)
# =============================================================================
# MEDIUM TASK: Logic Errors with JOINs
# =============================================================================
MEDIUM_TASK = TaskDefinition(
task_id="medium_join_logic",
difficulty="medium",
description="Fix a query that should find total order amounts per customer, but has JOIN and aggregation issues",
schema_ddl="""
CREATE TABLE customers (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
tier TEXT NOT NULL
);
CREATE TABLE orders (
id INTEGER PRIMARY KEY,
customer_id INTEGER NOT NULL,
amount DECIMAL(10,2) NOT NULL,
status TEXT NOT NULL,
FOREIGN KEY (customer_id) REFERENCES customers(id)
);
""",
sample_data_sql="""
INSERT INTO customers VALUES (1, 'Acme Corp', 'gold');
INSERT INTO customers VALUES (2, 'Beta Inc', 'silver');
INSERT INTO customers VALUES (3, 'Gamma LLC', 'gold');
INSERT INTO customers VALUES (4, 'Delta Co', 'bronze');
INSERT INTO orders VALUES (1, 1, 500.00, 'completed');
INSERT INTO orders VALUES (2, 1, 300.00, 'completed');
INSERT INTO orders VALUES (3, 2, 150.00, 'completed');
INSERT INTO orders VALUES (4, 2, 200.00, 'cancelled');
INSERT INTO orders VALUES (5, 3, 1000.00, 'completed');
INSERT INTO orders VALUES (6, 3, 250.00, 'completed');
INSERT INTO orders VALUES (7, 1, 100.00, 'pending');
""",
broken_query="""
SELECT c.name, SUM(o.amount) as total
FROM customers c
LEFT JOIN orders o ON c.id = o.id
WHERE o.status = 'completed'
GROUP BY c.id;
""",
correct_query="""
SELECT c.name, SUM(o.amount) as total
FROM customers c
INNER JOIN orders o ON c.id = o.customer_id
WHERE o.status = 'completed'
GROUP BY c.id, c.name
ORDER BY total DESC;
""",
expected_output=[
("Gamma LLC", 1250.00),
("Acme Corp", 800.00),
("Beta Inc", 150.00),
],
expected_output_hint="Should return 3 customers with completed orders and their totals, ordered by total descending",
hints=[
"Check the JOIN condition - what column should orders be joined on?",
"The LEFT JOIN becomes problematic with the WHERE clause filtering",
"JOIN should be ON c.id = o.customer_id, consider using INNER JOIN for filtering",
],
max_steps=12,
error_types=["wrong_join_column", "join_type_issue", "missing_group_by_column"],
)
# =============================================================================
# HARD TASK: Complex Subquery and Window Function Issues
# =============================================================================
HARD_TASK = TaskDefinition(
task_id="hard_complex_analysis",
difficulty="hard",
description="Fix a complex analytics query that should find customers whose latest order exceeds their average order amount",
schema_ddl="""
CREATE TABLE customers (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
segment TEXT NOT NULL
);
CREATE TABLE orders (
id INTEGER PRIMARY KEY,
customer_id INTEGER NOT NULL,
amount DECIMAL(10,2) NOT NULL,
order_date DATE NOT NULL,
FOREIGN KEY (customer_id) REFERENCES customers(id)
);
""",
sample_data_sql="""
INSERT INTO customers VALUES (1, 'TechStart', 'enterprise');
INSERT INTO customers VALUES (2, 'SmallBiz', 'smb');
INSERT INTO customers VALUES (3, 'MegaCorp', 'enterprise');
INSERT INTO customers VALUES (4, 'LocalShop', 'smb');
INSERT INTO orders VALUES (1, 1, 100.00, '2024-01-01');
INSERT INTO orders VALUES (2, 1, 200.00, '2024-02-01');
INSERT INTO orders VALUES (3, 1, 500.00, '2024-03-01');
INSERT INTO orders VALUES (4, 2, 50.00, '2024-01-15');
INSERT INTO orders VALUES (5, 2, 75.00, '2024-02-15');
INSERT INTO orders VALUES (6, 2, 60.00, '2024-03-15');
INSERT INTO orders VALUES (7, 3, 1000.00, '2024-01-10');
INSERT INTO orders VALUES (8, 3, 2000.00, '2024-02-10');
INSERT INTO orders VALUES (9, 3, 1500.00, '2024-03-10');
INSERT INTO orders VALUES (10, 4, 25.00, '2024-01-20');
INSERT INTO orders VALUES (11, 4, 30.00, '2024-02-20');
INSERT INTO orders VALUES (12, 4, 100.00, '2024-03-20');
""",
broken_query="""
SELECT c.name, c.segment, latest.amount as latest_order, avg_order.avg_amount
FROM customers c
JOIN (
SELECT customer_id, amount
FROM orders
WHERE order_date = (SELECT MAX(order_date) FROM orders)
) latest ON c.id = latest.customer_id
JOIN (
SELECT customer_id, AVG(amount) as avg_amount
FROM orders
GROUP BY customer_id
) avg_order ON c.id = avg_order.id
WHERE latest.amount > avg_order.avg_amount;
""",
correct_query="""
SELECT c.name, c.segment, latest.amount as latest_order, avg_order.avg_amount
FROM customers c
JOIN (
SELECT o1.customer_id, o1.amount
FROM orders o1
INNER JOIN (
SELECT customer_id, MAX(order_date) as max_date
FROM orders
GROUP BY customer_id
) o2 ON o1.customer_id = o2.customer_id AND o1.order_date = o2.max_date
) latest ON c.id = latest.customer_id
JOIN (
SELECT customer_id, AVG(amount) as avg_amount
FROM orders
GROUP BY customer_id
) avg_order ON c.id = avg_order.customer_id
WHERE latest.amount > avg_order.avg_amount
ORDER BY c.name;
""",
expected_output=[
("LocalShop", "smb", 100.00, 51.666666666666664),
("TechStart", "enterprise", 500.00, 266.6666666666667),
],
expected_output_hint="Should return 2 customers (LocalShop and TechStart) whose most recent order exceeds their historical average",
hints=[
"The latest order subquery finds the global max date, not per-customer max date",
"The avg_order subquery has a join condition bug - check the column name",
"For per-customer latest, you need to group by customer_id when finding max date",
],
max_steps=15,
error_types=["incorrect_subquery_logic", "wrong_join_column", "missing_correlation"],
)
# =============================================================================
# Task Registry
# =============================================================================
TASKS: dict[str, TaskDefinition] = {
"easy_syntax_fix": EASY_TASK,
"medium_join_logic": MEDIUM_TASK,
"hard_complex_analysis": HARD_TASK,
}
def get_task(task_id: str) -> TaskDefinition:
"""Get a task by ID."""
if task_id not in TASKS:
raise ValueError(f"Unknown task: {task_id}. Available: {list(TASKS.keys())}")
return TASKS[task_id]
def list_tasks() -> list[str]:
"""List all available task IDs."""
return list(TASKS.keys())
# =============================================================================
# Graders
# =============================================================================
class SQLGrader:
"""Grades SQL query fixes against expected results."""
def __init__(self, task: TaskDefinition):
self.task = task
self.conn: Optional[sqlite3.Connection] = None
def setup_database(self) -> sqlite3.Connection:
"""Create an in-memory database with the task schema and data."""
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()
# Execute schema DDL
for statement in self.task.schema_ddl.strip().split(";"):
statement = statement.strip()
if statement:
cursor.execute(statement)
# Insert sample data
for statement in self.task.sample_data_sql.strip().split(";"):
statement = statement.strip()
if statement:
cursor.execute(statement)
conn.commit()
self.conn = conn
return conn
def execute_query(self, query: str) -> tuple[bool, Optional[list[tuple]], Optional[str]]:
"""
Execute a query and return (success, results, error_message).
"""
if self.conn is None:
self.setup_database()
try:
cursor = self.conn.cursor()
cursor.execute(query)
results = cursor.fetchall()
return True, results, None
except sqlite3.Error as e:
return False, None, str(e)
def grade(self, submitted_query: str) -> tuple[float, str, dict[str, float]]:
"""
Grade a submitted query fix.
Returns:
- score: float between 0.0 and 1.0
- reason: explanation of the score
- partial_scores: breakdown of scoring components
"""
partial_scores = {}
# Component 1: Syntactic validity (0.2)
success, results, error = self.execute_query(submitted_query)
if not success:
partial_scores["syntax_valid"] = 0.0
return 0.0, f"Query failed to execute: {error}", partial_scores
partial_scores["syntax_valid"] = 0.2
# Component 2: Returns correct number of rows (0.2)
expected_rows = len(self.task.expected_output)
actual_rows = len(results) if results else 0
if actual_rows == expected_rows:
partial_scores["row_count"] = 0.2
elif actual_rows > 0:
# Partial credit for being close
ratio = min(actual_rows, expected_rows) / max(actual_rows, expected_rows)
partial_scores["row_count"] = 0.2 * ratio
else:
partial_scores["row_count"] = 0.0
# Component 3: Column count matches (0.1)
if results and self.task.expected_output:
expected_cols = len(self.task.expected_output[0])
actual_cols = len(results[0]) if results else 0
if actual_cols == expected_cols:
partial_scores["column_count"] = 0.1
else:
partial_scores["column_count"] = 0.0
else:
partial_scores["column_count"] = 0.0
# Component 4: Data correctness (0.5)
if results:
# Normalize results for comparison (handle float precision)
def normalize_row(row):
return tuple(
round(v, 2) if isinstance(v, float) else v
for v in row
)
normalized_results = set(normalize_row(r) for r in results)
normalized_expected = set(normalize_row(r) for r in self.task.expected_output)
if normalized_results == normalized_expected:
partial_scores["data_correct"] = 0.5
else:
# Partial credit for overlapping results
intersection = normalized_results & normalized_expected
union = normalized_results | normalized_expected
if union:
jaccard = len(intersection) / len(union)
partial_scores["data_correct"] = 0.5 * jaccard
else:
partial_scores["data_correct"] = 0.0
else:
partial_scores["data_correct"] = 0.0
total_score = sum(partial_scores.values())
# Clamp to strictly between 0 and 1 (validators reject exact 0.0 and 1.0)
total_score = max(0.001, min(0.999, total_score))
# Generate reason
if total_score >= 0.99:
reason = "Query is correct - returns exact expected results"
elif total_score >= 0.7:
reason = "Query executes and returns similar results, but with some differences"
elif total_score >= 0.4:
reason = "Query executes but results differ significantly from expected"
elif total_score >= 0.2:
reason = "Query executes but returns incorrect or no matching data"
else:
reason = "Query failed to execute or has critical issues"
return round(total_score, 4), reason, partial_scores
def cleanup(self):
"""Close database connection."""
if self.conn:
self.conn.close()
self.conn = None
def grade_task(task_id: str, submitted_query: str) -> tuple[float, str, dict[str, float]]:
"""
Grade a submitted query for a specific task.
Args:
task_id: The task identifier
submitted_query: The SQL query submitted by the agent
Returns:
- score: float between 0.0 and 1.0
- reason: explanation of the score
- partial_scores: breakdown of scoring components
"""
task = get_task(task_id)
grader = SQLGrader(task)
try:
score, reason, partial_scores = grader.grade(submitted_query)
return score, reason, partial_scores
finally:
grader.cleanup()