|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| SQL/Data Cleaning Sandbox Environment Implementation.
|
|
|
| Three tasks (easy medium hard) for AI agents:
|
| 1. Data Triage query revenue from sales data
|
| 2. Data Cleaning fix duplicates & nulls in a users table
|
| 3. Schema Migration normalize a flat table into two related tables
|
| """
|
|
|
| import io
|
| import os
|
| import sqlite3
|
| import sys
|
| import tempfile
|
| import traceback
|
| from contextlib import redirect_stderr, redirect_stdout
|
| from uuid import uuid4
|
|
|
| from openenv.core.env_server.interfaces import Environment
|
| from openenv.core.env_server.types import State
|
|
|
| try:
|
| from ..models import SqlSandboxAction, SqlSandboxObservation
|
| except ImportError:
|
| from models import SqlSandboxAction, SqlSandboxObservation
|
|
|
|
|
|
|
|
|
| TASKS = {
|
| "easy": {
|
| "id": "easy",
|
| "description": (
|
| "Find the total revenue from the 'sales' table for January 2024. "
|
| "The table has columns: id, product, amount, sale_date (YYYY-MM-DD). "
|
| "Return the exact total as a single number by running a SQL query. "
|
| "The expected result should be a SELECT query that returns one number."
|
| ),
|
| "max_steps": 10,
|
| },
|
| "medium": {
|
| "id": "medium",
|
| "description": (
|
| "The 'users' table has duplicate emails and NULL values in the 'age' column. "
|
| "Clean the data so that: (1) all emails are lowercase, "
|
| "(2) duplicate emails are removed (keep the row with the lowest id), "
|
| "(3) all NULL ages are replaced with 0. "
|
| "Use SQL or Python to fix the table in-place."
|
| ),
|
| "max_steps": 15,
|
| },
|
| "hard": {
|
| "id": "hard",
|
| "description": (
|
| "The 'flat_orders' table has columns: order_id, order_date, "
|
| "customer_name, customer_email, product, quantity, price. "
|
| "Normalize this into two tables: 'customers' (id INTEGER PRIMARY KEY, "
|
| "name TEXT, email TEXT UNIQUE) and 'orders' (id INTEGER PRIMARY KEY, "
|
| "customer_id INTEGER REFERENCES customers(id), order_date TEXT, "
|
| "product TEXT, quantity INTEGER, price REAL). "
|
| "Maintain foreign key integrity and migrate all data."
|
| ),
|
| "max_steps": 20,
|
| },
|
| }
|
|
|
|
|
|
|
|
|
|
|
| def _seed_easy(conn: sqlite3.Connection):
|
| """Create sales table with known data."""
|
| conn.execute("DROP TABLE IF EXISTS sales")
|
| conn.execute(
|
| "CREATE TABLE sales (id INTEGER PRIMARY KEY, product TEXT, amount REAL, sale_date TEXT)"
|
| )
|
| rows = [
|
| (1, "Widget A", 150.00, "2024-01-05"),
|
| (2, "Widget B", 250.50, "2024-01-12"),
|
| (3, "Widget C", 99.99, "2024-01-20"),
|
| (4, "Widget A", 150.00, "2024-01-28"),
|
| (5, "Widget D", 349.51, "2024-01-15"),
|
| (6, "Widget A", 200.00, "2024-02-03"),
|
| (7, "Widget B", 75.00, "2023-12-30"),
|
| ]
|
| conn.executemany("INSERT INTO sales VALUES (?,?,?,?)", rows)
|
| conn.commit()
|
|
|
|
|
| def _seed_medium(conn: sqlite3.Connection):
|
| """Create users table with messy data."""
|
| conn.execute("DROP TABLE IF EXISTS users")
|
| conn.execute(
|
| "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, email TEXT, age INTEGER)"
|
| )
|
| rows = [
|
| (1, "Alice", "Alice@Example.com", 30),
|
| (2, "Bob", "bob@example.com", None),
|
| (3, "Charlie", "charlie@test.com", 25),
|
| (4, "Alice Dup", "alice@example.com", 28),
|
| (5, "Dave", "DAVE@Test.COM", None),
|
| (6, "Eve", "eve@example.com", 35),
|
| (7, "Dave Dup", "dave@test.com", 40),
|
| (8, "Frank", "frank@example.com", None),
|
| ]
|
| conn.executemany("INSERT INTO users VALUES (?,?,?,?)", rows)
|
| conn.commit()
|
|
|
|
|
| def _seed_hard(conn: sqlite3.Connection):
|
| """Create flat_orders table."""
|
| conn.execute("DROP TABLE IF EXISTS flat_orders")
|
| conn.execute("DROP TABLE IF EXISTS customers")
|
| conn.execute("DROP TABLE IF EXISTS orders")
|
| conn.execute(
|
| "CREATE TABLE flat_orders ("
|
| "order_id INTEGER, order_date TEXT, customer_name TEXT, "
|
| "customer_email TEXT, product TEXT, quantity INTEGER, price REAL)"
|
| )
|
| rows = [
|
| (1, "2024-01-10", "Alice", "alice@example.com", "Laptop", 1, 999.99),
|
| (2, "2024-01-11", "Bob", "bob@example.com", "Mouse", 2, 25.50),
|
| (3, "2024-01-12", "Alice", "alice@example.com", "Keyboard", 1, 75.00),
|
| (4, "2024-01-13", "Charlie", "charlie@example.com", "Monitor", 1, 300.00),
|
| (5, "2024-01-14", "Bob", "bob@example.com", "Webcam", 1, 50.00),
|
| (6, "2024-01-15", "Diana", "diana@example.com", "USB Hub", 3, 15.99),
|
| ]
|
| conn.executemany("INSERT INTO flat_orders VALUES (?,?,?,?,?,?,?)", rows)
|
| conn.commit()
|
|
|
|
|
| SEED_FNS = {"easy": _seed_easy, "medium": _seed_medium, "hard": _seed_hard}
|
|
|
|
|
|
|
|
|
|
|
| EASY_EXPECTED = 1000.00
|
|
|
|
|
| def grade_easy(conn: sqlite3.Connection, last_output: str) -> float:
|
| """Check if agent returned correct total revenue for Jan 2024."""
|
| if not last_output:
|
| return 0.0
|
|
|
|
|
| try:
|
|
|
| import re
|
| numbers = re.findall(r"[-+]?\d*\.\d+|\d+", last_output)
|
| for num in numbers:
|
| if abs(float(num) - EASY_EXPECTED) < 0.01:
|
| return 1.0
|
| except Exception:
|
| pass
|
| return 0.0
|
|
|
|
|
| def grade_medium(conn: sqlite3.Connection, last_output: str) -> float:
|
| """Check cleaning quality: no duplicates, no nulls, lowercase emails."""
|
| score = 0.0
|
| try:
|
|
|
| cur = conn.execute("SELECT COUNT(*) FROM users")
|
| total = cur.fetchone()[0]
|
| if total == 0:
|
| return 0.0
|
|
|
|
|
| cur = conn.execute("SELECT COUNT(*) FROM users WHERE email != LOWER(email)")
|
| upper_count = cur.fetchone()[0]
|
| if upper_count == 0:
|
| score += 0.3
|
|
|
|
|
| cur = conn.execute(
|
| "SELECT COUNT(*) FROM (SELECT LOWER(email) as e FROM users GROUP BY e HAVING COUNT(*) > 1)"
|
| )
|
| dup_count = cur.fetchone()[0]
|
| if dup_count == 0:
|
| score += 0.4
|
|
|
|
|
| cur = conn.execute("SELECT COUNT(*) FROM users WHERE age IS NULL")
|
| null_count = cur.fetchone()[0]
|
| if null_count == 0:
|
| score += 0.3
|
| except Exception:
|
| pass
|
| return round(score, 2)
|
|
|
|
|
| def grade_hard(conn: sqlite3.Connection, last_output: str) -> float:
|
| """Verify normalized schema and data integrity."""
|
| score = 0.0
|
| try:
|
|
|
| cur = conn.execute("PRAGMA table_info(customers)")
|
| cols = {r[1] for r in cur.fetchall()}
|
| if {"id", "name", "email"}.issubset(cols):
|
| score += 0.2
|
|
|
|
|
| cur = conn.execute("PRAGMA table_info(orders)")
|
| cols = {r[1] for r in cur.fetchall()}
|
| if {"id", "customer_id", "order_date", "product", "quantity", "price"}.issubset(cols):
|
| score += 0.2
|
|
|
|
|
| cur = conn.execute("SELECT COUNT(*) FROM customers")
|
| if cur.fetchone()[0] == 4:
|
| score += 0.2
|
|
|
|
|
| cur = conn.execute("SELECT COUNT(*) FROM orders")
|
| if cur.fetchone()[0] == 6:
|
| score += 0.2
|
|
|
|
|
| cur = conn.execute(
|
| "SELECT COUNT(*) FROM orders WHERE customer_id NOT IN (SELECT id FROM customers)"
|
| )
|
| if cur.fetchone()[0] == 0:
|
| score += 0.2
|
| except Exception:
|
| pass
|
| return round(score, 2)
|
|
|
|
|
| GRADERS = {"easy": grade_easy, "medium": grade_medium, "hard": grade_hard}
|
|
|
|
|
|
|
|
|
|
|
| class SqlSandboxEnvironment(Environment):
|
| """
|
| SQL / Data Cleaning Sandbox a real-world OpenEnv environment.
|
|
|
| The agent sends SQL or Python commands to clean messy databases.
|
| Partial progress rewards are given after each step.
|
| """
|
|
|
| SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
|
|
| def __init__(self):
|
| self._state = State(episode_id=str(uuid4()), step_count=0)
|
| self._db_path = os.path.join(tempfile.gettempdir(), f"sqlsandbox_{uuid4().hex[:8]}.db")
|
| self._conn: sqlite3.Connection | None = None
|
| self._task_id = os.environ.get("TASK_ID", "easy")
|
| self._task = TASKS[self._task_id]
|
| self._max_steps = self._task["max_steps"]
|
| self._done = False
|
| self._last_reward = 0.0
|
|
|
|
|
|
|
| def _get_conn(self) -> sqlite3.Connection:
|
| if self._conn is None:
|
| self._conn = sqlite3.connect(self._db_path)
|
| self._conn.execute("PRAGMA foreign_keys = ON")
|
| return self._conn
|
|
|
| def _partial_reward(self, last_output: str) -> float:
|
| """Run the grader to compute partial progress."""
|
| return GRADERS[self._task_id](self._get_conn(), last_output)
|
|
|
| def _exec_sql(self, query: str) -> tuple[str, str | None]:
|
| try:
|
| conn = self._get_conn()
|
| cur = conn.execute(query)
|
| if cur.description:
|
| cols = [d[0] for d in cur.description]
|
| rows = cur.fetchall()
|
| header = " | ".join(cols)
|
| body = "\n".join(" | ".join(str(c) for c in r) for r in rows)
|
| output = f"{header}\n{body}" if rows else header + "\n(no rows)"
|
| else:
|
| output = f"OK {conn.total_changes} row(s) affected"
|
| conn.commit()
|
| return output, None
|
| except Exception as e:
|
| return "", str(e)
|
|
|
| def _exec_python(self, code: str) -> tuple[str, str | None]:
|
| stdout_buf, stderr_buf = io.StringIO(), io.StringIO()
|
| try:
|
| conn = self._get_conn()
|
| cursor = conn.cursor()
|
| globs = {
|
| "__builtins__": __builtins__,
|
| "sqlite3": sqlite3,
|
| "DB_PATH": self._db_path,
|
| "conn": conn,
|
| "cursor": cursor,
|
| }
|
| with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf):
|
| exec(code, globs)
|
|
|
|
|
| conn.commit()
|
|
|
| out = stdout_buf.getvalue()
|
| err = stderr_buf.getvalue() or None
|
| return out, err
|
| except Exception:
|
| return stdout_buf.getvalue(), traceback.format_exc()
|
|
|
|
|
| def reset(self, **kwargs) -> SqlSandboxObservation:
|
| """Resets the environment and forces a task switch if task_id is provided."""
|
|
|
|
|
| if self._conn:
|
| self._conn.close()
|
| self._conn = None
|
|
|
|
|
|
|
| self._task_id = kwargs.get("task_id", os.environ.get("TASK_ID", "easy"))
|
| self._task = TASKS[self._task_id]
|
| self._max_steps = self._task["max_steps"]
|
|
|
|
|
| self._state = State(episode_id=str(uuid4()), step_count=0)
|
| self._done = False
|
| self._last_reward = 0.0
|
|
|
|
|
|
|
| conn = self._get_conn()
|
| SEED_FNS[self._task_id](conn)
|
|
|
| return SqlSandboxObservation(
|
| output=f"Environment ready. Task: {self._task['description']}",
|
| error=None,
|
| current_step=0,
|
| max_steps=self._max_steps,
|
| task_description=self._task["description"],
|
| done=False,
|
| reward=0.0,
|
| )
|
|
|
| def step(self, action: SqlSandboxAction) -> SqlSandboxObservation:
|
| self._state.step_count += 1
|
| step = self._state.step_count
|
|
|
| if self._done:
|
| return SqlSandboxObservation(
|
| output="Episode already finished. Call reset().",
|
| error=None,
|
| current_step=step,
|
| max_steps=self._max_steps,
|
| task_description=self._task["description"],
|
| done=True,
|
| reward=self._last_reward,
|
| )
|
|
|
|
|
| if action.tool == "sql":
|
| output, error = self._exec_sql(action.command)
|
| else:
|
| output, error = self._exec_python(action.command)
|
|
|
|
|
| reward = self._partial_reward(output)
|
|
|
|
|
| done = step >= self._max_steps or reward >= 1.0
|
| if done:
|
| self._done = True
|
|
|
| self._last_reward = reward
|
|
|
|
|
| if error:
|
| reward = max(0.0, reward - 0.05)
|
|
|
| return SqlSandboxObservation(
|
| output=output[:4000],
|
| error=error[:2000] if error else None,
|
| current_step=step,
|
| max_steps=self._max_steps,
|
| task_description=self._task["description"],
|
| done=done,
|
| reward=round(reward, 4),
|
| )
|
|
|
| @property
|
| def state(self) -> State:
|
| return self._state
|
|
|