Spaces:
Running
Running
Initial OpenEnv SQL debug environment
Browse files- .dockerignore +14 -0
- .env.example +6 -0
- .gitignore +19 -0
- Dockerfile +31 -0
- README.md +193 -10
- inference.py +328 -0
- openenv.yaml +104 -0
- pyproject.toml +21 -0
- requirements.txt +8 -0
- scripts/benchmark_local.py +63 -0
- server/__init__.py +2 -0
- server/app.py +20 -0
- server/database.py +112 -0
- server/env.py +236 -0
- server/main.py +242 -0
- server/models.py +138 -0
- server/reward.py +125 -0
- server/tasks/__init__.py +2 -0
- server/tasks/base.py +169 -0
- server/tasks/task_easy.py +157 -0
- server/tasks/task_hard.py +199 -0
- server/tasks/task_medium.py +163 -0
- tests/test_env.py +44 -0
- tests/test_graders.py +46 -0
- tests/test_reward.py +51 -0
- uv.lock +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.pytest_cache/
|
| 4 |
+
.mypy_cache/
|
| 5 |
+
.ruff_cache/
|
| 6 |
+
.DS_Store
|
| 7 |
+
.git/
|
| 8 |
+
.gitignore
|
| 9 |
+
.env
|
| 10 |
+
.env.*
|
| 11 |
+
!.env.example
|
| 12 |
+
.venv/
|
| 13 |
+
.cursor/
|
| 14 |
+
|
.env.example
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPENAI_API_KEY=
|
| 2 |
+
HF_TOKEN=
|
| 3 |
+
API_BASE_URL=https://api.openai.com/v1
|
| 4 |
+
MODEL_NAME=gpt-4o-mini
|
| 5 |
+
ENV_BASE_URL=http://localhost:7860
|
| 6 |
+
|
.gitignore
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.pytest_cache/
|
| 4 |
+
.mypy_cache/
|
| 5 |
+
.ruff_cache/
|
| 6 |
+
.DS_Store
|
| 7 |
+
|
| 8 |
+
# local env / secrets
|
| 9 |
+
.env
|
| 10 |
+
.env.*
|
| 11 |
+
!.env.example
|
| 12 |
+
|
| 13 |
+
# OpenEnv / uv
|
| 14 |
+
.venv/
|
| 15 |
+
.python-version
|
| 16 |
+
|
| 17 |
+
# editor metadata
|
| 18 |
+
.cursor/
|
| 19 |
+
|
Dockerfile
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
+
curl \
|
| 8 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Copy requirements first for layer caching
|
| 11 |
+
COPY requirements.txt .
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
# Copy application code
|
| 15 |
+
COPY server/ ./server/
|
| 16 |
+
COPY openenv.yaml .
|
| 17 |
+
|
| 18 |
+
# Create non-root user for security
|
| 19 |
+
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
|
| 20 |
+
USER appuser
|
| 21 |
+
|
| 22 |
+
# Expose port
|
| 23 |
+
EXPOSE 7860
|
| 24 |
+
|
| 25 |
+
# Health check
|
| 26 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 27 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 28 |
+
|
| 29 |
+
# Start server
|
| 30 |
+
CMD ["uvicorn", "server.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
| 31 |
+
|
README.md
CHANGED
|
@@ -1,10 +1,193 @@
|
|
| 1 |
-
--
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
--
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SQL Debug Environment (`sql-debug-env`)
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+

|
| 5 |
+

|
| 6 |
+

|
| 7 |
+

|
| 8 |
+

|
| 9 |
+
|
| 10 |
+
An OpenEnv environment for a real task people do every day: **debugging SQL**. The agent gets a broken query, a live (in-memory) SQLite database, and a description of the expected output. It can inspect schema/errors/samples and submit fixed queries until it solves the task.
|
| 11 |
+
|
| 12 |
+
## What’s in this repo
|
| 13 |
+
- **FastAPI server**: `server/main.py` (endpoints: `/health`, `/tasks`, `/reset`, `/step`, `/state`)
|
| 14 |
+
- **Environment logic**: `server/env.py` + `server/database.py`
|
| 15 |
+
- **Tasks**: `server/tasks/` (easy → medium → hard, deterministic seed data)
|
| 16 |
+
- **Baseline agent**: `inference.py` (OpenAI client + `[START]/[STEP]/[END]` logs)
|
| 17 |
+
|
| 18 |
+
## Tech Stack
|
| 19 |
+
- Python 3.11+
|
| 20 |
+
- FastAPI + Uvicorn
|
| 21 |
+
- Pydantic v2
|
| 22 |
+
- SQLite (in-memory)
|
| 23 |
+
- OpenEnv Core
|
| 24 |
+
- Docker
|
| 25 |
+
- OpenAI Python SDK (baseline inference)
|
| 26 |
+
|
| 27 |
+
## Production Notes
|
| 28 |
+
- Stateless HTTP API with per-session environment instances keyed by `X-Session-Id`
|
| 29 |
+
- Deterministic task data (in-memory SQLite) for reproducible grading
|
| 30 |
+
- Reward clamped to `[0.0, 1.0]` with partial-progress shaping
|
| 31 |
+
- Docker-first deployment path (local and Hugging Face Spaces)
|
| 32 |
+
- Local benchmark endpoint for live latency checks (`/benchmark`)
|
| 33 |
+
|
| 34 |
+
## API Docs (FastAPI Auto Docs)
|
| 35 |
+
Use these for interactive testing in browser:
|
| 36 |
+
|
| 37 |
+
- Swagger UI: `http://localhost:7860/docs`
|
| 38 |
+
- ReDoc: `http://localhost:7860/redoc`
|
| 39 |
+
- OpenAPI spec: `http://localhost:7860/openapi.json`
|
| 40 |
+
|
| 41 |
+
## Action Space
|
| 42 |
+
| Action | Required fields | Cost / reward effect |
|
| 43 |
+
|---|---|---|
|
| 44 |
+
| `submit_query` | `query` | Main evaluation step (dense reward based on grading) |
|
| 45 |
+
| `inspect_schema` | none | Free information action (small positive reward component) |
|
| 46 |
+
| `inspect_error` | none | Free information action (small positive reward component) |
|
| 47 |
+
| `inspect_sample` | `table_name` | Free information action (small positive reward component) |
|
| 48 |
+
| `reset_query` | none | Penalty action (reduces reward for that step) |
|
| 49 |
+
|
| 50 |
+
## Observation Space
|
| 51 |
+
| Field | Type |
|
| 52 |
+
|---|---|
|
| 53 |
+
| `task_id` | `string` |
|
| 54 |
+
| `task_description` | `string` |
|
| 55 |
+
| `original_query` | `string` |
|
| 56 |
+
| `current_query` | `string_or_null` |
|
| 57 |
+
| `expected_description` | `string` |
|
| 58 |
+
| `last_action_type` | `string` |
|
| 59 |
+
| `last_query_result` | `object_or_null` |
|
| 60 |
+
| `steps_taken` | `integer` |
|
| 61 |
+
| `steps_remaining` | `integer` |
|
| 62 |
+
| `current_score` | `float` |
|
| 63 |
+
| `schema_info` | `object_or_null` |
|
| 64 |
+
| `error_details` | `string_or_null` |
|
| 65 |
+
| `sample_rows` | `array_or_null` |
|
| 66 |
+
| `hint` | `string_or_null` |
|
| 67 |
+
| `is_done` | `boolean` |
|
| 68 |
+
| `success` | `boolean` |
|
| 69 |
+
|
| 70 |
+
## Reward Function
|
| 71 |
+
| Component | Range | Description |
|
| 72 |
+
|---|---|---|
|
| 73 |
+
| `correctness` | `[0.0, 0.6]` | Row-level match vs expected output |
|
| 74 |
+
| `efficiency` | `[0.0, 0.2]` | Bonus for solving with fewer steps |
|
| 75 |
+
| `syntax_progress` | `[0.0, 0.1]` | Small reward for producing syntactically valid SQL |
|
| 76 |
+
| `schema_bonus` | `[0.0, 0.1]` | Bonus for referencing correct tables/columns |
|
| 77 |
+
| `penalty` | `[0.0, 0.2]` | Deduction magnitude for resets/regressions/urgency near step limit |
|
| 78 |
+
|
| 79 |
+
## Tasks
|
| 80 |
+
### Task 1: Easy — Syntax Error Fix (`easy_syntax_fix`)
|
| 81 |
+
Two straightforward issues: a misspelled keyword (`GRUP BY`) and an `ORDER BY` alias mismatch.
|
| 82 |
+
|
| 83 |
+
### Task 2: Medium — Logic Error Fix (`medium_logic_fix`)
|
| 84 |
+
Logic bugs around outer joins + filtering scope + aggregation scope.
|
| 85 |
+
|
| 86 |
+
### Task 3: Hard — Multi-Bug Fix (`hard_multi_bug`)
|
| 87 |
+
Five bugs across correlated subqueries, window functions, CTE scope, date logic, and duplication.
|
| 88 |
+
|
| 89 |
+
## Baseline
|
| 90 |
+
The baseline script is intentionally simple: it loops `reset → step` and asks an OpenAI model to choose the next JSON action.
|
| 91 |
+
|
| 92 |
+
## Reliability & Benchmarking
|
| 93 |
+
|
| 94 |
+
### Verified status (local)
|
| 95 |
+
- `openenv validate --verbose`: **PASS**
|
| 96 |
+
- `python3 -m unittest discover -s tests -p "test_*.py"`: **10/10 PASS**
|
| 97 |
+
- Docker smoke test: **PASS** (`/health`, `/tasks`, `/reset`, `/step`)
|
| 98 |
+
- FastAPI docs available: **PASS** (`/docs`, `/redoc`, `/openapi.json`)
|
| 99 |
+
|
| 100 |
+
### Endpoint benchmark (local Docker run, n=25)
|
| 101 |
+
Measured with `scripts/benchmark_local.py` on a running local container:
|
| 102 |
+
|
| 103 |
+
| Endpoint | avg | p50 | p95 |
|
| 104 |
+
|---|---:|---:|---:|
|
| 105 |
+
| `GET /health` | 0.69 ms | 0.67 ms | 0.76 ms |
|
| 106 |
+
| `GET /tasks` | 0.82 ms | 0.81 ms | 0.90 ms |
|
| 107 |
+
| `POST /reset` | 1.34 ms | 1.26 ms | 1.62 ms |
|
| 108 |
+
| `POST /step` (`inspect_schema`) | 1.07 ms | 1.01 ms | 1.34 ms |
|
| 109 |
+
|
| 110 |
+
Re-run anytime:
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
python3 scripts/benchmark_local.py
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
Notes:
|
| 117 |
+
- These are local-machine numbers (single container, warm runtime).
|
| 118 |
+
- For submission-grade reporting, also capture one run against your HF Space URL after deploy.
|
| 119 |
+
|
| 120 |
+
## Setup & Usage
|
| 121 |
+
|
| 122 |
+
### Local Development
|
| 123 |
+
```bash
|
| 124 |
+
pip install -r requirements.txt
|
| 125 |
+
uvicorn server.main:app --host 0.0.0.0 --port 7860
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### Docker
|
| 129 |
+
```bash
|
| 130 |
+
docker build -t sql-debug-env .
|
| 131 |
+
docker run -p 7860:7860 sql-debug-env
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
### Quick smoke test
|
| 135 |
+
```bash
|
| 136 |
+
curl http://localhost:7860/health
|
| 137 |
+
curl http://localhost:7860/tasks
|
| 138 |
+
curl -X POST http://localhost:7860/reset -H "Content-Type: application/json" -d '{"task_id":"easy_syntax_fix"}'
|
| 139 |
+
curl -X POST http://localhost:7860/step -H "Content-Type: application/json" -d '{"action":{"action_type":"inspect_schema"}}'
|
| 140 |
+
curl "http://localhost:7860/benchmark?runs=20"
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
### Real-time benchmark API (for dashboards/web pages)
|
| 144 |
+
This is a live endpoint, not static/dummy data. Every request runs fresh measurements.
|
| 145 |
+
|
| 146 |
+
- Endpoint: `GET /benchmark?runs=20`
|
| 147 |
+
- `runs` range: `1` to `100`
|
| 148 |
+
- Returns JSON with `avg_ms`, `p50_ms`, `p95_ms`, `n`, and a fresh `timestamp_epoch_ms`
|
| 149 |
+
|
| 150 |
+
Example:
|
| 151 |
+
```bash
|
| 152 |
+
curl "http://localhost:7860/benchmark?runs=30"
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
### Run Baseline
|
| 156 |
+
```bash
|
| 157 |
+
export API_BASE_URL="https://api.openai.com/v1"
|
| 158 |
+
export MODEL_NAME="gpt-4o-mini"
|
| 159 |
+
export OPENAI_API_KEY="your-key"
|
| 160 |
+
export ENV_BASE_URL="http://localhost:7860"
|
| 161 |
+
export HF_TOKEN="$OPENAI_API_KEY"
|
| 162 |
+
export SEED="1"
|
| 163 |
+
python inference.py
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
### OpenEnv Validation
|
| 167 |
+
```bash
|
| 168 |
+
pip install openenv-core
|
| 169 |
+
openenv validate
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
### Suggested pre-submit check
|
| 173 |
+
```bash
|
| 174 |
+
openenv validate --verbose
|
| 175 |
+
python3 -m unittest discover -s tests -p "test_*.py"
|
| 176 |
+
docker build -t sql-debug-env .
|
| 177 |
+
docker run --rm -p 7860:7860 sql-debug-env
|
| 178 |
+
# in another terminal:
|
| 179 |
+
curl -s http://localhost:7860/health
|
| 180 |
+
curl -s http://localhost:7860/docs >/dev/null
|
| 181 |
+
curl -s "http://localhost:7860/benchmark?runs=20"
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
## Hugging Face Spaces (Docker)
|
| 185 |
+
1. Create a new **Space → Docker**.
|
| 186 |
+
2. Push this repo.
|
| 187 |
+
3. Update `openenv.yaml` → `api.base_url` to your Space URL: `https://<your-space>.hf.space`
|
| 188 |
+
4. Wait for build, then verify:
|
| 189 |
+
|
| 190 |
+
```bash
|
| 191 |
+
curl -X POST https://<your-space>.hf.space/reset -H "Content-Type: application/json" -d '{}'
|
| 192 |
+
```
|
| 193 |
+
|
inference.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inference.py — OpenEnv SQL Debug Environment Baseline Agent
|
| 3 |
+
MUST be at root level. MUST use exact [START]/[STEP]/[END] log format.
|
| 4 |
+
Uses OpenAI client. Reads from environment variables.
|
| 5 |
+
Runtime target: < 20 minutes on 2vCPU / 8GB.
|
| 6 |
+
"""
|
| 7 |
+
import asyncio
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
import sys
|
| 11 |
+
import time
|
| 12 |
+
from typing import List, Dict, Any, Optional
|
| 13 |
+
from openai import OpenAI
|
| 14 |
+
import httpx
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ── Configuration from environment variables ────────────────────────────────
|
| 18 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 19 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 20 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 21 |
+
API_KEY = os.environ.get("OPENAI_API_KEY", HF_TOKEN or "sk-placeholder")
|
| 22 |
+
|
| 23 |
+
# ── Environment config ───────────────────────────────────────────────────────
|
| 24 |
+
ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
|
| 25 |
+
BENCHMARK = "sql-debug-env"
|
| 26 |
+
TEMPERATURE = 0.0
|
| 27 |
+
MAX_TOKENS = 1024
|
| 28 |
+
SEED = int(os.environ.get("SEED", "1"))
|
| 29 |
+
|
| 30 |
+
# ── Per-task config ──────────────────────────────────────────────────────────
|
| 31 |
+
TASK_CONFIGS = {
|
| 32 |
+
"easy_syntax_fix": {"max_steps": 10, "success_threshold": 0.8},
|
| 33 |
+
"medium_logic_fix": {"max_steps": 20, "success_threshold": 0.7},
|
| 34 |
+
"hard_multi_bug": {"max_steps": 30, "success_threshold": 0.5},
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ── Logging functions (EXACT FORMAT — DO NOT MODIFY) ────────────────────────
|
| 39 |
+
def log_start(task: str, env: str, model: str):
|
| 40 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]):
|
| 44 |
+
error_str = error if error else "null"
|
| 45 |
+
# Escape action for single-line logging
|
| 46 |
+
action_clean = action.replace("\n", "\\n").replace('"', '\\"')[:200]
|
| 47 |
+
print(
|
| 48 |
+
f"[STEP] step={step} action=\"{action_clean}\" "
|
| 49 |
+
f"reward={reward:.4f} done={str(done).lower()} error={error_str}",
|
| 50 |
+
flush=True
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]):
|
| 55 |
+
rewards_str = json.dumps([round(r, 4) for r in rewards])
|
| 56 |
+
print(
|
| 57 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 58 |
+
f"score={score:.4f} rewards={rewards_str}",
|
| 59 |
+
flush=True
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ── System prompt ────────────────────────────────────────────────────────────
|
| 64 |
+
SYSTEM_PROMPT = """You are an expert SQL debugger. You will receive a broken SQL query and must fix it.
|
| 65 |
+
|
| 66 |
+
You interact with a SQL debugging environment via JSON actions.
|
| 67 |
+
|
| 68 |
+
Available actions (respond with ONLY valid JSON, no markdown, no explanation):
|
| 69 |
+
|
| 70 |
+
1. Submit a fixed query:
|
| 71 |
+
{"action_type": "submit_query", "query": "SELECT ..."}
|
| 72 |
+
|
| 73 |
+
2. Inspect schema (free, no penalty):
|
| 74 |
+
{"action_type": "inspect_schema"}
|
| 75 |
+
|
| 76 |
+
3. Inspect last error (free, no penalty):
|
| 77 |
+
{"action_type": "inspect_error"}
|
| 78 |
+
|
| 79 |
+
4. Inspect sample rows from a table (free, no penalty):
|
| 80 |
+
{"action_type": "inspect_sample", "table_name": "table_name_here"}
|
| 81 |
+
|
| 82 |
+
Strategy:
|
| 83 |
+
- Start by submitting a fixed query if the bug is obvious
|
| 84 |
+
- Use inspect_schema first if you need to verify column names/table structure
|
| 85 |
+
- Use inspect_error to understand why your query failed
|
| 86 |
+
- Read error messages carefully — they tell you exactly what's wrong
|
| 87 |
+
- Fix one bug at a time and resubmit
|
| 88 |
+
- You get partial credit for partially correct queries
|
| 89 |
+
|
| 90 |
+
IMPORTANT: Respond with ONLY the JSON action. No explanation, no markdown blocks, just raw JSON."""
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def build_prompt(obs: Dict[str, Any], step: int, reward_history: List[float]) -> str:
|
| 94 |
+
"""Build the user prompt for each step."""
|
| 95 |
+
|
| 96 |
+
lines = [
|
| 97 |
+
f"=== SQL Debugging Task (Step {step}) ===",
|
| 98 |
+
f"Task: {obs.get('task_description', '')[:500]}",
|
| 99 |
+
f"",
|
| 100 |
+
f"ORIGINAL BROKEN QUERY:",
|
| 101 |
+
f"```sql",
|
| 102 |
+
f"{obs.get('original_query', '')}",
|
| 103 |
+
f"```",
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
if obs.get('current_query'):
|
| 107 |
+
lines += [
|
| 108 |
+
f"",
|
| 109 |
+
f"YOUR LAST SUBMITTED QUERY:",
|
| 110 |
+
f"```sql",
|
| 111 |
+
f"{obs.get('current_query', '')}",
|
| 112 |
+
f"```",
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
last_result = obs.get('last_query_result')
|
| 116 |
+
if last_result:
|
| 117 |
+
if last_result.get('success'):
|
| 118 |
+
rows = last_result.get('rows', [])
|
| 119 |
+
lines += [
|
| 120 |
+
f"",
|
| 121 |
+
f"LAST QUERY RESULT: {len(rows)} rows returned",
|
| 122 |
+
f"Sample (first 3): {json.dumps(rows[:3], default=str)}",
|
| 123 |
+
]
|
| 124 |
+
else:
|
| 125 |
+
lines += [
|
| 126 |
+
f"",
|
| 127 |
+
f"LAST QUERY ERROR: {last_result.get('error_message', 'Unknown error')}",
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
if obs.get('schema_info'):
|
| 131 |
+
schema = obs['schema_info'].get('tables', {})
|
| 132 |
+
lines += [f"", f"DATABASE SCHEMA:"]
|
| 133 |
+
for table, cols in schema.items():
|
| 134 |
+
col_str = ", ".join(f"{c['name']} ({c['type']})" for c in cols)
|
| 135 |
+
lines.append(f" {table}: {col_str}")
|
| 136 |
+
|
| 137 |
+
if obs.get('error_details'):
|
| 138 |
+
lines += [f"", f"ERROR DETAILS: {obs['error_details']}"]
|
| 139 |
+
|
| 140 |
+
if obs.get('sample_rows'):
|
| 141 |
+
lines += [f"", f"SAMPLE ROWS: {json.dumps(obs['sample_rows'][:3], default=str)}"]
|
| 142 |
+
|
| 143 |
+
if obs.get('hint'):
|
| 144 |
+
lines += [f"", f"HINT: {obs['hint']}"]
|
| 145 |
+
|
| 146 |
+
lines += [
|
| 147 |
+
f"",
|
| 148 |
+
f"Current score: {obs.get('current_score', 0):.3f}",
|
| 149 |
+
f"Steps remaining: {obs.get('steps_remaining', 0)}",
|
| 150 |
+
f"Expected output: {obs.get('expected_description', '')}",
|
| 151 |
+
f"",
|
| 152 |
+
f"What is your next action? (respond with ONLY valid JSON)"
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
return "\n".join(lines)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def call_model(client: OpenAI, prompt: str) -> Dict[str, Any]:
|
| 159 |
+
"""Call model and parse JSON action response."""
|
| 160 |
+
try:
|
| 161 |
+
response = client.chat.completions.create(
|
| 162 |
+
model=MODEL_NAME,
|
| 163 |
+
messages=[
|
| 164 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 165 |
+
{"role": "user", "content": prompt}
|
| 166 |
+
],
|
| 167 |
+
temperature=TEMPERATURE,
|
| 168 |
+
seed=SEED,
|
| 169 |
+
max_tokens=MAX_TOKENS,
|
| 170 |
+
)
|
| 171 |
+
text = (response.choices[0].message.content or "").strip()
|
| 172 |
+
|
| 173 |
+
# Strip markdown if model wraps in backticks
|
| 174 |
+
if text.startswith("```"):
|
| 175 |
+
text = text.split("```")[1]
|
| 176 |
+
if text.startswith("json"):
|
| 177 |
+
text = text[4:]
|
| 178 |
+
text = text.strip()
|
| 179 |
+
|
| 180 |
+
return json.loads(text)
|
| 181 |
+
except json.JSONDecodeError:
|
| 182 |
+
# Fallback: try to extract JSON from response
|
| 183 |
+
import re
|
| 184 |
+
match = re.search(r'\{.*\}', text, re.DOTALL)
|
| 185 |
+
if match:
|
| 186 |
+
try:
|
| 187 |
+
return json.loads(match.group())
|
| 188 |
+
except:
|
| 189 |
+
pass
|
| 190 |
+
# Default fallback action
|
| 191 |
+
return {"action_type": "inspect_schema"}
|
| 192 |
+
except Exception as e:
|
| 193 |
+
print(f"[DEBUG] Model error: {e}", flush=True)
|
| 194 |
+
return {"action_type": "inspect_schema"}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def run_task(
|
| 198 |
+
client: OpenAI,
|
| 199 |
+
task_id: str,
|
| 200 |
+
config: Dict[str, Any]
|
| 201 |
+
) -> Dict[str, Any]:
|
| 202 |
+
"""Run one task episode synchronously via HTTP."""
|
| 203 |
+
|
| 204 |
+
max_steps = config["max_steps"]
|
| 205 |
+
success_threshold = config["success_threshold"]
|
| 206 |
+
|
| 207 |
+
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 208 |
+
|
| 209 |
+
rewards = []
|
| 210 |
+
steps_taken = 0
|
| 211 |
+
score = 0.0
|
| 212 |
+
success = False
|
| 213 |
+
|
| 214 |
+
with httpx.Client(base_url=ENV_BASE_URL, timeout=30.0) as http:
|
| 215 |
+
# Reset
|
| 216 |
+
reset_resp = http.post("/reset", json={"task_id": task_id})
|
| 217 |
+
reset_resp.raise_for_status()
|
| 218 |
+
result = reset_resp.json()
|
| 219 |
+
obs = result["observation"]
|
| 220 |
+
done = result["done"]
|
| 221 |
+
|
| 222 |
+
reward_history = []
|
| 223 |
+
|
| 224 |
+
for step in range(1, max_steps + 1):
|
| 225 |
+
if done:
|
| 226 |
+
break
|
| 227 |
+
|
| 228 |
+
# Get model action
|
| 229 |
+
prompt = build_prompt(obs, step, reward_history)
|
| 230 |
+
action_dict = call_model(client, prompt)
|
| 231 |
+
|
| 232 |
+
# Execute step
|
| 233 |
+
try:
|
| 234 |
+
step_resp = http.post("/step", json={"action": action_dict})
|
| 235 |
+
step_resp.raise_for_status()
|
| 236 |
+
step_result = step_resp.json()
|
| 237 |
+
except Exception as e:
|
| 238 |
+
log_step(step=step, action=str(action_dict), reward=0.0, done=False, error=str(e))
|
| 239 |
+
continue
|
| 240 |
+
|
| 241 |
+
obs = step_result["observation"]
|
| 242 |
+
reward = float(step_result.get("reward") or 0.0)
|
| 243 |
+
done = step_result["done"]
|
| 244 |
+
error = None
|
| 245 |
+
info = step_result.get("info") or {}
|
| 246 |
+
|
| 247 |
+
# Extract error for logging
|
| 248 |
+
last_result = obs.get("last_query_result")
|
| 249 |
+
if last_result and not last_result.get("success"):
|
| 250 |
+
error = last_result.get("error_message", "")
|
| 251 |
+
|
| 252 |
+
action_str = action_dict.get("query") or action_dict.get("action_type", "unknown")
|
| 253 |
+
|
| 254 |
+
rewards.append(reward)
|
| 255 |
+
reward_history.append(reward)
|
| 256 |
+
steps_taken = step
|
| 257 |
+
score = float(info.get("grade_score") or obs.get("current_score") or 0.0)
|
| 258 |
+
|
| 259 |
+
log_step(step=step, action=action_str, reward=reward, done=done, error=error)
|
| 260 |
+
|
| 261 |
+
if done:
|
| 262 |
+
break
|
| 263 |
+
|
| 264 |
+
# Compute final score
|
| 265 |
+
score = min(max(score, 0.0), 1.0)
|
| 266 |
+
success = score >= success_threshold
|
| 267 |
+
|
| 268 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 269 |
+
|
| 270 |
+
return {
|
| 271 |
+
"task_id": task_id,
|
| 272 |
+
"score": score,
|
| 273 |
+
"success": success,
|
| 274 |
+
"steps": steps_taken,
|
| 275 |
+
"rewards": rewards
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def main():
|
| 280 |
+
"""Run baseline agent across all 3 tasks."""
|
| 281 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 282 |
+
|
| 283 |
+
print(f"[DEBUG] Starting SQL Debug Env baseline", flush=True)
|
| 284 |
+
print(f"[DEBUG] Model: {MODEL_NAME}", flush=True)
|
| 285 |
+
print(f"[DEBUG] Env URL: {ENV_BASE_URL}", flush=True)
|
| 286 |
+
|
| 287 |
+
# Wait for server to be ready
|
| 288 |
+
max_wait = 30
|
| 289 |
+
for i in range(max_wait):
|
| 290 |
+
try:
|
| 291 |
+
resp = httpx.get(f"{ENV_BASE_URL}/health", timeout=5)
|
| 292 |
+
if resp.status_code == 200:
|
| 293 |
+
print(f"[DEBUG] Server ready", flush=True)
|
| 294 |
+
break
|
| 295 |
+
except:
|
| 296 |
+
pass
|
| 297 |
+
print(f"[DEBUG] Waiting for server... ({i+1}/{max_wait})", flush=True)
|
| 298 |
+
time.sleep(1)
|
| 299 |
+
|
| 300 |
+
all_results = []
|
| 301 |
+
|
| 302 |
+
for task_id, config in TASK_CONFIGS.items():
|
| 303 |
+
print(f"\n[DEBUG] Running task: {task_id}", flush=True)
|
| 304 |
+
try:
|
| 305 |
+
result = run_task(client, task_id, config)
|
| 306 |
+
all_results.append(result)
|
| 307 |
+
except Exception as e:
|
| 308 |
+
print(f"[DEBUG] Task {task_id} failed: {e}", flush=True)
|
| 309 |
+
log_end(success=False, steps=0, score=0.0, rewards=[])
|
| 310 |
+
|
| 311 |
+
# Small delay between tasks
|
| 312 |
+
time.sleep(2)
|
| 313 |
+
|
| 314 |
+
# Summary
|
| 315 |
+
print(f"\n[DEBUG] === BASELINE RESULTS ===", flush=True)
|
| 316 |
+
total_score = 0.0
|
| 317 |
+
for r in all_results:
|
| 318 |
+
print(f"[DEBUG] {r['task_id']}: score={r['score']:.3f} success={r['success']}", flush=True)
|
| 319 |
+
total_score += r['score']
|
| 320 |
+
|
| 321 |
+
if all_results:
|
| 322 |
+
avg = total_score / len(all_results)
|
| 323 |
+
print(f"[DEBUG] Average score: {avg:.3f}", flush=True)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
if __name__ == "__main__":
|
| 327 |
+
main()
|
| 328 |
+
|
openenv.yaml
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sql-debug-env
|
| 2 |
+
version: 0.1.0
|
| 3 |
+
description: >
|
| 4 |
+
A reinforcement learning environment for training AI agents to debug SQL queries.
|
| 5 |
+
Agents receive broken SQL queries against a live SQLite database and must fix them
|
| 6 |
+
through iterative actions: submitting queries, inspecting schemas, and analyzing errors.
|
| 7 |
+
Models a real-world task performed daily by data analysts, engineers, and scientists.
|
| 8 |
+
|
| 9 |
+
author: md-ayan
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
|
| 12 |
+
tags:
|
| 13 |
+
- openenv
|
| 14 |
+
- sql
|
| 15 |
+
- debugging
|
| 16 |
+
- data-engineering
|
| 17 |
+
- real-world
|
| 18 |
+
- analytics
|
| 19 |
+
|
| 20 |
+
tasks:
|
| 21 |
+
- id: easy_syntax_fix
|
| 22 |
+
name: "Top Customers by Revenue — Syntax Error Fix"
|
| 23 |
+
difficulty: easy
|
| 24 |
+
max_steps: 10
|
| 25 |
+
description: "Fix 2 syntax/reference bugs in a customer analytics query"
|
| 26 |
+
|
| 27 |
+
- id: medium_logic_fix
|
| 28 |
+
name: "Department Headcount Report — Logic Error Fix"
|
| 29 |
+
difficulty: medium
|
| 30 |
+
max_steps: 20
|
| 31 |
+
description: "Fix JOIN type, WHERE clause placement, and aggregation scope bugs"
|
| 32 |
+
|
| 33 |
+
- id: hard_multi_bug
|
| 34 |
+
name: "SaaS Cohort Activation Report — Multi-Bug Fix"
|
| 35 |
+
difficulty: hard
|
| 36 |
+
max_steps: 30
|
| 37 |
+
description: "Fix 5 bugs: correlated subquery, window function, duplicate rows, date logic, CTE scope"
|
| 38 |
+
|
| 39 |
+
api:
|
| 40 |
+
base_url: "https://YOUR-USERNAME-sql-debug-env.hf.space"
|
| 41 |
+
reset: "/reset"
|
| 42 |
+
step: "/step"
|
| 43 |
+
state: "/state"
|
| 44 |
+
health: "/health"
|
| 45 |
+
tasks: "/tasks"
|
| 46 |
+
|
| 47 |
+
observation_space:
|
| 48 |
+
type: structured
|
| 49 |
+
fields:
|
| 50 |
+
- name: task_description
|
| 51 |
+
type: string
|
| 52 |
+
- name: original_query
|
| 53 |
+
type: string
|
| 54 |
+
- name: current_query
|
| 55 |
+
type: string_or_null
|
| 56 |
+
- name: last_query_result
|
| 57 |
+
type: object_or_null
|
| 58 |
+
- name: steps_taken
|
| 59 |
+
type: integer
|
| 60 |
+
- name: current_score
|
| 61 |
+
type: float
|
| 62 |
+
|
| 63 |
+
action_space:
|
| 64 |
+
type: structured
|
| 65 |
+
actions:
|
| 66 |
+
- id: submit_query
|
| 67 |
+
description: "Submit a fixed SQL query for evaluation"
|
| 68 |
+
required_fields: [query]
|
| 69 |
+
- id: inspect_schema
|
| 70 |
+
description: "Get database schema (free action)"
|
| 71 |
+
- id: inspect_error
|
| 72 |
+
description: "Get last error details (free action)"
|
| 73 |
+
- id: inspect_sample
|
| 74 |
+
description: "Get 3 sample rows from a table"
|
| 75 |
+
required_fields: [table_name]
|
| 76 |
+
- id: reset_query
|
| 77 |
+
description: "Reset to original broken query (penalty: -0.05)"
|
| 78 |
+
|
| 79 |
+
reward:
|
| 80 |
+
range: [0.0, 1.0]
|
| 81 |
+
components:
|
| 82 |
+
- name: correctness
|
| 83 |
+
range: [0.0, 0.6]
|
| 84 |
+
description: "Row-level match vs expected output"
|
| 85 |
+
- name: efficiency
|
| 86 |
+
range: [0.0, 0.2]
|
| 87 |
+
description: "Bonus for solving with fewer steps"
|
| 88 |
+
- name: syntax_progress
|
| 89 |
+
range: [0.0, 0.1]
|
| 90 |
+
description: "Valid SQL even if wrong content"
|
| 91 |
+
- name: schema_bonus
|
| 92 |
+
range: [0.0, 0.1]
|
| 93 |
+
description: "Correct table/column references"
|
| 94 |
+
- name: penalty
|
| 95 |
+
range: [0.0, 0.2]
|
| 96 |
+
description: "Penalty deduction magnitude for bad actions / urgency"
|
| 97 |
+
|
| 98 |
+
runtime:
|
| 99 |
+
max_concurrent_sessions: 64
|
| 100 |
+
episode_timeout_seconds: 300
|
| 101 |
+
machine_requirements:
|
| 102 |
+
vcpu: 2
|
| 103 |
+
memory_gb: 8
|
| 104 |
+
|
pyproject.toml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "sql-debug-env"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
requires-python = ">=3.11"
|
| 9 |
+
dependencies = [
|
| 10 |
+
"fastapi==0.115.0",
|
| 11 |
+
"uvicorn[standard]==0.30.6",
|
| 12 |
+
"pydantic==2.9.2",
|
| 13 |
+
"openenv-core>=0.1.0",
|
| 14 |
+
"openai>=1.50.0",
|
| 15 |
+
"httpx>=0.27.0",
|
| 16 |
+
"python-multipart==0.0.9"
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[project.scripts]
|
| 20 |
+
server = "server.app:main"
|
| 21 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.0
|
| 2 |
+
uvicorn[standard]==0.30.6
|
| 3 |
+
pydantic==2.9.2
|
| 4 |
+
openenv-core>=0.1.0
|
| 5 |
+
openai>=1.50.0
|
| 6 |
+
httpx>=0.27.0
|
| 7 |
+
python-multipart==0.0.9
|
| 8 |
+
|
scripts/benchmark_local.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lightweight local benchmark for sql-debug-env.
|
| 3 |
+
|
| 4 |
+
Runs deterministic endpoint checks and prints simple latency metrics.
|
| 5 |
+
No LLM key required.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import statistics
|
| 10 |
+
import time
|
| 11 |
+
from typing import Dict, List
|
| 12 |
+
|
| 13 |
+
import httpx
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
BASE_URL = "http://localhost:7860"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def timed_call(client: httpx.Client, method: str, path: str, json_body: Dict | None = None) -> float:
|
| 20 |
+
start = time.perf_counter()
|
| 21 |
+
if method == "GET":
|
| 22 |
+
r = client.get(path)
|
| 23 |
+
else:
|
| 24 |
+
r = client.post(path, json=json_body)
|
| 25 |
+
r.raise_for_status()
|
| 26 |
+
return (time.perf_counter() - start) * 1000
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def summarize(samples: List[float]) -> str:
|
| 30 |
+
p50 = statistics.median(samples)
|
| 31 |
+
p95 = sorted(samples)[int(len(samples) * 0.95) - 1]
|
| 32 |
+
avg = statistics.mean(samples)
|
| 33 |
+
return f"avg={avg:.2f}ms p50={p50:.2f}ms p95={p95:.2f}ms n={len(samples)}"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main() -> None:
|
| 37 |
+
with httpx.Client(base_url=BASE_URL, timeout=30.0) as client:
|
| 38 |
+
# Warmup + health check
|
| 39 |
+
client.get("/health").raise_for_status()
|
| 40 |
+
|
| 41 |
+
health_times = [timed_call(client, "GET", "/health") for _ in range(25)]
|
| 42 |
+
tasks_times = [timed_call(client, "GET", "/tasks") for _ in range(25)]
|
| 43 |
+
|
| 44 |
+
reset_times: List[float] = []
|
| 45 |
+
step_times: List[float] = []
|
| 46 |
+
for _ in range(25):
|
| 47 |
+
reset_times.append(
|
| 48 |
+
timed_call(client, "POST", "/reset", {"task_id": "easy_syntax_fix"})
|
| 49 |
+
)
|
| 50 |
+
step_times.append(
|
| 51 |
+
timed_call(client, "POST", "/step", {"action": {"action_type": "inspect_schema"}})
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
print("Benchmark results (local)")
|
| 55 |
+
print(f"GET /health: {summarize(health_times)}")
|
| 56 |
+
print(f"GET /tasks: {summarize(tasks_times)}")
|
| 57 |
+
print(f"POST /reset: {summarize(reset_times)}")
|
| 58 |
+
print(f"POST /step (inspect_schema): {summarize(step_times)}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
main()
|
| 63 |
+
|
server/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# sql-debug-env
|
| 2 |
+
|
server/app.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uvicorn
|
| 3 |
+
|
| 4 |
+
from .main import app
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def main():
|
| 8 |
+
"""
|
| 9 |
+
OpenEnv entry point.
|
| 10 |
+
|
| 11 |
+
This module is required for `openenv validate` multi-mode deployment checks.
|
| 12 |
+
"""
|
| 13 |
+
host = os.environ.get("HOST", "0.0.0.0")
|
| 14 |
+
port = int(os.environ.get("PORT", "7860"))
|
| 15 |
+
uvicorn.run("server.app:app", host=host, port=port, workers=1)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if __name__ == "__main__":
|
| 19 |
+
main()
|
| 20 |
+
|
server/database.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQLite in-memory database management.
|
| 3 |
+
Creates fresh DB instances per episode with deterministic seed data.
|
| 4 |
+
"""
|
| 5 |
+
import sqlite3
|
| 6 |
+
import time
|
| 7 |
+
from typing import Dict, Any, List
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class EpisodeDatabase:
|
| 11 |
+
"""
|
| 12 |
+
Manages a single SQLite in-memory database for one episode.
|
| 13 |
+
Seeded with deterministic data per task.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, task_id: str, schema_sql: str, seed_data_sql: str):
|
| 17 |
+
self.task_id = task_id
|
| 18 |
+
self.conn = sqlite3.connect(":memory:", check_same_thread=False)
|
| 19 |
+
self.conn.row_factory = sqlite3.Row
|
| 20 |
+
self.conn.execute("PRAGMA foreign_keys = ON")
|
| 21 |
+
self._setup(schema_sql, seed_data_sql)
|
| 22 |
+
|
| 23 |
+
def _setup(self, schema_sql: str, seed_data_sql: str):
|
| 24 |
+
"""Create schema and insert seed data."""
|
| 25 |
+
cursor = self.conn.cursor()
|
| 26 |
+
for statement in schema_sql.strip().split(";"):
|
| 27 |
+
stmt = statement.strip()
|
| 28 |
+
if stmt:
|
| 29 |
+
cursor.execute(stmt)
|
| 30 |
+
for statement in seed_data_sql.strip().split(";"):
|
| 31 |
+
stmt = statement.strip()
|
| 32 |
+
if stmt:
|
| 33 |
+
cursor.execute(stmt)
|
| 34 |
+
self.conn.commit()
|
| 35 |
+
|
| 36 |
+
def execute_query(self, query: str) -> Dict[str, Any]:
|
| 37 |
+
"""
|
| 38 |
+
Execute a read-only SQL query safely.
|
| 39 |
+
Returns rows or error. Enforces SELECT-only.
|
| 40 |
+
Execution timeout: 5 seconds.
|
| 41 |
+
"""
|
| 42 |
+
query_stripped = query.strip().upper()
|
| 43 |
+
|
| 44 |
+
# Block dangerous operations
|
| 45 |
+
blocked = ["DROP", "DELETE", "UPDATE", "INSERT", "CREATE", "ALTER",
|
| 46 |
+
"TRUNCATE", "REPLACE", "ATTACH", "DETACH"]
|
| 47 |
+
for kw in blocked:
|
| 48 |
+
if query_stripped.startswith(kw) or f" {kw} " in query_stripped:
|
| 49 |
+
return {
|
| 50 |
+
"success": False,
|
| 51 |
+
"rows": None,
|
| 52 |
+
"row_count": None,
|
| 53 |
+
"error_message": f"BLOCKED: Only SELECT queries are allowed. '{kw}' is not permitted.",
|
| 54 |
+
"execution_time_ms": 0.0
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
start = time.time()
|
| 58 |
+
try:
|
| 59 |
+
cursor = self.conn.cursor()
|
| 60 |
+
cursor.execute(query)
|
| 61 |
+
rows = cursor.fetchall()
|
| 62 |
+
elapsed = (time.time() - start) * 1000
|
| 63 |
+
|
| 64 |
+
# Convert Row objects to dicts
|
| 65 |
+
result_rows = [dict(row) for row in rows]
|
| 66 |
+
|
| 67 |
+
return {
|
| 68 |
+
"success": True,
|
| 69 |
+
"rows": result_rows,
|
| 70 |
+
"row_count": len(result_rows),
|
| 71 |
+
"error_message": None,
|
| 72 |
+
"execution_time_ms": round(elapsed, 2)
|
| 73 |
+
}
|
| 74 |
+
except sqlite3.Error as e:
|
| 75 |
+
elapsed = (time.time() - start) * 1000
|
| 76 |
+
return {
|
| 77 |
+
"success": False,
|
| 78 |
+
"rows": None,
|
| 79 |
+
"row_count": None,
|
| 80 |
+
"error_message": str(e),
|
| 81 |
+
"execution_time_ms": round(elapsed, 2)
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
def get_schema(self) -> Dict[str, List[Dict[str, str]]]:
|
| 85 |
+
"""Return schema info: tables and their columns."""
|
| 86 |
+
schema = {}
|
| 87 |
+
cursor = self.conn.cursor()
|
| 88 |
+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
|
| 89 |
+
tables = [row[0] for row in cursor.fetchall()]
|
| 90 |
+
|
| 91 |
+
for table in tables:
|
| 92 |
+
cursor.execute(f"PRAGMA table_info({table})")
|
| 93 |
+
columns = []
|
| 94 |
+
for col in cursor.fetchall():
|
| 95 |
+
columns.append({
|
| 96 |
+
"name": col[1],
|
| 97 |
+
"type": col[2],
|
| 98 |
+
"nullable": "YES" if col[3] == 0 else "NO",
|
| 99 |
+
"primary_key": "YES" if col[5] > 0 else "NO"
|
| 100 |
+
})
|
| 101 |
+
schema[table] = columns
|
| 102 |
+
|
| 103 |
+
return schema
|
| 104 |
+
|
| 105 |
+
def get_sample_rows(self, table_name: str, limit: int = 3) -> List[Dict[str, Any]]:
|
| 106 |
+
"""Get sample rows from a table."""
|
| 107 |
+
result = self.execute_query(f"SELECT * FROM {table_name} LIMIT {limit}")
|
| 108 |
+
return result.get("rows", []) or []
|
| 109 |
+
|
| 110 |
+
def close(self):
|
| 111 |
+
self.conn.close()
|
| 112 |
+
|
server/env.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core SQL Debug Environment.
|
| 3 |
+
Manages episode state, delegates to tasks and reward function.
|
| 4 |
+
"""
|
| 5 |
+
import uuid
|
| 6 |
+
import asyncio
|
| 7 |
+
from typing import Optional, Dict, Any, List
|
| 8 |
+
from .models import (
|
| 9 |
+
SQLDebugAction, SQLDebugObservation, SQLDebugReward,
|
| 10 |
+
EpisodeState, ActionType, QueryResult, SchemaInfo
|
| 11 |
+
)
|
| 12 |
+
from .database import EpisodeDatabase
|
| 13 |
+
from .reward import compute_reward
|
| 14 |
+
from .tasks.task_easy import EasyTask
|
| 15 |
+
from .tasks.task_medium import MediumTask, MediumTaskGrader
|
| 16 |
+
from .tasks.task_hard import HardTask
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
TASKS = {
|
| 20 |
+
"easy_syntax_fix": EasyTask(),
|
| 21 |
+
"medium_logic_fix": MediumTask(),
|
| 22 |
+
"hard_multi_bug": HardTask(),
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SQLDebugEnv:
|
| 27 |
+
"""
|
| 28 |
+
The SQL Debug Environment.
|
| 29 |
+
Manages one active episode at a time per session.
|
| 30 |
+
Thread-safe for concurrent sessions via instance-per-session pattern.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, task_id: str = "easy_syntax_fix"):
|
| 34 |
+
self.task_id = task_id
|
| 35 |
+
self.task = TASKS[task_id]
|
| 36 |
+
self._db: Optional[EpisodeDatabase] = None
|
| 37 |
+
self._state: Optional[EpisodeState] = None
|
| 38 |
+
self._lock = asyncio.Lock()
|
| 39 |
+
|
| 40 |
+
async def reset(self) -> tuple[SQLDebugObservation, Dict]:
|
| 41 |
+
"""Reset environment to initial state. Returns (observation, info)."""
|
| 42 |
+
async with self._lock:
|
| 43 |
+
# Close previous DB if exists
|
| 44 |
+
if self._db:
|
| 45 |
+
self._db.close()
|
| 46 |
+
|
| 47 |
+
# Fresh DB
|
| 48 |
+
self._db = EpisodeDatabase(
|
| 49 |
+
task_id=self.task.task_id,
|
| 50 |
+
schema_sql=self.task.schema_sql,
|
| 51 |
+
seed_data_sql=self.task.seed_data_sql
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Fresh state
|
| 55 |
+
self._state = EpisodeState(
|
| 56 |
+
task_id=self.task.task_id,
|
| 57 |
+
task_difficulty=self.task.difficulty,
|
| 58 |
+
original_query=self.task.broken_query,
|
| 59 |
+
current_query=None,
|
| 60 |
+
best_score_so_far=0.0,
|
| 61 |
+
steps_taken=0,
|
| 62 |
+
max_steps=self.task.max_steps,
|
| 63 |
+
action_history=[],
|
| 64 |
+
reward_history=[],
|
| 65 |
+
is_done=False,
|
| 66 |
+
success=False,
|
| 67 |
+
db_schema=self._db.get_schema()
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
obs = SQLDebugObservation(
|
| 71 |
+
task_id=self.task.task_id,
|
| 72 |
+
task_description=self.task.description,
|
| 73 |
+
original_query=self.task.broken_query,
|
| 74 |
+
current_query=None,
|
| 75 |
+
expected_description=self.task.expected_output_description,
|
| 76 |
+
last_action_type="reset",
|
| 77 |
+
last_query_result=None,
|
| 78 |
+
steps_taken=0,
|
| 79 |
+
steps_remaining=self.task.max_steps,
|
| 80 |
+
current_score=0.0,
|
| 81 |
+
schema_info=SchemaInfo(tables=self._db.get_schema()),
|
| 82 |
+
is_done=False,
|
| 83 |
+
success=False
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
return obs, {"task": self.task.to_dict()}
|
| 87 |
+
|
| 88 |
+
async def step(self, action: SQLDebugAction) -> tuple[SQLDebugObservation, float, bool, Dict]:
|
| 89 |
+
"""
|
| 90 |
+
Execute one action.
|
| 91 |
+
Returns (observation, reward_value, done, info)
|
| 92 |
+
"""
|
| 93 |
+
async with self._lock:
|
| 94 |
+
if self._state is None:
|
| 95 |
+
raise RuntimeError("Call reset() before step()")
|
| 96 |
+
|
| 97 |
+
if self._state.is_done:
|
| 98 |
+
raise RuntimeError("Episode is done. Call reset() to start new episode.")
|
| 99 |
+
|
| 100 |
+
self._state.steps_taken += 1
|
| 101 |
+
steps_taken = self._state.steps_taken
|
| 102 |
+
|
| 103 |
+
query_result_raw = None
|
| 104 |
+
prev_best_score = self._state.best_score_so_far
|
| 105 |
+
grade_score = self._state.best_score_so_far
|
| 106 |
+
schema_info = None
|
| 107 |
+
error_details = None
|
| 108 |
+
sample_rows = None
|
| 109 |
+
hint = None
|
| 110 |
+
|
| 111 |
+
# --- Execute action ---
|
| 112 |
+
if action.action_type == ActionType.SUBMIT_QUERY:
|
| 113 |
+
if not action.query:
|
| 114 |
+
raise ValueError("query is required for submit_query action")
|
| 115 |
+
|
| 116 |
+
self._state.current_query = action.query
|
| 117 |
+
query_result_raw = self._db.execute_query(action.query)
|
| 118 |
+
|
| 119 |
+
# Grade the result
|
| 120 |
+
actual_rows = query_result_raw.get("rows") if query_result_raw.get("success") else None
|
| 121 |
+
|
| 122 |
+
# Use custom grader for medium task
|
| 123 |
+
if self.task.task_id == "medium_logic_fix":
|
| 124 |
+
grade_score = MediumTaskGrader.grade(actual_rows or [])
|
| 125 |
+
else:
|
| 126 |
+
grade_score = self.task.grade(actual_rows)
|
| 127 |
+
|
| 128 |
+
if grade_score > self._state.best_score_so_far:
|
| 129 |
+
self._state.best_score_so_far = grade_score
|
| 130 |
+
|
| 131 |
+
elif action.action_type == ActionType.INSPECT_SCHEMA:
|
| 132 |
+
schema = self._db.get_schema()
|
| 133 |
+
schema_info = SchemaInfo(tables=schema)
|
| 134 |
+
grade_score = self._state.best_score_so_far
|
| 135 |
+
|
| 136 |
+
elif action.action_type == ActionType.INSPECT_ERROR:
|
| 137 |
+
# Return last error if available
|
| 138 |
+
if self._state.action_history:
|
| 139 |
+
last = self._state.action_history[-1]
|
| 140 |
+
error_details = last.get("error_message", "No error recorded from last query.")
|
| 141 |
+
else:
|
| 142 |
+
error_details = "No query has been submitted yet."
|
| 143 |
+
grade_score = self._state.best_score_so_far
|
| 144 |
+
|
| 145 |
+
elif action.action_type == ActionType.INSPECT_SAMPLE:
|
| 146 |
+
if not action.table_name:
|
| 147 |
+
raise ValueError("table_name required for inspect_sample")
|
| 148 |
+
sample_rows = self._db.get_sample_rows(action.table_name)
|
| 149 |
+
grade_score = self._state.best_score_so_far
|
| 150 |
+
|
| 151 |
+
elif action.action_type == ActionType.RESET_QUERY:
|
| 152 |
+
self._state.current_query = self.task.broken_query
|
| 153 |
+
grade_score = self._state.best_score_so_far
|
| 154 |
+
|
| 155 |
+
# --- Compute reward ---
|
| 156 |
+
schema_tables = list(self._db.get_schema().keys())
|
| 157 |
+
reward_obj = compute_reward(
|
| 158 |
+
action_type=action.action_type.value,
|
| 159 |
+
query_result=query_result_raw,
|
| 160 |
+
grade_score=grade_score,
|
| 161 |
+
steps_taken=steps_taken,
|
| 162 |
+
max_steps=self.task.max_steps,
|
| 163 |
+
previous_best_score=prev_best_score,
|
| 164 |
+
schema_tables=schema_tables,
|
| 165 |
+
submitted_query=action.query if action.action_type == ActionType.SUBMIT_QUERY else None
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# --- Check done conditions ---
|
| 169 |
+
is_done = False
|
| 170 |
+
success = False
|
| 171 |
+
|
| 172 |
+
if grade_score >= 0.95:
|
| 173 |
+
is_done = True
|
| 174 |
+
success = True
|
| 175 |
+
elif steps_taken >= self.task.max_steps:
|
| 176 |
+
is_done = True
|
| 177 |
+
success = self._state.best_score_so_far >= 0.5
|
| 178 |
+
|
| 179 |
+
self._state.is_done = is_done
|
| 180 |
+
self._state.success = success
|
| 181 |
+
|
| 182 |
+
# --- Hint logic ---
|
| 183 |
+
hint_threshold = 3 if self.task.difficulty == "easy" else 5
|
| 184 |
+
if steps_taken >= hint_threshold:
|
| 185 |
+
hint = self.task.hint
|
| 186 |
+
|
| 187 |
+
# --- Record history ---
|
| 188 |
+
self._state.action_history.append({
|
| 189 |
+
"step": steps_taken,
|
| 190 |
+
"action_type": action.action_type.value,
|
| 191 |
+
"query": action.query,
|
| 192 |
+
"grade_score": grade_score,
|
| 193 |
+
"reward": reward_obj.value,
|
| 194 |
+
"error_message": query_result_raw.get("error_message") if query_result_raw else None
|
| 195 |
+
})
|
| 196 |
+
self._state.reward_history.append(reward_obj.value)
|
| 197 |
+
|
| 198 |
+
# --- Build observation ---
|
| 199 |
+
qr = QueryResult(**query_result_raw) if query_result_raw else None
|
| 200 |
+
|
| 201 |
+
obs = SQLDebugObservation(
|
| 202 |
+
task_id=self.task.task_id,
|
| 203 |
+
task_description=self.task.description,
|
| 204 |
+
original_query=self.task.broken_query,
|
| 205 |
+
current_query=self._state.current_query,
|
| 206 |
+
expected_description=self.task.expected_output_description,
|
| 207 |
+
last_action_type=action.action_type.value,
|
| 208 |
+
last_query_result=qr,
|
| 209 |
+
steps_taken=steps_taken,
|
| 210 |
+
steps_remaining=max(0, self.task.max_steps - steps_taken),
|
| 211 |
+
current_score=self._state.best_score_so_far,
|
| 212 |
+
schema_info=schema_info,
|
| 213 |
+
error_details=error_details,
|
| 214 |
+
sample_rows=sample_rows,
|
| 215 |
+
hint=hint,
|
| 216 |
+
is_done=is_done,
|
| 217 |
+
success=success
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
return obs, reward_obj.value, is_done, {
|
| 221 |
+
"grade_score": grade_score,
|
| 222 |
+
"reward_breakdown": reward_obj.breakdown,
|
| 223 |
+
"success": success,
|
| 224 |
+
"steps_taken": steps_taken
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
def get_state(self) -> EpisodeState:
|
| 228 |
+
if self._state is None:
|
| 229 |
+
raise RuntimeError("Call reset() first")
|
| 230 |
+
return self._state
|
| 231 |
+
|
| 232 |
+
def close(self):
|
| 233 |
+
if self._db:
|
| 234 |
+
self._db.close()
|
| 235 |
+
self._db = None
|
| 236 |
+
|
server/main.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI server exposing the OpenEnv HTTP API.
|
| 3 |
+
Endpoints: POST /reset, POST /step, GET /state
|
| 4 |
+
Also includes: GET /tasks (list available tasks), GET /health
|
| 5 |
+
"""
|
| 6 |
+
import asyncio
|
| 7 |
+
import time
|
| 8 |
+
import statistics
|
| 9 |
+
from typing import Dict, Optional
|
| 10 |
+
from contextlib import asynccontextmanager
|
| 11 |
+
|
| 12 |
+
from fastapi import FastAPI, HTTPException, Header
|
| 13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
+
from pydantic import BaseModel
|
| 15 |
+
|
| 16 |
+
from .models import SQLDebugAction, SQLDebugObservation, EpisodeState
|
| 17 |
+
from .env import SQLDebugEnv, TASKS
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Session management: one env instance per session
|
| 21 |
+
# For HF Space: allow up to 64 concurrent sessions
|
| 22 |
+
MAX_SESSIONS = 64
|
| 23 |
+
_sessions: Dict[str, SQLDebugEnv] = {}
|
| 24 |
+
_session_lock = asyncio.Lock()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@asynccontextmanager
|
| 28 |
+
async def lifespan(app: FastAPI):
|
| 29 |
+
yield
|
| 30 |
+
# Cleanup all sessions on shutdown
|
| 31 |
+
for env in _sessions.values():
|
| 32 |
+
env.close()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
app = FastAPI(
|
| 36 |
+
title="SQL Debug Environment",
|
| 37 |
+
description="OpenEnv-compliant SQL query debugging environment for RL agent training.",
|
| 38 |
+
version="0.1.0",
|
| 39 |
+
lifespan=lifespan
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
app.add_middleware(
|
| 43 |
+
CORSMiddleware,
|
| 44 |
+
allow_origins=["*"],
|
| 45 |
+
allow_methods=["*"],
|
| 46 |
+
allow_headers=["*"],
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@app.get("/")
|
| 51 |
+
async def root():
|
| 52 |
+
return {
|
| 53 |
+
"name": "sql-debug-env",
|
| 54 |
+
"status": "ok",
|
| 55 |
+
"message": "Use /health, /tasks, /reset, /step, /state, /benchmark",
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@app.get("/favicon.ico", status_code=204)
|
| 60 |
+
async def favicon():
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ResetRequest(BaseModel):
|
| 65 |
+
task_id: Optional[str] = "easy_syntax_fix"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class StepRequest(BaseModel):
|
| 69 |
+
action: SQLDebugAction
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
async def get_or_create_session(session_id: str, task_id: str = "easy_syntax_fix") -> SQLDebugEnv:
|
| 73 |
+
async with _session_lock:
|
| 74 |
+
if session_id not in _sessions:
|
| 75 |
+
if len(_sessions) >= MAX_SESSIONS:
|
| 76 |
+
# Evict oldest session
|
| 77 |
+
oldest = next(iter(_sessions))
|
| 78 |
+
_sessions[oldest].close()
|
| 79 |
+
del _sessions[oldest]
|
| 80 |
+
_sessions[session_id] = SQLDebugEnv(task_id=task_id)
|
| 81 |
+
return _sessions[session_id]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@app.get("/health")
|
| 85 |
+
async def health():
|
| 86 |
+
return {"status": "ok", "sessions_active": len(_sessions)}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@app.get("/tasks")
|
| 90 |
+
async def list_tasks():
|
| 91 |
+
"""List all available tasks with metadata."""
|
| 92 |
+
return {
|
| 93 |
+
"tasks": [task.to_dict() for task in TASKS.values()]
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _stats(values: list[float]) -> Dict[str, float]:
|
| 98 |
+
ordered = sorted(values)
|
| 99 |
+
n = len(ordered)
|
| 100 |
+
p95_idx = max(0, int(n * 0.95) - 1)
|
| 101 |
+
return {
|
| 102 |
+
"avg_ms": round(statistics.mean(ordered), 3),
|
| 103 |
+
"p50_ms": round(statistics.median(ordered), 3),
|
| 104 |
+
"p95_ms": round(ordered[p95_idx], 3),
|
| 105 |
+
"n": n,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@app.get("/benchmark")
|
| 110 |
+
async def benchmark(runs: int = 20):
|
| 111 |
+
"""
|
| 112 |
+
Real-time benchmark endpoint (fresh measurements on every call).
|
| 113 |
+
Safe to call from dashboards/web pages for live verification.
|
| 114 |
+
"""
|
| 115 |
+
runs = max(1, min(runs, 100))
|
| 116 |
+
|
| 117 |
+
health_times: list[float] = []
|
| 118 |
+
tasks_times: list[float] = []
|
| 119 |
+
reset_times: list[float] = []
|
| 120 |
+
step_times: list[float] = []
|
| 121 |
+
|
| 122 |
+
bench_env = SQLDebugEnv(task_id="easy_syntax_fix")
|
| 123 |
+
try:
|
| 124 |
+
for _ in range(runs):
|
| 125 |
+
t0 = time.perf_counter()
|
| 126 |
+
_ = {"status": "ok", "sessions_active": len(_sessions)}
|
| 127 |
+
health_times.append((time.perf_counter() - t0) * 1000)
|
| 128 |
+
|
| 129 |
+
t0 = time.perf_counter()
|
| 130 |
+
_ = [task.to_dict() for task in TASKS.values()]
|
| 131 |
+
tasks_times.append((time.perf_counter() - t0) * 1000)
|
| 132 |
+
|
| 133 |
+
t0 = time.perf_counter()
|
| 134 |
+
await bench_env.reset()
|
| 135 |
+
reset_times.append((time.perf_counter() - t0) * 1000)
|
| 136 |
+
|
| 137 |
+
t0 = time.perf_counter()
|
| 138 |
+
await bench_env.step(SQLDebugAction(action_type="inspect_schema"))
|
| 139 |
+
step_times.append((time.perf_counter() - t0) * 1000)
|
| 140 |
+
finally:
|
| 141 |
+
bench_env.close()
|
| 142 |
+
|
| 143 |
+
return {
|
| 144 |
+
"benchmark": {
|
| 145 |
+
"runs": runs,
|
| 146 |
+
"task_id": "easy_syntax_fix",
|
| 147 |
+
"timestamp_epoch_ms": int(time.time() * 1000),
|
| 148 |
+
"results": {
|
| 149 |
+
"health": _stats(health_times),
|
| 150 |
+
"tasks": _stats(tasks_times),
|
| 151 |
+
"reset": _stats(reset_times),
|
| 152 |
+
"step_inspect_schema": _stats(step_times),
|
| 153 |
+
},
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@app.post("/reset")
|
| 159 |
+
async def reset(
|
| 160 |
+
request: ResetRequest = ResetRequest(),
|
| 161 |
+
x_session_id: Optional[str] = Header(default=None)
|
| 162 |
+
):
|
| 163 |
+
"""
|
| 164 |
+
Reset the environment for a new episode.
|
| 165 |
+
|
| 166 |
+
Returns initial observation with task description and broken query.
|
| 167 |
+
"""
|
| 168 |
+
session_id = x_session_id or "default"
|
| 169 |
+
task_id = request.task_id or "easy_syntax_fix"
|
| 170 |
+
|
| 171 |
+
if task_id not in TASKS:
|
| 172 |
+
raise HTTPException(status_code=400, detail=f"Unknown task_id: {task_id}. Valid: {list(TASKS.keys())}")
|
| 173 |
+
|
| 174 |
+
# Always create fresh env on reset
|
| 175 |
+
async with _session_lock:
|
| 176 |
+
if session_id in _sessions:
|
| 177 |
+
_sessions[session_id].close()
|
| 178 |
+
_sessions[session_id] = SQLDebugEnv(task_id=task_id)
|
| 179 |
+
|
| 180 |
+
env = _sessions[session_id]
|
| 181 |
+
observation, info = await env.reset()
|
| 182 |
+
|
| 183 |
+
return {
|
| 184 |
+
"observation": observation.model_dump(),
|
| 185 |
+
"info": info,
|
| 186 |
+
"reward": None,
|
| 187 |
+
"done": False
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@app.post("/step")
|
| 192 |
+
async def step(
|
| 193 |
+
request: StepRequest,
|
| 194 |
+
x_session_id: Optional[str] = Header(default=None)
|
| 195 |
+
):
|
| 196 |
+
"""
|
| 197 |
+
Execute one action in the environment.
|
| 198 |
+
|
| 199 |
+
Action types:
|
| 200 |
+
- submit_query: Submit SQL for evaluation (requires 'query' field)
|
| 201 |
+
- inspect_schema: Get table schema (free action)
|
| 202 |
+
- inspect_error: Get last error message (free action)
|
| 203 |
+
- inspect_sample: Get sample rows from table (requires 'table_name')
|
| 204 |
+
- reset_query: Reset to original broken query (small penalty)
|
| 205 |
+
"""
|
| 206 |
+
session_id = x_session_id or "default"
|
| 207 |
+
|
| 208 |
+
if session_id not in _sessions:
|
| 209 |
+
raise HTTPException(status_code=400, detail="Session not found. Call /reset first.")
|
| 210 |
+
|
| 211 |
+
env = _sessions[session_id]
|
| 212 |
+
|
| 213 |
+
try:
|
| 214 |
+
observation, reward, done, info = await env.step(request.action)
|
| 215 |
+
except RuntimeError as e:
|
| 216 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 217 |
+
except ValueError as e:
|
| 218 |
+
raise HTTPException(status_code=422, detail=str(e))
|
| 219 |
+
|
| 220 |
+
return {
|
| 221 |
+
"observation": observation.model_dump(),
|
| 222 |
+
"reward": reward,
|
| 223 |
+
"done": done,
|
| 224 |
+
"info": info
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@app.get("/state")
|
| 229 |
+
async def state(x_session_id: Optional[str] = Header(default=None)):
|
| 230 |
+
"""Return current full episode state."""
|
| 231 |
+
session_id = x_session_id or "default"
|
| 232 |
+
|
| 233 |
+
if session_id not in _sessions:
|
| 234 |
+
raise HTTPException(status_code=400, detail="No active session. Call /reset first.")
|
| 235 |
+
|
| 236 |
+
env = _sessions[session_id]
|
| 237 |
+
try:
|
| 238 |
+
current_state = env.get_state()
|
| 239 |
+
return current_state.model_dump()
|
| 240 |
+
except RuntimeError as e:
|
| 241 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 242 |
+
|
server/models.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Typed Pydantic models for the SQL Debug Environment.
|
| 3 |
+
Implements the OpenEnv spec: Observation, Action, Reward.
|
| 4 |
+
"""
|
| 5 |
+
from typing import Optional, List, Dict, Any
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
from enum import Enum
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ActionType(str, Enum):
|
| 11 |
+
SUBMIT_QUERY = "submit_query" # Submit a fixed SQL query for evaluation
|
| 12 |
+
INSPECT_SCHEMA = "inspect_schema" # Request schema info (costs 0 reward, gives info)
|
| 13 |
+
INSPECT_ERROR = "inspect_error" # Request error details (costs 0, gives stack trace)
|
| 14 |
+
INSPECT_SAMPLE = "inspect_sample" # Request 3 sample rows from a table
|
| 15 |
+
RESET_QUERY = "reset_query" # Reset to the original broken query (costs -0.05 penalty)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SQLDebugAction(BaseModel):
|
| 19 |
+
"""
|
| 20 |
+
Action model for the SQL Debug Environment.
|
| 21 |
+
|
| 22 |
+
The agent can either:
|
| 23 |
+
- submit_query: Submit a fixed SQL string for evaluation
|
| 24 |
+
- inspect_schema: Get table schema info (free action, no reward change)
|
| 25 |
+
- inspect_error: Get detailed error message from last query run
|
| 26 |
+
- inspect_sample: Get sample rows from a specified table
|
| 27 |
+
- reset_query: Go back to original broken query (costs -0.05 penalty)
|
| 28 |
+
"""
|
| 29 |
+
action_type: ActionType = Field(
|
| 30 |
+
description="Type of action to take"
|
| 31 |
+
)
|
| 32 |
+
query: Optional[str] = Field(
|
| 33 |
+
default=None,
|
| 34 |
+
description="SQL query string. Required when action_type is 'submit_query'."
|
| 35 |
+
)
|
| 36 |
+
table_name: Optional[str] = Field(
|
| 37 |
+
default=None,
|
| 38 |
+
description="Table name. Required when action_type is 'inspect_sample'."
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
class Config:
|
| 42 |
+
json_schema_extra = {
|
| 43 |
+
"example": {
|
| 44 |
+
"action_type": "submit_query",
|
| 45 |
+
"query": "SELECT u.name, COUNT(o.id) as order_count FROM users u LEFT JOIN orders o ON u.id = o.user_id GROUP BY u.id, u.name ORDER BY order_count DESC"
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class QueryResult(BaseModel):
|
| 51 |
+
"""Result of executing a SQL query."""
|
| 52 |
+
success: bool
|
| 53 |
+
rows: Optional[List[Dict[str, Any]]] = None
|
| 54 |
+
row_count: Optional[int] = None
|
| 55 |
+
error_message: Optional[str] = None
|
| 56 |
+
execution_time_ms: Optional[float] = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class SchemaInfo(BaseModel):
|
| 60 |
+
"""Database schema information."""
|
| 61 |
+
tables: Dict[str, List[Dict[str, str]]] # table_name -> list of {name, type, nullable}
|
| 62 |
+
sample_data: Optional[Dict[str, List[Dict[str, Any]]]] = None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SQLDebugObservation(BaseModel):
|
| 66 |
+
"""
|
| 67 |
+
Observation returned after each step.
|
| 68 |
+
|
| 69 |
+
Contains the current state of the debugging session:
|
| 70 |
+
- The original broken query (always visible)
|
| 71 |
+
- The agent's current best query
|
| 72 |
+
- Result of last action
|
| 73 |
+
- Progress indicators
|
| 74 |
+
- Schema/error info if requested
|
| 75 |
+
"""
|
| 76 |
+
task_id: str = Field(description="Current task identifier")
|
| 77 |
+
task_description: str = Field(description="Natural language description of the bug to fix")
|
| 78 |
+
original_query: str = Field(description="The original broken SQL query")
|
| 79 |
+
current_query: Optional[str] = Field(default=None, description="Agent's last submitted query")
|
| 80 |
+
expected_description: str = Field(description="Description of what the correct output should look like")
|
| 81 |
+
|
| 82 |
+
# Last action result
|
| 83 |
+
last_action_type: str
|
| 84 |
+
last_query_result: Optional[QueryResult] = None
|
| 85 |
+
|
| 86 |
+
# Progress
|
| 87 |
+
steps_taken: int
|
| 88 |
+
steps_remaining: int
|
| 89 |
+
current_score: float = Field(description="Current score 0.0-1.0 for this episode")
|
| 90 |
+
|
| 91 |
+
# Contextual help (populated based on action type)
|
| 92 |
+
schema_info: Optional[SchemaInfo] = None
|
| 93 |
+
error_details: Optional[str] = None
|
| 94 |
+
sample_rows: Optional[List[Dict[str, Any]]] = None
|
| 95 |
+
|
| 96 |
+
# Hints (unlocked after step 3 on easy, step 5 on medium/hard)
|
| 97 |
+
hint: Optional[str] = None
|
| 98 |
+
|
| 99 |
+
# Episode status
|
| 100 |
+
is_done: bool = False
|
| 101 |
+
success: bool = False
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class SQLDebugReward(BaseModel):
|
| 105 |
+
"""
|
| 106 |
+
Reward signal for the SQL Debug Environment.
|
| 107 |
+
|
| 108 |
+
Reward components (all sum to final reward):
|
| 109 |
+
- correctness: 0.0-0.6 based on row-level match vs expected output
|
| 110 |
+
- efficiency: 0.0-0.2 bonus for solving in fewer steps
|
| 111 |
+
- syntax_progress: 0.0-0.1 for getting a syntactically valid query (even if wrong)
|
| 112 |
+
- schema_bonus: 0.0-0.1 for queries that reference correct tables/columns
|
| 113 |
+
- penalties: negative values for reset_query, infinite loops, destructive SQL
|
| 114 |
+
"""
|
| 115 |
+
value: float = Field(ge=0.0, le=1.0, description="Total reward for this step")
|
| 116 |
+
correctness: float = Field(ge=0.0, le=0.6)
|
| 117 |
+
efficiency: float = Field(ge=0.0, le=0.2)
|
| 118 |
+
syntax_progress: float = Field(ge=0.0, le=0.1)
|
| 119 |
+
schema_bonus: float = Field(ge=0.0, le=0.1)
|
| 120 |
+
penalty: float = Field(ge=0.0, le=0.2, description="Penalty deduction magnitude (non-negative)")
|
| 121 |
+
breakdown: str = Field(description="Human-readable reward breakdown")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class EpisodeState(BaseModel):
|
| 125 |
+
"""Full internal state of an episode. Used by state() endpoint."""
|
| 126 |
+
task_id: str
|
| 127 |
+
task_difficulty: str
|
| 128 |
+
original_query: str
|
| 129 |
+
current_query: Optional[str]
|
| 130 |
+
best_score_so_far: float
|
| 131 |
+
steps_taken: int
|
| 132 |
+
max_steps: int
|
| 133 |
+
action_history: List[Dict[str, Any]]
|
| 134 |
+
reward_history: List[float]
|
| 135 |
+
is_done: bool
|
| 136 |
+
success: bool
|
| 137 |
+
db_schema: Dict[str, Any]
|
| 138 |
+
|
server/reward.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reward function for the SQL Debug Environment.
|
| 3 |
+
|
| 4 |
+
Reward is computed at every step (not just end of episode).
|
| 5 |
+
This provides dense, meaningful signal for RL training.
|
| 6 |
+
|
| 7 |
+
Reward components:
|
| 8 |
+
- correctness: 0.0–0.6 (row-level match vs expected)
|
| 9 |
+
- efficiency: 0.0–0.2 (bonus for solving quickly)
|
| 10 |
+
- syntax_progress: 0.0–0.1 (valid SQL even if wrong content)
|
| 11 |
+
- schema_bonus: 0.0–0.1 (correct tables/columns referenced)
|
| 12 |
+
- penalty: 0.0 to 0.2 (deduction for bad actions)
|
| 13 |
+
|
| 14 |
+
Total range: 0.0 to 1.0 (clamped to [0.0, 1.0])
|
| 15 |
+
"""
|
| 16 |
+
from typing import Optional, List, Dict, Any
|
| 17 |
+
from .models import SQLDebugReward
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def compute_reward(
|
| 21 |
+
action_type: str,
|
| 22 |
+
query_result: Optional[Dict[str, Any]],
|
| 23 |
+
grade_score: float,
|
| 24 |
+
steps_taken: int,
|
| 25 |
+
max_steps: int,
|
| 26 |
+
previous_best_score: float,
|
| 27 |
+
schema_tables: List[str],
|
| 28 |
+
submitted_query: Optional[str] = None,
|
| 29 |
+
) -> SQLDebugReward:
|
| 30 |
+
"""
|
| 31 |
+
Compute the full reward for a step.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
action_type: The action taken this step
|
| 35 |
+
query_result: Result dict from EpisodeDatabase.execute_query()
|
| 36 |
+
grade_score: 0.0-1.0 score from task grader
|
| 37 |
+
steps_taken: How many steps have been used (1-indexed)
|
| 38 |
+
max_steps: Maximum steps for this task
|
| 39 |
+
previous_best_score: Best grade score seen so far
|
| 40 |
+
schema_tables: List of valid table names in this task's DB
|
| 41 |
+
submitted_query: The SQL query string (if action was submit_query)
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
correctness = 0.0
|
| 45 |
+
efficiency = 0.0
|
| 46 |
+
syntax_progress = 0.0
|
| 47 |
+
schema_bonus = 0.0
|
| 48 |
+
penalty = 0.0 # deduction magnitude (non-negative)
|
| 49 |
+
|
| 50 |
+
if action_type == "submit_query":
|
| 51 |
+
# Correctness: primary signal
|
| 52 |
+
correctness = min(0.6, grade_score * 0.6)
|
| 53 |
+
|
| 54 |
+
# Syntax progress: reward for at least getting a valid query
|
| 55 |
+
if query_result and query_result.get("success"):
|
| 56 |
+
syntax_progress = 0.1
|
| 57 |
+
elif query_result and not query_result.get("success"):
|
| 58 |
+
# Partially reward if it's getting closer (fewer errors)
|
| 59 |
+
error = query_result.get("error_message", "")
|
| 60 |
+
if "no such column" in error.lower():
|
| 61 |
+
syntax_progress = 0.03 # Structure is right but wrong column
|
| 62 |
+
elif "no such table" in error.lower():
|
| 63 |
+
syntax_progress = 0.01
|
| 64 |
+
else:
|
| 65 |
+
syntax_progress = 0.0
|
| 66 |
+
|
| 67 |
+
# Schema bonus: correct table references
|
| 68 |
+
if submitted_query and schema_tables:
|
| 69 |
+
query_upper = submitted_query.upper()
|
| 70 |
+
tables_referenced = sum(
|
| 71 |
+
1 for t in schema_tables if t.upper() in query_upper
|
| 72 |
+
)
|
| 73 |
+
schema_bonus = min(0.1, (tables_referenced / len(schema_tables)) * 0.1)
|
| 74 |
+
|
| 75 |
+
# Efficiency bonus: reward solving with fewer steps
|
| 76 |
+
if grade_score >= 0.95: # Near-perfect solution
|
| 77 |
+
steps_fraction = steps_taken / max_steps
|
| 78 |
+
if steps_fraction <= 0.3:
|
| 79 |
+
efficiency = 0.2
|
| 80 |
+
elif steps_fraction <= 0.5:
|
| 81 |
+
efficiency = 0.15
|
| 82 |
+
elif steps_fraction <= 0.7:
|
| 83 |
+
efficiency = 0.1
|
| 84 |
+
else:
|
| 85 |
+
efficiency = 0.05
|
| 86 |
+
|
| 87 |
+
# Penalty: if score went DOWN from previous best (regressed)
|
| 88 |
+
if grade_score < previous_best_score - 0.1:
|
| 89 |
+
penalty = 0.05
|
| 90 |
+
|
| 91 |
+
elif action_type == "reset_query":
|
| 92 |
+
# Penalize resetting — agent should be making progress
|
| 93 |
+
penalty = 0.05
|
| 94 |
+
|
| 95 |
+
elif action_type in ("inspect_schema", "inspect_error", "inspect_sample"):
|
| 96 |
+
# Free information actions — small positive for using schema info
|
| 97 |
+
# (encourages agents to explore rather than blindly guess)
|
| 98 |
+
syntax_progress = 0.01
|
| 99 |
+
|
| 100 |
+
# Penalty: approaching step limit (urgency signal)
|
| 101 |
+
steps_remaining = max_steps - steps_taken
|
| 102 |
+
if steps_remaining <= 2 and grade_score < 0.5:
|
| 103 |
+
penalty += 0.03
|
| 104 |
+
|
| 105 |
+
total_raw = correctness + efficiency + syntax_progress + schema_bonus - penalty
|
| 106 |
+
total = round(max(0.0, min(1.0, total_raw)), 4)
|
| 107 |
+
|
| 108 |
+
breakdown = (
|
| 109 |
+
f"correctness={correctness:.3f} + "
|
| 110 |
+
f"efficiency={efficiency:.3f} + "
|
| 111 |
+
f"syntax={syntax_progress:.3f} + "
|
| 112 |
+
f"schema={schema_bonus:.3f} + "
|
| 113 |
+
f"penalty={penalty:.3f} = {total:.4f}"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
return SQLDebugReward(
|
| 117 |
+
value=total,
|
| 118 |
+
correctness=correctness,
|
| 119 |
+
efficiency=efficiency,
|
| 120 |
+
syntax_progress=syntax_progress,
|
| 121 |
+
schema_bonus=schema_bonus,
|
| 122 |
+
penalty=penalty,
|
| 123 |
+
breakdown=breakdown
|
| 124 |
+
)
|
| 125 |
+
|
server/tasks/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# sql-debug-env
|
| 2 |
+
|
server/tasks/base.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base class for all SQL Debug tasks."""
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import Dict, Any, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BaseTask(ABC):
|
| 7 |
+
"""
|
| 8 |
+
Abstract base for all tasks.
|
| 9 |
+
|
| 10 |
+
Each task defines:
|
| 11 |
+
- A broken SQL query (the one the agent must fix)
|
| 12 |
+
- A database schema (SQLite CREATE TABLE statements)
|
| 13 |
+
- Seed data (INSERT statements, deterministic)
|
| 14 |
+
- Expected output (what the correct query should return)
|
| 15 |
+
- A grader (compares agent output vs expected)
|
| 16 |
+
- Metadata (id, name, difficulty, description, hint)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
@property
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def task_id(self) -> str:
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
@property
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def name(self) -> str:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
@abstractmethod
|
| 31 |
+
def difficulty(self) -> str:
|
| 32 |
+
pass # "easy", "medium", "hard"
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
@abstractmethod
|
| 36 |
+
def description(self) -> str:
|
| 37 |
+
"""Natural language description given to the agent."""
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
@abstractmethod
|
| 42 |
+
def expected_output_description(self) -> str:
|
| 43 |
+
"""Describes what the correct output looks like."""
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def broken_query(self) -> str:
|
| 49 |
+
"""The SQL query with bugs that the agent must fix."""
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
@abstractmethod
|
| 54 |
+
def schema_sql(self) -> str:
|
| 55 |
+
"""SQLite CREATE TABLE statements."""
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
@abstractmethod
|
| 60 |
+
def seed_data_sql(self) -> str:
|
| 61 |
+
"""INSERT statements for deterministic test data."""
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
@abstractmethod
|
| 66 |
+
def expected_output(self) -> List[Dict[str, Any]]:
|
| 67 |
+
"""
|
| 68 |
+
The exact rows the correct query should return.
|
| 69 |
+
Used by the grader to score the agent's output.
|
| 70 |
+
Must be deterministic and match seed_data_sql exactly.
|
| 71 |
+
"""
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def hint(self) -> str:
|
| 76 |
+
"""Optional hint shown after N steps. Override in subclass."""
|
| 77 |
+
return ""
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def max_steps(self) -> int:
|
| 81 |
+
"""Maximum steps for this task."""
|
| 82 |
+
return {"easy": 10, "medium": 20, "hard": 30}.get(self.difficulty, 20)
|
| 83 |
+
|
| 84 |
+
def grade(self, actual_rows: Optional[List[Dict[str, Any]]]) -> float:
|
| 85 |
+
"""
|
| 86 |
+
Grade the agent's query output vs expected output.
|
| 87 |
+
Returns a score 0.0-1.0.
|
| 88 |
+
|
| 89 |
+
Scoring:
|
| 90 |
+
- 1.0: exact match (correct rows, correct order if ORDER BY expected)
|
| 91 |
+
- 0.5-0.9: partial match (subset of correct rows, or wrong order)
|
| 92 |
+
- 0.1-0.4: syntactically valid but wrong content
|
| 93 |
+
- 0.0: null result, syntax error, or empty when non-empty expected
|
| 94 |
+
"""
|
| 95 |
+
if not actual_rows:
|
| 96 |
+
return 0.0
|
| 97 |
+
|
| 98 |
+
expected = self.expected_output
|
| 99 |
+
|
| 100 |
+
if not expected:
|
| 101 |
+
# Expected empty result
|
| 102 |
+
return 1.0 if len(actual_rows) == 0 else 0.0
|
| 103 |
+
|
| 104 |
+
# Exact row count match
|
| 105 |
+
if len(actual_rows) != len(expected):
|
| 106 |
+
# Partial credit for getting some rows right
|
| 107 |
+
overlap = self._count_matching_rows(actual_rows, expected)
|
| 108 |
+
return round(min(0.5, overlap / max(len(expected), 1) * 0.5), 3)
|
| 109 |
+
|
| 110 |
+
# Check row-by-row match (order-sensitive if task requires it)
|
| 111 |
+
matching = self._count_matching_rows(actual_rows, expected)
|
| 112 |
+
score = matching / len(expected)
|
| 113 |
+
|
| 114 |
+
# Check column names match
|
| 115 |
+
if actual_rows and expected:
|
| 116 |
+
actual_cols = set(actual_rows[0].keys())
|
| 117 |
+
expected_cols = set(expected[0].keys())
|
| 118 |
+
if actual_cols != expected_cols:
|
| 119 |
+
score *= 0.7 # Penalty for wrong columns
|
| 120 |
+
|
| 121 |
+
return round(score, 3)
|
| 122 |
+
|
| 123 |
+
def _count_matching_rows(
|
| 124 |
+
self,
|
| 125 |
+
actual: List[Dict[str, Any]],
|
| 126 |
+
expected: List[Dict[str, Any]]
|
| 127 |
+
) -> int:
|
| 128 |
+
"""Count how many actual rows match expected rows (normalized comparison)."""
|
| 129 |
+
matches = 0
|
| 130 |
+
expected_normalized = [self._normalize_row(r) for r in expected]
|
| 131 |
+
|
| 132 |
+
for i, actual_row in enumerate(actual):
|
| 133 |
+
actual_norm = self._normalize_row(actual_row)
|
| 134 |
+
if i < len(expected_normalized):
|
| 135 |
+
# Positional match (respects ORDER BY)
|
| 136 |
+
if actual_norm == expected_normalized[i]:
|
| 137 |
+
matches += 1
|
| 138 |
+
else:
|
| 139 |
+
# Extra rows don't count
|
| 140 |
+
break
|
| 141 |
+
|
| 142 |
+
return matches
|
| 143 |
+
|
| 144 |
+
def _normalize_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
|
| 145 |
+
"""Normalize a row for comparison: lowercase keys, string-normalize values."""
|
| 146 |
+
normalized = {}
|
| 147 |
+
for k, v in row.items():
|
| 148 |
+
key = k.lower().strip()
|
| 149 |
+
if isinstance(v, float):
|
| 150 |
+
val = round(v, 2)
|
| 151 |
+
elif isinstance(v, str):
|
| 152 |
+
val = v.strip()
|
| 153 |
+
else:
|
| 154 |
+
val = v
|
| 155 |
+
normalized[key] = val
|
| 156 |
+
return normalized
|
| 157 |
+
|
| 158 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 159 |
+
return {
|
| 160 |
+
"task_id": self.task_id,
|
| 161 |
+
"name": self.name,
|
| 162 |
+
"difficulty": self.difficulty,
|
| 163 |
+
"description": self.description,
|
| 164 |
+
"expected_output_description": self.expected_output_description,
|
| 165 |
+
"broken_query": self.broken_query,
|
| 166 |
+
"max_steps": self.max_steps,
|
| 167 |
+
"hint": self.hint
|
| 168 |
+
}
|
| 169 |
+
|
server/tasks/task_easy.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TASK 1 — EASY: Syntax Error Fix
|
| 3 |
+
Difficulty: Easy
|
| 4 |
+
Bug type: Simple syntax errors (typo in keyword, missing alias, wrong column name)
|
| 5 |
+
Max steps: 10
|
| 6 |
+
Expected baseline model score: 0.8-1.0
|
| 7 |
+
"""
|
| 8 |
+
from typing import List, Dict, Any
|
| 9 |
+
from .base import BaseTask
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class EasyTask(BaseTask):
|
| 13 |
+
"""
|
| 14 |
+
Scenario: An e-commerce company wants to find the top 5 customers
|
| 15 |
+
by total order value. The query has a syntax error:
|
| 16 |
+
uses 'GRUP BY' instead of 'GROUP BY' and references wrong column alias.
|
| 17 |
+
|
| 18 |
+
Database: customers, orders, order_items
|
| 19 |
+
Bug 1: 'GRUP BY' typo
|
| 20 |
+
Bug 2: ORDER BY references 'total' but SELECT aliases it as 'total_value'
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def task_id(self) -> str:
|
| 25 |
+
return "easy_syntax_fix"
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def name(self) -> str:
|
| 29 |
+
return "Top Customers by Revenue — Syntax Error Fix"
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def difficulty(self) -> str:
|
| 33 |
+
return "easy"
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def description(self) -> str:
|
| 37 |
+
return """You are debugging a SQL query for an e-commerce analytics dashboard.
|
| 38 |
+
|
| 39 |
+
The query is supposed to find the top 5 customers by their total order value
|
| 40 |
+
(sum of quantity * unit_price across all their orders).
|
| 41 |
+
|
| 42 |
+
The query has 2 syntax/reference bugs that prevent it from running:
|
| 43 |
+
1. A typo in a SQL keyword
|
| 44 |
+
2. An ORDER BY clause that references a column alias incorrectly
|
| 45 |
+
|
| 46 |
+
Fix both bugs so the query runs and returns the correct result.
|
| 47 |
+
|
| 48 |
+
The result should show: customer_name, total_value (rounded to 2 decimal places),
|
| 49 |
+
ordered from highest to lowest, top 5 only."""
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def expected_output_description(self) -> str:
|
| 53 |
+
return "5 rows: customer_name, total_value (DESC order). Alice Chen should be first with 2847.50."
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def broken_query(self) -> str:
|
| 57 |
+
return """SELECT
|
| 58 |
+
c.name AS customer_name,
|
| 59 |
+
ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_value
|
| 60 |
+
FROM customers c
|
| 61 |
+
JOIN orders o ON c.id = o.customer_id
|
| 62 |
+
JOIN order_items oi ON o.id = oi.order_id
|
| 63 |
+
GRUP BY c.id, c.name
|
| 64 |
+
ORDER BY total DESC
|
| 65 |
+
LIMIT 5"""
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def schema_sql(self) -> str:
|
| 69 |
+
return """
|
| 70 |
+
CREATE TABLE customers (
|
| 71 |
+
id INTEGER PRIMARY KEY,
|
| 72 |
+
name TEXT NOT NULL,
|
| 73 |
+
email TEXT UNIQUE NOT NULL,
|
| 74 |
+
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
| 75 |
+
);
|
| 76 |
+
|
| 77 |
+
CREATE TABLE orders (
|
| 78 |
+
id INTEGER PRIMARY KEY,
|
| 79 |
+
customer_id INTEGER NOT NULL,
|
| 80 |
+
order_date TEXT NOT NULL,
|
| 81 |
+
status TEXT DEFAULT 'completed',
|
| 82 |
+
FOREIGN KEY (customer_id) REFERENCES customers(id)
|
| 83 |
+
);
|
| 84 |
+
|
| 85 |
+
CREATE TABLE order_items (
|
| 86 |
+
id INTEGER PRIMARY KEY,
|
| 87 |
+
order_id INTEGER NOT NULL,
|
| 88 |
+
product_name TEXT NOT NULL,
|
| 89 |
+
quantity INTEGER NOT NULL,
|
| 90 |
+
unit_price REAL NOT NULL,
|
| 91 |
+
FOREIGN KEY (order_id) REFERENCES orders(id)
|
| 92 |
+
)"""
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def seed_data_sql(self) -> str:
|
| 96 |
+
return """
|
| 97 |
+
INSERT INTO customers VALUES (1,'Alice Chen','alice@example.com','2023-01-01');
|
| 98 |
+
INSERT INTO customers VALUES (2,'Bob Kumar','bob@example.com','2023-01-05');
|
| 99 |
+
INSERT INTO customers VALUES (3,'Carol White','carol@example.com','2023-01-10');
|
| 100 |
+
INSERT INTO customers VALUES (4,'David Park','david@example.com','2023-02-01');
|
| 101 |
+
INSERT INTO customers VALUES (5,'Eva Rodriguez','eva@example.com','2023-02-15');
|
| 102 |
+
INSERT INTO customers VALUES (6,'Frank Liu','frank@example.com','2023-03-01');
|
| 103 |
+
|
| 104 |
+
INSERT INTO orders VALUES (1,1,'2023-06-01','completed');
|
| 105 |
+
INSERT INTO orders VALUES (2,1,'2023-07-15','completed');
|
| 106 |
+
INSERT INTO orders VALUES (3,2,'2023-06-10','completed');
|
| 107 |
+
INSERT INTO orders VALUES (4,3,'2023-06-20','completed');
|
| 108 |
+
INSERT INTO orders VALUES (5,3,'2023-08-01','completed');
|
| 109 |
+
INSERT INTO orders VALUES (6,4,'2023-07-01','completed');
|
| 110 |
+
INSERT INTO orders VALUES (7,5,'2023-07-20','completed');
|
| 111 |
+
INSERT INTO orders VALUES (8,5,'2023-08-10','completed');
|
| 112 |
+
INSERT INTO orders VALUES (9,6,'2023-09-01','completed');
|
| 113 |
+
|
| 114 |
+
INSERT INTO order_items VALUES (1,1,'Laptop',1,1200.00);
|
| 115 |
+
INSERT INTO order_items VALUES (2,1,'Mouse',2,25.00);
|
| 116 |
+
INSERT INTO order_items VALUES (3,2,'Keyboard',1,150.00);
|
| 117 |
+
INSERT INTO order_items VALUES (4,2,'Monitor',1,450.00);
|
| 118 |
+
INSERT INTO order_items VALUES (5,2,'Webcam',1,97.50);
|
| 119 |
+
INSERT INTO order_items VALUES (6,3,'Headphones',1,350.00);
|
| 120 |
+
INSERT INTO order_items VALUES (7,3,'USB Hub',2,45.00);
|
| 121 |
+
INSERT INTO order_items VALUES (8,4,'Tablet',1,600.00);
|
| 122 |
+
INSERT INTO order_items VALUES (9,4,'Case',1,35.00);
|
| 123 |
+
INSERT INTO order_items VALUES (10,5,'Charger',2,30.00);
|
| 124 |
+
INSERT INTO order_items VALUES (11,5,'Cable',3,15.00);
|
| 125 |
+
INSERT INTO order_items VALUES (12,6,'Desk Lamp',1,85.00);
|
| 126 |
+
INSERT INTO order_items VALUES (13,6,'Chair Mat',1,60.00);
|
| 127 |
+
INSERT INTO order_items VALUES (14,7,'Speakers',1,220.00);
|
| 128 |
+
INSERT INTO order_items VALUES (15,7,'Microphone',1,180.00);
|
| 129 |
+
INSERT INTO order_items VALUES (16,8,'Webcam',1,97.50);
|
| 130 |
+
INSERT INTO order_items VALUES (17,9,'Monitor',1,450.00)"""
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def expected_output(self) -> List[Dict[str, Any]]:
|
| 134 |
+
# Alice: 1200+50+150+450+97.50 = 1947.50 (orders 1,2)
|
| 135 |
+
# Wait: recalculate
|
| 136 |
+
# Alice order 1: laptop 1200 + mouse 2*25=50 = 1250
|
| 137 |
+
# Alice order 2: keyboard 150 + monitor 450 + webcam 97.50 = 697.50
|
| 138 |
+
# Alice total: 1947.50 — but let me recalculate with all items
|
| 139 |
+
# Actually: 1200+50+150+450+97.50 = 1947.50
|
| 140 |
+
# Carol: tablet 600 + case 35 + charger 60 + cable 45 = 740
|
| 141 |
+
# Eva: speakers 220 + micro 180 + webcam 97.50 = 497.50
|
| 142 |
+
# Bob: headphones 350 + hub 90 = 440
|
| 143 |
+
# Frank: lamp 85 + mat 60 + monitor 450 = 595
|
| 144 |
+
# David: lamp 85 + mat 60 = 145 — wait David is order 6
|
| 145 |
+
# Order 6 items 12,13: lamp 85 + mat 60 = 145
|
| 146 |
+
return [
|
| 147 |
+
{"customer_name": "Alice Chen", "total_value": 1947.50},
|
| 148 |
+
{"customer_name": "Carol White", "total_value": 740.00},
|
| 149 |
+
{"customer_name": "Frank Liu", "total_value": 595.00},
|
| 150 |
+
{"customer_name": "Eva Rodriguez", "total_value": 497.50},
|
| 151 |
+
{"customer_name": "Bob Kumar", "total_value": 440.00},
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def hint(self) -> str:
|
| 156 |
+
return "Hint: Check every SQL keyword spelling carefully. Also check that your ORDER BY column name exactly matches the alias in your SELECT clause."
|
| 157 |
+
|
server/tasks/task_hard.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TASK 3 — HARD: Multi-bug + Optimization
|
| 3 |
+
Difficulty: Hard
|
| 4 |
+
Bug types:
|
| 5 |
+
1. Correlated subquery returns wrong scope
|
| 6 |
+
2. Window function partition incorrect
|
| 7 |
+
3. CTE has circular logic bug
|
| 8 |
+
4. Off-by-one in date range
|
| 9 |
+
5. Missing DISTINCT causing row duplication
|
| 10 |
+
Max steps: 30
|
| 11 |
+
Expected baseline model score: 0.0-0.3 (frontier models barely pass)
|
| 12 |
+
"""
|
| 13 |
+
from typing import List, Dict, Any
|
| 14 |
+
from .base import BaseTask
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class HardTask(BaseTask):
|
| 18 |
+
"""
|
| 19 |
+
Scenario: SaaS product analytics — find users who:
|
| 20 |
+
1. Signed up in Q1 2023 (Jan 1 – Mar 31)
|
| 21 |
+
2. Made at least 2 purchases in their first 30 days
|
| 22 |
+
3. Return their: user_id, username, signup_date,
|
| 23 |
+
first_purchase_date, days_to_first_purchase,
|
| 24 |
+
purchases_in_first_30_days, total_lifetime_value
|
| 25 |
+
|
| 26 |
+
Bugs:
|
| 27 |
+
1. Date range is '>= 2023-01-01 AND < 2023-04-01' but query uses '<= 2023-03-31'
|
| 28 |
+
(off by 1 for timestamps — in SQLite string comparison this is actually fine,
|
| 29 |
+
but the REAL bug is the upper bound uses wrong column: filters on purchase_date
|
| 30 |
+
instead of signup_date in the CTE)
|
| 31 |
+
2. The window function for running total uses PARTITION BY user_id but
|
| 32 |
+
ORDER BY is missing — gives wrong cumulative values
|
| 33 |
+
3. HAVING clause uses COUNT(*) but should use COUNT(DISTINCT purchase_id)
|
| 34 |
+
due to JOIN multiplication
|
| 35 |
+
4. The subquery for first_purchase_date is not correlated properly
|
| 36 |
+
(missing WHERE p.user_id = u.id)
|
| 37 |
+
5. days_to_first_purchase calculation uses wrong date subtraction direction
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def task_id(self) -> str:
|
| 42 |
+
return "hard_multi_bug"
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def name(self) -> str:
|
| 46 |
+
return "SaaS Cohort Activation Report — Multi-Bug Fix"
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def difficulty(self) -> str:
|
| 50 |
+
return "hard"
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def description(self) -> str:
|
| 54 |
+
return """You are debugging a SaaS product analytics query.
|
| 55 |
+
|
| 56 |
+
The query should identify "activated users": users who signed up in Q1 2023
|
| 57 |
+
AND made at least 2 purchases within their first 30 days of signup.
|
| 58 |
+
|
| 59 |
+
For each activated user, return:
|
| 60 |
+
- user_id (INTEGER)
|
| 61 |
+
- username (TEXT)
|
| 62 |
+
- signup_date (TEXT, YYYY-MM-DD)
|
| 63 |
+
- first_purchase_date (TEXT, YYYY-MM-DD)
|
| 64 |
+
- days_to_first_purchase (INTEGER, how many days after signup they first purchased)
|
| 65 |
+
- purchases_in_first_30_days (INTEGER)
|
| 66 |
+
- total_lifetime_value (REAL, sum of all their purchases ever, rounded to 2 dp)
|
| 67 |
+
|
| 68 |
+
Results ordered by total_lifetime_value DESC.
|
| 69 |
+
|
| 70 |
+
The query has FIVE bugs — some are logic errors, one is a missing correlation
|
| 71 |
+
in a subquery, one is an incorrect window function, one causes row duplication.
|
| 72 |
+
You must find and fix all of them to get the correct result.
|
| 73 |
+
|
| 74 |
+
Q1 2023 = signup_date >= '2023-01-01' AND signup_date <= '2023-03-31'"""
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def expected_output_description(self) -> str:
|
| 78 |
+
return "2 rows: users who made 2+ purchases in first 30 days. Maya Torres first (higher LTV), then James Osei."
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def broken_query(self) -> str:
|
| 82 |
+
return """WITH q1_users AS (
|
| 83 |
+
SELECT DISTINCT u.id, u.username, u.signup_date
|
| 84 |
+
FROM users u
|
| 85 |
+
JOIN purchases p ON u.id = p.user_id
|
| 86 |
+
WHERE u.signup_date >= '2023-01-01'
|
| 87 |
+
AND u.signup_date <= '2023-03-31'
|
| 88 |
+
AND p.purchase_date <= '2023-03-31'
|
| 89 |
+
),
|
| 90 |
+
user_purchase_stats AS (
|
| 91 |
+
SELECT
|
| 92 |
+
q.id AS user_id,
|
| 93 |
+
q.username,
|
| 94 |
+
q.signup_date,
|
| 95 |
+
(SELECT MIN(purchase_date) FROM purchases WHERE amount > 0) AS first_purchase_date,
|
| 96 |
+
COUNT(*) AS purchases_in_first_30_days,
|
| 97 |
+
SUM(SUM(p.amount)) OVER (PARTITION BY q.id) AS total_lifetime_value
|
| 98 |
+
FROM q1_users q
|
| 99 |
+
JOIN purchases p ON q.id = p.user_id
|
| 100 |
+
WHERE julianday(p.purchase_date) - julianday(q.signup_date) <= 30
|
| 101 |
+
GROUP BY q.id, q.username, q.signup_date
|
| 102 |
+
)
|
| 103 |
+
SELECT
|
| 104 |
+
user_id,
|
| 105 |
+
username,
|
| 106 |
+
signup_date,
|
| 107 |
+
first_purchase_date,
|
| 108 |
+
CAST(julianday(q1_users.signup_date) - julianday(first_purchase_date) AS INTEGER) AS days_to_first_purchase,
|
| 109 |
+
purchases_in_first_30_days,
|
| 110 |
+
ROUND(total_lifetime_value, 2) AS total_lifetime_value
|
| 111 |
+
FROM user_purchase_stats
|
| 112 |
+
WHERE purchases_in_first_30_days >= 2
|
| 113 |
+
ORDER BY total_lifetime_value DESC"""
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def schema_sql(self) -> str:
|
| 117 |
+
return """
|
| 118 |
+
CREATE TABLE users (
|
| 119 |
+
id INTEGER PRIMARY KEY,
|
| 120 |
+
username TEXT NOT NULL,
|
| 121 |
+
email TEXT UNIQUE,
|
| 122 |
+
signup_date TEXT NOT NULL,
|
| 123 |
+
plan TEXT DEFAULT 'free'
|
| 124 |
+
);
|
| 125 |
+
|
| 126 |
+
CREATE TABLE purchases (
|
| 127 |
+
id INTEGER PRIMARY KEY,
|
| 128 |
+
user_id INTEGER NOT NULL,
|
| 129 |
+
product_name TEXT NOT NULL,
|
| 130 |
+
amount REAL NOT NULL,
|
| 131 |
+
purchase_date TEXT NOT NULL,
|
| 132 |
+
FOREIGN KEY (user_id) REFERENCES users(id)
|
| 133 |
+
)"""
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def seed_data_sql(self) -> str:
|
| 137 |
+
return """
|
| 138 |
+
INSERT INTO users VALUES (1,'maya_torres','maya@ex.com','2023-01-15','pro');
|
| 139 |
+
INSERT INTO users VALUES (2,'james_osei','james@ex.com','2023-02-10','pro');
|
| 140 |
+
INSERT INTO users VALUES (3,'sophie_liang','sophie@ex.com','2023-03-05','free');
|
| 141 |
+
INSERT INTO users VALUES (4,'raj_mehta','raj@ex.com','2023-06-01','free');
|
| 142 |
+
INSERT INTO users VALUES (5,'anna_kovacs','anna@ex.com','2022-12-20','pro');
|
| 143 |
+
|
| 144 |
+
-- Maya: 2 purchases in first 30 days (days 5 and 18), more later
|
| 145 |
+
INSERT INTO purchases VALUES (1,1,'Pro Plan',99.00,'2023-01-20');
|
| 146 |
+
INSERT INTO purchases VALUES (2,1,'Add-on Pack',29.00,'2023-02-02');
|
| 147 |
+
INSERT INTO purchases VALUES (3,1,'Pro Renewal',99.00,'2023-04-15');
|
| 148 |
+
INSERT INTO purchases VALUES (4,1,'Consulting',150.00,'2023-07-01');
|
| 149 |
+
|
| 150 |
+
-- James: 2 purchases in first 30 days (days 3 and 25)
|
| 151 |
+
INSERT INTO purchases VALUES (5,2,'Starter Plan',49.00,'2023-02-13');
|
| 152 |
+
INSERT INTO purchases VALUES (6,2,'Storage Add-on',19.00,'2023-03-07');
|
| 153 |
+
INSERT INTO purchases VALUES (7,2,'Starter Renewal',49.00,'2023-05-10');
|
| 154 |
+
|
| 155 |
+
-- Sophie: only 1 purchase in first 30 days (should NOT qualify)
|
| 156 |
+
INSERT INTO purchases VALUES (8,3,'Free Trial Upgrade',9.00,'2023-03-10');
|
| 157 |
+
INSERT INTO purchases VALUES (9,3,'Pro Plan',99.00,'2023-04-20');
|
| 158 |
+
|
| 159 |
+
-- Raj: signed up Q2, not Q1 (should NOT qualify)
|
| 160 |
+
INSERT INTO purchases VALUES (10,4,'Starter Plan',49.00,'2023-06-05');
|
| 161 |
+
INSERT INTO purchases VALUES (11,4,'Add-on',19.00,'2023-06-10');
|
| 162 |
+
|
| 163 |
+
-- Anna: signed up Q4 2022, not Q1 2023 (should NOT qualify)
|
| 164 |
+
INSERT INTO purchases VALUES (12,5,'Pro Plan',99.00,'2023-01-01');
|
| 165 |
+
INSERT INTO purchases VALUES (13,5,'Consulting',150.00,'2023-03-15')"""
|
| 166 |
+
|
| 167 |
+
@property
|
| 168 |
+
def expected_output(self) -> List[Dict[str, Any]]:
|
| 169 |
+
# Maya: signup 2023-01-15, first purchase 2023-01-20 (day 5)
|
| 170 |
+
# purchases in 30 days: Jan-20 (day5), Feb-02 (day18) = 2 ✓
|
| 171 |
+
# total LTV: 99+29+99+150 = 377
|
| 172 |
+
# James: signup 2023-02-10, first purchase 2023-02-13 (day 3)
|
| 173 |
+
# purchases in 30 days: Feb-13 (day3), Mar-07 (day25) = 2 ✓
|
| 174 |
+
# total LTV: 49+19+49 = 117
|
| 175 |
+
return [
|
| 176 |
+
{
|
| 177 |
+
"user_id": 1,
|
| 178 |
+
"username": "maya_torres",
|
| 179 |
+
"signup_date": "2023-01-15",
|
| 180 |
+
"first_purchase_date": "2023-01-20",
|
| 181 |
+
"days_to_first_purchase": 5,
|
| 182 |
+
"purchases_in_first_30_days": 2,
|
| 183 |
+
"total_lifetime_value": 377.00
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"user_id": 2,
|
| 187 |
+
"username": "james_osei",
|
| 188 |
+
"signup_date": "2023-02-10",
|
| 189 |
+
"first_purchase_date": "2023-02-13",
|
| 190 |
+
"days_to_first_purchase": 3,
|
| 191 |
+
"purchases_in_first_30_days": 2,
|
| 192 |
+
"total_lifetime_value": 117.00
|
| 193 |
+
}
|
| 194 |
+
]
|
| 195 |
+
|
| 196 |
+
@property
|
| 197 |
+
def hint(self) -> str:
|
| 198 |
+
return "Hint: There are 5 bugs total. Check: (1) the subquery for first_purchase_date needs a WHERE correlation, (2) the date subtraction direction for days_to_first_purchase, (3) COUNT(*) vs COUNT(DISTINCT) when JOINs can multiply rows, (4) window functions need ORDER BY for meaningful results, (5) the q1_users CTE may be filtering on the wrong table's date column."
|
| 199 |
+
|
server/tasks/task_medium.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TASK 2 — MEDIUM: Logic Error Fix
|
| 3 |
+
Difficulty: Medium
|
| 4 |
+
Bug types: Wrong JOIN type causing missing rows, incorrect aggregation logic,
|
| 5 |
+
missing HAVING clause, wrong date filter
|
| 6 |
+
Max steps: 20
|
| 7 |
+
Expected baseline model score: 0.3-0.6
|
| 8 |
+
"""
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
+
from .base import BaseTask
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MediumTask(BaseTask):
|
| 14 |
+
"""
|
| 15 |
+
Scenario: HR analytics team wants monthly headcount and average salary
|
| 16 |
+
by department for the current year, including departments with zero employees
|
| 17 |
+
(i.e., departments that exist but no one joined this year).
|
| 18 |
+
|
| 19 |
+
Bugs:
|
| 20 |
+
1. Uses INNER JOIN instead of LEFT JOIN — excludes empty departments
|
| 21 |
+
2. Uses AVG(salary) over all employees instead of only those who joined this year
|
| 22 |
+
3. Missing: the date filter for 'this year' is applied in WHERE, breaking the LEFT JOIN
|
| 23 |
+
(should be in ON clause or use CASE)
|
| 24 |
+
4. GROUP BY missing department_id (ambiguous grouping)
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def task_id(self) -> str:
|
| 29 |
+
return "medium_logic_fix"
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def name(self) -> str:
|
| 33 |
+
return "Department Headcount Report — Logic Error Fix"
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def difficulty(self) -> str:
|
| 37 |
+
return "medium"
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def description(self) -> str:
|
| 41 |
+
return """You are debugging a HR analytics SQL query.
|
| 42 |
+
|
| 43 |
+
The query should produce a monthly department headcount report showing:
|
| 44 |
+
- department_name
|
| 45 |
+
- headcount: number of employees who joined IN 2023
|
| 46 |
+
- avg_salary: average salary of employees who joined IN 2023
|
| 47 |
+
- All departments must appear, even those with 0 new hires in 2023
|
| 48 |
+
|
| 49 |
+
The current query has 3 logic bugs:
|
| 50 |
+
1. It uses the wrong JOIN type, which silently drops departments with no 2023 hires
|
| 51 |
+
2. The WHERE clause on hire_date breaks the outer join semantics
|
| 52 |
+
3. The AVG calculation includes employees from all years, not just 2023
|
| 53 |
+
|
| 54 |
+
Fix these logic errors. The result should be ordered by department_name ascending."""
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def expected_output_description(self) -> str:
|
| 58 |
+
return "4 rows (all departments), headcount=0 for 'Legal', correct avg_salary only from 2023 hires."
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def broken_query(self) -> str:
|
| 62 |
+
return """SELECT
|
| 63 |
+
d.name AS department_name,
|
| 64 |
+
COUNT(e.id) AS headcount,
|
| 65 |
+
ROUND(AVG(e.salary), 2) AS avg_salary
|
| 66 |
+
FROM departments d
|
| 67 |
+
INNER JOIN employees e ON d.id = e.department_id
|
| 68 |
+
WHERE strftime('%Y', e.hire_date) = '2023'
|
| 69 |
+
GROUP BY d.name
|
| 70 |
+
ORDER BY department_name ASC"""
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def schema_sql(self) -> str:
|
| 74 |
+
return """
|
| 75 |
+
CREATE TABLE departments (
|
| 76 |
+
id INTEGER PRIMARY KEY,
|
| 77 |
+
name TEXT NOT NULL,
|
| 78 |
+
budget REAL
|
| 79 |
+
);
|
| 80 |
+
|
| 81 |
+
CREATE TABLE employees (
|
| 82 |
+
id INTEGER PRIMARY KEY,
|
| 83 |
+
name TEXT NOT NULL,
|
| 84 |
+
department_id INTEGER NOT NULL,
|
| 85 |
+
salary REAL NOT NULL,
|
| 86 |
+
hire_date TEXT NOT NULL,
|
| 87 |
+
FOREIGN KEY (department_id) REFERENCES departments(id)
|
| 88 |
+
)"""
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def seed_data_sql(self) -> str:
|
| 92 |
+
return """
|
| 93 |
+
INSERT INTO departments VALUES (1,'Engineering',500000);
|
| 94 |
+
INSERT INTO departments VALUES (2,'Marketing',200000);
|
| 95 |
+
INSERT INTO departments VALUES (3,'Sales',300000);
|
| 96 |
+
INSERT INTO departments VALUES (4,'Legal',150000);
|
| 97 |
+
|
| 98 |
+
INSERT INTO employees VALUES (1,'Ana Lima',1,95000,'2023-03-15');
|
| 99 |
+
INSERT INTO employees VALUES (2,'Ben Sharma',1,102000,'2023-06-01');
|
| 100 |
+
INSERT INTO employees VALUES (3,'Chris Wang',1,88000,'2022-01-10');
|
| 101 |
+
INSERT INTO employees VALUES (4,'Diana Patel',2,72000,'2023-04-20');
|
| 102 |
+
INSERT INTO employees VALUES (5,'Erik Johnson',2,68000,'2022-11-05');
|
| 103 |
+
INSERT INTO employees VALUES (6,'Fatima Al-Hassan',3,55000,'2023-01-08');
|
| 104 |
+
INSERT INTO employees VALUES (7,'George Okafor',3,61000,'2023-07-22');
|
| 105 |
+
INSERT INTO employees VALUES (8,'Hannah Kim',3,58000,'2022-05-30');
|
| 106 |
+
INSERT INTO employees VALUES (9,'Ivan Petrov',1,91000,'2022-08-14')"""
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def expected_output(self) -> List[Dict[str, Any]]:
|
| 110 |
+
# Engineering 2023 hires: Ana 95000, Ben 102000 → count=2, avg=98500
|
| 111 |
+
# Marketing 2023 hires: Diana 72000 → count=1, avg=72000
|
| 112 |
+
# Sales 2023 hires: Fatima 55000, George 61000 → count=2, avg=58000
|
| 113 |
+
# Legal 2023 hires: none → count=0, avg=NULL
|
| 114 |
+
return [
|
| 115 |
+
{"department_name": "Engineering", "headcount": 2, "avg_salary": 98500.00},
|
| 116 |
+
{"department_name": "Legal", "headcount": 0, "avg_salary": None},
|
| 117 |
+
{"department_name": "Marketing", "headcount": 1, "avg_salary": 72000.00},
|
| 118 |
+
{"department_name": "Sales", "headcount": 2, "avg_salary": 58000.00},
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
@property
|
| 122 |
+
def hint(self) -> str:
|
| 123 |
+
return "Hint: When you want ALL rows from the left table even when there's no match on the right, think about which JOIN type preserves those rows. Also, WHERE on a nullable column after a join changes join semantics — consider moving that condition."
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class MediumTaskGrader:
|
| 127 |
+
"""
|
| 128 |
+
Custom grader for medium task — handles NULL comparison.
|
| 129 |
+
"""
|
| 130 |
+
@staticmethod
|
| 131 |
+
def grade(actual: List[Dict]) -> float:
|
| 132 |
+
if not actual or len(actual) != 4:
|
| 133 |
+
return 0.0
|
| 134 |
+
|
| 135 |
+
# Sort both by dept name for comparison
|
| 136 |
+
actual_sorted = sorted(actual, key=lambda r: r.get("department_name", ""))
|
| 137 |
+
expected = [
|
| 138 |
+
{"department_name": "Engineering", "headcount": 2, "avg_salary": 98500.00},
|
| 139 |
+
{"department_name": "Legal", "headcount": 0, "avg_salary": None},
|
| 140 |
+
{"department_name": "Marketing", "headcount": 1, "avg_salary": 72000.00},
|
| 141 |
+
{"department_name": "Sales", "headcount": 2, "avg_salary": 58000.00},
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
matches = 0
|
| 145 |
+
for a, e in zip(actual_sorted, expected):
|
| 146 |
+
dept_ok = str(a.get("department_name","")).lower() == str(e["department_name"]).lower()
|
| 147 |
+
count_ok = int(a.get("headcount", -1)) == e["headcount"]
|
| 148 |
+
|
| 149 |
+
e_salary = e["avg_salary"]
|
| 150 |
+
a_salary = a.get("avg_salary")
|
| 151 |
+
if e_salary is None:
|
| 152 |
+
salary_ok = a_salary is None or a_salary == 0
|
| 153 |
+
else:
|
| 154 |
+
try:
|
| 155 |
+
salary_ok = abs(float(a_salary) - float(e_salary)) < 1.0
|
| 156 |
+
except (TypeError, ValueError):
|
| 157 |
+
salary_ok = False
|
| 158 |
+
|
| 159 |
+
if dept_ok and count_ok and salary_ok:
|
| 160 |
+
matches += 1
|
| 161 |
+
|
| 162 |
+
return round(matches / 4, 3)
|
| 163 |
+
|
tests/test_env.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import unittest
|
| 3 |
+
|
| 4 |
+
from server.env import SQLDebugEnv
|
| 5 |
+
from server.models import SQLDebugAction, ActionType
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TestEnv(unittest.TestCase):
|
| 9 |
+
def test_reset_and_inspect_schema(self):
|
| 10 |
+
async def run():
|
| 11 |
+
env = SQLDebugEnv(task_id="easy_syntax_fix")
|
| 12 |
+
obs, info = await env.reset()
|
| 13 |
+
self.assertFalse(obs.is_done)
|
| 14 |
+
|
| 15 |
+
action = SQLDebugAction(action_type=ActionType.INSPECT_SCHEMA)
|
| 16 |
+
obs2, reward, done, info2 = await env.step(action)
|
| 17 |
+
self.assertFalse(done)
|
| 18 |
+
self.assertIsNotNone(obs2.schema_info)
|
| 19 |
+
self.assertGreaterEqual(reward, 0.0)
|
| 20 |
+
|
| 21 |
+
asyncio.run(run())
|
| 22 |
+
|
| 23 |
+
def test_submit_broken_query_does_not_finish(self):
|
| 24 |
+
async def run():
|
| 25 |
+
env = SQLDebugEnv(task_id="easy_syntax_fix")
|
| 26 |
+
obs, _ = await env.reset()
|
| 27 |
+
|
| 28 |
+
action = SQLDebugAction(
|
| 29 |
+
action_type=ActionType.SUBMIT_QUERY,
|
| 30 |
+
query=env.task.broken_query,
|
| 31 |
+
)
|
| 32 |
+
obs2, reward, done, _ = await env.step(action)
|
| 33 |
+
|
| 34 |
+
self.assertFalse(done)
|
| 35 |
+
self.assertLessEqual(reward, 0.2)
|
| 36 |
+
self.assertGreaterEqual(reward, -1.0)
|
| 37 |
+
self.assertEqual(obs2.current_query, env.task.broken_query)
|
| 38 |
+
|
| 39 |
+
asyncio.run(run())
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
unittest.main()
|
| 44 |
+
|
tests/test_graders.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
|
| 3 |
+
from server.tasks.task_easy import EasyTask
|
| 4 |
+
from server.tasks.task_medium import MediumTask, MediumTaskGrader
|
| 5 |
+
from server.tasks.task_hard import HardTask
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TestGraders(unittest.TestCase):
|
| 9 |
+
def test_easy_grade_perfect(self):
|
| 10 |
+
task = EasyTask()
|
| 11 |
+
score = task.grade(task.expected_output)
|
| 12 |
+
self.assertAlmostEqual(score, 1.0, places=3)
|
| 13 |
+
|
| 14 |
+
def test_hard_grade_perfect(self):
|
| 15 |
+
task = HardTask()
|
| 16 |
+
score = task.grade(task.expected_output)
|
| 17 |
+
self.assertAlmostEqual(score, 1.0, places=3)
|
| 18 |
+
|
| 19 |
+
def test_easy_grade_empty(self):
|
| 20 |
+
task = EasyTask()
|
| 21 |
+
score = task.grade(None)
|
| 22 |
+
self.assertEqual(score, 0.0)
|
| 23 |
+
|
| 24 |
+
def test_medium_grader_perfect(self):
|
| 25 |
+
task = MediumTask()
|
| 26 |
+
score = MediumTaskGrader.grade(task.expected_output)
|
| 27 |
+
self.assertAlmostEqual(score, 1.0, places=3)
|
| 28 |
+
|
| 29 |
+
def test_medium_grader_partial(self):
|
| 30 |
+
# Flip one row's avg_salary so it no longer matches within tolerance.
|
| 31 |
+
task = MediumTask()
|
| 32 |
+
actual = [dict(r) for r in task.expected_output]
|
| 33 |
+
|
| 34 |
+
# Expected avg_salary is None for "Legal". Any non-None/non-zero value should fail.
|
| 35 |
+
for r in actual:
|
| 36 |
+
if r["department_name"] == "Legal":
|
| 37 |
+
r["avg_salary"] = 12345.0
|
| 38 |
+
|
| 39 |
+
score = MediumTaskGrader.grade(actual)
|
| 40 |
+
self.assertLess(score, 1.0)
|
| 41 |
+
self.assertAlmostEqual(score, 0.75, places=3)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
unittest.main()
|
| 46 |
+
|
tests/test_reward.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
|
| 3 |
+
from server.reward import compute_reward
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TestReward(unittest.TestCase):
|
| 7 |
+
def test_submit_query_perfect_reward(self):
|
| 8 |
+
reward = compute_reward(
|
| 9 |
+
action_type="submit_query",
|
| 10 |
+
query_result={"success": True},
|
| 11 |
+
grade_score=1.0,
|
| 12 |
+
steps_taken=1,
|
| 13 |
+
max_steps=10,
|
| 14 |
+
previous_best_score=0.0,
|
| 15 |
+
schema_tables=["t1", "t2"],
|
| 16 |
+
submitted_query="SELECT * FROM t1 JOIN t2",
|
| 17 |
+
)
|
| 18 |
+
self.assertAlmostEqual(reward.value, 1.0, places=4)
|
| 19 |
+
|
| 20 |
+
def test_reset_query_penalty(self):
|
| 21 |
+
reward = compute_reward(
|
| 22 |
+
action_type="reset_query",
|
| 23 |
+
query_result=None,
|
| 24 |
+
grade_score=0.0,
|
| 25 |
+
steps_taken=1,
|
| 26 |
+
max_steps=10,
|
| 27 |
+
previous_best_score=0.0,
|
| 28 |
+
schema_tables=[],
|
| 29 |
+
submitted_query=None,
|
| 30 |
+
)
|
| 31 |
+
self.assertAlmostEqual(reward.value, 0.0, places=4)
|
| 32 |
+
|
| 33 |
+
def test_inspect_schema_urgency_penalty(self):
|
| 34 |
+
# Make steps_remaining <= 2 and grade_score < 0.5 to trigger urgency penalty.
|
| 35 |
+
reward = compute_reward(
|
| 36 |
+
action_type="inspect_schema",
|
| 37 |
+
query_result=None,
|
| 38 |
+
grade_score=0.0,
|
| 39 |
+
steps_taken=8,
|
| 40 |
+
max_steps=9,
|
| 41 |
+
previous_best_score=0.0,
|
| 42 |
+
schema_tables=[],
|
| 43 |
+
submitted_query=None,
|
| 44 |
+
)
|
| 45 |
+
# syntax_progress=0.01, penalty=0.03 => total_raw=-0.02, clamped to 0.0
|
| 46 |
+
self.assertAlmostEqual(reward.value, 0.0, places=4)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
unittest.main()
|
| 51 |
+
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|