abhinavthedev commited on
Commit
aa3a171
·
verified ·
1 Parent(s): 0a519f8

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=sql_debug
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD curl -f http://localhost:8000/health || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ ENV ENABLE_WEB_INTERFACE=true
81
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
README.md CHANGED
@@ -1,10 +1,231 @@
1
- ---
2
- title: Sql Debug
3
- emoji: 🏃
4
- colorFrom: green
5
- colorTo: indigo
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sql Debug Environment Server
3
+ emoji: 🏒
4
+ colorFrom: pink
5
+ colorTo: red
6
+ sdk: docker
7
+ pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
+ - sql
13
+ - debugging
14
+ - optimization
15
+ ---
16
+
17
+ # 🏒 OpenEnv: SQL Debug Environment
18
+
19
+ An [OpenEnv](https://openenv.dev)-compliant environment where AI agents fix broken SQL queries and optimize slow ones against in-memory SQLite databases.
20
+
21
+ > ✅ **Validator:** `openenv validate` passes when the environment is wired up correctly
22
+ > 🚀 **Local API:** `https://abhinavthedev-sql-debug.hf.space`
23
+ > 📖 **Swagger UI:** `https://abhinavthedev-sql-debug.hf.space/docs`
24
+
25
+ ---
26
+
27
+ ## 🎯 Environment Description
28
+
29
+ This environment simulates the work of a SQL engineer who must repair syntax errors, correct logic bugs, and improve query performance. Agents receive a schema, a broken or slow query, and a natural-language target description. They submit SQL queries, observe the execution result and query plan, and are scored on correctness and efficiency.
30
+
31
+ The environment is intentionally practical: each task mirrors a real debugging pattern used in analytics, reporting, and data engineering workflows.
32
+
33
+ ---
34
+
35
+ ## 📋 Tasks
36
+
37
+ ### Task 1 - Syntax Fix *(Easy)*
38
+ **Task ID:** `syntax_fix_001`
39
+
40
+ **Objective:** Fix a malformed query so it returns all orders where `amount > 500`.
41
+
42
+ | Field | Description |
43
+ |---|---|
44
+ | `schema` | `orders` table with `id`, `customer`, `amount`, `order_date` |
45
+ | `broken_query` | `SELEC * FORM orders WERE amount > 500` |
46
+ | `target` | Return all orders where amount is greater than 500 |
47
+
48
+ **Max steps:** 5 | **Difficulty:** Easy
49
+
50
+ ---
51
+
52
+ ### Task 2 - Logic Fix *(Medium)*
53
+ **Task ID:** `logic_fix_001`
54
+
55
+ **Objective:** Correct a join bug so only employees in valid departments are returned.
56
+
57
+ | Field | Description |
58
+ |---|---|
59
+ | `schema` | `employees` and `departments` tables |
60
+ | `broken_query` | Query uses `LEFT JOIN` but should exclude missing departments |
61
+ | `target` | Return employees in departments with budget > 400000 |
62
+
63
+ **Max steps:** 8 | **Difficulty:** Medium
64
+
65
+ ---
66
+
67
+ ### Task 3 - Query Optimization *(Hard)*
68
+ **Task ID:** `optimize_001`
69
+
70
+ **Objective:** Rewrite a correlated subquery into an efficient CTE or grouped subquery.
71
+
72
+ | Field | Description |
73
+ |---|---|
74
+ | `schema` | `transactions` table with generated sample rows |
75
+ | `broken_query` | Correlated subquery that scans per row |
76
+ | `target` | Return completed transactions above the user's average amount |
77
+
78
+ **Max steps:** 10 | **Difficulty:** Hard
79
+
80
+ ---
81
+
82
+ ## 🔌 API Reference
83
+
84
+ ### Base URL
85
+ ```text
86
+ https://abhinavthedev-sql-debug.hf.space
87
+ ```
88
+
89
+ ### Core Endpoints
90
+
91
+ | Method | Endpoint | Description |
92
+ |---|---|---|
93
+ | `POST` | `/reset` | Start a new episode; pass `task_id` to choose a task |
94
+ | `POST` | `/step` | Submit a SQL query and receive the next observation |
95
+ | `GET` | `/state/{session_id}` | Inspect the current episode state |
96
+ | `GET` | `/schema` | View action, observation, and state schemas |
97
+ | `GET` | `/ws` | WebSocket endpoint for low-latency sessions |
98
+ | `GET` | `/health` | Health check |
99
+ | `GET` | `/docs` | Swagger UI |
100
+
101
+ ---
102
+
103
+ ## 🎮 Action Space
104
+
105
+ The agent submits a single SQL query each step.
106
+
107
+ ```json
108
+ {
109
+ "query": "SELECT * FROM orders WHERE amount > 500"
110
+ }
111
+ ```
112
+
113
+ ### Example Actions
114
+
115
+ ```json
116
+ { "query": "SELECT * FROM orders WHERE amount > 500" }
117
+
118
+ { "query": "SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id = d.id WHERE d.budget > 400000" }
119
+
120
+ { "query": "WITH avg_amount AS (SELECT user_id, AVG(amount) AS avg_amount FROM transactions GROUP BY user_id) SELECT t.* FROM transactions t JOIN avg_amount a ON t.user_id = a.user_id WHERE t.status = 'completed' AND t.amount > a.avg_amount" }
121
+ ```
122
+
123
+ ---
124
+
125
+ ## 📊 Observation Space
126
+
127
+ ```json
128
+ {
129
+ "task_id": "syntax_fix_001",
130
+ "schema_sql": "CREATE TABLE orders (...)",
131
+ "current_query": "SELEC * FORM orders WERE amount > 500",
132
+ "error_message": "near \"SELEC\": syntax error",
133
+ "query_result": [],
134
+ "execution_plan": "",
135
+ "step_count": 0,
136
+ "target_description": "Return all orders where amount is greater than 500",
137
+ "reward_so_far": 0.0,
138
+ "available_tasks": ["syntax_fix_001", "logic_fix_001", "optimize_001"],
139
+ "done": false,
140
+ "reward": 0.05
141
+ }
142
+ ```
143
+
144
+ ---
145
+
146
+ ## 💰 Reward Function
147
+
148
+ The reward is computed from syntax validity, result correctness, and query plan quality.
149
+
150
+ | Event | Reward |
151
+ |---|---|
152
+ | Query fails with syntax error | `0.05` |
153
+ | Query runs successfully | contributes to the main score |
154
+ | Correct row match on easy and medium tasks | up to `0.6` of the score |
155
+ | Good query plan on hard task | up to `0.2` of the score |
156
+ | Uses correlated-subquery pattern on hard task | heavy plan penalty |
157
+ | Excessively long query | length penalty |
158
+
159
+ Final scores are clamped to the range `[0.0, 1.0]`.
160
+
161
+ ---
162
+
163
+ ## 🚀 Setup & Usage
164
+
165
+ ### Option 1 - Run Locally
166
+
167
+ ```bash
168
+ pip install -e .
169
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --reload
170
+ # Open https://abhinavthedev-sql-debug.hf.space/docs
171
+ ```
172
+
173
+ ### Option 2 - Run with Docker
174
+
175
+ ```bash
176
+ docker build -t sql-debug-env -f server/Dockerfile .
177
+ docker run -p 8000:8000 sql-debug-env
178
+ curl https://abhinavthedev-sql-debug.hf.space/health
179
+ ```
180
+
181
+ ### Option 3 - Run the Inference Loop
182
+
183
+ ```bash
184
+ export SERVER_URL=https://abhinavthedev-sql-debug.hf.space
185
+ export API_KEY=sk-...
186
+ python inference.py
187
+ ```
188
+
189
+ The inference script defaults to `syntax_fix_001`, logs each step, and stops when the episode ends or the step budget is reached.
190
+
191
+ ---
192
+
193
+ ## 🏗️ Project Structure
194
+
195
+ ```text
196
+ sql_exp/
197
+ ├── client.py # OpenEnv client wrapper
198
+ ├── grader.py # Reward computation
199
+ ├── inference.py # LLM-driven inference loop
200
+ ├── models.py # Action and observation models
201
+ ├── openenv.yaml # OpenEnv manifest
202
+ ├── pyproject.toml # Project metadata and dependencies
203
+ ├── runner.py # SQLite query runner
204
+ ├── server/
205
+ │ ├── app.py # FastAPI app and OpenEnv wiring
206
+ │ ├── Dockerfile # Container definition
207
+ │ └── sql_debug_environment.py # Core environment logic
208
+ ├── tasks/
209
+ │ ├── task_easy.py # Syntax-fix task
210
+ │ ├── task_medium.py # Join logic task
211
+ │ └── task_hard.py # Query optimization task
212
+ ├── test.py # Manual websocket smoke test
213
+ └── README.md # Project overview
214
+ ```
215
+
216
+ ---
217
+
218
+ ## 🛠️ Tech Stack
219
+
220
+ - **Python 3.10+** - Runtime
221
+ - **FastAPI** - HTTP framework
222
+ - **OpenEnv Core** - Environment server and client primitives
223
+ - **SQLite** - Query execution engine
224
+ - **Uvicorn** - ASGI server
225
+ - **Docker** - Containerization
226
+
227
+ ---
228
+
229
+ ## 📝 License
230
+
231
+ BSD-style license, matching the source headers in this repository.
__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Sql Exp Environment."""
8
+
9
+ from .client import SqlExpEnv
10
+ from .models import SqlExpAction, SqlExpObservation
11
+
12
+ __all__ = [
13
+ "SqlExpAction",
14
+ "SqlExpObservation",
15
+ "SqlExpEnv",
16
+ ]
client.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # client.py
8
+ """
9
+ SQL Debug Environment client.
10
+ This is what inference.py uses to talk to the running server.
11
+ """
12
+
13
+ from typing import Dict
14
+
15
+ from openenv.core import EnvClient
16
+ from openenv.core.client_types import StepResult
17
+ from openenv.core.env_server.types import State
18
+
19
+ from models import SQLDebugAction, SQLDebugObservation
20
+
21
+
22
+ class SQLDebugEnv(EnvClient[SQLDebugAction, SQLDebugObservation, State]):
23
+ """
24
+ Client for the SQL Debug & Optimizer environment.
25
+
26
+ Maintains a persistent WebSocket connection to the server.
27
+ Each instance gets its own dedicated environment session.
28
+
29
+ Usage (direct server):
30
+ with SQLDebugEnv(base_url="http://localhost:8000") as env:
31
+ result = env.reset()
32
+ print(result.observation.target_description)
33
+ result = env.step(SQLDebugAction(query="SELECT * FROM orders"))
34
+ print(result.reward)
35
+
36
+ Usage (Docker):
37
+ env = SQLDebugEnv.from_docker_image("sql-debug-env:latest")
38
+ try:
39
+ result = env.reset()
40
+ result = env.step(SQLDebugAction(query="SELECT * FROM orders WHERE amount > 500"))
41
+ finally:
42
+ env.close()
43
+ """
44
+
45
+ def _step_payload(self, action: SQLDebugAction) -> Dict:
46
+ """Convert SQLDebugAction to JSON payload."""
47
+ return {"query": action.query}
48
+
49
+ def _parse_result(self, payload: Dict) -> StepResult[SQLDebugObservation]:
50
+ """Parse server JSON response into a typed StepResult."""
51
+ obs_data = payload.get("observation", {})
52
+
53
+ observation = SQLDebugObservation(
54
+ task_id=obs_data.get("task_id", ""),
55
+ schema_sql=obs_data.get("schema_sql", ""),
56
+ current_query=obs_data.get("current_query", ""),
57
+ error_message=obs_data.get("error_message", ""),
58
+ query_result=obs_data.get("query_result", []),
59
+ execution_plan=obs_data.get("execution_plan", ""),
60
+ step_count=obs_data.get("step_count", 0),
61
+ target_description=obs_data.get("target_description", ""),
62
+ reward_so_far=obs_data.get("reward_so_far", 0.0),
63
+ available_tasks=obs_data.get("available_tasks", []),
64
+ done=payload.get("done", False),
65
+ reward=payload.get("reward", 0.0),
66
+ )
67
+
68
+ return StepResult(
69
+ observation=observation,
70
+ reward=payload.get("reward", 0.0),
71
+ done=payload.get("done", False),
72
+ )
73
+
74
+ def _parse_state(self, payload: Dict) -> State:
75
+ """Parse server JSON response into a State object."""
76
+ return State(
77
+ episode_id=payload.get("episode_id"),
78
+ step_count=payload.get("step_count", 0),
79
+ )
grader.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def compute_reward(task: dict, agent_query: str, run_result: dict) -> dict:
2
+ """
3
+ task = one of TASK dicts from tasks/
4
+ agent_query = the SQL string the agent submitted
5
+ run_result = output from runner.run_query()
6
+
7
+ Returns a dict: { value, syntax_ok, result_match_pct, plan_score, message }
8
+ """
9
+
10
+ # ── Step 1: Did the query even run? ───────────────────────────────────────
11
+ syntax_ok = (run_result["error"] is None)
12
+
13
+ if not syntax_ok:
14
+ # Give tiny credit for trying (not zero, so agent gets gradient signal)
15
+ return {
16
+ "value": 0.05,
17
+ "syntax_ok": False,
18
+ "result_match_pct": 0.0,
19
+ "plan_score": 0.0,
20
+ "message": f"Syntax error: {run_result['error'][:100]}",
21
+ }
22
+
23
+ # ── Step 2: Did we get the right rows? ────────────────────────────────────
24
+ result_match_pct = 0.0
25
+
26
+ if task["expected_rows"] is not None:
27
+ expected = task["expected_rows"]
28
+ got = run_result["rows"]
29
+
30
+ # Count how many expected rows are present in the result
31
+ matches = sum(1 for row in expected if row in got)
32
+ result_match_pct = matches / max(len(expected), 1)
33
+
34
+ # Penalize extra rows (returned too many rows = wrong query)
35
+ if len(got) > len(expected) * 2:
36
+ result_match_pct *= 0.7 # 30% penalty for bloated results
37
+
38
+ else:
39
+ # Hard task: no fixed rows — give full match credit if query runs
40
+ result_match_pct = 1.0
41
+
42
+ # ── Step 3: Is the query plan good? (hard task only) ─────────────────────
43
+ plan_score = 0.0
44
+
45
+ if task.get("check_plan"):
46
+ query_upper = agent_query.upper()
47
+ good_patterns = task.get("good_patterns", [])
48
+
49
+ # Each good pattern found = partial credit
50
+ found = sum(1 for p in good_patterns if p.upper() in query_upper)
51
+ plan_score = found / max(len(good_patterns), 1)
52
+
53
+ # Also penalize if they still use correlated subquery pattern
54
+ if "WHERE" in query_upper and "SELECT AVG" in query_upper:
55
+ plan_score *= 0.3 # Heavy penalty — they didn't really optimize
56
+
57
+ # ── Step 4: Combine into final score ──────────────────────────────────────
58
+ # Weights: syntax 20% + correctness 60% + plan 20%
59
+ base_score = 0.2 + (0.6 * result_match_pct) + (0.2 * plan_score)
60
+
61
+ # Penalize absurdly long queries (e.g. agent spams SELECT *)
62
+ length_penalty = max(0.0, (len(agent_query) - 800) / 2000)
63
+ final = max(0.0, min(1.0, base_score - length_penalty))
64
+
65
+ status = "perfect" if final >= 0.99 else "partial" if final > 0.2 else "wrong"
66
+ msg = f"{status} | rows matched: {result_match_pct:.0%} | plan: {plan_score:.0%}"
67
+
68
+ return {
69
+ "value": round(final, 3),
70
+ "syntax_ok": True,
71
+ "result_match_pct": result_match_pct,
72
+ "plan_score": plan_score,
73
+ "message": msg,
74
+ }
inference.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ """
3
+ SQL Debug & Optimizer — OpenEnv Inference Script
4
+
5
+ Mandatory stdout format:
6
+ [START] task=<task_name> env=<benchmark> model=<model_name>
7
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
8
+ [END] success=<true|false> steps=<n> score=<0.00> rewards=<r1,r2,...>
9
+ """
10
+
11
+ import asyncio
12
+ import os
13
+ import textwrap
14
+ from typing import List, Optional
15
+
16
+ from openai import OpenAI
17
+ from client import SQLDebugEnv, SQLDebugAction
18
+
19
+ # ── Mandatory env vars (injected by evaluator on submission) ──────────────────
20
+ IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
21
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
22
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
23
+ MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
24
+
25
+ # ── Task + run config ─────────────────────────────────────────────────────────
26
+ TASK_NAME = os.getenv("SQL_ENV_TASK", "syntax_fix_001")
27
+ BENCHMARK = "sql-debug-optimizer"
28
+ MAX_STEPS = 8 # well under 20 min limit; each step is ~2s
29
+ TEMPERATURE = 0.0 # deterministic = reproducible scores
30
+ MAX_TOKENS = 400
31
+ SUCCESS_THRESHOLD = 0.5 # reward >= 0.5 = success
32
+
33
+
34
+ # ── Mandatory stdout loggers — DO NOT change field names or order ─────────────
35
+
36
+ def log_start(task: str, env: str, model: str) -> None:
37
+ print(f"[START] task={task} env={env} model={model}", flush=True)
38
+
39
+
40
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
41
+ # action must be single-line — newlines break log parsing
42
+ action_clean = action.replace("\n", " ").replace("\r", "").strip()
43
+ error_val = error if error else "null"
44
+ done_val = str(done).lower()
45
+ print(
46
+ f"[STEP] step={step} action={action_clean} reward={reward:.2f} "
47
+ f"done={done_val} error={error_val}",
48
+ flush=True,
49
+ )
50
+
51
+
52
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
53
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
54
+ print(
55
+ f"[END] success={str(success).lower()} steps={steps} "
56
+ f"score={score:.2f} rewards={rewards_str}",
57
+ flush=True,
58
+ )
59
+
60
+
61
+ # ── Prompt design ─────────────────────────────────────────────────────────────
62
+
63
+ SYSTEM_PROMPT = textwrap.dedent("""
64
+ You are an expert SQL engineer helping debug and optimize SQL queries.
65
+
66
+ Rules (follow exactly):
67
+ - Respond with ONLY the corrected SQL query.
68
+ - No markdown, no code fences (no ```sql), no explanation.
69
+ - No comments inside the SQL.
70
+ - If the query has a syntax error, fix it first.
71
+ - If the query has a logic bug (wrong JOIN, wrong WHERE), fix the logic.
72
+ - If asked to optimize, replace correlated subqueries with CTEs using WITH.
73
+ - Output raw SQL only — it will be executed directly.
74
+ """).strip()
75
+
76
+
77
+ def build_prompt(obs) -> str:
78
+ """Build the user prompt from the current observation."""
79
+ result_preview = str(obs.query_result[:3]) if obs.query_result else "empty / error"
80
+ return textwrap.dedent(f"""
81
+ TASK: {obs.target_description}
82
+
83
+ DATABASE SCHEMA:
84
+ {obs.schema_sql.strip()[:800]}
85
+
86
+ CURRENT QUERY (this is broken or slow — fix it):
87
+ {obs.current_query.strip()}
88
+
89
+ ERROR: {obs.error_message or "none"}
90
+ CURRENT RESULT (first 3 rows): {result_preview}
91
+ STEP: {obs.step_count + 1} of {MAX_STEPS}
92
+
93
+ Write the corrected SQL query:
94
+ """).strip()
95
+
96
+
97
+ def call_llm(client: OpenAI, obs) -> str:
98
+ """Ask the LLM for a better SQL query. Returns clean SQL string."""
99
+ try:
100
+ completion = client.chat.completions.create(
101
+ model=MODEL_NAME,
102
+ messages=[
103
+ {"role": "system", "content": SYSTEM_PROMPT},
104
+ {"role": "user", "content": build_prompt(obs)},
105
+ ],
106
+ temperature=TEMPERATURE,
107
+ max_tokens=MAX_TOKENS,
108
+ stream=False,
109
+ )
110
+ raw = (completion.choices[0].message.content or "").strip()
111
+
112
+ # Strip markdown code fences if model adds them despite instructions
113
+ if "```" in raw:
114
+ lines = raw.split("\n")
115
+ raw = "\n".join(
116
+ line for line in lines if not line.strip().startswith("```")
117
+ ).strip()
118
+
119
+ return raw if raw else "SELECT 1"
120
+
121
+ except Exception as exc:
122
+ print(f"[DEBUG] LLM call failed: {exc}", flush=True)
123
+ return "SELECT 1"
124
+
125
+
126
+ # ── Main loop ─────────────────────────────────────────────────────────────────
127
+
128
+ async def main() -> None:
129
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
130
+
131
+ # Connect to the environment (Docker or local server)
132
+ SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
133
+ env = SQLDebugEnv(base_url=SERVER_URL)
134
+
135
+ rewards: List[float] = []
136
+ steps_taken = 0
137
+ score = 0.0
138
+ success = False
139
+
140
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
141
+
142
+ try:
143
+ # Reset — get the broken query and task info
144
+ result = await env.reset(task_id=TASK_NAME)
145
+ obs = result.observation
146
+
147
+ for step in range(1, MAX_STEPS + 1):
148
+ if result.done:
149
+ break
150
+
151
+ # Ask LLM for a better query
152
+ sql_query = call_llm(client, obs)
153
+
154
+ # Submit to environment
155
+ result = await env.step(SQLDebugAction(query=sql_query))
156
+ obs = result.observation
157
+
158
+ reward = result.reward or 0.0
159
+ done = result.done
160
+ error = obs.error_message if obs.error_message else None
161
+
162
+ rewards.append(reward)
163
+ steps_taken = step
164
+
165
+ log_step(
166
+ step=step,
167
+ action=sql_query,
168
+ reward=reward,
169
+ done=done,
170
+ error=error,
171
+ )
172
+
173
+ if done:
174
+ break
175
+
176
+ # Score = best reward achieved (already 0.0–1.0 from grader)
177
+ score = max(rewards) if rewards else 0.0
178
+ score = min(max(score, 0.0), 1.0)
179
+ success = score >= SUCCESS_THRESHOLD
180
+
181
+ except Exception as exc:
182
+ print(f"[DEBUG] Episode error: {exc}", flush=True)
183
+
184
+ finally:
185
+ try:
186
+ await env.close()
187
+ except Exception as e:
188
+ print(f"[DEBUG] env.close() error: {e}", flush=True)
189
+
190
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
191
+
192
+
193
+ if __name__ == "__main__":
194
+ asyncio.run(main())
models.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for the SQL Debug & Optimizer Environment.
9
+ """
10
+
11
+ from typing import Any, Dict, List
12
+ from pydantic import Field
13
+ from openenv.core.env_server.types import Action, Observation
14
+
15
+
16
+ class SQLDebugAction(Action):
17
+ """
18
+ What the agent submits each step — just a SQL query string.
19
+ The environment will run it, grade it, and return a new observation.
20
+ """
21
+ query: str = Field(..., description="The SQL query the agent wants to try")
22
+
23
+
24
+ class SQLDebugObservation(Observation):
25
+ """
26
+ What the agent sees after each step.
27
+ Contains everything it needs to improve its next query.
28
+ """
29
+ task_id: str = Field(default="", description="Which task is active")
30
+ schema_sql: str = Field(default="", description="CREATE TABLE statements for this task")
31
+ current_query: str = Field(default="", description="Last query that was run")
32
+ error_message: str = Field(default="", description="SQLite error if query failed, else empty string")
33
+ query_result: List[Dict[str, Any]] = Field(default_factory=list, description="First 10 rows returned")
34
+ execution_plan: str = Field(default="", description="EXPLAIN QUERY PLAN output")
35
+ step_count: int = Field(default=0, description="How many steps taken so far")
36
+ target_description: str = Field(default="", description="Plain English goal for this task")
37
+ reward_so_far: float = Field(default=0.0, description="Best reward achieved this episode")
38
+ available_tasks: List[str] = Field(default_factory=list, description="All task IDs you can reset to")
openenv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: sql_debug
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
openenv_sql_debug.egg-info/PKG-INFO ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-sql_debug
3
+ Version: 0.1.0
4
+ Summary: Sql Debug environment for OpenEnv
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.2
7
+ Requires-Dist: openai>=2.30.0
8
+ Requires-Dist: uvicorn>=0.43.0
9
+ Provides-Extra: dev
10
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
11
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_sql_debug.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ ./__init__.py
4
+ ./client.py
5
+ ./grader.py
6
+ ./inference.py
7
+ ./models.py
8
+ ./runner.py
9
+ ./test.py
10
+ openenv_sql_debug.egg-info/PKG-INFO
11
+ openenv_sql_debug.egg-info/SOURCES.txt
12
+ openenv_sql_debug.egg-info/dependency_links.txt
13
+ openenv_sql_debug.egg-info/entry_points.txt
14
+ openenv_sql_debug.egg-info/requires.txt
15
+ openenv_sql_debug.egg-info/top_level.txt
16
+ server/__init__.py
17
+ server/app.py
18
+ server/sql_debug_environment.py
openenv_sql_debug.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_sql_debug.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = sql_debug.server.app:main
openenv_sql_debug.egg-info/requires.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.2
2
+ openai>=2.30.0
3
+ uvicorn>=0.43.0
4
+
5
+ [dev]
6
+ pytest>=8.0.0
7
+ pytest-cov>=4.0.0
openenv_sql_debug.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ sql_debug
pyproject.toml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-sql_debug"
13
+ version = "0.1.0"
14
+ description = "Sql Debug environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.2",
21
+ "openai>=2.30.0",
22
+ "uvicorn>=0.43.0",
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ dev = [
27
+ "pytest>=8.0.0",
28
+ "pytest-cov>=4.0.0",
29
+ ]
30
+
31
+ [project.scripts]
32
+ # Server entry point - enables running via: uv run --project . server
33
+ # or: python -m sql_debug.server.app
34
+ server = "sql_debug.server.app:main"
35
+
36
+ [tool.setuptools]
37
+ include-package-data = true
38
+ packages = ["sql_debug", "sql_debug.server"]
39
+ package-dir = { "sql_debug" = ".", "sql_debug.server" = "server" }
runner.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+
3
+ def run_query(schema_sql: str, query: str) -> dict:
4
+ """
5
+ Runs query against an in-memory SQLite DB seeded with schema_sql.
6
+ Returns: { "rows": [...], "error": str|None, "plan": str }
7
+ """
8
+ conn = sqlite3.connect(":memory:")
9
+ conn.row_factory = sqlite3.Row
10
+ try:
11
+ conn.executescript(schema_sql)
12
+ plan_rows = conn.execute(f"EXPLAIN QUERY PLAN {query}").fetchall()
13
+ plan = " | ".join(str(dict(r)) for r in plan_rows)
14
+ result_rows = [dict(r) for r in conn.execute(query).fetchall()]
15
+ return {"rows": result_rows, "error": None, "plan": plan}
16
+ except Exception as e:
17
+ return {"rows": [], "error": str(e), "plan": ""}
18
+ finally:
19
+ conn.close()
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Sql Exp environment server components."""
8
+
9
+ from .sql_debug_environment import SQLDebugEnvironment
10
+
11
+ __all__ = ["SQLDebugEnvironment"]
server/app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI server for the SQL Debug & Optimizer Environment.
9
+
10
+ Exposes the environment over HTTP + WebSocket so inference.py
11
+ (and the OpenEnv evaluator) can interact with it remotely.
12
+
13
+ Endpoints created automatically by openenv:
14
+ POST /reset — start new episode (optionally pass task_id in body)
15
+ POST /step — submit an action, get observation + reward
16
+ GET /state — current episode state
17
+ GET /schema — action/observation JSON schemas
18
+ WS /ws — WebSocket for persistent low-latency sessions
19
+
20
+ Run locally:
21
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --reload
22
+
23
+ Or via Docker (defined in Dockerfile):
24
+ docker build -t sql-debug-env .
25
+ docker run -p 8000:8000 sql-debug-env
26
+ """
27
+
28
+ try:
29
+ from openenv.core.env_server.http_server import create_app
30
+ except Exception as e:
31
+ raise ImportError(
32
+ "openenv-core is required. Install with: pip install openenv-core"
33
+ ) from e
34
+
35
+ try:
36
+ from models import SQLDebugAction, SQLDebugObservation
37
+ from .sql_debug_environment import SQLDebugEnvironment
38
+ except ModuleNotFoundError:
39
+ from models import SQLDebugAction, SQLDebugObservation
40
+ from sql_exp.server.sql_debug_environment import SQLDebugEnvironment
41
+
42
+
43
+ app = create_app(
44
+ SQLDebugEnvironment,
45
+ SQLDebugAction,
46
+ SQLDebugObservation,
47
+ env_name="sql_debug_optimizer",
48
+ max_concurrent_envs=4, # one per task running in parallel
49
+ )
50
+
51
+
52
+ def main(host: str = "0.0.0.0", port: int = 8000):
53
+ import uvicorn
54
+ uvicorn.run(app, host=host, port=port)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
server/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
server/sql_debug_environment.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ SQL Debug & Optimizer Environment — server-side implementation.
9
+
10
+ The server runs this. The agent never touches this file directly.
11
+ It loads tasks, runs queries in SQLite, grades them, and returns observations.
12
+ """
13
+
14
+ from uuid import uuid4
15
+ from openenv.core.env_server.interfaces import Environment
16
+ from openenv.core.env_server.types import State
17
+
18
+ try:
19
+ from ..models import SQLDebugAction, SQLDebugObservation
20
+ except ImportError:
21
+ from models import SQLDebugAction, SQLDebugObservation
22
+
23
+ from runner import run_query
24
+ from grader import compute_reward
25
+
26
+
27
+ def _load_all_tasks() -> dict:
28
+ """Load every task from the tasks/ folder into a dict keyed by task_id."""
29
+ from tasks.task_easy import TASK as EASY
30
+ from tasks.task_medium import TASK as MEDIUM
31
+ from tasks.task_hard import TASK as HARD
32
+ return {
33
+ EASY["task_id"]: EASY,
34
+ MEDIUM["task_id"]: MEDIUM,
35
+ HARD["task_id"]: HARD,
36
+ }
37
+
38
+
39
+ class SQLDebugEnvironment(Environment):
40
+ """
41
+ SQL Debug & Optimizer environment.
42
+
43
+ The agent receives a broken or slow SQL query and must fix/optimize it.
44
+ Each step the agent submits a new query — the environment runs it in
45
+ SQLite, grades it (0.0–1.0), and returns the result as an observation.
46
+
47
+ Three tasks:
48
+ syntax_fix_001 (easy) — fix typos in SQL keywords
49
+ logic_fix_001 (medium) — fix wrong JOIN type causing bad results
50
+ # optimize_001 (hard) — rewrite correlated subquery as a CTE
51
+ """
52
+
53
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
54
+
55
+ def __init__(self):
56
+ self._all_tasks = _load_all_tasks()
57
+ self._current_task = None
58
+ self._state = State(episode_id=str(uuid4()), step_count=0)
59
+ self._best_reward = 0.0
60
+ self._current_query = ""
61
+
62
+ # ── reset ────────────────────────────────────────────────────────────────
63
+
64
+ def reset(self, task_id: str = None) -> SQLDebugObservation:
65
+ """
66
+ Start a new episode.
67
+ Pass task_id to pick a specific task, or leave None for the default (easy).
68
+ """
69
+ if task_id is None:
70
+ task_id = list(self._all_tasks.keys())[0] # default: easy
71
+
72
+ if task_id not in self._all_tasks:
73
+ # Unknown task — return error observation instead of crashing
74
+ return SQLDebugObservation(
75
+ task_id=task_id,
76
+ error_message=f"Unknown task_id '{task_id}'. Available: {list(self._all_tasks.keys())}",
77
+ available_tasks=list(self._all_tasks.keys()),
78
+ )
79
+
80
+ self._current_task = self._all_tasks[task_id]
81
+ self._state = State(episode_id=str(uuid4()), step_count=0)
82
+ self._best_reward = 0.0
83
+ self._current_query = self._current_task["broken_query"]
84
+
85
+ # Run the broken query so the agent sees the starting error
86
+ run_result = run_query(
87
+ self._current_task["schema_sql"],
88
+ self._current_query,
89
+ )
90
+
91
+ return SQLDebugObservation(
92
+ task_id=task_id,
93
+ schema_sql=self._current_task["schema_sql"],
94
+ current_query=self._current_query,
95
+ error_message=run_result["error"] or "",
96
+ query_result=run_result["rows"][:10],
97
+ execution_plan=run_result["plan"],
98
+ step_count=0,
99
+ target_description=self._current_task["target_description"],
100
+ reward_so_far=0.0,
101
+ available_tasks=list(self._all_tasks.keys()),
102
+ done=False,
103
+ reward=0.0,
104
+ )
105
+
106
+ # ── step ─────────────────────────────────────────────────────────────────
107
+
108
+ def step(self, action: SQLDebugAction) -> SQLDebugObservation:
109
+ """
110
+ Agent submits a query.
111
+ We run it, grade it, and return the new observation + reward.
112
+ """
113
+ if self._current_task is None:
114
+ return SQLDebugObservation(
115
+ error_message="Call reset() before step()",
116
+ available_tasks=list(self._all_tasks.keys()),
117
+ done=True,
118
+ reward=0.0,
119
+ )
120
+
121
+ self._state.step_count += 1
122
+ self._current_query = action.query
123
+
124
+ # Run the query in SQLite
125
+ run_result = run_query(
126
+ self._current_task["schema_sql"],
127
+ action.query,
128
+ )
129
+
130
+ # Grade it (returns dict with value, syntax_ok, result_match_pct, etc.)
131
+ reward_dict = compute_reward(self._current_task, action.query, run_result)
132
+ reward_value = reward_dict["value"]
133
+
134
+ # Track the best reward this episode
135
+ self._best_reward = max(self._best_reward, reward_value)
136
+
137
+ # Episode ends on perfect score or max steps
138
+ max_steps = self._current_task.get("max_steps", 8)
139
+ done = (reward_value >= 0.99) or (self._state.step_count >= max_steps)
140
+
141
+ return SQLDebugObservation(
142
+ task_id=self._current_task["task_id"],
143
+ schema_sql=self._current_task["schema_sql"],
144
+ current_query=action.query,
145
+ error_message=run_result["error"] or "",
146
+ query_result=run_result["rows"][:10],
147
+ execution_plan=run_result["plan"],
148
+ step_count=self._state.step_count,
149
+ target_description=self._current_task["target_description"],
150
+ reward_so_far=self._best_reward,
151
+ available_tasks=list(self._all_tasks.keys()),
152
+ done=done,
153
+ reward=reward_value,
154
+ )
155
+
156
+ # ── state ─────────────────────────────────────────────────────────────────
157
+
158
+ @property
159
+ def state(self) -> State:
160
+ return self._state
tasks/__init__.py ADDED
File without changes
tasks/task_easy.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK = {
2
+ "task_id": "syntax_fix_001",
3
+ "difficulty": "easy",
4
+ "max_steps": 5,
5
+
6
+ # This creates the database the agent works with
7
+ "schema_sql": """
8
+ CREATE TABLE orders (
9
+ id INTEGER, customer TEXT, amount REAL, order_date TEXT
10
+ );
11
+ INSERT INTO orders VALUES (1, 'Alice', 520.0, '2024-01-15');
12
+ INSERT INTO orders VALUES (2, 'Bob', 90.0, '2024-01-16');
13
+ INSERT INTO orders VALUES (3, 'Carol', 800.0, '2024-01-17');
14
+ INSERT INTO orders VALUES (4, 'Dan', 150.0, '2024-01-18');
15
+ """,
16
+
17
+ # This is the broken query the agent must fix
18
+ "broken_query": "SELEC * FORM orders WERE amount > 500",
19
+
20
+ # Plain English: what should the fixed query do?
21
+ "target_description": "Return all orders where amount is greater than 500",
22
+
23
+ # What the correct answer looks like — used by grader to check
24
+ "expected_rows": [
25
+ {"id": 1, "customer": "Alice", "amount": 520.0, "order_date": "2024-01-15"},
26
+ {"id": 3, "customer": "Carol", "amount": 800.0, "order_date": "2024-01-17"},
27
+ ],
28
+
29
+ # For easy task, plan quality doesn't matter
30
+ "check_plan": False,
31
+ }
tasks/task_hard.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tasks/task_hard.py
2
+ import random
3
+
4
+ def generate_schema(n_rows=5000, seed=42):
5
+ """Generates schema + INSERT statements for n_rows transactions."""
6
+ rng = random.Random(seed)
7
+ statuses = ['completed', 'pending', 'failed']
8
+ inserts = []
9
+ for i in range(1, n_rows + 1):
10
+ user_id = rng.randint(1, 100)
11
+ amount = round(rng.uniform(10, 1000), 2)
12
+ status = rng.choice(statuses)
13
+ inserts.append(f"INSERT INTO transactions VALUES ({i}, {user_id}, {amount}, 'completed');")
14
+ return (
15
+ "CREATE TABLE transactions (id INTEGER, user_id INTEGER, amount REAL, ts TEXT, status TEXT);\n"
16
+ + "\n".join(inserts[:200]) # Keep it fast for demo (200 rows)
17
+ )
18
+
19
+ TASK = {
20
+ "task_id": "optimize_001",
21
+ "difficulty": "hard",
22
+ "max_steps": 10,
23
+
24
+ "schema_sql": generate_schema(200), # Use 200 rows for speed in hackathon
25
+
26
+ # Slow: correlated subquery — runs inner SELECT once per outer row
27
+ "broken_query": """
28
+ SELECT *
29
+ FROM transactions t1
30
+ WHERE amount > (
31
+ SELECT AVG(amount)
32
+ FROM transactions t2
33
+ WHERE t2.user_id = t1.user_id
34
+ )
35
+ AND t1.status = 'completed'
36
+ """,
37
+
38
+ "target_description": (
39
+ "Return all completed transactions where the amount exceeds that user's average. "
40
+ "Optimize it — avoid correlated subqueries. Use a CTE or subquery with GROUP BY."
41
+ ),
42
+
43
+ # For hard task we grade differently — no fixed expected_rows
44
+ "expected_rows": None,
45
+
46
+ # We check that the query plan is efficient (no per-row correlated scans)
47
+ "check_plan": True,
48
+
49
+ # Keywords we look for in the agent's solution
50
+ "good_patterns": ["WITH", "GROUP BY", "AVG("],
51
+ }
tasks/task_medium.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK = {
2
+ "task_id": "logic_fix_001",
3
+ "difficulty": "medium",
4
+ "max_steps": 8,
5
+
6
+ "schema_sql": """
7
+ CREATE TABLE employees (id INTEGER, name TEXT, dept_id INTEGER, salary REAL);
8
+ CREATE TABLE departments (id INTEGER, dept_name TEXT, budget REAL);
9
+
10
+ INSERT INTO departments VALUES (1, 'Engineering', 500000);
11
+ INSERT INTO departments VALUES (2, 'Sales', 300000);
12
+
13
+ INSERT INTO employees VALUES (1, 'Alice', 1, 95000);
14
+ INSERT INTO employees VALUES (2, 'Bob', 2, 60000);
15
+ INSERT INTO employees VALUES (3, 'Carol', 1, 85000);
16
+ INSERT INTO employees VALUES (4, 'Dan', 99, 55000); -- dept 99 doesn't exist!
17
+ """,
18
+
19
+ # Bug: LEFT JOIN means Dan (no dept) appears in results. Should be INNER JOIN.
20
+ "broken_query": """
21
+ SELECT e.name, d.dept_name
22
+ FROM employees e
23
+ LEFT JOIN departments d ON e.dept_id = d.id
24
+ WHERE d.budget > 400000
25
+ """,
26
+
27
+ "target_description": (
28
+ "Return names of employees in departments with budget > 400000. "
29
+ "Do NOT include employees whose department doesn't exist."
30
+ ),
31
+
32
+ "expected_rows": [
33
+ {"name": "Alice", "dept_name": "Engineering"},
34
+ {"name": "Carol", "dept_name": "Engineering"},
35
+ ],
36
+
37
+ "check_plan": False,
38
+ }
test.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_websocket.py
2
+ from client import SQLDebugEnv
3
+
4
+ def test():
5
+ # Use WebSocket URL
6
+ env = SQLDebugEnv(base_url="ws://localhost:8000")
7
+
8
+ try:
9
+ for task_id in ["syntax_fix_002", "logic_fix_002", "optimize_002", "pipeline_audit_001"]:
10
+ print(f"\n{'='*60}")
11
+ print(f"Testing: {task_id}")
12
+
13
+ # Connect and reset
14
+ result = env.reset(task_id=task_id)
15
+ obs = result.observation
16
+
17
+ print(f"✓ task_id: {obs.task_id}")
18
+ print(f"✓ description: {obs.target_description[:50]}...")
19
+ print(f"✓ query: {obs.current_query[:60]}...")
20
+
21
+ # Try one step
22
+ from models import SQLDebugAction
23
+ result = env.step(SQLDebugAction(query="SELECT 1"))
24
+ print(f"✓ step reward: {result.reward}")
25
+
26
+ finally:
27
+ env.close()
28
+
29
+ test()
uv.lock ADDED
The diff for this file is too large to render. See raw diff