md896 commited on
Commit
30cf758
·
1 Parent(s): 825ffea

Initial OpenEnv SQL debug environment

Browse files
.dockerignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .pytest_cache/
4
+ .mypy_cache/
5
+ .ruff_cache/
6
+ .DS_Store
7
+ .git/
8
+ .gitignore
9
+ .env
10
+ .env.*
11
+ !.env.example
12
+ .venv/
13
+ .cursor/
14
+
.env.example ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ OPENAI_API_KEY=
2
+ HF_TOKEN=
3
+ API_BASE_URL=https://api.openai.com/v1
4
+ MODEL_NAME=gpt-4o-mini
5
+ ENV_BASE_URL=http://localhost:7860
6
+
.gitignore ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .pytest_cache/
4
+ .mypy_cache/
5
+ .ruff_cache/
6
+ .DS_Store
7
+
8
+ # local env / secrets
9
+ .env
10
+ .env.*
11
+ !.env.example
12
+
13
+ # OpenEnv / uv
14
+ .venv/
15
+ .python-version
16
+
17
+ # editor metadata
18
+ .cursor/
19
+
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ curl \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy requirements first for layer caching
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy application code
15
+ COPY server/ ./server/
16
+ COPY openenv.yaml .
17
+
18
+ # Create non-root user for security
19
+ RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
20
+ USER appuser
21
+
22
+ # Expose port
23
+ EXPOSE 7860
24
+
25
+ # Health check
26
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
27
+ CMD curl -f http://localhost:7860/health || exit 1
28
+
29
+ # Start server
30
+ CMD ["uvicorn", "server.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
31
+
README.md CHANGED
@@ -1,10 +1,193 @@
1
- ---
2
- title: Sql Debug Env
3
- emoji: 💻
4
- colorFrom: indigo
5
- colorTo: gray
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SQL Debug Environment (`sql-debug-env`)
2
+
3
+ ![Python](https://img.shields.io/badge/Python-3.11+-3776AB?logo=python&logoColor=white)
4
+ ![FastAPI](https://img.shields.io/badge/FastAPI-0.115-009688?logo=fastapi&logoColor=white)
5
+ ![Pydantic](https://img.shields.io/badge/Pydantic-v2-E92063?logo=pydantic&logoColor=white)
6
+ ![SQLite](https://img.shields.io/badge/SQLite-In_Memory-003B57?logo=sqlite&logoColor=white)
7
+ ![Docker](https://img.shields.io/badge/Docker-Ready-2496ED?logo=docker&logoColor=white)
8
+ ![OpenEnv](https://img.shields.io/badge/OpenEnv-Validated-2ea44f)
9
+
10
+ An OpenEnv environment for a real task people do every day: **debugging SQL**. The agent gets a broken query, a live (in-memory) SQLite database, and a description of the expected output. It can inspect schema/errors/samples and submit fixed queries until it solves the task.
11
+
12
+ ## What’s in this repo
13
+ - **FastAPI server**: `server/main.py` (endpoints: `/health`, `/tasks`, `/reset`, `/step`, `/state`)
14
+ - **Environment logic**: `server/env.py` + `server/database.py`
15
+ - **Tasks**: `server/tasks/` (easy → medium → hard, deterministic seed data)
16
+ - **Baseline agent**: `inference.py` (OpenAI client + `[START]/[STEP]/[END]` logs)
17
+
18
+ ## Tech Stack
19
+ - Python 3.11+
20
+ - FastAPI + Uvicorn
21
+ - Pydantic v2
22
+ - SQLite (in-memory)
23
+ - OpenEnv Core
24
+ - Docker
25
+ - OpenAI Python SDK (baseline inference)
26
+
27
+ ## Production Notes
28
+ - Stateless HTTP API with per-session environment instances keyed by `X-Session-Id`
29
+ - Deterministic task data (in-memory SQLite) for reproducible grading
30
+ - Reward clamped to `[0.0, 1.0]` with partial-progress shaping
31
+ - Docker-first deployment path (local and Hugging Face Spaces)
32
+ - Local benchmark endpoint for live latency checks (`/benchmark`)
33
+
34
+ ## API Docs (FastAPI Auto Docs)
35
+ Use these for interactive testing in browser:
36
+
37
+ - Swagger UI: `http://localhost:7860/docs`
38
+ - ReDoc: `http://localhost:7860/redoc`
39
+ - OpenAPI spec: `http://localhost:7860/openapi.json`
40
+
41
+ ## Action Space
42
+ | Action | Required fields | Cost / reward effect |
43
+ |---|---|---|
44
+ | `submit_query` | `query` | Main evaluation step (dense reward based on grading) |
45
+ | `inspect_schema` | none | Free information action (small positive reward component) |
46
+ | `inspect_error` | none | Free information action (small positive reward component) |
47
+ | `inspect_sample` | `table_name` | Free information action (small positive reward component) |
48
+ | `reset_query` | none | Penalty action (reduces reward for that step) |
49
+
50
+ ## Observation Space
51
+ | Field | Type |
52
+ |---|---|
53
+ | `task_id` | `string` |
54
+ | `task_description` | `string` |
55
+ | `original_query` | `string` |
56
+ | `current_query` | `string_or_null` |
57
+ | `expected_description` | `string` |
58
+ | `last_action_type` | `string` |
59
+ | `last_query_result` | `object_or_null` |
60
+ | `steps_taken` | `integer` |
61
+ | `steps_remaining` | `integer` |
62
+ | `current_score` | `float` |
63
+ | `schema_info` | `object_or_null` |
64
+ | `error_details` | `string_or_null` |
65
+ | `sample_rows` | `array_or_null` |
66
+ | `hint` | `string_or_null` |
67
+ | `is_done` | `boolean` |
68
+ | `success` | `boolean` |
69
+
70
+ ## Reward Function
71
+ | Component | Range | Description |
72
+ |---|---|---|
73
+ | `correctness` | `[0.0, 0.6]` | Row-level match vs expected output |
74
+ | `efficiency` | `[0.0, 0.2]` | Bonus for solving with fewer steps |
75
+ | `syntax_progress` | `[0.0, 0.1]` | Small reward for producing syntactically valid SQL |
76
+ | `schema_bonus` | `[0.0, 0.1]` | Bonus for referencing correct tables/columns |
77
+ | `penalty` | `[0.0, 0.2]` | Deduction magnitude for resets/regressions/urgency near step limit |
78
+
79
+ ## Tasks
80
+ ### Task 1: Easy — Syntax Error Fix (`easy_syntax_fix`)
81
+ Two straightforward issues: a misspelled keyword (`GRUP BY`) and an `ORDER BY` alias mismatch.
82
+
83
+ ### Task 2: Medium — Logic Error Fix (`medium_logic_fix`)
84
+ Logic bugs around outer joins + filtering scope + aggregation scope.
85
+
86
+ ### Task 3: Hard — Multi-Bug Fix (`hard_multi_bug`)
87
+ Five bugs across correlated subqueries, window functions, CTE scope, date logic, and duplication.
88
+
89
+ ## Baseline
90
+ The baseline script is intentionally simple: it loops `reset → step` and asks an OpenAI model to choose the next JSON action.
91
+
92
+ ## Reliability & Benchmarking
93
+
94
+ ### Verified status (local)
95
+ - `openenv validate --verbose`: **PASS**
96
+ - `python3 -m unittest discover -s tests -p "test_*.py"`: **10/10 PASS**
97
+ - Docker smoke test: **PASS** (`/health`, `/tasks`, `/reset`, `/step`)
98
+ - FastAPI docs available: **PASS** (`/docs`, `/redoc`, `/openapi.json`)
99
+
100
+ ### Endpoint benchmark (local Docker run, n=25)
101
+ Measured with `scripts/benchmark_local.py` on a running local container:
102
+
103
+ | Endpoint | avg | p50 | p95 |
104
+ |---|---:|---:|---:|
105
+ | `GET /health` | 0.69 ms | 0.67 ms | 0.76 ms |
106
+ | `GET /tasks` | 0.82 ms | 0.81 ms | 0.90 ms |
107
+ | `POST /reset` | 1.34 ms | 1.26 ms | 1.62 ms |
108
+ | `POST /step` (`inspect_schema`) | 1.07 ms | 1.01 ms | 1.34 ms |
109
+
110
+ Re-run anytime:
111
+
112
+ ```bash
113
+ python3 scripts/benchmark_local.py
114
+ ```
115
+
116
+ Notes:
117
+ - These are local-machine numbers (single container, warm runtime).
118
+ - For submission-grade reporting, also capture one run against your HF Space URL after deploy.
119
+
120
+ ## Setup & Usage
121
+
122
+ ### Local Development
123
+ ```bash
124
+ pip install -r requirements.txt
125
+ uvicorn server.main:app --host 0.0.0.0 --port 7860
126
+ ```
127
+
128
+ ### Docker
129
+ ```bash
130
+ docker build -t sql-debug-env .
131
+ docker run -p 7860:7860 sql-debug-env
132
+ ```
133
+
134
+ ### Quick smoke test
135
+ ```bash
136
+ curl http://localhost:7860/health
137
+ curl http://localhost:7860/tasks
138
+ curl -X POST http://localhost:7860/reset -H "Content-Type: application/json" -d '{"task_id":"easy_syntax_fix"}'
139
+ curl -X POST http://localhost:7860/step -H "Content-Type: application/json" -d '{"action":{"action_type":"inspect_schema"}}'
140
+ curl "http://localhost:7860/benchmark?runs=20"
141
+ ```
142
+
143
+ ### Real-time benchmark API (for dashboards/web pages)
144
+ This is a live endpoint, not static/dummy data. Every request runs fresh measurements.
145
+
146
+ - Endpoint: `GET /benchmark?runs=20`
147
+ - `runs` range: `1` to `100`
148
+ - Returns JSON with `avg_ms`, `p50_ms`, `p95_ms`, `n`, and a fresh `timestamp_epoch_ms`
149
+
150
+ Example:
151
+ ```bash
152
+ curl "http://localhost:7860/benchmark?runs=30"
153
+ ```
154
+
155
+ ### Run Baseline
156
+ ```bash
157
+ export API_BASE_URL="https://api.openai.com/v1"
158
+ export MODEL_NAME="gpt-4o-mini"
159
+ export OPENAI_API_KEY="your-key"
160
+ export ENV_BASE_URL="http://localhost:7860"
161
+ export HF_TOKEN="$OPENAI_API_KEY"
162
+ export SEED="1"
163
+ python inference.py
164
+ ```
165
+
166
+ ### OpenEnv Validation
167
+ ```bash
168
+ pip install openenv-core
169
+ openenv validate
170
+ ```
171
+
172
+ ### Suggested pre-submit check
173
+ ```bash
174
+ openenv validate --verbose
175
+ python3 -m unittest discover -s tests -p "test_*.py"
176
+ docker build -t sql-debug-env .
177
+ docker run --rm -p 7860:7860 sql-debug-env
178
+ # in another terminal:
179
+ curl -s http://localhost:7860/health
180
+ curl -s http://localhost:7860/docs >/dev/null
181
+ curl -s "http://localhost:7860/benchmark?runs=20"
182
+ ```
183
+
184
+ ## Hugging Face Spaces (Docker)
185
+ 1. Create a new **Space → Docker**.
186
+ 2. Push this repo.
187
+ 3. Update `openenv.yaml` → `api.base_url` to your Space URL: `https://<your-space>.hf.space`
188
+ 4. Wait for build, then verify:
189
+
190
+ ```bash
191
+ curl -X POST https://<your-space>.hf.space/reset -H "Content-Type: application/json" -d '{}'
192
+ ```
193
+
inference.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py — OpenEnv SQL Debug Environment Baseline Agent
3
+ MUST be at root level. MUST use exact [START]/[STEP]/[END] log format.
4
+ Uses OpenAI client. Reads from environment variables.
5
+ Runtime target: < 20 minutes on 2vCPU / 8GB.
6
+ """
7
+ import asyncio
8
+ import os
9
+ import json
10
+ import sys
11
+ import time
12
+ from typing import List, Dict, Any, Optional
13
+ from openai import OpenAI
14
+ import httpx
15
+
16
+
17
+ # ── Configuration from environment variables ────────────────────────────────
18
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
19
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
20
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
21
+ API_KEY = os.environ.get("OPENAI_API_KEY", HF_TOKEN or "sk-placeholder")
22
+
23
+ # ── Environment config ───────────────────────────────────────────────────────
24
+ ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
25
+ BENCHMARK = "sql-debug-env"
26
+ TEMPERATURE = 0.0
27
+ MAX_TOKENS = 1024
28
+ SEED = int(os.environ.get("SEED", "1"))
29
+
30
+ # ── Per-task config ──────────────────────────────────────────────────────────
31
+ TASK_CONFIGS = {
32
+ "easy_syntax_fix": {"max_steps": 10, "success_threshold": 0.8},
33
+ "medium_logic_fix": {"max_steps": 20, "success_threshold": 0.7},
34
+ "hard_multi_bug": {"max_steps": 30, "success_threshold": 0.5},
35
+ }
36
+
37
+
38
+ # ── Logging functions (EXACT FORMAT — DO NOT MODIFY) ────────────────────────
39
+ def log_start(task: str, env: str, model: str):
40
+ print(f"[START] task={task} env={env} model={model}", flush=True)
41
+
42
+
43
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]):
44
+ error_str = error if error else "null"
45
+ # Escape action for single-line logging
46
+ action_clean = action.replace("\n", "\\n").replace('"', '\\"')[:200]
47
+ print(
48
+ f"[STEP] step={step} action=\"{action_clean}\" "
49
+ f"reward={reward:.4f} done={str(done).lower()} error={error_str}",
50
+ flush=True
51
+ )
52
+
53
+
54
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]):
55
+ rewards_str = json.dumps([round(r, 4) for r in rewards])
56
+ print(
57
+ f"[END] success={str(success).lower()} steps={steps} "
58
+ f"score={score:.4f} rewards={rewards_str}",
59
+ flush=True
60
+ )
61
+
62
+
63
+ # ── System prompt ────────────────────────────────────────────────────────────
64
+ SYSTEM_PROMPT = """You are an expert SQL debugger. You will receive a broken SQL query and must fix it.
65
+
66
+ You interact with a SQL debugging environment via JSON actions.
67
+
68
+ Available actions (respond with ONLY valid JSON, no markdown, no explanation):
69
+
70
+ 1. Submit a fixed query:
71
+ {"action_type": "submit_query", "query": "SELECT ..."}
72
+
73
+ 2. Inspect schema (free, no penalty):
74
+ {"action_type": "inspect_schema"}
75
+
76
+ 3. Inspect last error (free, no penalty):
77
+ {"action_type": "inspect_error"}
78
+
79
+ 4. Inspect sample rows from a table (free, no penalty):
80
+ {"action_type": "inspect_sample", "table_name": "table_name_here"}
81
+
82
+ Strategy:
83
+ - Start by submitting a fixed query if the bug is obvious
84
+ - Use inspect_schema first if you need to verify column names/table structure
85
+ - Use inspect_error to understand why your query failed
86
+ - Read error messages carefully — they tell you exactly what's wrong
87
+ - Fix one bug at a time and resubmit
88
+ - You get partial credit for partially correct queries
89
+
90
+ IMPORTANT: Respond with ONLY the JSON action. No explanation, no markdown blocks, just raw JSON."""
91
+
92
+
93
+ def build_prompt(obs: Dict[str, Any], step: int, reward_history: List[float]) -> str:
94
+ """Build the user prompt for each step."""
95
+
96
+ lines = [
97
+ f"=== SQL Debugging Task (Step {step}) ===",
98
+ f"Task: {obs.get('task_description', '')[:500]}",
99
+ f"",
100
+ f"ORIGINAL BROKEN QUERY:",
101
+ f"```sql",
102
+ f"{obs.get('original_query', '')}",
103
+ f"```",
104
+ ]
105
+
106
+ if obs.get('current_query'):
107
+ lines += [
108
+ f"",
109
+ f"YOUR LAST SUBMITTED QUERY:",
110
+ f"```sql",
111
+ f"{obs.get('current_query', '')}",
112
+ f"```",
113
+ ]
114
+
115
+ last_result = obs.get('last_query_result')
116
+ if last_result:
117
+ if last_result.get('success'):
118
+ rows = last_result.get('rows', [])
119
+ lines += [
120
+ f"",
121
+ f"LAST QUERY RESULT: {len(rows)} rows returned",
122
+ f"Sample (first 3): {json.dumps(rows[:3], default=str)}",
123
+ ]
124
+ else:
125
+ lines += [
126
+ f"",
127
+ f"LAST QUERY ERROR: {last_result.get('error_message', 'Unknown error')}",
128
+ ]
129
+
130
+ if obs.get('schema_info'):
131
+ schema = obs['schema_info'].get('tables', {})
132
+ lines += [f"", f"DATABASE SCHEMA:"]
133
+ for table, cols in schema.items():
134
+ col_str = ", ".join(f"{c['name']} ({c['type']})" for c in cols)
135
+ lines.append(f" {table}: {col_str}")
136
+
137
+ if obs.get('error_details'):
138
+ lines += [f"", f"ERROR DETAILS: {obs['error_details']}"]
139
+
140
+ if obs.get('sample_rows'):
141
+ lines += [f"", f"SAMPLE ROWS: {json.dumps(obs['sample_rows'][:3], default=str)}"]
142
+
143
+ if obs.get('hint'):
144
+ lines += [f"", f"HINT: {obs['hint']}"]
145
+
146
+ lines += [
147
+ f"",
148
+ f"Current score: {obs.get('current_score', 0):.3f}",
149
+ f"Steps remaining: {obs.get('steps_remaining', 0)}",
150
+ f"Expected output: {obs.get('expected_description', '')}",
151
+ f"",
152
+ f"What is your next action? (respond with ONLY valid JSON)"
153
+ ]
154
+
155
+ return "\n".join(lines)
156
+
157
+
158
+ def call_model(client: OpenAI, prompt: str) -> Dict[str, Any]:
159
+ """Call model and parse JSON action response."""
160
+ try:
161
+ response = client.chat.completions.create(
162
+ model=MODEL_NAME,
163
+ messages=[
164
+ {"role": "system", "content": SYSTEM_PROMPT},
165
+ {"role": "user", "content": prompt}
166
+ ],
167
+ temperature=TEMPERATURE,
168
+ seed=SEED,
169
+ max_tokens=MAX_TOKENS,
170
+ )
171
+ text = (response.choices[0].message.content or "").strip()
172
+
173
+ # Strip markdown if model wraps in backticks
174
+ if text.startswith("```"):
175
+ text = text.split("```")[1]
176
+ if text.startswith("json"):
177
+ text = text[4:]
178
+ text = text.strip()
179
+
180
+ return json.loads(text)
181
+ except json.JSONDecodeError:
182
+ # Fallback: try to extract JSON from response
183
+ import re
184
+ match = re.search(r'\{.*\}', text, re.DOTALL)
185
+ if match:
186
+ try:
187
+ return json.loads(match.group())
188
+ except:
189
+ pass
190
+ # Default fallback action
191
+ return {"action_type": "inspect_schema"}
192
+ except Exception as e:
193
+ print(f"[DEBUG] Model error: {e}", flush=True)
194
+ return {"action_type": "inspect_schema"}
195
+
196
+
197
+ def run_task(
198
+ client: OpenAI,
199
+ task_id: str,
200
+ config: Dict[str, Any]
201
+ ) -> Dict[str, Any]:
202
+ """Run one task episode synchronously via HTTP."""
203
+
204
+ max_steps = config["max_steps"]
205
+ success_threshold = config["success_threshold"]
206
+
207
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
208
+
209
+ rewards = []
210
+ steps_taken = 0
211
+ score = 0.0
212
+ success = False
213
+
214
+ with httpx.Client(base_url=ENV_BASE_URL, timeout=30.0) as http:
215
+ # Reset
216
+ reset_resp = http.post("/reset", json={"task_id": task_id})
217
+ reset_resp.raise_for_status()
218
+ result = reset_resp.json()
219
+ obs = result["observation"]
220
+ done = result["done"]
221
+
222
+ reward_history = []
223
+
224
+ for step in range(1, max_steps + 1):
225
+ if done:
226
+ break
227
+
228
+ # Get model action
229
+ prompt = build_prompt(obs, step, reward_history)
230
+ action_dict = call_model(client, prompt)
231
+
232
+ # Execute step
233
+ try:
234
+ step_resp = http.post("/step", json={"action": action_dict})
235
+ step_resp.raise_for_status()
236
+ step_result = step_resp.json()
237
+ except Exception as e:
238
+ log_step(step=step, action=str(action_dict), reward=0.0, done=False, error=str(e))
239
+ continue
240
+
241
+ obs = step_result["observation"]
242
+ reward = float(step_result.get("reward") or 0.0)
243
+ done = step_result["done"]
244
+ error = None
245
+ info = step_result.get("info") or {}
246
+
247
+ # Extract error for logging
248
+ last_result = obs.get("last_query_result")
249
+ if last_result and not last_result.get("success"):
250
+ error = last_result.get("error_message", "")
251
+
252
+ action_str = action_dict.get("query") or action_dict.get("action_type", "unknown")
253
+
254
+ rewards.append(reward)
255
+ reward_history.append(reward)
256
+ steps_taken = step
257
+ score = float(info.get("grade_score") or obs.get("current_score") or 0.0)
258
+
259
+ log_step(step=step, action=action_str, reward=reward, done=done, error=error)
260
+
261
+ if done:
262
+ break
263
+
264
+ # Compute final score
265
+ score = min(max(score, 0.0), 1.0)
266
+ success = score >= success_threshold
267
+
268
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
269
+
270
+ return {
271
+ "task_id": task_id,
272
+ "score": score,
273
+ "success": success,
274
+ "steps": steps_taken,
275
+ "rewards": rewards
276
+ }
277
+
278
+
279
+ def main():
280
+ """Run baseline agent across all 3 tasks."""
281
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
282
+
283
+ print(f"[DEBUG] Starting SQL Debug Env baseline", flush=True)
284
+ print(f"[DEBUG] Model: {MODEL_NAME}", flush=True)
285
+ print(f"[DEBUG] Env URL: {ENV_BASE_URL}", flush=True)
286
+
287
+ # Wait for server to be ready
288
+ max_wait = 30
289
+ for i in range(max_wait):
290
+ try:
291
+ resp = httpx.get(f"{ENV_BASE_URL}/health", timeout=5)
292
+ if resp.status_code == 200:
293
+ print(f"[DEBUG] Server ready", flush=True)
294
+ break
295
+ except:
296
+ pass
297
+ print(f"[DEBUG] Waiting for server... ({i+1}/{max_wait})", flush=True)
298
+ time.sleep(1)
299
+
300
+ all_results = []
301
+
302
+ for task_id, config in TASK_CONFIGS.items():
303
+ print(f"\n[DEBUG] Running task: {task_id}", flush=True)
304
+ try:
305
+ result = run_task(client, task_id, config)
306
+ all_results.append(result)
307
+ except Exception as e:
308
+ print(f"[DEBUG] Task {task_id} failed: {e}", flush=True)
309
+ log_end(success=False, steps=0, score=0.0, rewards=[])
310
+
311
+ # Small delay between tasks
312
+ time.sleep(2)
313
+
314
+ # Summary
315
+ print(f"\n[DEBUG] === BASELINE RESULTS ===", flush=True)
316
+ total_score = 0.0
317
+ for r in all_results:
318
+ print(f"[DEBUG] {r['task_id']}: score={r['score']:.3f} success={r['success']}", flush=True)
319
+ total_score += r['score']
320
+
321
+ if all_results:
322
+ avg = total_score / len(all_results)
323
+ print(f"[DEBUG] Average score: {avg:.3f}", flush=True)
324
+
325
+
326
+ if __name__ == "__main__":
327
+ main()
328
+
openenv.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sql-debug-env
2
+ version: 0.1.0
3
+ description: >
4
+ A reinforcement learning environment for training AI agents to debug SQL queries.
5
+ Agents receive broken SQL queries against a live SQLite database and must fix them
6
+ through iterative actions: submitting queries, inspecting schemas, and analyzing errors.
7
+ Models a real-world task performed daily by data analysts, engineers, and scientists.
8
+
9
+ author: md-ayan
10
+ license: apache-2.0
11
+
12
+ tags:
13
+ - openenv
14
+ - sql
15
+ - debugging
16
+ - data-engineering
17
+ - real-world
18
+ - analytics
19
+
20
+ tasks:
21
+ - id: easy_syntax_fix
22
+ name: "Top Customers by Revenue — Syntax Error Fix"
23
+ difficulty: easy
24
+ max_steps: 10
25
+ description: "Fix 2 syntax/reference bugs in a customer analytics query"
26
+
27
+ - id: medium_logic_fix
28
+ name: "Department Headcount Report — Logic Error Fix"
29
+ difficulty: medium
30
+ max_steps: 20
31
+ description: "Fix JOIN type, WHERE clause placement, and aggregation scope bugs"
32
+
33
+ - id: hard_multi_bug
34
+ name: "SaaS Cohort Activation Report — Multi-Bug Fix"
35
+ difficulty: hard
36
+ max_steps: 30
37
+ description: "Fix 5 bugs: correlated subquery, window function, duplicate rows, date logic, CTE scope"
38
+
39
+ api:
40
+ base_url: "https://YOUR-USERNAME-sql-debug-env.hf.space"
41
+ reset: "/reset"
42
+ step: "/step"
43
+ state: "/state"
44
+ health: "/health"
45
+ tasks: "/tasks"
46
+
47
+ observation_space:
48
+ type: structured
49
+ fields:
50
+ - name: task_description
51
+ type: string
52
+ - name: original_query
53
+ type: string
54
+ - name: current_query
55
+ type: string_or_null
56
+ - name: last_query_result
57
+ type: object_or_null
58
+ - name: steps_taken
59
+ type: integer
60
+ - name: current_score
61
+ type: float
62
+
63
+ action_space:
64
+ type: structured
65
+ actions:
66
+ - id: submit_query
67
+ description: "Submit a fixed SQL query for evaluation"
68
+ required_fields: [query]
69
+ - id: inspect_schema
70
+ description: "Get database schema (free action)"
71
+ - id: inspect_error
72
+ description: "Get last error details (free action)"
73
+ - id: inspect_sample
74
+ description: "Get 3 sample rows from a table"
75
+ required_fields: [table_name]
76
+ - id: reset_query
77
+ description: "Reset to original broken query (penalty: -0.05)"
78
+
79
+ reward:
80
+ range: [0.0, 1.0]
81
+ components:
82
+ - name: correctness
83
+ range: [0.0, 0.6]
84
+ description: "Row-level match vs expected output"
85
+ - name: efficiency
86
+ range: [0.0, 0.2]
87
+ description: "Bonus for solving with fewer steps"
88
+ - name: syntax_progress
89
+ range: [0.0, 0.1]
90
+ description: "Valid SQL even if wrong content"
91
+ - name: schema_bonus
92
+ range: [0.0, 0.1]
93
+ description: "Correct table/column references"
94
+ - name: penalty
95
+ range: [0.0, 0.2]
96
+ description: "Penalty deduction magnitude for bad actions / urgency"
97
+
98
+ runtime:
99
+ max_concurrent_sessions: 64
100
+ episode_timeout_seconds: 300
101
+ machine_requirements:
102
+ vcpu: 2
103
+ memory_gb: 8
104
+
pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sql-debug-env"
7
+ version = "0.1.0"
8
+ requires-python = ">=3.11"
9
+ dependencies = [
10
+ "fastapi==0.115.0",
11
+ "uvicorn[standard]==0.30.6",
12
+ "pydantic==2.9.2",
13
+ "openenv-core>=0.1.0",
14
+ "openai>=1.50.0",
15
+ "httpx>=0.27.0",
16
+ "python-multipart==0.0.9"
17
+ ]
18
+
19
+ [project.scripts]
20
+ server = "server.app:main"
21
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.0
2
+ uvicorn[standard]==0.30.6
3
+ pydantic==2.9.2
4
+ openenv-core>=0.1.0
5
+ openai>=1.50.0
6
+ httpx>=0.27.0
7
+ python-multipart==0.0.9
8
+
scripts/benchmark_local.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lightweight local benchmark for sql-debug-env.
3
+
4
+ Runs deterministic endpoint checks and prints simple latency metrics.
5
+ No LLM key required.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import statistics
10
+ import time
11
+ from typing import Dict, List
12
+
13
+ import httpx
14
+
15
+
16
+ BASE_URL = "http://localhost:7860"
17
+
18
+
19
+ def timed_call(client: httpx.Client, method: str, path: str, json_body: Dict | None = None) -> float:
20
+ start = time.perf_counter()
21
+ if method == "GET":
22
+ r = client.get(path)
23
+ else:
24
+ r = client.post(path, json=json_body)
25
+ r.raise_for_status()
26
+ return (time.perf_counter() - start) * 1000
27
+
28
+
29
+ def summarize(samples: List[float]) -> str:
30
+ p50 = statistics.median(samples)
31
+ p95 = sorted(samples)[int(len(samples) * 0.95) - 1]
32
+ avg = statistics.mean(samples)
33
+ return f"avg={avg:.2f}ms p50={p50:.2f}ms p95={p95:.2f}ms n={len(samples)}"
34
+
35
+
36
+ def main() -> None:
37
+ with httpx.Client(base_url=BASE_URL, timeout=30.0) as client:
38
+ # Warmup + health check
39
+ client.get("/health").raise_for_status()
40
+
41
+ health_times = [timed_call(client, "GET", "/health") for _ in range(25)]
42
+ tasks_times = [timed_call(client, "GET", "/tasks") for _ in range(25)]
43
+
44
+ reset_times: List[float] = []
45
+ step_times: List[float] = []
46
+ for _ in range(25):
47
+ reset_times.append(
48
+ timed_call(client, "POST", "/reset", {"task_id": "easy_syntax_fix"})
49
+ )
50
+ step_times.append(
51
+ timed_call(client, "POST", "/step", {"action": {"action_type": "inspect_schema"}})
52
+ )
53
+
54
+ print("Benchmark results (local)")
55
+ print(f"GET /health: {summarize(health_times)}")
56
+ print(f"GET /tasks: {summarize(tasks_times)}")
57
+ print(f"POST /reset: {summarize(reset_times)}")
58
+ print(f"POST /step (inspect_schema): {summarize(step_times)}")
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
63
+
server/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # sql-debug-env
2
+
server/app.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uvicorn
3
+
4
+ from .main import app
5
+
6
+
7
+ def main():
8
+ """
9
+ OpenEnv entry point.
10
+
11
+ This module is required for `openenv validate` multi-mode deployment checks.
12
+ """
13
+ host = os.environ.get("HOST", "0.0.0.0")
14
+ port = int(os.environ.get("PORT", "7860"))
15
+ uvicorn.run("server.app:app", host=host, port=port, workers=1)
16
+
17
+
18
+ if __name__ == "__main__":
19
+ main()
20
+
server/database.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQLite in-memory database management.
3
+ Creates fresh DB instances per episode with deterministic seed data.
4
+ """
5
+ import sqlite3
6
+ import time
7
+ from typing import Dict, Any, List
8
+
9
+
10
+ class EpisodeDatabase:
11
+ """
12
+ Manages a single SQLite in-memory database for one episode.
13
+ Seeded with deterministic data per task.
14
+ """
15
+
16
+ def __init__(self, task_id: str, schema_sql: str, seed_data_sql: str):
17
+ self.task_id = task_id
18
+ self.conn = sqlite3.connect(":memory:", check_same_thread=False)
19
+ self.conn.row_factory = sqlite3.Row
20
+ self.conn.execute("PRAGMA foreign_keys = ON")
21
+ self._setup(schema_sql, seed_data_sql)
22
+
23
+ def _setup(self, schema_sql: str, seed_data_sql: str):
24
+ """Create schema and insert seed data."""
25
+ cursor = self.conn.cursor()
26
+ for statement in schema_sql.strip().split(";"):
27
+ stmt = statement.strip()
28
+ if stmt:
29
+ cursor.execute(stmt)
30
+ for statement in seed_data_sql.strip().split(";"):
31
+ stmt = statement.strip()
32
+ if stmt:
33
+ cursor.execute(stmt)
34
+ self.conn.commit()
35
+
36
+ def execute_query(self, query: str) -> Dict[str, Any]:
37
+ """
38
+ Execute a read-only SQL query safely.
39
+ Returns rows or error. Enforces SELECT-only.
40
+ Execution timeout: 5 seconds.
41
+ """
42
+ query_stripped = query.strip().upper()
43
+
44
+ # Block dangerous operations
45
+ blocked = ["DROP", "DELETE", "UPDATE", "INSERT", "CREATE", "ALTER",
46
+ "TRUNCATE", "REPLACE", "ATTACH", "DETACH"]
47
+ for kw in blocked:
48
+ if query_stripped.startswith(kw) or f" {kw} " in query_stripped:
49
+ return {
50
+ "success": False,
51
+ "rows": None,
52
+ "row_count": None,
53
+ "error_message": f"BLOCKED: Only SELECT queries are allowed. '{kw}' is not permitted.",
54
+ "execution_time_ms": 0.0
55
+ }
56
+
57
+ start = time.time()
58
+ try:
59
+ cursor = self.conn.cursor()
60
+ cursor.execute(query)
61
+ rows = cursor.fetchall()
62
+ elapsed = (time.time() - start) * 1000
63
+
64
+ # Convert Row objects to dicts
65
+ result_rows = [dict(row) for row in rows]
66
+
67
+ return {
68
+ "success": True,
69
+ "rows": result_rows,
70
+ "row_count": len(result_rows),
71
+ "error_message": None,
72
+ "execution_time_ms": round(elapsed, 2)
73
+ }
74
+ except sqlite3.Error as e:
75
+ elapsed = (time.time() - start) * 1000
76
+ return {
77
+ "success": False,
78
+ "rows": None,
79
+ "row_count": None,
80
+ "error_message": str(e),
81
+ "execution_time_ms": round(elapsed, 2)
82
+ }
83
+
84
+ def get_schema(self) -> Dict[str, List[Dict[str, str]]]:
85
+ """Return schema info: tables and their columns."""
86
+ schema = {}
87
+ cursor = self.conn.cursor()
88
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
89
+ tables = [row[0] for row in cursor.fetchall()]
90
+
91
+ for table in tables:
92
+ cursor.execute(f"PRAGMA table_info({table})")
93
+ columns = []
94
+ for col in cursor.fetchall():
95
+ columns.append({
96
+ "name": col[1],
97
+ "type": col[2],
98
+ "nullable": "YES" if col[3] == 0 else "NO",
99
+ "primary_key": "YES" if col[5] > 0 else "NO"
100
+ })
101
+ schema[table] = columns
102
+
103
+ return schema
104
+
105
+ def get_sample_rows(self, table_name: str, limit: int = 3) -> List[Dict[str, Any]]:
106
+ """Get sample rows from a table."""
107
+ result = self.execute_query(f"SELECT * FROM {table_name} LIMIT {limit}")
108
+ return result.get("rows", []) or []
109
+
110
+ def close(self):
111
+ self.conn.close()
112
+
server/env.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core SQL Debug Environment.
3
+ Manages episode state, delegates to tasks and reward function.
4
+ """
5
+ import uuid
6
+ import asyncio
7
+ from typing import Optional, Dict, Any, List
8
+ from .models import (
9
+ SQLDebugAction, SQLDebugObservation, SQLDebugReward,
10
+ EpisodeState, ActionType, QueryResult, SchemaInfo
11
+ )
12
+ from .database import EpisodeDatabase
13
+ from .reward import compute_reward
14
+ from .tasks.task_easy import EasyTask
15
+ from .tasks.task_medium import MediumTask, MediumTaskGrader
16
+ from .tasks.task_hard import HardTask
17
+
18
+
19
+ TASKS = {
20
+ "easy_syntax_fix": EasyTask(),
21
+ "medium_logic_fix": MediumTask(),
22
+ "hard_multi_bug": HardTask(),
23
+ }
24
+
25
+
26
+ class SQLDebugEnv:
27
+ """
28
+ The SQL Debug Environment.
29
+ Manages one active episode at a time per session.
30
+ Thread-safe for concurrent sessions via instance-per-session pattern.
31
+ """
32
+
33
+ def __init__(self, task_id: str = "easy_syntax_fix"):
34
+ self.task_id = task_id
35
+ self.task = TASKS[task_id]
36
+ self._db: Optional[EpisodeDatabase] = None
37
+ self._state: Optional[EpisodeState] = None
38
+ self._lock = asyncio.Lock()
39
+
40
+ async def reset(self) -> tuple[SQLDebugObservation, Dict]:
41
+ """Reset environment to initial state. Returns (observation, info)."""
42
+ async with self._lock:
43
+ # Close previous DB if exists
44
+ if self._db:
45
+ self._db.close()
46
+
47
+ # Fresh DB
48
+ self._db = EpisodeDatabase(
49
+ task_id=self.task.task_id,
50
+ schema_sql=self.task.schema_sql,
51
+ seed_data_sql=self.task.seed_data_sql
52
+ )
53
+
54
+ # Fresh state
55
+ self._state = EpisodeState(
56
+ task_id=self.task.task_id,
57
+ task_difficulty=self.task.difficulty,
58
+ original_query=self.task.broken_query,
59
+ current_query=None,
60
+ best_score_so_far=0.0,
61
+ steps_taken=0,
62
+ max_steps=self.task.max_steps,
63
+ action_history=[],
64
+ reward_history=[],
65
+ is_done=False,
66
+ success=False,
67
+ db_schema=self._db.get_schema()
68
+ )
69
+
70
+ obs = SQLDebugObservation(
71
+ task_id=self.task.task_id,
72
+ task_description=self.task.description,
73
+ original_query=self.task.broken_query,
74
+ current_query=None,
75
+ expected_description=self.task.expected_output_description,
76
+ last_action_type="reset",
77
+ last_query_result=None,
78
+ steps_taken=0,
79
+ steps_remaining=self.task.max_steps,
80
+ current_score=0.0,
81
+ schema_info=SchemaInfo(tables=self._db.get_schema()),
82
+ is_done=False,
83
+ success=False
84
+ )
85
+
86
+ return obs, {"task": self.task.to_dict()}
87
+
88
+ async def step(self, action: SQLDebugAction) -> tuple[SQLDebugObservation, float, bool, Dict]:
89
+ """
90
+ Execute one action.
91
+ Returns (observation, reward_value, done, info)
92
+ """
93
+ async with self._lock:
94
+ if self._state is None:
95
+ raise RuntimeError("Call reset() before step()")
96
+
97
+ if self._state.is_done:
98
+ raise RuntimeError("Episode is done. Call reset() to start new episode.")
99
+
100
+ self._state.steps_taken += 1
101
+ steps_taken = self._state.steps_taken
102
+
103
+ query_result_raw = None
104
+ prev_best_score = self._state.best_score_so_far
105
+ grade_score = self._state.best_score_so_far
106
+ schema_info = None
107
+ error_details = None
108
+ sample_rows = None
109
+ hint = None
110
+
111
+ # --- Execute action ---
112
+ if action.action_type == ActionType.SUBMIT_QUERY:
113
+ if not action.query:
114
+ raise ValueError("query is required for submit_query action")
115
+
116
+ self._state.current_query = action.query
117
+ query_result_raw = self._db.execute_query(action.query)
118
+
119
+ # Grade the result
120
+ actual_rows = query_result_raw.get("rows") if query_result_raw.get("success") else None
121
+
122
+ # Use custom grader for medium task
123
+ if self.task.task_id == "medium_logic_fix":
124
+ grade_score = MediumTaskGrader.grade(actual_rows or [])
125
+ else:
126
+ grade_score = self.task.grade(actual_rows)
127
+
128
+ if grade_score > self._state.best_score_so_far:
129
+ self._state.best_score_so_far = grade_score
130
+
131
+ elif action.action_type == ActionType.INSPECT_SCHEMA:
132
+ schema = self._db.get_schema()
133
+ schema_info = SchemaInfo(tables=schema)
134
+ grade_score = self._state.best_score_so_far
135
+
136
+ elif action.action_type == ActionType.INSPECT_ERROR:
137
+ # Return last error if available
138
+ if self._state.action_history:
139
+ last = self._state.action_history[-1]
140
+ error_details = last.get("error_message", "No error recorded from last query.")
141
+ else:
142
+ error_details = "No query has been submitted yet."
143
+ grade_score = self._state.best_score_so_far
144
+
145
+ elif action.action_type == ActionType.INSPECT_SAMPLE:
146
+ if not action.table_name:
147
+ raise ValueError("table_name required for inspect_sample")
148
+ sample_rows = self._db.get_sample_rows(action.table_name)
149
+ grade_score = self._state.best_score_so_far
150
+
151
+ elif action.action_type == ActionType.RESET_QUERY:
152
+ self._state.current_query = self.task.broken_query
153
+ grade_score = self._state.best_score_so_far
154
+
155
+ # --- Compute reward ---
156
+ schema_tables = list(self._db.get_schema().keys())
157
+ reward_obj = compute_reward(
158
+ action_type=action.action_type.value,
159
+ query_result=query_result_raw,
160
+ grade_score=grade_score,
161
+ steps_taken=steps_taken,
162
+ max_steps=self.task.max_steps,
163
+ previous_best_score=prev_best_score,
164
+ schema_tables=schema_tables,
165
+ submitted_query=action.query if action.action_type == ActionType.SUBMIT_QUERY else None
166
+ )
167
+
168
+ # --- Check done conditions ---
169
+ is_done = False
170
+ success = False
171
+
172
+ if grade_score >= 0.95:
173
+ is_done = True
174
+ success = True
175
+ elif steps_taken >= self.task.max_steps:
176
+ is_done = True
177
+ success = self._state.best_score_so_far >= 0.5
178
+
179
+ self._state.is_done = is_done
180
+ self._state.success = success
181
+
182
+ # --- Hint logic ---
183
+ hint_threshold = 3 if self.task.difficulty == "easy" else 5
184
+ if steps_taken >= hint_threshold:
185
+ hint = self.task.hint
186
+
187
+ # --- Record history ---
188
+ self._state.action_history.append({
189
+ "step": steps_taken,
190
+ "action_type": action.action_type.value,
191
+ "query": action.query,
192
+ "grade_score": grade_score,
193
+ "reward": reward_obj.value,
194
+ "error_message": query_result_raw.get("error_message") if query_result_raw else None
195
+ })
196
+ self._state.reward_history.append(reward_obj.value)
197
+
198
+ # --- Build observation ---
199
+ qr = QueryResult(**query_result_raw) if query_result_raw else None
200
+
201
+ obs = SQLDebugObservation(
202
+ task_id=self.task.task_id,
203
+ task_description=self.task.description,
204
+ original_query=self.task.broken_query,
205
+ current_query=self._state.current_query,
206
+ expected_description=self.task.expected_output_description,
207
+ last_action_type=action.action_type.value,
208
+ last_query_result=qr,
209
+ steps_taken=steps_taken,
210
+ steps_remaining=max(0, self.task.max_steps - steps_taken),
211
+ current_score=self._state.best_score_so_far,
212
+ schema_info=schema_info,
213
+ error_details=error_details,
214
+ sample_rows=sample_rows,
215
+ hint=hint,
216
+ is_done=is_done,
217
+ success=success
218
+ )
219
+
220
+ return obs, reward_obj.value, is_done, {
221
+ "grade_score": grade_score,
222
+ "reward_breakdown": reward_obj.breakdown,
223
+ "success": success,
224
+ "steps_taken": steps_taken
225
+ }
226
+
227
+ def get_state(self) -> EpisodeState:
228
+ if self._state is None:
229
+ raise RuntimeError("Call reset() first")
230
+ return self._state
231
+
232
+ def close(self):
233
+ if self._db:
234
+ self._db.close()
235
+ self._db = None
236
+
server/main.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI server exposing the OpenEnv HTTP API.
3
+ Endpoints: POST /reset, POST /step, GET /state
4
+ Also includes: GET /tasks (list available tasks), GET /health
5
+ """
6
+ import asyncio
7
+ import time
8
+ import statistics
9
+ from typing import Dict, Optional
10
+ from contextlib import asynccontextmanager
11
+
12
+ from fastapi import FastAPI, HTTPException, Header
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from pydantic import BaseModel
15
+
16
+ from .models import SQLDebugAction, SQLDebugObservation, EpisodeState
17
+ from .env import SQLDebugEnv, TASKS
18
+
19
+
20
+ # Session management: one env instance per session
21
+ # For HF Space: allow up to 64 concurrent sessions
22
+ MAX_SESSIONS = 64
23
+ _sessions: Dict[str, SQLDebugEnv] = {}
24
+ _session_lock = asyncio.Lock()
25
+
26
+
27
+ @asynccontextmanager
28
+ async def lifespan(app: FastAPI):
29
+ yield
30
+ # Cleanup all sessions on shutdown
31
+ for env in _sessions.values():
32
+ env.close()
33
+
34
+
35
+ app = FastAPI(
36
+ title="SQL Debug Environment",
37
+ description="OpenEnv-compliant SQL query debugging environment for RL agent training.",
38
+ version="0.1.0",
39
+ lifespan=lifespan
40
+ )
41
+
42
+ app.add_middleware(
43
+ CORSMiddleware,
44
+ allow_origins=["*"],
45
+ allow_methods=["*"],
46
+ allow_headers=["*"],
47
+ )
48
+
49
+
50
+ @app.get("/")
51
+ async def root():
52
+ return {
53
+ "name": "sql-debug-env",
54
+ "status": "ok",
55
+ "message": "Use /health, /tasks, /reset, /step, /state, /benchmark",
56
+ }
57
+
58
+
59
+ @app.get("/favicon.ico", status_code=204)
60
+ async def favicon():
61
+ return None
62
+
63
+
64
+ class ResetRequest(BaseModel):
65
+ task_id: Optional[str] = "easy_syntax_fix"
66
+
67
+
68
+ class StepRequest(BaseModel):
69
+ action: SQLDebugAction
70
+
71
+
72
+ async def get_or_create_session(session_id: str, task_id: str = "easy_syntax_fix") -> SQLDebugEnv:
73
+ async with _session_lock:
74
+ if session_id not in _sessions:
75
+ if len(_sessions) >= MAX_SESSIONS:
76
+ # Evict oldest session
77
+ oldest = next(iter(_sessions))
78
+ _sessions[oldest].close()
79
+ del _sessions[oldest]
80
+ _sessions[session_id] = SQLDebugEnv(task_id=task_id)
81
+ return _sessions[session_id]
82
+
83
+
84
+ @app.get("/health")
85
+ async def health():
86
+ return {"status": "ok", "sessions_active": len(_sessions)}
87
+
88
+
89
+ @app.get("/tasks")
90
+ async def list_tasks():
91
+ """List all available tasks with metadata."""
92
+ return {
93
+ "tasks": [task.to_dict() for task in TASKS.values()]
94
+ }
95
+
96
+
97
+ def _stats(values: list[float]) -> Dict[str, float]:
98
+ ordered = sorted(values)
99
+ n = len(ordered)
100
+ p95_idx = max(0, int(n * 0.95) - 1)
101
+ return {
102
+ "avg_ms": round(statistics.mean(ordered), 3),
103
+ "p50_ms": round(statistics.median(ordered), 3),
104
+ "p95_ms": round(ordered[p95_idx], 3),
105
+ "n": n,
106
+ }
107
+
108
+
109
+ @app.get("/benchmark")
110
+ async def benchmark(runs: int = 20):
111
+ """
112
+ Real-time benchmark endpoint (fresh measurements on every call).
113
+ Safe to call from dashboards/web pages for live verification.
114
+ """
115
+ runs = max(1, min(runs, 100))
116
+
117
+ health_times: list[float] = []
118
+ tasks_times: list[float] = []
119
+ reset_times: list[float] = []
120
+ step_times: list[float] = []
121
+
122
+ bench_env = SQLDebugEnv(task_id="easy_syntax_fix")
123
+ try:
124
+ for _ in range(runs):
125
+ t0 = time.perf_counter()
126
+ _ = {"status": "ok", "sessions_active": len(_sessions)}
127
+ health_times.append((time.perf_counter() - t0) * 1000)
128
+
129
+ t0 = time.perf_counter()
130
+ _ = [task.to_dict() for task in TASKS.values()]
131
+ tasks_times.append((time.perf_counter() - t0) * 1000)
132
+
133
+ t0 = time.perf_counter()
134
+ await bench_env.reset()
135
+ reset_times.append((time.perf_counter() - t0) * 1000)
136
+
137
+ t0 = time.perf_counter()
138
+ await bench_env.step(SQLDebugAction(action_type="inspect_schema"))
139
+ step_times.append((time.perf_counter() - t0) * 1000)
140
+ finally:
141
+ bench_env.close()
142
+
143
+ return {
144
+ "benchmark": {
145
+ "runs": runs,
146
+ "task_id": "easy_syntax_fix",
147
+ "timestamp_epoch_ms": int(time.time() * 1000),
148
+ "results": {
149
+ "health": _stats(health_times),
150
+ "tasks": _stats(tasks_times),
151
+ "reset": _stats(reset_times),
152
+ "step_inspect_schema": _stats(step_times),
153
+ },
154
+ }
155
+ }
156
+
157
+
158
+ @app.post("/reset")
159
+ async def reset(
160
+ request: ResetRequest = ResetRequest(),
161
+ x_session_id: Optional[str] = Header(default=None)
162
+ ):
163
+ """
164
+ Reset the environment for a new episode.
165
+
166
+ Returns initial observation with task description and broken query.
167
+ """
168
+ session_id = x_session_id or "default"
169
+ task_id = request.task_id or "easy_syntax_fix"
170
+
171
+ if task_id not in TASKS:
172
+ raise HTTPException(status_code=400, detail=f"Unknown task_id: {task_id}. Valid: {list(TASKS.keys())}")
173
+
174
+ # Always create fresh env on reset
175
+ async with _session_lock:
176
+ if session_id in _sessions:
177
+ _sessions[session_id].close()
178
+ _sessions[session_id] = SQLDebugEnv(task_id=task_id)
179
+
180
+ env = _sessions[session_id]
181
+ observation, info = await env.reset()
182
+
183
+ return {
184
+ "observation": observation.model_dump(),
185
+ "info": info,
186
+ "reward": None,
187
+ "done": False
188
+ }
189
+
190
+
191
+ @app.post("/step")
192
+ async def step(
193
+ request: StepRequest,
194
+ x_session_id: Optional[str] = Header(default=None)
195
+ ):
196
+ """
197
+ Execute one action in the environment.
198
+
199
+ Action types:
200
+ - submit_query: Submit SQL for evaluation (requires 'query' field)
201
+ - inspect_schema: Get table schema (free action)
202
+ - inspect_error: Get last error message (free action)
203
+ - inspect_sample: Get sample rows from table (requires 'table_name')
204
+ - reset_query: Reset to original broken query (small penalty)
205
+ """
206
+ session_id = x_session_id or "default"
207
+
208
+ if session_id not in _sessions:
209
+ raise HTTPException(status_code=400, detail="Session not found. Call /reset first.")
210
+
211
+ env = _sessions[session_id]
212
+
213
+ try:
214
+ observation, reward, done, info = await env.step(request.action)
215
+ except RuntimeError as e:
216
+ raise HTTPException(status_code=400, detail=str(e))
217
+ except ValueError as e:
218
+ raise HTTPException(status_code=422, detail=str(e))
219
+
220
+ return {
221
+ "observation": observation.model_dump(),
222
+ "reward": reward,
223
+ "done": done,
224
+ "info": info
225
+ }
226
+
227
+
228
+ @app.get("/state")
229
+ async def state(x_session_id: Optional[str] = Header(default=None)):
230
+ """Return current full episode state."""
231
+ session_id = x_session_id or "default"
232
+
233
+ if session_id not in _sessions:
234
+ raise HTTPException(status_code=400, detail="No active session. Call /reset first.")
235
+
236
+ env = _sessions[session_id]
237
+ try:
238
+ current_state = env.get_state()
239
+ return current_state.model_dump()
240
+ except RuntimeError as e:
241
+ raise HTTPException(status_code=400, detail=str(e))
242
+
server/models.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Typed Pydantic models for the SQL Debug Environment.
3
+ Implements the OpenEnv spec: Observation, Action, Reward.
4
+ """
5
+ from typing import Optional, List, Dict, Any
6
+ from pydantic import BaseModel, Field
7
+ from enum import Enum
8
+
9
+
10
+ class ActionType(str, Enum):
11
+ SUBMIT_QUERY = "submit_query" # Submit a fixed SQL query for evaluation
12
+ INSPECT_SCHEMA = "inspect_schema" # Request schema info (costs 0 reward, gives info)
13
+ INSPECT_ERROR = "inspect_error" # Request error details (costs 0, gives stack trace)
14
+ INSPECT_SAMPLE = "inspect_sample" # Request 3 sample rows from a table
15
+ RESET_QUERY = "reset_query" # Reset to the original broken query (costs -0.05 penalty)
16
+
17
+
18
+ class SQLDebugAction(BaseModel):
19
+ """
20
+ Action model for the SQL Debug Environment.
21
+
22
+ The agent can either:
23
+ - submit_query: Submit a fixed SQL string for evaluation
24
+ - inspect_schema: Get table schema info (free action, no reward change)
25
+ - inspect_error: Get detailed error message from last query run
26
+ - inspect_sample: Get sample rows from a specified table
27
+ - reset_query: Go back to original broken query (costs -0.05 penalty)
28
+ """
29
+ action_type: ActionType = Field(
30
+ description="Type of action to take"
31
+ )
32
+ query: Optional[str] = Field(
33
+ default=None,
34
+ description="SQL query string. Required when action_type is 'submit_query'."
35
+ )
36
+ table_name: Optional[str] = Field(
37
+ default=None,
38
+ description="Table name. Required when action_type is 'inspect_sample'."
39
+ )
40
+
41
+ class Config:
42
+ json_schema_extra = {
43
+ "example": {
44
+ "action_type": "submit_query",
45
+ "query": "SELECT u.name, COUNT(o.id) as order_count FROM users u LEFT JOIN orders o ON u.id = o.user_id GROUP BY u.id, u.name ORDER BY order_count DESC"
46
+ }
47
+ }
48
+
49
+
50
+ class QueryResult(BaseModel):
51
+ """Result of executing a SQL query."""
52
+ success: bool
53
+ rows: Optional[List[Dict[str, Any]]] = None
54
+ row_count: Optional[int] = None
55
+ error_message: Optional[str] = None
56
+ execution_time_ms: Optional[float] = None
57
+
58
+
59
+ class SchemaInfo(BaseModel):
60
+ """Database schema information."""
61
+ tables: Dict[str, List[Dict[str, str]]] # table_name -> list of {name, type, nullable}
62
+ sample_data: Optional[Dict[str, List[Dict[str, Any]]]] = None
63
+
64
+
65
+ class SQLDebugObservation(BaseModel):
66
+ """
67
+ Observation returned after each step.
68
+
69
+ Contains the current state of the debugging session:
70
+ - The original broken query (always visible)
71
+ - The agent's current best query
72
+ - Result of last action
73
+ - Progress indicators
74
+ - Schema/error info if requested
75
+ """
76
+ task_id: str = Field(description="Current task identifier")
77
+ task_description: str = Field(description="Natural language description of the bug to fix")
78
+ original_query: str = Field(description="The original broken SQL query")
79
+ current_query: Optional[str] = Field(default=None, description="Agent's last submitted query")
80
+ expected_description: str = Field(description="Description of what the correct output should look like")
81
+
82
+ # Last action result
83
+ last_action_type: str
84
+ last_query_result: Optional[QueryResult] = None
85
+
86
+ # Progress
87
+ steps_taken: int
88
+ steps_remaining: int
89
+ current_score: float = Field(description="Current score 0.0-1.0 for this episode")
90
+
91
+ # Contextual help (populated based on action type)
92
+ schema_info: Optional[SchemaInfo] = None
93
+ error_details: Optional[str] = None
94
+ sample_rows: Optional[List[Dict[str, Any]]] = None
95
+
96
+ # Hints (unlocked after step 3 on easy, step 5 on medium/hard)
97
+ hint: Optional[str] = None
98
+
99
+ # Episode status
100
+ is_done: bool = False
101
+ success: bool = False
102
+
103
+
104
+ class SQLDebugReward(BaseModel):
105
+ """
106
+ Reward signal for the SQL Debug Environment.
107
+
108
+ Reward components (all sum to final reward):
109
+ - correctness: 0.0-0.6 based on row-level match vs expected output
110
+ - efficiency: 0.0-0.2 bonus for solving in fewer steps
111
+ - syntax_progress: 0.0-0.1 for getting a syntactically valid query (even if wrong)
112
+ - schema_bonus: 0.0-0.1 for queries that reference correct tables/columns
113
+ - penalties: negative values for reset_query, infinite loops, destructive SQL
114
+ """
115
+ value: float = Field(ge=0.0, le=1.0, description="Total reward for this step")
116
+ correctness: float = Field(ge=0.0, le=0.6)
117
+ efficiency: float = Field(ge=0.0, le=0.2)
118
+ syntax_progress: float = Field(ge=0.0, le=0.1)
119
+ schema_bonus: float = Field(ge=0.0, le=0.1)
120
+ penalty: float = Field(ge=0.0, le=0.2, description="Penalty deduction magnitude (non-negative)")
121
+ breakdown: str = Field(description="Human-readable reward breakdown")
122
+
123
+
124
+ class EpisodeState(BaseModel):
125
+ """Full internal state of an episode. Used by state() endpoint."""
126
+ task_id: str
127
+ task_difficulty: str
128
+ original_query: str
129
+ current_query: Optional[str]
130
+ best_score_so_far: float
131
+ steps_taken: int
132
+ max_steps: int
133
+ action_history: List[Dict[str, Any]]
134
+ reward_history: List[float]
135
+ is_done: bool
136
+ success: bool
137
+ db_schema: Dict[str, Any]
138
+
server/reward.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reward function for the SQL Debug Environment.
3
+
4
+ Reward is computed at every step (not just end of episode).
5
+ This provides dense, meaningful signal for RL training.
6
+
7
+ Reward components:
8
+ - correctness: 0.0–0.6 (row-level match vs expected)
9
+ - efficiency: 0.0–0.2 (bonus for solving quickly)
10
+ - syntax_progress: 0.0–0.1 (valid SQL even if wrong content)
11
+ - schema_bonus: 0.0–0.1 (correct tables/columns referenced)
12
+ - penalty: 0.0 to 0.2 (deduction for bad actions)
13
+
14
+ Total range: 0.0 to 1.0 (clamped to [0.0, 1.0])
15
+ """
16
+ from typing import Optional, List, Dict, Any
17
+ from .models import SQLDebugReward
18
+
19
+
20
+ def compute_reward(
21
+ action_type: str,
22
+ query_result: Optional[Dict[str, Any]],
23
+ grade_score: float,
24
+ steps_taken: int,
25
+ max_steps: int,
26
+ previous_best_score: float,
27
+ schema_tables: List[str],
28
+ submitted_query: Optional[str] = None,
29
+ ) -> SQLDebugReward:
30
+ """
31
+ Compute the full reward for a step.
32
+
33
+ Args:
34
+ action_type: The action taken this step
35
+ query_result: Result dict from EpisodeDatabase.execute_query()
36
+ grade_score: 0.0-1.0 score from task grader
37
+ steps_taken: How many steps have been used (1-indexed)
38
+ max_steps: Maximum steps for this task
39
+ previous_best_score: Best grade score seen so far
40
+ schema_tables: List of valid table names in this task's DB
41
+ submitted_query: The SQL query string (if action was submit_query)
42
+ """
43
+
44
+ correctness = 0.0
45
+ efficiency = 0.0
46
+ syntax_progress = 0.0
47
+ schema_bonus = 0.0
48
+ penalty = 0.0 # deduction magnitude (non-negative)
49
+
50
+ if action_type == "submit_query":
51
+ # Correctness: primary signal
52
+ correctness = min(0.6, grade_score * 0.6)
53
+
54
+ # Syntax progress: reward for at least getting a valid query
55
+ if query_result and query_result.get("success"):
56
+ syntax_progress = 0.1
57
+ elif query_result and not query_result.get("success"):
58
+ # Partially reward if it's getting closer (fewer errors)
59
+ error = query_result.get("error_message", "")
60
+ if "no such column" in error.lower():
61
+ syntax_progress = 0.03 # Structure is right but wrong column
62
+ elif "no such table" in error.lower():
63
+ syntax_progress = 0.01
64
+ else:
65
+ syntax_progress = 0.0
66
+
67
+ # Schema bonus: correct table references
68
+ if submitted_query and schema_tables:
69
+ query_upper = submitted_query.upper()
70
+ tables_referenced = sum(
71
+ 1 for t in schema_tables if t.upper() in query_upper
72
+ )
73
+ schema_bonus = min(0.1, (tables_referenced / len(schema_tables)) * 0.1)
74
+
75
+ # Efficiency bonus: reward solving with fewer steps
76
+ if grade_score >= 0.95: # Near-perfect solution
77
+ steps_fraction = steps_taken / max_steps
78
+ if steps_fraction <= 0.3:
79
+ efficiency = 0.2
80
+ elif steps_fraction <= 0.5:
81
+ efficiency = 0.15
82
+ elif steps_fraction <= 0.7:
83
+ efficiency = 0.1
84
+ else:
85
+ efficiency = 0.05
86
+
87
+ # Penalty: if score went DOWN from previous best (regressed)
88
+ if grade_score < previous_best_score - 0.1:
89
+ penalty = 0.05
90
+
91
+ elif action_type == "reset_query":
92
+ # Penalize resetting — agent should be making progress
93
+ penalty = 0.05
94
+
95
+ elif action_type in ("inspect_schema", "inspect_error", "inspect_sample"):
96
+ # Free information actions — small positive for using schema info
97
+ # (encourages agents to explore rather than blindly guess)
98
+ syntax_progress = 0.01
99
+
100
+ # Penalty: approaching step limit (urgency signal)
101
+ steps_remaining = max_steps - steps_taken
102
+ if steps_remaining <= 2 and grade_score < 0.5:
103
+ penalty += 0.03
104
+
105
+ total_raw = correctness + efficiency + syntax_progress + schema_bonus - penalty
106
+ total = round(max(0.0, min(1.0, total_raw)), 4)
107
+
108
+ breakdown = (
109
+ f"correctness={correctness:.3f} + "
110
+ f"efficiency={efficiency:.3f} + "
111
+ f"syntax={syntax_progress:.3f} + "
112
+ f"schema={schema_bonus:.3f} + "
113
+ f"penalty={penalty:.3f} = {total:.4f}"
114
+ )
115
+
116
+ return SQLDebugReward(
117
+ value=total,
118
+ correctness=correctness,
119
+ efficiency=efficiency,
120
+ syntax_progress=syntax_progress,
121
+ schema_bonus=schema_bonus,
122
+ penalty=penalty,
123
+ breakdown=breakdown
124
+ )
125
+
server/tasks/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # sql-debug-env
2
+
server/tasks/base.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base class for all SQL Debug tasks."""
2
+ from abc import ABC, abstractmethod
3
+ from typing import Dict, Any, List, Optional, Tuple
4
+
5
+
6
+ class BaseTask(ABC):
7
+ """
8
+ Abstract base for all tasks.
9
+
10
+ Each task defines:
11
+ - A broken SQL query (the one the agent must fix)
12
+ - A database schema (SQLite CREATE TABLE statements)
13
+ - Seed data (INSERT statements, deterministic)
14
+ - Expected output (what the correct query should return)
15
+ - A grader (compares agent output vs expected)
16
+ - Metadata (id, name, difficulty, description, hint)
17
+ """
18
+
19
+ @property
20
+ @abstractmethod
21
+ def task_id(self) -> str:
22
+ pass
23
+
24
+ @property
25
+ @abstractmethod
26
+ def name(self) -> str:
27
+ pass
28
+
29
+ @property
30
+ @abstractmethod
31
+ def difficulty(self) -> str:
32
+ pass # "easy", "medium", "hard"
33
+
34
+ @property
35
+ @abstractmethod
36
+ def description(self) -> str:
37
+ """Natural language description given to the agent."""
38
+ pass
39
+
40
+ @property
41
+ @abstractmethod
42
+ def expected_output_description(self) -> str:
43
+ """Describes what the correct output looks like."""
44
+ pass
45
+
46
+ @property
47
+ @abstractmethod
48
+ def broken_query(self) -> str:
49
+ """The SQL query with bugs that the agent must fix."""
50
+ pass
51
+
52
+ @property
53
+ @abstractmethod
54
+ def schema_sql(self) -> str:
55
+ """SQLite CREATE TABLE statements."""
56
+ pass
57
+
58
+ @property
59
+ @abstractmethod
60
+ def seed_data_sql(self) -> str:
61
+ """INSERT statements for deterministic test data."""
62
+ pass
63
+
64
+ @property
65
+ @abstractmethod
66
+ def expected_output(self) -> List[Dict[str, Any]]:
67
+ """
68
+ The exact rows the correct query should return.
69
+ Used by the grader to score the agent's output.
70
+ Must be deterministic and match seed_data_sql exactly.
71
+ """
72
+ pass
73
+
74
+ @property
75
+ def hint(self) -> str:
76
+ """Optional hint shown after N steps. Override in subclass."""
77
+ return ""
78
+
79
+ @property
80
+ def max_steps(self) -> int:
81
+ """Maximum steps for this task."""
82
+ return {"easy": 10, "medium": 20, "hard": 30}.get(self.difficulty, 20)
83
+
84
+ def grade(self, actual_rows: Optional[List[Dict[str, Any]]]) -> float:
85
+ """
86
+ Grade the agent's query output vs expected output.
87
+ Returns a score 0.0-1.0.
88
+
89
+ Scoring:
90
+ - 1.0: exact match (correct rows, correct order if ORDER BY expected)
91
+ - 0.5-0.9: partial match (subset of correct rows, or wrong order)
92
+ - 0.1-0.4: syntactically valid but wrong content
93
+ - 0.0: null result, syntax error, or empty when non-empty expected
94
+ """
95
+ if not actual_rows:
96
+ return 0.0
97
+
98
+ expected = self.expected_output
99
+
100
+ if not expected:
101
+ # Expected empty result
102
+ return 1.0 if len(actual_rows) == 0 else 0.0
103
+
104
+ # Exact row count match
105
+ if len(actual_rows) != len(expected):
106
+ # Partial credit for getting some rows right
107
+ overlap = self._count_matching_rows(actual_rows, expected)
108
+ return round(min(0.5, overlap / max(len(expected), 1) * 0.5), 3)
109
+
110
+ # Check row-by-row match (order-sensitive if task requires it)
111
+ matching = self._count_matching_rows(actual_rows, expected)
112
+ score = matching / len(expected)
113
+
114
+ # Check column names match
115
+ if actual_rows and expected:
116
+ actual_cols = set(actual_rows[0].keys())
117
+ expected_cols = set(expected[0].keys())
118
+ if actual_cols != expected_cols:
119
+ score *= 0.7 # Penalty for wrong columns
120
+
121
+ return round(score, 3)
122
+
123
+ def _count_matching_rows(
124
+ self,
125
+ actual: List[Dict[str, Any]],
126
+ expected: List[Dict[str, Any]]
127
+ ) -> int:
128
+ """Count how many actual rows match expected rows (normalized comparison)."""
129
+ matches = 0
130
+ expected_normalized = [self._normalize_row(r) for r in expected]
131
+
132
+ for i, actual_row in enumerate(actual):
133
+ actual_norm = self._normalize_row(actual_row)
134
+ if i < len(expected_normalized):
135
+ # Positional match (respects ORDER BY)
136
+ if actual_norm == expected_normalized[i]:
137
+ matches += 1
138
+ else:
139
+ # Extra rows don't count
140
+ break
141
+
142
+ return matches
143
+
144
+ def _normalize_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
145
+ """Normalize a row for comparison: lowercase keys, string-normalize values."""
146
+ normalized = {}
147
+ for k, v in row.items():
148
+ key = k.lower().strip()
149
+ if isinstance(v, float):
150
+ val = round(v, 2)
151
+ elif isinstance(v, str):
152
+ val = v.strip()
153
+ else:
154
+ val = v
155
+ normalized[key] = val
156
+ return normalized
157
+
158
+ def to_dict(self) -> Dict[str, Any]:
159
+ return {
160
+ "task_id": self.task_id,
161
+ "name": self.name,
162
+ "difficulty": self.difficulty,
163
+ "description": self.description,
164
+ "expected_output_description": self.expected_output_description,
165
+ "broken_query": self.broken_query,
166
+ "max_steps": self.max_steps,
167
+ "hint": self.hint
168
+ }
169
+
server/tasks/task_easy.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TASK 1 — EASY: Syntax Error Fix
3
+ Difficulty: Easy
4
+ Bug type: Simple syntax errors (typo in keyword, missing alias, wrong column name)
5
+ Max steps: 10
6
+ Expected baseline model score: 0.8-1.0
7
+ """
8
+ from typing import List, Dict, Any
9
+ from .base import BaseTask
10
+
11
+
12
+ class EasyTask(BaseTask):
13
+ """
14
+ Scenario: An e-commerce company wants to find the top 5 customers
15
+ by total order value. The query has a syntax error:
16
+ uses 'GRUP BY' instead of 'GROUP BY' and references wrong column alias.
17
+
18
+ Database: customers, orders, order_items
19
+ Bug 1: 'GRUP BY' typo
20
+ Bug 2: ORDER BY references 'total' but SELECT aliases it as 'total_value'
21
+ """
22
+
23
+ @property
24
+ def task_id(self) -> str:
25
+ return "easy_syntax_fix"
26
+
27
+ @property
28
+ def name(self) -> str:
29
+ return "Top Customers by Revenue — Syntax Error Fix"
30
+
31
+ @property
32
+ def difficulty(self) -> str:
33
+ return "easy"
34
+
35
+ @property
36
+ def description(self) -> str:
37
+ return """You are debugging a SQL query for an e-commerce analytics dashboard.
38
+
39
+ The query is supposed to find the top 5 customers by their total order value
40
+ (sum of quantity * unit_price across all their orders).
41
+
42
+ The query has 2 syntax/reference bugs that prevent it from running:
43
+ 1. A typo in a SQL keyword
44
+ 2. An ORDER BY clause that references a column alias incorrectly
45
+
46
+ Fix both bugs so the query runs and returns the correct result.
47
+
48
+ The result should show: customer_name, total_value (rounded to 2 decimal places),
49
+ ordered from highest to lowest, top 5 only."""
50
+
51
+ @property
52
+ def expected_output_description(self) -> str:
53
+ return "5 rows: customer_name, total_value (DESC order). Alice Chen should be first with 2847.50."
54
+
55
+ @property
56
+ def broken_query(self) -> str:
57
+ return """SELECT
58
+ c.name AS customer_name,
59
+ ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_value
60
+ FROM customers c
61
+ JOIN orders o ON c.id = o.customer_id
62
+ JOIN order_items oi ON o.id = oi.order_id
63
+ GRUP BY c.id, c.name
64
+ ORDER BY total DESC
65
+ LIMIT 5"""
66
+
67
+ @property
68
+ def schema_sql(self) -> str:
69
+ return """
70
+ CREATE TABLE customers (
71
+ id INTEGER PRIMARY KEY,
72
+ name TEXT NOT NULL,
73
+ email TEXT UNIQUE NOT NULL,
74
+ created_at TEXT DEFAULT CURRENT_TIMESTAMP
75
+ );
76
+
77
+ CREATE TABLE orders (
78
+ id INTEGER PRIMARY KEY,
79
+ customer_id INTEGER NOT NULL,
80
+ order_date TEXT NOT NULL,
81
+ status TEXT DEFAULT 'completed',
82
+ FOREIGN KEY (customer_id) REFERENCES customers(id)
83
+ );
84
+
85
+ CREATE TABLE order_items (
86
+ id INTEGER PRIMARY KEY,
87
+ order_id INTEGER NOT NULL,
88
+ product_name TEXT NOT NULL,
89
+ quantity INTEGER NOT NULL,
90
+ unit_price REAL NOT NULL,
91
+ FOREIGN KEY (order_id) REFERENCES orders(id)
92
+ )"""
93
+
94
+ @property
95
+ def seed_data_sql(self) -> str:
96
+ return """
97
+ INSERT INTO customers VALUES (1,'Alice Chen','alice@example.com','2023-01-01');
98
+ INSERT INTO customers VALUES (2,'Bob Kumar','bob@example.com','2023-01-05');
99
+ INSERT INTO customers VALUES (3,'Carol White','carol@example.com','2023-01-10');
100
+ INSERT INTO customers VALUES (4,'David Park','david@example.com','2023-02-01');
101
+ INSERT INTO customers VALUES (5,'Eva Rodriguez','eva@example.com','2023-02-15');
102
+ INSERT INTO customers VALUES (6,'Frank Liu','frank@example.com','2023-03-01');
103
+
104
+ INSERT INTO orders VALUES (1,1,'2023-06-01','completed');
105
+ INSERT INTO orders VALUES (2,1,'2023-07-15','completed');
106
+ INSERT INTO orders VALUES (3,2,'2023-06-10','completed');
107
+ INSERT INTO orders VALUES (4,3,'2023-06-20','completed');
108
+ INSERT INTO orders VALUES (5,3,'2023-08-01','completed');
109
+ INSERT INTO orders VALUES (6,4,'2023-07-01','completed');
110
+ INSERT INTO orders VALUES (7,5,'2023-07-20','completed');
111
+ INSERT INTO orders VALUES (8,5,'2023-08-10','completed');
112
+ INSERT INTO orders VALUES (9,6,'2023-09-01','completed');
113
+
114
+ INSERT INTO order_items VALUES (1,1,'Laptop',1,1200.00);
115
+ INSERT INTO order_items VALUES (2,1,'Mouse',2,25.00);
116
+ INSERT INTO order_items VALUES (3,2,'Keyboard',1,150.00);
117
+ INSERT INTO order_items VALUES (4,2,'Monitor',1,450.00);
118
+ INSERT INTO order_items VALUES (5,2,'Webcam',1,97.50);
119
+ INSERT INTO order_items VALUES (6,3,'Headphones',1,350.00);
120
+ INSERT INTO order_items VALUES (7,3,'USB Hub',2,45.00);
121
+ INSERT INTO order_items VALUES (8,4,'Tablet',1,600.00);
122
+ INSERT INTO order_items VALUES (9,4,'Case',1,35.00);
123
+ INSERT INTO order_items VALUES (10,5,'Charger',2,30.00);
124
+ INSERT INTO order_items VALUES (11,5,'Cable',3,15.00);
125
+ INSERT INTO order_items VALUES (12,6,'Desk Lamp',1,85.00);
126
+ INSERT INTO order_items VALUES (13,6,'Chair Mat',1,60.00);
127
+ INSERT INTO order_items VALUES (14,7,'Speakers',1,220.00);
128
+ INSERT INTO order_items VALUES (15,7,'Microphone',1,180.00);
129
+ INSERT INTO order_items VALUES (16,8,'Webcam',1,97.50);
130
+ INSERT INTO order_items VALUES (17,9,'Monitor',1,450.00)"""
131
+
132
+ @property
133
+ def expected_output(self) -> List[Dict[str, Any]]:
134
+ # Alice: 1200+50+150+450+97.50 = 1947.50 (orders 1,2)
135
+ # Wait: recalculate
136
+ # Alice order 1: laptop 1200 + mouse 2*25=50 = 1250
137
+ # Alice order 2: keyboard 150 + monitor 450 + webcam 97.50 = 697.50
138
+ # Alice total: 1947.50 — but let me recalculate with all items
139
+ # Actually: 1200+50+150+450+97.50 = 1947.50
140
+ # Carol: tablet 600 + case 35 + charger 60 + cable 45 = 740
141
+ # Eva: speakers 220 + micro 180 + webcam 97.50 = 497.50
142
+ # Bob: headphones 350 + hub 90 = 440
143
+ # Frank: lamp 85 + mat 60 + monitor 450 = 595
144
+ # David: lamp 85 + mat 60 = 145 — wait David is order 6
145
+ # Order 6 items 12,13: lamp 85 + mat 60 = 145
146
+ return [
147
+ {"customer_name": "Alice Chen", "total_value": 1947.50},
148
+ {"customer_name": "Carol White", "total_value": 740.00},
149
+ {"customer_name": "Frank Liu", "total_value": 595.00},
150
+ {"customer_name": "Eva Rodriguez", "total_value": 497.50},
151
+ {"customer_name": "Bob Kumar", "total_value": 440.00},
152
+ ]
153
+
154
+ @property
155
+ def hint(self) -> str:
156
+ return "Hint: Check every SQL keyword spelling carefully. Also check that your ORDER BY column name exactly matches the alias in your SELECT clause."
157
+
server/tasks/task_hard.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TASK 3 — HARD: Multi-bug + Optimization
3
+ Difficulty: Hard
4
+ Bug types:
5
+ 1. Correlated subquery returns wrong scope
6
+ 2. Window function partition incorrect
7
+ 3. CTE has circular logic bug
8
+ 4. Off-by-one in date range
9
+ 5. Missing DISTINCT causing row duplication
10
+ Max steps: 30
11
+ Expected baseline model score: 0.0-0.3 (frontier models barely pass)
12
+ """
13
+ from typing import List, Dict, Any
14
+ from .base import BaseTask
15
+
16
+
17
+ class HardTask(BaseTask):
18
+ """
19
+ Scenario: SaaS product analytics — find users who:
20
+ 1. Signed up in Q1 2023 (Jan 1 – Mar 31)
21
+ 2. Made at least 2 purchases in their first 30 days
22
+ 3. Return their: user_id, username, signup_date,
23
+ first_purchase_date, days_to_first_purchase,
24
+ purchases_in_first_30_days, total_lifetime_value
25
+
26
+ Bugs:
27
+ 1. Date range is '>= 2023-01-01 AND < 2023-04-01' but query uses '<= 2023-03-31'
28
+ (off by 1 for timestamps — in SQLite string comparison this is actually fine,
29
+ but the REAL bug is the upper bound uses wrong column: filters on purchase_date
30
+ instead of signup_date in the CTE)
31
+ 2. The window function for running total uses PARTITION BY user_id but
32
+ ORDER BY is missing — gives wrong cumulative values
33
+ 3. HAVING clause uses COUNT(*) but should use COUNT(DISTINCT purchase_id)
34
+ due to JOIN multiplication
35
+ 4. The subquery for first_purchase_date is not correlated properly
36
+ (missing WHERE p.user_id = u.id)
37
+ 5. days_to_first_purchase calculation uses wrong date subtraction direction
38
+ """
39
+
40
+ @property
41
+ def task_id(self) -> str:
42
+ return "hard_multi_bug"
43
+
44
+ @property
45
+ def name(self) -> str:
46
+ return "SaaS Cohort Activation Report — Multi-Bug Fix"
47
+
48
+ @property
49
+ def difficulty(self) -> str:
50
+ return "hard"
51
+
52
+ @property
53
+ def description(self) -> str:
54
+ return """You are debugging a SaaS product analytics query.
55
+
56
+ The query should identify "activated users": users who signed up in Q1 2023
57
+ AND made at least 2 purchases within their first 30 days of signup.
58
+
59
+ For each activated user, return:
60
+ - user_id (INTEGER)
61
+ - username (TEXT)
62
+ - signup_date (TEXT, YYYY-MM-DD)
63
+ - first_purchase_date (TEXT, YYYY-MM-DD)
64
+ - days_to_first_purchase (INTEGER, how many days after signup they first purchased)
65
+ - purchases_in_first_30_days (INTEGER)
66
+ - total_lifetime_value (REAL, sum of all their purchases ever, rounded to 2 dp)
67
+
68
+ Results ordered by total_lifetime_value DESC.
69
+
70
+ The query has FIVE bugs — some are logic errors, one is a missing correlation
71
+ in a subquery, one is an incorrect window function, one causes row duplication.
72
+ You must find and fix all of them to get the correct result.
73
+
74
+ Q1 2023 = signup_date >= '2023-01-01' AND signup_date <= '2023-03-31'"""
75
+
76
+ @property
77
+ def expected_output_description(self) -> str:
78
+ return "2 rows: users who made 2+ purchases in first 30 days. Maya Torres first (higher LTV), then James Osei."
79
+
80
+ @property
81
+ def broken_query(self) -> str:
82
+ return """WITH q1_users AS (
83
+ SELECT DISTINCT u.id, u.username, u.signup_date
84
+ FROM users u
85
+ JOIN purchases p ON u.id = p.user_id
86
+ WHERE u.signup_date >= '2023-01-01'
87
+ AND u.signup_date <= '2023-03-31'
88
+ AND p.purchase_date <= '2023-03-31'
89
+ ),
90
+ user_purchase_stats AS (
91
+ SELECT
92
+ q.id AS user_id,
93
+ q.username,
94
+ q.signup_date,
95
+ (SELECT MIN(purchase_date) FROM purchases WHERE amount > 0) AS first_purchase_date,
96
+ COUNT(*) AS purchases_in_first_30_days,
97
+ SUM(SUM(p.amount)) OVER (PARTITION BY q.id) AS total_lifetime_value
98
+ FROM q1_users q
99
+ JOIN purchases p ON q.id = p.user_id
100
+ WHERE julianday(p.purchase_date) - julianday(q.signup_date) <= 30
101
+ GROUP BY q.id, q.username, q.signup_date
102
+ )
103
+ SELECT
104
+ user_id,
105
+ username,
106
+ signup_date,
107
+ first_purchase_date,
108
+ CAST(julianday(q1_users.signup_date) - julianday(first_purchase_date) AS INTEGER) AS days_to_first_purchase,
109
+ purchases_in_first_30_days,
110
+ ROUND(total_lifetime_value, 2) AS total_lifetime_value
111
+ FROM user_purchase_stats
112
+ WHERE purchases_in_first_30_days >= 2
113
+ ORDER BY total_lifetime_value DESC"""
114
+
115
+ @property
116
+ def schema_sql(self) -> str:
117
+ return """
118
+ CREATE TABLE users (
119
+ id INTEGER PRIMARY KEY,
120
+ username TEXT NOT NULL,
121
+ email TEXT UNIQUE,
122
+ signup_date TEXT NOT NULL,
123
+ plan TEXT DEFAULT 'free'
124
+ );
125
+
126
+ CREATE TABLE purchases (
127
+ id INTEGER PRIMARY KEY,
128
+ user_id INTEGER NOT NULL,
129
+ product_name TEXT NOT NULL,
130
+ amount REAL NOT NULL,
131
+ purchase_date TEXT NOT NULL,
132
+ FOREIGN KEY (user_id) REFERENCES users(id)
133
+ )"""
134
+
135
+ @property
136
+ def seed_data_sql(self) -> str:
137
+ return """
138
+ INSERT INTO users VALUES (1,'maya_torres','maya@ex.com','2023-01-15','pro');
139
+ INSERT INTO users VALUES (2,'james_osei','james@ex.com','2023-02-10','pro');
140
+ INSERT INTO users VALUES (3,'sophie_liang','sophie@ex.com','2023-03-05','free');
141
+ INSERT INTO users VALUES (4,'raj_mehta','raj@ex.com','2023-06-01','free');
142
+ INSERT INTO users VALUES (5,'anna_kovacs','anna@ex.com','2022-12-20','pro');
143
+
144
+ -- Maya: 2 purchases in first 30 days (days 5 and 18), more later
145
+ INSERT INTO purchases VALUES (1,1,'Pro Plan',99.00,'2023-01-20');
146
+ INSERT INTO purchases VALUES (2,1,'Add-on Pack',29.00,'2023-02-02');
147
+ INSERT INTO purchases VALUES (3,1,'Pro Renewal',99.00,'2023-04-15');
148
+ INSERT INTO purchases VALUES (4,1,'Consulting',150.00,'2023-07-01');
149
+
150
+ -- James: 2 purchases in first 30 days (days 3 and 25)
151
+ INSERT INTO purchases VALUES (5,2,'Starter Plan',49.00,'2023-02-13');
152
+ INSERT INTO purchases VALUES (6,2,'Storage Add-on',19.00,'2023-03-07');
153
+ INSERT INTO purchases VALUES (7,2,'Starter Renewal',49.00,'2023-05-10');
154
+
155
+ -- Sophie: only 1 purchase in first 30 days (should NOT qualify)
156
+ INSERT INTO purchases VALUES (8,3,'Free Trial Upgrade',9.00,'2023-03-10');
157
+ INSERT INTO purchases VALUES (9,3,'Pro Plan',99.00,'2023-04-20');
158
+
159
+ -- Raj: signed up Q2, not Q1 (should NOT qualify)
160
+ INSERT INTO purchases VALUES (10,4,'Starter Plan',49.00,'2023-06-05');
161
+ INSERT INTO purchases VALUES (11,4,'Add-on',19.00,'2023-06-10');
162
+
163
+ -- Anna: signed up Q4 2022, not Q1 2023 (should NOT qualify)
164
+ INSERT INTO purchases VALUES (12,5,'Pro Plan',99.00,'2023-01-01');
165
+ INSERT INTO purchases VALUES (13,5,'Consulting',150.00,'2023-03-15')"""
166
+
167
+ @property
168
+ def expected_output(self) -> List[Dict[str, Any]]:
169
+ # Maya: signup 2023-01-15, first purchase 2023-01-20 (day 5)
170
+ # purchases in 30 days: Jan-20 (day5), Feb-02 (day18) = 2 ✓
171
+ # total LTV: 99+29+99+150 = 377
172
+ # James: signup 2023-02-10, first purchase 2023-02-13 (day 3)
173
+ # purchases in 30 days: Feb-13 (day3), Mar-07 (day25) = 2 ✓
174
+ # total LTV: 49+19+49 = 117
175
+ return [
176
+ {
177
+ "user_id": 1,
178
+ "username": "maya_torres",
179
+ "signup_date": "2023-01-15",
180
+ "first_purchase_date": "2023-01-20",
181
+ "days_to_first_purchase": 5,
182
+ "purchases_in_first_30_days": 2,
183
+ "total_lifetime_value": 377.00
184
+ },
185
+ {
186
+ "user_id": 2,
187
+ "username": "james_osei",
188
+ "signup_date": "2023-02-10",
189
+ "first_purchase_date": "2023-02-13",
190
+ "days_to_first_purchase": 3,
191
+ "purchases_in_first_30_days": 2,
192
+ "total_lifetime_value": 117.00
193
+ }
194
+ ]
195
+
196
+ @property
197
+ def hint(self) -> str:
198
+ return "Hint: There are 5 bugs total. Check: (1) the subquery for first_purchase_date needs a WHERE correlation, (2) the date subtraction direction for days_to_first_purchase, (3) COUNT(*) vs COUNT(DISTINCT) when JOINs can multiply rows, (4) window functions need ORDER BY for meaningful results, (5) the q1_users CTE may be filtering on the wrong table's date column."
199
+
server/tasks/task_medium.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TASK 2 — MEDIUM: Logic Error Fix
3
+ Difficulty: Medium
4
+ Bug types: Wrong JOIN type causing missing rows, incorrect aggregation logic,
5
+ missing HAVING clause, wrong date filter
6
+ Max steps: 20
7
+ Expected baseline model score: 0.3-0.6
8
+ """
9
+ from typing import List, Dict, Any
10
+ from .base import BaseTask
11
+
12
+
13
+ class MediumTask(BaseTask):
14
+ """
15
+ Scenario: HR analytics team wants monthly headcount and average salary
16
+ by department for the current year, including departments with zero employees
17
+ (i.e., departments that exist but no one joined this year).
18
+
19
+ Bugs:
20
+ 1. Uses INNER JOIN instead of LEFT JOIN — excludes empty departments
21
+ 2. Uses AVG(salary) over all employees instead of only those who joined this year
22
+ 3. Missing: the date filter for 'this year' is applied in WHERE, breaking the LEFT JOIN
23
+ (should be in ON clause or use CASE)
24
+ 4. GROUP BY missing department_id (ambiguous grouping)
25
+ """
26
+
27
+ @property
28
+ def task_id(self) -> str:
29
+ return "medium_logic_fix"
30
+
31
+ @property
32
+ def name(self) -> str:
33
+ return "Department Headcount Report — Logic Error Fix"
34
+
35
+ @property
36
+ def difficulty(self) -> str:
37
+ return "medium"
38
+
39
+ @property
40
+ def description(self) -> str:
41
+ return """You are debugging a HR analytics SQL query.
42
+
43
+ The query should produce a monthly department headcount report showing:
44
+ - department_name
45
+ - headcount: number of employees who joined IN 2023
46
+ - avg_salary: average salary of employees who joined IN 2023
47
+ - All departments must appear, even those with 0 new hires in 2023
48
+
49
+ The current query has 3 logic bugs:
50
+ 1. It uses the wrong JOIN type, which silently drops departments with no 2023 hires
51
+ 2. The WHERE clause on hire_date breaks the outer join semantics
52
+ 3. The AVG calculation includes employees from all years, not just 2023
53
+
54
+ Fix these logic errors. The result should be ordered by department_name ascending."""
55
+
56
+ @property
57
+ def expected_output_description(self) -> str:
58
+ return "4 rows (all departments), headcount=0 for 'Legal', correct avg_salary only from 2023 hires."
59
+
60
+ @property
61
+ def broken_query(self) -> str:
62
+ return """SELECT
63
+ d.name AS department_name,
64
+ COUNT(e.id) AS headcount,
65
+ ROUND(AVG(e.salary), 2) AS avg_salary
66
+ FROM departments d
67
+ INNER JOIN employees e ON d.id = e.department_id
68
+ WHERE strftime('%Y', e.hire_date) = '2023'
69
+ GROUP BY d.name
70
+ ORDER BY department_name ASC"""
71
+
72
+ @property
73
+ def schema_sql(self) -> str:
74
+ return """
75
+ CREATE TABLE departments (
76
+ id INTEGER PRIMARY KEY,
77
+ name TEXT NOT NULL,
78
+ budget REAL
79
+ );
80
+
81
+ CREATE TABLE employees (
82
+ id INTEGER PRIMARY KEY,
83
+ name TEXT NOT NULL,
84
+ department_id INTEGER NOT NULL,
85
+ salary REAL NOT NULL,
86
+ hire_date TEXT NOT NULL,
87
+ FOREIGN KEY (department_id) REFERENCES departments(id)
88
+ )"""
89
+
90
+ @property
91
+ def seed_data_sql(self) -> str:
92
+ return """
93
+ INSERT INTO departments VALUES (1,'Engineering',500000);
94
+ INSERT INTO departments VALUES (2,'Marketing',200000);
95
+ INSERT INTO departments VALUES (3,'Sales',300000);
96
+ INSERT INTO departments VALUES (4,'Legal',150000);
97
+
98
+ INSERT INTO employees VALUES (1,'Ana Lima',1,95000,'2023-03-15');
99
+ INSERT INTO employees VALUES (2,'Ben Sharma',1,102000,'2023-06-01');
100
+ INSERT INTO employees VALUES (3,'Chris Wang',1,88000,'2022-01-10');
101
+ INSERT INTO employees VALUES (4,'Diana Patel',2,72000,'2023-04-20');
102
+ INSERT INTO employees VALUES (5,'Erik Johnson',2,68000,'2022-11-05');
103
+ INSERT INTO employees VALUES (6,'Fatima Al-Hassan',3,55000,'2023-01-08');
104
+ INSERT INTO employees VALUES (7,'George Okafor',3,61000,'2023-07-22');
105
+ INSERT INTO employees VALUES (8,'Hannah Kim',3,58000,'2022-05-30');
106
+ INSERT INTO employees VALUES (9,'Ivan Petrov',1,91000,'2022-08-14')"""
107
+
108
+ @property
109
+ def expected_output(self) -> List[Dict[str, Any]]:
110
+ # Engineering 2023 hires: Ana 95000, Ben 102000 → count=2, avg=98500
111
+ # Marketing 2023 hires: Diana 72000 → count=1, avg=72000
112
+ # Sales 2023 hires: Fatima 55000, George 61000 → count=2, avg=58000
113
+ # Legal 2023 hires: none → count=0, avg=NULL
114
+ return [
115
+ {"department_name": "Engineering", "headcount": 2, "avg_salary": 98500.00},
116
+ {"department_name": "Legal", "headcount": 0, "avg_salary": None},
117
+ {"department_name": "Marketing", "headcount": 1, "avg_salary": 72000.00},
118
+ {"department_name": "Sales", "headcount": 2, "avg_salary": 58000.00},
119
+ ]
120
+
121
+ @property
122
+ def hint(self) -> str:
123
+ return "Hint: When you want ALL rows from the left table even when there's no match on the right, think about which JOIN type preserves those rows. Also, WHERE on a nullable column after a join changes join semantics — consider moving that condition."
124
+
125
+
126
+ class MediumTaskGrader:
127
+ """
128
+ Custom grader for medium task — handles NULL comparison.
129
+ """
130
+ @staticmethod
131
+ def grade(actual: List[Dict]) -> float:
132
+ if not actual or len(actual) != 4:
133
+ return 0.0
134
+
135
+ # Sort both by dept name for comparison
136
+ actual_sorted = sorted(actual, key=lambda r: r.get("department_name", ""))
137
+ expected = [
138
+ {"department_name": "Engineering", "headcount": 2, "avg_salary": 98500.00},
139
+ {"department_name": "Legal", "headcount": 0, "avg_salary": None},
140
+ {"department_name": "Marketing", "headcount": 1, "avg_salary": 72000.00},
141
+ {"department_name": "Sales", "headcount": 2, "avg_salary": 58000.00},
142
+ ]
143
+
144
+ matches = 0
145
+ for a, e in zip(actual_sorted, expected):
146
+ dept_ok = str(a.get("department_name","")).lower() == str(e["department_name"]).lower()
147
+ count_ok = int(a.get("headcount", -1)) == e["headcount"]
148
+
149
+ e_salary = e["avg_salary"]
150
+ a_salary = a.get("avg_salary")
151
+ if e_salary is None:
152
+ salary_ok = a_salary is None or a_salary == 0
153
+ else:
154
+ try:
155
+ salary_ok = abs(float(a_salary) - float(e_salary)) < 1.0
156
+ except (TypeError, ValueError):
157
+ salary_ok = False
158
+
159
+ if dept_ok and count_ok and salary_ok:
160
+ matches += 1
161
+
162
+ return round(matches / 4, 3)
163
+
tests/test_env.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import unittest
3
+
4
+ from server.env import SQLDebugEnv
5
+ from server.models import SQLDebugAction, ActionType
6
+
7
+
8
+ class TestEnv(unittest.TestCase):
9
+ def test_reset_and_inspect_schema(self):
10
+ async def run():
11
+ env = SQLDebugEnv(task_id="easy_syntax_fix")
12
+ obs, info = await env.reset()
13
+ self.assertFalse(obs.is_done)
14
+
15
+ action = SQLDebugAction(action_type=ActionType.INSPECT_SCHEMA)
16
+ obs2, reward, done, info2 = await env.step(action)
17
+ self.assertFalse(done)
18
+ self.assertIsNotNone(obs2.schema_info)
19
+ self.assertGreaterEqual(reward, 0.0)
20
+
21
+ asyncio.run(run())
22
+
23
+ def test_submit_broken_query_does_not_finish(self):
24
+ async def run():
25
+ env = SQLDebugEnv(task_id="easy_syntax_fix")
26
+ obs, _ = await env.reset()
27
+
28
+ action = SQLDebugAction(
29
+ action_type=ActionType.SUBMIT_QUERY,
30
+ query=env.task.broken_query,
31
+ )
32
+ obs2, reward, done, _ = await env.step(action)
33
+
34
+ self.assertFalse(done)
35
+ self.assertLessEqual(reward, 0.2)
36
+ self.assertGreaterEqual(reward, -1.0)
37
+ self.assertEqual(obs2.current_query, env.task.broken_query)
38
+
39
+ asyncio.run(run())
40
+
41
+
42
+ if __name__ == "__main__":
43
+ unittest.main()
44
+
tests/test_graders.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from server.tasks.task_easy import EasyTask
4
+ from server.tasks.task_medium import MediumTask, MediumTaskGrader
5
+ from server.tasks.task_hard import HardTask
6
+
7
+
8
+ class TestGraders(unittest.TestCase):
9
+ def test_easy_grade_perfect(self):
10
+ task = EasyTask()
11
+ score = task.grade(task.expected_output)
12
+ self.assertAlmostEqual(score, 1.0, places=3)
13
+
14
+ def test_hard_grade_perfect(self):
15
+ task = HardTask()
16
+ score = task.grade(task.expected_output)
17
+ self.assertAlmostEqual(score, 1.0, places=3)
18
+
19
+ def test_easy_grade_empty(self):
20
+ task = EasyTask()
21
+ score = task.grade(None)
22
+ self.assertEqual(score, 0.0)
23
+
24
+ def test_medium_grader_perfect(self):
25
+ task = MediumTask()
26
+ score = MediumTaskGrader.grade(task.expected_output)
27
+ self.assertAlmostEqual(score, 1.0, places=3)
28
+
29
+ def test_medium_grader_partial(self):
30
+ # Flip one row's avg_salary so it no longer matches within tolerance.
31
+ task = MediumTask()
32
+ actual = [dict(r) for r in task.expected_output]
33
+
34
+ # Expected avg_salary is None for "Legal". Any non-None/non-zero value should fail.
35
+ for r in actual:
36
+ if r["department_name"] == "Legal":
37
+ r["avg_salary"] = 12345.0
38
+
39
+ score = MediumTaskGrader.grade(actual)
40
+ self.assertLess(score, 1.0)
41
+ self.assertAlmostEqual(score, 0.75, places=3)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ unittest.main()
46
+
tests/test_reward.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from server.reward import compute_reward
4
+
5
+
6
+ class TestReward(unittest.TestCase):
7
+ def test_submit_query_perfect_reward(self):
8
+ reward = compute_reward(
9
+ action_type="submit_query",
10
+ query_result={"success": True},
11
+ grade_score=1.0,
12
+ steps_taken=1,
13
+ max_steps=10,
14
+ previous_best_score=0.0,
15
+ schema_tables=["t1", "t2"],
16
+ submitted_query="SELECT * FROM t1 JOIN t2",
17
+ )
18
+ self.assertAlmostEqual(reward.value, 1.0, places=4)
19
+
20
+ def test_reset_query_penalty(self):
21
+ reward = compute_reward(
22
+ action_type="reset_query",
23
+ query_result=None,
24
+ grade_score=0.0,
25
+ steps_taken=1,
26
+ max_steps=10,
27
+ previous_best_score=0.0,
28
+ schema_tables=[],
29
+ submitted_query=None,
30
+ )
31
+ self.assertAlmostEqual(reward.value, 0.0, places=4)
32
+
33
+ def test_inspect_schema_urgency_penalty(self):
34
+ # Make steps_remaining <= 2 and grade_score < 0.5 to trigger urgency penalty.
35
+ reward = compute_reward(
36
+ action_type="inspect_schema",
37
+ query_result=None,
38
+ grade_score=0.0,
39
+ steps_taken=8,
40
+ max_steps=9,
41
+ previous_best_score=0.0,
42
+ schema_tables=[],
43
+ submitted_query=None,
44
+ )
45
+ # syntax_progress=0.01, penalty=0.03 => total_raw=-0.02, clamped to 0.0
46
+ self.assertAlmostEqual(reward.value, 0.0, places=4)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ unittest.main()
51
+
uv.lock ADDED
The diff for this file is too large to render. See raw diff