Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +81 -0
- README.md +231 -10
- __init__.py +16 -0
- client.py +79 -0
- grader.py +74 -0
- inference.py +194 -0
- models.py +38 -0
- openenv.yaml +7 -0
- openenv_sql_debug.egg-info/PKG-INFO +11 -0
- openenv_sql_debug.egg-info/SOURCES.txt +18 -0
- openenv_sql_debug.egg-info/dependency_links.txt +1 -0
- openenv_sql_debug.egg-info/entry_points.txt +2 -0
- openenv_sql_debug.egg-info/requires.txt +7 -0
- openenv_sql_debug.egg-info/top_level.txt +1 -0
- pyproject.toml +39 -0
- runner.py +19 -0
- server/__init__.py +11 -0
- server/app.py +58 -0
- server/requirements.txt +3 -0
- server/sql_debug_environment.py +160 -0
- tasks/__init__.py +0 -0
- tasks/task_easy.py +31 -0
- tasks/task_hard.py +51 -0
- tasks/task_medium.py +38 -0
- test.py +29 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=sql_debug
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 81 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,231 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Sql Debug
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: docker
|
| 7 |
-
pinned: false
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Sql Debug Environment Server
|
| 3 |
+
emoji: 🏒
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
+
- sql
|
| 13 |
+
- debugging
|
| 14 |
+
- optimization
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# 🏒 OpenEnv: SQL Debug Environment
|
| 18 |
+
|
| 19 |
+
An [OpenEnv](https://openenv.dev)-compliant environment where AI agents fix broken SQL queries and optimize slow ones against in-memory SQLite databases.
|
| 20 |
+
|
| 21 |
+
> ✅ **Validator:** `openenv validate` passes when the environment is wired up correctly
|
| 22 |
+
> 🚀 **Local API:** `https://abhinavthedev-sql-debug.hf.space`
|
| 23 |
+
> 📖 **Swagger UI:** `https://abhinavthedev-sql-debug.hf.space/docs`
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## 🎯 Environment Description
|
| 28 |
+
|
| 29 |
+
This environment simulates the work of a SQL engineer who must repair syntax errors, correct logic bugs, and improve query performance. Agents receive a schema, a broken or slow query, and a natural-language target description. They submit SQL queries, observe the execution result and query plan, and are scored on correctness and efficiency.
|
| 30 |
+
|
| 31 |
+
The environment is intentionally practical: each task mirrors a real debugging pattern used in analytics, reporting, and data engineering workflows.
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## 📋 Tasks
|
| 36 |
+
|
| 37 |
+
### Task 1 - Syntax Fix *(Easy)*
|
| 38 |
+
**Task ID:** `syntax_fix_001`
|
| 39 |
+
|
| 40 |
+
**Objective:** Fix a malformed query so it returns all orders where `amount > 500`.
|
| 41 |
+
|
| 42 |
+
| Field | Description |
|
| 43 |
+
|---|---|
|
| 44 |
+
| `schema` | `orders` table with `id`, `customer`, `amount`, `order_date` |
|
| 45 |
+
| `broken_query` | `SELEC * FORM orders WERE amount > 500` |
|
| 46 |
+
| `target` | Return all orders where amount is greater than 500 |
|
| 47 |
+
|
| 48 |
+
**Max steps:** 5 | **Difficulty:** Easy
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
### Task 2 - Logic Fix *(Medium)*
|
| 53 |
+
**Task ID:** `logic_fix_001`
|
| 54 |
+
|
| 55 |
+
**Objective:** Correct a join bug so only employees in valid departments are returned.
|
| 56 |
+
|
| 57 |
+
| Field | Description |
|
| 58 |
+
|---|---|
|
| 59 |
+
| `schema` | `employees` and `departments` tables |
|
| 60 |
+
| `broken_query` | Query uses `LEFT JOIN` but should exclude missing departments |
|
| 61 |
+
| `target` | Return employees in departments with budget > 400000 |
|
| 62 |
+
|
| 63 |
+
**Max steps:** 8 | **Difficulty:** Medium
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
### Task 3 - Query Optimization *(Hard)*
|
| 68 |
+
**Task ID:** `optimize_001`
|
| 69 |
+
|
| 70 |
+
**Objective:** Rewrite a correlated subquery into an efficient CTE or grouped subquery.
|
| 71 |
+
|
| 72 |
+
| Field | Description |
|
| 73 |
+
|---|---|
|
| 74 |
+
| `schema` | `transactions` table with generated sample rows |
|
| 75 |
+
| `broken_query` | Correlated subquery that scans per row |
|
| 76 |
+
| `target` | Return completed transactions above the user's average amount |
|
| 77 |
+
|
| 78 |
+
**Max steps:** 10 | **Difficulty:** Hard
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
## 🔌 API Reference
|
| 83 |
+
|
| 84 |
+
### Base URL
|
| 85 |
+
```text
|
| 86 |
+
https://abhinavthedev-sql-debug.hf.space
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
### Core Endpoints
|
| 90 |
+
|
| 91 |
+
| Method | Endpoint | Description |
|
| 92 |
+
|---|---|---|
|
| 93 |
+
| `POST` | `/reset` | Start a new episode; pass `task_id` to choose a task |
|
| 94 |
+
| `POST` | `/step` | Submit a SQL query and receive the next observation |
|
| 95 |
+
| `GET` | `/state/{session_id}` | Inspect the current episode state |
|
| 96 |
+
| `GET` | `/schema` | View action, observation, and state schemas |
|
| 97 |
+
| `GET` | `/ws` | WebSocket endpoint for low-latency sessions |
|
| 98 |
+
| `GET` | `/health` | Health check |
|
| 99 |
+
| `GET` | `/docs` | Swagger UI |
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
## 🎮 Action Space
|
| 104 |
+
|
| 105 |
+
The agent submits a single SQL query each step.
|
| 106 |
+
|
| 107 |
+
```json
|
| 108 |
+
{
|
| 109 |
+
"query": "SELECT * FROM orders WHERE amount > 500"
|
| 110 |
+
}
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
### Example Actions
|
| 114 |
+
|
| 115 |
+
```json
|
| 116 |
+
{ "query": "SELECT * FROM orders WHERE amount > 500" }
|
| 117 |
+
|
| 118 |
+
{ "query": "SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id = d.id WHERE d.budget > 400000" }
|
| 119 |
+
|
| 120 |
+
{ "query": "WITH avg_amount AS (SELECT user_id, AVG(amount) AS avg_amount FROM transactions GROUP BY user_id) SELECT t.* FROM transactions t JOIN avg_amount a ON t.user_id = a.user_id WHERE t.status = 'completed' AND t.amount > a.avg_amount" }
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
---
|
| 124 |
+
|
| 125 |
+
## 📊 Observation Space
|
| 126 |
+
|
| 127 |
+
```json
|
| 128 |
+
{
|
| 129 |
+
"task_id": "syntax_fix_001",
|
| 130 |
+
"schema_sql": "CREATE TABLE orders (...)",
|
| 131 |
+
"current_query": "SELEC * FORM orders WERE amount > 500",
|
| 132 |
+
"error_message": "near \"SELEC\": syntax error",
|
| 133 |
+
"query_result": [],
|
| 134 |
+
"execution_plan": "",
|
| 135 |
+
"step_count": 0,
|
| 136 |
+
"target_description": "Return all orders where amount is greater than 500",
|
| 137 |
+
"reward_so_far": 0.0,
|
| 138 |
+
"available_tasks": ["syntax_fix_001", "logic_fix_001", "optimize_001"],
|
| 139 |
+
"done": false,
|
| 140 |
+
"reward": 0.05
|
| 141 |
+
}
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
---
|
| 145 |
+
|
| 146 |
+
## 💰 Reward Function
|
| 147 |
+
|
| 148 |
+
The reward is computed from syntax validity, result correctness, and query plan quality.
|
| 149 |
+
|
| 150 |
+
| Event | Reward |
|
| 151 |
+
|---|---|
|
| 152 |
+
| Query fails with syntax error | `0.05` |
|
| 153 |
+
| Query runs successfully | contributes to the main score |
|
| 154 |
+
| Correct row match on easy and medium tasks | up to `0.6` of the score |
|
| 155 |
+
| Good query plan on hard task | up to `0.2` of the score |
|
| 156 |
+
| Uses correlated-subquery pattern on hard task | heavy plan penalty |
|
| 157 |
+
| Excessively long query | length penalty |
|
| 158 |
+
|
| 159 |
+
Final scores are clamped to the range `[0.0, 1.0]`.
|
| 160 |
+
|
| 161 |
+
---
|
| 162 |
+
|
| 163 |
+
## 🚀 Setup & Usage
|
| 164 |
+
|
| 165 |
+
### Option 1 - Run Locally
|
| 166 |
+
|
| 167 |
+
```bash
|
| 168 |
+
pip install -e .
|
| 169 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000 --reload
|
| 170 |
+
# Open https://abhinavthedev-sql-debug.hf.space/docs
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
### Option 2 - Run with Docker
|
| 174 |
+
|
| 175 |
+
```bash
|
| 176 |
+
docker build -t sql-debug-env -f server/Dockerfile .
|
| 177 |
+
docker run -p 8000:8000 sql-debug-env
|
| 178 |
+
curl https://abhinavthedev-sql-debug.hf.space/health
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
### Option 3 - Run the Inference Loop
|
| 182 |
+
|
| 183 |
+
```bash
|
| 184 |
+
export SERVER_URL=https://abhinavthedev-sql-debug.hf.space
|
| 185 |
+
export API_KEY=sk-...
|
| 186 |
+
python inference.py
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
The inference script defaults to `syntax_fix_001`, logs each step, and stops when the episode ends or the step budget is reached.
|
| 190 |
+
|
| 191 |
+
---
|
| 192 |
+
|
| 193 |
+
## 🏗️ Project Structure
|
| 194 |
+
|
| 195 |
+
```text
|
| 196 |
+
sql_exp/
|
| 197 |
+
├── client.py # OpenEnv client wrapper
|
| 198 |
+
├── grader.py # Reward computation
|
| 199 |
+
├── inference.py # LLM-driven inference loop
|
| 200 |
+
├── models.py # Action and observation models
|
| 201 |
+
├── openenv.yaml # OpenEnv manifest
|
| 202 |
+
├── pyproject.toml # Project metadata and dependencies
|
| 203 |
+
├── runner.py # SQLite query runner
|
| 204 |
+
├── server/
|
| 205 |
+
│ ├── app.py # FastAPI app and OpenEnv wiring
|
| 206 |
+
│ ├── Dockerfile # Container definition
|
| 207 |
+
│ └── sql_debug_environment.py # Core environment logic
|
| 208 |
+
├── tasks/
|
| 209 |
+
│ ├── task_easy.py # Syntax-fix task
|
| 210 |
+
│ ├── task_medium.py # Join logic task
|
| 211 |
+
│ └── task_hard.py # Query optimization task
|
| 212 |
+
├── test.py # Manual websocket smoke test
|
| 213 |
+
└── README.md # Project overview
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
---
|
| 217 |
+
|
| 218 |
+
## 🛠️ Tech Stack
|
| 219 |
+
|
| 220 |
+
- **Python 3.10+** - Runtime
|
| 221 |
+
- **FastAPI** - HTTP framework
|
| 222 |
+
- **OpenEnv Core** - Environment server and client primitives
|
| 223 |
+
- **SQLite** - Query execution engine
|
| 224 |
+
- **Uvicorn** - ASGI server
|
| 225 |
+
- **Docker** - Containerization
|
| 226 |
+
|
| 227 |
+
---
|
| 228 |
+
|
| 229 |
+
## 📝 License
|
| 230 |
+
|
| 231 |
+
BSD-style license, matching the source headers in this repository.
|
__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Sql Exp Environment."""
|
| 8 |
+
|
| 9 |
+
from .client import SqlExpEnv
|
| 10 |
+
from .models import SqlExpAction, SqlExpObservation
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"SqlExpAction",
|
| 14 |
+
"SqlExpObservation",
|
| 15 |
+
"SqlExpEnv",
|
| 16 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# client.py
|
| 8 |
+
"""
|
| 9 |
+
SQL Debug Environment client.
|
| 10 |
+
This is what inference.py uses to talk to the running server.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from typing import Dict
|
| 14 |
+
|
| 15 |
+
from openenv.core import EnvClient
|
| 16 |
+
from openenv.core.client_types import StepResult
|
| 17 |
+
from openenv.core.env_server.types import State
|
| 18 |
+
|
| 19 |
+
from models import SQLDebugAction, SQLDebugObservation
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SQLDebugEnv(EnvClient[SQLDebugAction, SQLDebugObservation, State]):
|
| 23 |
+
"""
|
| 24 |
+
Client for the SQL Debug & Optimizer environment.
|
| 25 |
+
|
| 26 |
+
Maintains a persistent WebSocket connection to the server.
|
| 27 |
+
Each instance gets its own dedicated environment session.
|
| 28 |
+
|
| 29 |
+
Usage (direct server):
|
| 30 |
+
with SQLDebugEnv(base_url="http://localhost:8000") as env:
|
| 31 |
+
result = env.reset()
|
| 32 |
+
print(result.observation.target_description)
|
| 33 |
+
result = env.step(SQLDebugAction(query="SELECT * FROM orders"))
|
| 34 |
+
print(result.reward)
|
| 35 |
+
|
| 36 |
+
Usage (Docker):
|
| 37 |
+
env = SQLDebugEnv.from_docker_image("sql-debug-env:latest")
|
| 38 |
+
try:
|
| 39 |
+
result = env.reset()
|
| 40 |
+
result = env.step(SQLDebugAction(query="SELECT * FROM orders WHERE amount > 500"))
|
| 41 |
+
finally:
|
| 42 |
+
env.close()
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def _step_payload(self, action: SQLDebugAction) -> Dict:
|
| 46 |
+
"""Convert SQLDebugAction to JSON payload."""
|
| 47 |
+
return {"query": action.query}
|
| 48 |
+
|
| 49 |
+
def _parse_result(self, payload: Dict) -> StepResult[SQLDebugObservation]:
|
| 50 |
+
"""Parse server JSON response into a typed StepResult."""
|
| 51 |
+
obs_data = payload.get("observation", {})
|
| 52 |
+
|
| 53 |
+
observation = SQLDebugObservation(
|
| 54 |
+
task_id=obs_data.get("task_id", ""),
|
| 55 |
+
schema_sql=obs_data.get("schema_sql", ""),
|
| 56 |
+
current_query=obs_data.get("current_query", ""),
|
| 57 |
+
error_message=obs_data.get("error_message", ""),
|
| 58 |
+
query_result=obs_data.get("query_result", []),
|
| 59 |
+
execution_plan=obs_data.get("execution_plan", ""),
|
| 60 |
+
step_count=obs_data.get("step_count", 0),
|
| 61 |
+
target_description=obs_data.get("target_description", ""),
|
| 62 |
+
reward_so_far=obs_data.get("reward_so_far", 0.0),
|
| 63 |
+
available_tasks=obs_data.get("available_tasks", []),
|
| 64 |
+
done=payload.get("done", False),
|
| 65 |
+
reward=payload.get("reward", 0.0),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
return StepResult(
|
| 69 |
+
observation=observation,
|
| 70 |
+
reward=payload.get("reward", 0.0),
|
| 71 |
+
done=payload.get("done", False),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 75 |
+
"""Parse server JSON response into a State object."""
|
| 76 |
+
return State(
|
| 77 |
+
episode_id=payload.get("episode_id"),
|
| 78 |
+
step_count=payload.get("step_count", 0),
|
| 79 |
+
)
|
grader.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def compute_reward(task: dict, agent_query: str, run_result: dict) -> dict:
|
| 2 |
+
"""
|
| 3 |
+
task = one of TASK dicts from tasks/
|
| 4 |
+
agent_query = the SQL string the agent submitted
|
| 5 |
+
run_result = output from runner.run_query()
|
| 6 |
+
|
| 7 |
+
Returns a dict: { value, syntax_ok, result_match_pct, plan_score, message }
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
# ── Step 1: Did the query even run? ───────────────────────────────────────
|
| 11 |
+
syntax_ok = (run_result["error"] is None)
|
| 12 |
+
|
| 13 |
+
if not syntax_ok:
|
| 14 |
+
# Give tiny credit for trying (not zero, so agent gets gradient signal)
|
| 15 |
+
return {
|
| 16 |
+
"value": 0.05,
|
| 17 |
+
"syntax_ok": False,
|
| 18 |
+
"result_match_pct": 0.0,
|
| 19 |
+
"plan_score": 0.0,
|
| 20 |
+
"message": f"Syntax error: {run_result['error'][:100]}",
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
# ── Step 2: Did we get the right rows? ────────────────────────────────────
|
| 24 |
+
result_match_pct = 0.0
|
| 25 |
+
|
| 26 |
+
if task["expected_rows"] is not None:
|
| 27 |
+
expected = task["expected_rows"]
|
| 28 |
+
got = run_result["rows"]
|
| 29 |
+
|
| 30 |
+
# Count how many expected rows are present in the result
|
| 31 |
+
matches = sum(1 for row in expected if row in got)
|
| 32 |
+
result_match_pct = matches / max(len(expected), 1)
|
| 33 |
+
|
| 34 |
+
# Penalize extra rows (returned too many rows = wrong query)
|
| 35 |
+
if len(got) > len(expected) * 2:
|
| 36 |
+
result_match_pct *= 0.7 # 30% penalty for bloated results
|
| 37 |
+
|
| 38 |
+
else:
|
| 39 |
+
# Hard task: no fixed rows — give full match credit if query runs
|
| 40 |
+
result_match_pct = 1.0
|
| 41 |
+
|
| 42 |
+
# ── Step 3: Is the query plan good? (hard task only) ─────────────────────
|
| 43 |
+
plan_score = 0.0
|
| 44 |
+
|
| 45 |
+
if task.get("check_plan"):
|
| 46 |
+
query_upper = agent_query.upper()
|
| 47 |
+
good_patterns = task.get("good_patterns", [])
|
| 48 |
+
|
| 49 |
+
# Each good pattern found = partial credit
|
| 50 |
+
found = sum(1 for p in good_patterns if p.upper() in query_upper)
|
| 51 |
+
plan_score = found / max(len(good_patterns), 1)
|
| 52 |
+
|
| 53 |
+
# Also penalize if they still use correlated subquery pattern
|
| 54 |
+
if "WHERE" in query_upper and "SELECT AVG" in query_upper:
|
| 55 |
+
plan_score *= 0.3 # Heavy penalty — they didn't really optimize
|
| 56 |
+
|
| 57 |
+
# ── Step 4: Combine into final score ──────────────────────────────────────
|
| 58 |
+
# Weights: syntax 20% + correctness 60% + plan 20%
|
| 59 |
+
base_score = 0.2 + (0.6 * result_match_pct) + (0.2 * plan_score)
|
| 60 |
+
|
| 61 |
+
# Penalize absurdly long queries (e.g. agent spams SELECT *)
|
| 62 |
+
length_penalty = max(0.0, (len(agent_query) - 800) / 2000)
|
| 63 |
+
final = max(0.0, min(1.0, base_score - length_penalty))
|
| 64 |
+
|
| 65 |
+
status = "perfect" if final >= 0.99 else "partial" if final > 0.2 else "wrong"
|
| 66 |
+
msg = f"{status} | rows matched: {result_match_pct:.0%} | plan: {plan_score:.0%}"
|
| 67 |
+
|
| 68 |
+
return {
|
| 69 |
+
"value": round(final, 3),
|
| 70 |
+
"syntax_ok": True,
|
| 71 |
+
"result_match_pct": result_match_pct,
|
| 72 |
+
"plan_score": plan_score,
|
| 73 |
+
"message": msg,
|
| 74 |
+
}
|
inference.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference.py
|
| 2 |
+
"""
|
| 3 |
+
SQL Debug & Optimizer — OpenEnv Inference Script
|
| 4 |
+
|
| 5 |
+
Mandatory stdout format:
|
| 6 |
+
[START] task=<task_name> env=<benchmark> model=<model_name>
|
| 7 |
+
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 8 |
+
[END] success=<true|false> steps=<n> score=<0.00> rewards=<r1,r2,...>
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import asyncio
|
| 12 |
+
import os
|
| 13 |
+
import textwrap
|
| 14 |
+
from typing import List, Optional
|
| 15 |
+
|
| 16 |
+
from openai import OpenAI
|
| 17 |
+
from client import SQLDebugEnv, SQLDebugAction
|
| 18 |
+
|
| 19 |
+
# ── Mandatory env vars (injected by evaluator on submission) ──────────────────
|
| 20 |
+
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 21 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 22 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
|
| 23 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
|
| 24 |
+
|
| 25 |
+
# ── Task + run config ─────────────────────────────────────────────────────────
|
| 26 |
+
TASK_NAME = os.getenv("SQL_ENV_TASK", "syntax_fix_001")
|
| 27 |
+
BENCHMARK = "sql-debug-optimizer"
|
| 28 |
+
MAX_STEPS = 8 # well under 20 min limit; each step is ~2s
|
| 29 |
+
TEMPERATURE = 0.0 # deterministic = reproducible scores
|
| 30 |
+
MAX_TOKENS = 400
|
| 31 |
+
SUCCESS_THRESHOLD = 0.5 # reward >= 0.5 = success
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ── Mandatory stdout loggers — DO NOT change field names or order ─────────────
|
| 35 |
+
|
| 36 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 37 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 41 |
+
# action must be single-line — newlines break log parsing
|
| 42 |
+
action_clean = action.replace("\n", " ").replace("\r", "").strip()
|
| 43 |
+
error_val = error if error else "null"
|
| 44 |
+
done_val = str(done).lower()
|
| 45 |
+
print(
|
| 46 |
+
f"[STEP] step={step} action={action_clean} reward={reward:.2f} "
|
| 47 |
+
f"done={done_val} error={error_val}",
|
| 48 |
+
flush=True,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 53 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 54 |
+
print(
|
| 55 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 56 |
+
f"score={score:.2f} rewards={rewards_str}",
|
| 57 |
+
flush=True,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ── Prompt design ─────────────────────────────────────────────────────────────
|
| 62 |
+
|
| 63 |
+
SYSTEM_PROMPT = textwrap.dedent("""
|
| 64 |
+
You are an expert SQL engineer helping debug and optimize SQL queries.
|
| 65 |
+
|
| 66 |
+
Rules (follow exactly):
|
| 67 |
+
- Respond with ONLY the corrected SQL query.
|
| 68 |
+
- No markdown, no code fences (no ```sql), no explanation.
|
| 69 |
+
- No comments inside the SQL.
|
| 70 |
+
- If the query has a syntax error, fix it first.
|
| 71 |
+
- If the query has a logic bug (wrong JOIN, wrong WHERE), fix the logic.
|
| 72 |
+
- If asked to optimize, replace correlated subqueries with CTEs using WITH.
|
| 73 |
+
- Output raw SQL only — it will be executed directly.
|
| 74 |
+
""").strip()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def build_prompt(obs) -> str:
|
| 78 |
+
"""Build the user prompt from the current observation."""
|
| 79 |
+
result_preview = str(obs.query_result[:3]) if obs.query_result else "empty / error"
|
| 80 |
+
return textwrap.dedent(f"""
|
| 81 |
+
TASK: {obs.target_description}
|
| 82 |
+
|
| 83 |
+
DATABASE SCHEMA:
|
| 84 |
+
{obs.schema_sql.strip()[:800]}
|
| 85 |
+
|
| 86 |
+
CURRENT QUERY (this is broken or slow — fix it):
|
| 87 |
+
{obs.current_query.strip()}
|
| 88 |
+
|
| 89 |
+
ERROR: {obs.error_message or "none"}
|
| 90 |
+
CURRENT RESULT (first 3 rows): {result_preview}
|
| 91 |
+
STEP: {obs.step_count + 1} of {MAX_STEPS}
|
| 92 |
+
|
| 93 |
+
Write the corrected SQL query:
|
| 94 |
+
""").strip()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def call_llm(client: OpenAI, obs) -> str:
|
| 98 |
+
"""Ask the LLM for a better SQL query. Returns clean SQL string."""
|
| 99 |
+
try:
|
| 100 |
+
completion = client.chat.completions.create(
|
| 101 |
+
model=MODEL_NAME,
|
| 102 |
+
messages=[
|
| 103 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 104 |
+
{"role": "user", "content": build_prompt(obs)},
|
| 105 |
+
],
|
| 106 |
+
temperature=TEMPERATURE,
|
| 107 |
+
max_tokens=MAX_TOKENS,
|
| 108 |
+
stream=False,
|
| 109 |
+
)
|
| 110 |
+
raw = (completion.choices[0].message.content or "").strip()
|
| 111 |
+
|
| 112 |
+
# Strip markdown code fences if model adds them despite instructions
|
| 113 |
+
if "```" in raw:
|
| 114 |
+
lines = raw.split("\n")
|
| 115 |
+
raw = "\n".join(
|
| 116 |
+
line for line in lines if not line.strip().startswith("```")
|
| 117 |
+
).strip()
|
| 118 |
+
|
| 119 |
+
return raw if raw else "SELECT 1"
|
| 120 |
+
|
| 121 |
+
except Exception as exc:
|
| 122 |
+
print(f"[DEBUG] LLM call failed: {exc}", flush=True)
|
| 123 |
+
return "SELECT 1"
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# ── Main loop ─────────────────────────────────────────────────────────────────
|
| 127 |
+
|
| 128 |
+
async def main() -> None:
|
| 129 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 130 |
+
|
| 131 |
+
# Connect to the environment (Docker or local server)
|
| 132 |
+
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
|
| 133 |
+
env = SQLDebugEnv(base_url=SERVER_URL)
|
| 134 |
+
|
| 135 |
+
rewards: List[float] = []
|
| 136 |
+
steps_taken = 0
|
| 137 |
+
score = 0.0
|
| 138 |
+
success = False
|
| 139 |
+
|
| 140 |
+
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
# Reset — get the broken query and task info
|
| 144 |
+
result = await env.reset(task_id=TASK_NAME)
|
| 145 |
+
obs = result.observation
|
| 146 |
+
|
| 147 |
+
for step in range(1, MAX_STEPS + 1):
|
| 148 |
+
if result.done:
|
| 149 |
+
break
|
| 150 |
+
|
| 151 |
+
# Ask LLM for a better query
|
| 152 |
+
sql_query = call_llm(client, obs)
|
| 153 |
+
|
| 154 |
+
# Submit to environment
|
| 155 |
+
result = await env.step(SQLDebugAction(query=sql_query))
|
| 156 |
+
obs = result.observation
|
| 157 |
+
|
| 158 |
+
reward = result.reward or 0.0
|
| 159 |
+
done = result.done
|
| 160 |
+
error = obs.error_message if obs.error_message else None
|
| 161 |
+
|
| 162 |
+
rewards.append(reward)
|
| 163 |
+
steps_taken = step
|
| 164 |
+
|
| 165 |
+
log_step(
|
| 166 |
+
step=step,
|
| 167 |
+
action=sql_query,
|
| 168 |
+
reward=reward,
|
| 169 |
+
done=done,
|
| 170 |
+
error=error,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
if done:
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
# Score = best reward achieved (already 0.0–1.0 from grader)
|
| 177 |
+
score = max(rewards) if rewards else 0.0
|
| 178 |
+
score = min(max(score, 0.0), 1.0)
|
| 179 |
+
success = score >= SUCCESS_THRESHOLD
|
| 180 |
+
|
| 181 |
+
except Exception as exc:
|
| 182 |
+
print(f"[DEBUG] Episode error: {exc}", flush=True)
|
| 183 |
+
|
| 184 |
+
finally:
|
| 185 |
+
try:
|
| 186 |
+
await env.close()
|
| 187 |
+
except Exception as e:
|
| 188 |
+
print(f"[DEBUG] env.close() error: {e}", flush=True)
|
| 189 |
+
|
| 190 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
asyncio.run(main())
|
models.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Data models for the SQL Debug & Optimizer Environment.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from typing import Any, Dict, List
|
| 12 |
+
from pydantic import Field
|
| 13 |
+
from openenv.core.env_server.types import Action, Observation
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SQLDebugAction(Action):
|
| 17 |
+
"""
|
| 18 |
+
What the agent submits each step — just a SQL query string.
|
| 19 |
+
The environment will run it, grade it, and return a new observation.
|
| 20 |
+
"""
|
| 21 |
+
query: str = Field(..., description="The SQL query the agent wants to try")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SQLDebugObservation(Observation):
|
| 25 |
+
"""
|
| 26 |
+
What the agent sees after each step.
|
| 27 |
+
Contains everything it needs to improve its next query.
|
| 28 |
+
"""
|
| 29 |
+
task_id: str = Field(default="", description="Which task is active")
|
| 30 |
+
schema_sql: str = Field(default="", description="CREATE TABLE statements for this task")
|
| 31 |
+
current_query: str = Field(default="", description="Last query that was run")
|
| 32 |
+
error_message: str = Field(default="", description="SQLite error if query failed, else empty string")
|
| 33 |
+
query_result: List[Dict[str, Any]] = Field(default_factory=list, description="First 10 rows returned")
|
| 34 |
+
execution_plan: str = Field(default="", description="EXPLAIN QUERY PLAN output")
|
| 35 |
+
step_count: int = Field(default=0, description="How many steps taken so far")
|
| 36 |
+
target_description: str = Field(default="", description="Plain English goal for this task")
|
| 37 |
+
reward_so_far: float = Field(default=0.0, description="Best reward achieved this episode")
|
| 38 |
+
available_tasks: List[str] = Field(default_factory=list, description="All task IDs you can reset to")
|
openenv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: sql_debug
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
| 7 |
+
|
openenv_sql_debug.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: openenv-sql_debug
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Sql Debug environment for OpenEnv
|
| 5 |
+
Requires-Python: >=3.10
|
| 6 |
+
Requires-Dist: openenv-core[core]>=0.2.2
|
| 7 |
+
Requires-Dist: openai>=2.30.0
|
| 8 |
+
Requires-Dist: uvicorn>=0.43.0
|
| 9 |
+
Provides-Extra: dev
|
| 10 |
+
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
| 11 |
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
openenv_sql_debug.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
./__init__.py
|
| 4 |
+
./client.py
|
| 5 |
+
./grader.py
|
| 6 |
+
./inference.py
|
| 7 |
+
./models.py
|
| 8 |
+
./runner.py
|
| 9 |
+
./test.py
|
| 10 |
+
openenv_sql_debug.egg-info/PKG-INFO
|
| 11 |
+
openenv_sql_debug.egg-info/SOURCES.txt
|
| 12 |
+
openenv_sql_debug.egg-info/dependency_links.txt
|
| 13 |
+
openenv_sql_debug.egg-info/entry_points.txt
|
| 14 |
+
openenv_sql_debug.egg-info/requires.txt
|
| 15 |
+
openenv_sql_debug.egg-info/top_level.txt
|
| 16 |
+
server/__init__.py
|
| 17 |
+
server/app.py
|
| 18 |
+
server/sql_debug_environment.py
|
openenv_sql_debug.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
openenv_sql_debug.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = sql_debug.server.app:main
|
openenv_sql_debug.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.2
|
| 2 |
+
openai>=2.30.0
|
| 3 |
+
uvicorn>=0.43.0
|
| 4 |
+
|
| 5 |
+
[dev]
|
| 6 |
+
pytest>=8.0.0
|
| 7 |
+
pytest-cov>=4.0.0
|
openenv_sql_debug.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
sql_debug
|
pyproject.toml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-sql_debug"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Sql Debug environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]>=0.2.2",
|
| 21 |
+
"openai>=2.30.0",
|
| 22 |
+
"uvicorn>=0.43.0",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
[project.optional-dependencies]
|
| 26 |
+
dev = [
|
| 27 |
+
"pytest>=8.0.0",
|
| 28 |
+
"pytest-cov>=4.0.0",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.scripts]
|
| 32 |
+
# Server entry point - enables running via: uv run --project . server
|
| 33 |
+
# or: python -m sql_debug.server.app
|
| 34 |
+
server = "sql_debug.server.app:main"
|
| 35 |
+
|
| 36 |
+
[tool.setuptools]
|
| 37 |
+
include-package-data = true
|
| 38 |
+
packages = ["sql_debug", "sql_debug.server"]
|
| 39 |
+
package-dir = { "sql_debug" = ".", "sql_debug.server" = "server" }
|
runner.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
|
| 3 |
+
def run_query(schema_sql: str, query: str) -> dict:
|
| 4 |
+
"""
|
| 5 |
+
Runs query against an in-memory SQLite DB seeded with schema_sql.
|
| 6 |
+
Returns: { "rows": [...], "error": str|None, "plan": str }
|
| 7 |
+
"""
|
| 8 |
+
conn = sqlite3.connect(":memory:")
|
| 9 |
+
conn.row_factory = sqlite3.Row
|
| 10 |
+
try:
|
| 11 |
+
conn.executescript(schema_sql)
|
| 12 |
+
plan_rows = conn.execute(f"EXPLAIN QUERY PLAN {query}").fetchall()
|
| 13 |
+
plan = " | ".join(str(dict(r)) for r in plan_rows)
|
| 14 |
+
result_rows = [dict(r) for r in conn.execute(query).fetchall()]
|
| 15 |
+
return {"rows": result_rows, "error": None, "plan": plan}
|
| 16 |
+
except Exception as e:
|
| 17 |
+
return {"rows": [], "error": str(e), "plan": ""}
|
| 18 |
+
finally:
|
| 19 |
+
conn.close()
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Sql Exp environment server components."""
|
| 8 |
+
|
| 9 |
+
from .sql_debug_environment import SQLDebugEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["SQLDebugEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
FastAPI server for the SQL Debug & Optimizer Environment.
|
| 9 |
+
|
| 10 |
+
Exposes the environment over HTTP + WebSocket so inference.py
|
| 11 |
+
(and the OpenEnv evaluator) can interact with it remotely.
|
| 12 |
+
|
| 13 |
+
Endpoints created automatically by openenv:
|
| 14 |
+
POST /reset — start new episode (optionally pass task_id in body)
|
| 15 |
+
POST /step — submit an action, get observation + reward
|
| 16 |
+
GET /state — current episode state
|
| 17 |
+
GET /schema — action/observation JSON schemas
|
| 18 |
+
WS /ws — WebSocket for persistent low-latency sessions
|
| 19 |
+
|
| 20 |
+
Run locally:
|
| 21 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000 --reload
|
| 22 |
+
|
| 23 |
+
Or via Docker (defined in Dockerfile):
|
| 24 |
+
docker build -t sql-debug-env .
|
| 25 |
+
docker run -p 8000:8000 sql-debug-env
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from openenv.core.env_server.http_server import create_app
|
| 30 |
+
except Exception as e:
|
| 31 |
+
raise ImportError(
|
| 32 |
+
"openenv-core is required. Install with: pip install openenv-core"
|
| 33 |
+
) from e
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from models import SQLDebugAction, SQLDebugObservation
|
| 37 |
+
from .sql_debug_environment import SQLDebugEnvironment
|
| 38 |
+
except ModuleNotFoundError:
|
| 39 |
+
from models import SQLDebugAction, SQLDebugObservation
|
| 40 |
+
from sql_exp.server.sql_debug_environment import SQLDebugEnvironment
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
app = create_app(
|
| 44 |
+
SQLDebugEnvironment,
|
| 45 |
+
SQLDebugAction,
|
| 46 |
+
SQLDebugObservation,
|
| 47 |
+
env_name="sql_debug_optimizer",
|
| 48 |
+
max_concurrent_envs=4, # one per task running in parallel
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 53 |
+
import uvicorn
|
| 54 |
+
uvicorn.run(app, host=host, port=port)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
main()
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
server/sql_debug_environment.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
SQL Debug & Optimizer Environment — server-side implementation.
|
| 9 |
+
|
| 10 |
+
The server runs this. The agent never touches this file directly.
|
| 11 |
+
It loads tasks, runs queries in SQLite, grades them, and returns observations.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from uuid import uuid4
|
| 15 |
+
from openenv.core.env_server.interfaces import Environment
|
| 16 |
+
from openenv.core.env_server.types import State
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from ..models import SQLDebugAction, SQLDebugObservation
|
| 20 |
+
except ImportError:
|
| 21 |
+
from models import SQLDebugAction, SQLDebugObservation
|
| 22 |
+
|
| 23 |
+
from runner import run_query
|
| 24 |
+
from grader import compute_reward
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _load_all_tasks() -> dict:
|
| 28 |
+
"""Load every task from the tasks/ folder into a dict keyed by task_id."""
|
| 29 |
+
from tasks.task_easy import TASK as EASY
|
| 30 |
+
from tasks.task_medium import TASK as MEDIUM
|
| 31 |
+
from tasks.task_hard import TASK as HARD
|
| 32 |
+
return {
|
| 33 |
+
EASY["task_id"]: EASY,
|
| 34 |
+
MEDIUM["task_id"]: MEDIUM,
|
| 35 |
+
HARD["task_id"]: HARD,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SQLDebugEnvironment(Environment):
|
| 40 |
+
"""
|
| 41 |
+
SQL Debug & Optimizer environment.
|
| 42 |
+
|
| 43 |
+
The agent receives a broken or slow SQL query and must fix/optimize it.
|
| 44 |
+
Each step the agent submits a new query — the environment runs it in
|
| 45 |
+
SQLite, grades it (0.0–1.0), and returns the result as an observation.
|
| 46 |
+
|
| 47 |
+
Three tasks:
|
| 48 |
+
syntax_fix_001 (easy) — fix typos in SQL keywords
|
| 49 |
+
logic_fix_001 (medium) — fix wrong JOIN type causing bad results
|
| 50 |
+
# optimize_001 (hard) — rewrite correlated subquery as a CTE
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 54 |
+
|
| 55 |
+
def __init__(self):
|
| 56 |
+
self._all_tasks = _load_all_tasks()
|
| 57 |
+
self._current_task = None
|
| 58 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 59 |
+
self._best_reward = 0.0
|
| 60 |
+
self._current_query = ""
|
| 61 |
+
|
| 62 |
+
# ── reset ────────────────────────────────────────────────────────────────
|
| 63 |
+
|
| 64 |
+
def reset(self, task_id: str = None) -> SQLDebugObservation:
|
| 65 |
+
"""
|
| 66 |
+
Start a new episode.
|
| 67 |
+
Pass task_id to pick a specific task, or leave None for the default (easy).
|
| 68 |
+
"""
|
| 69 |
+
if task_id is None:
|
| 70 |
+
task_id = list(self._all_tasks.keys())[0] # default: easy
|
| 71 |
+
|
| 72 |
+
if task_id not in self._all_tasks:
|
| 73 |
+
# Unknown task — return error observation instead of crashing
|
| 74 |
+
return SQLDebugObservation(
|
| 75 |
+
task_id=task_id,
|
| 76 |
+
error_message=f"Unknown task_id '{task_id}'. Available: {list(self._all_tasks.keys())}",
|
| 77 |
+
available_tasks=list(self._all_tasks.keys()),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
self._current_task = self._all_tasks[task_id]
|
| 81 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 82 |
+
self._best_reward = 0.0
|
| 83 |
+
self._current_query = self._current_task["broken_query"]
|
| 84 |
+
|
| 85 |
+
# Run the broken query so the agent sees the starting error
|
| 86 |
+
run_result = run_query(
|
| 87 |
+
self._current_task["schema_sql"],
|
| 88 |
+
self._current_query,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return SQLDebugObservation(
|
| 92 |
+
task_id=task_id,
|
| 93 |
+
schema_sql=self._current_task["schema_sql"],
|
| 94 |
+
current_query=self._current_query,
|
| 95 |
+
error_message=run_result["error"] or "",
|
| 96 |
+
query_result=run_result["rows"][:10],
|
| 97 |
+
execution_plan=run_result["plan"],
|
| 98 |
+
step_count=0,
|
| 99 |
+
target_description=self._current_task["target_description"],
|
| 100 |
+
reward_so_far=0.0,
|
| 101 |
+
available_tasks=list(self._all_tasks.keys()),
|
| 102 |
+
done=False,
|
| 103 |
+
reward=0.0,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# ── step ─────────────────────────────────────────────────────────────────
|
| 107 |
+
|
| 108 |
+
def step(self, action: SQLDebugAction) -> SQLDebugObservation:
|
| 109 |
+
"""
|
| 110 |
+
Agent submits a query.
|
| 111 |
+
We run it, grade it, and return the new observation + reward.
|
| 112 |
+
"""
|
| 113 |
+
if self._current_task is None:
|
| 114 |
+
return SQLDebugObservation(
|
| 115 |
+
error_message="Call reset() before step()",
|
| 116 |
+
available_tasks=list(self._all_tasks.keys()),
|
| 117 |
+
done=True,
|
| 118 |
+
reward=0.0,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self._state.step_count += 1
|
| 122 |
+
self._current_query = action.query
|
| 123 |
+
|
| 124 |
+
# Run the query in SQLite
|
| 125 |
+
run_result = run_query(
|
| 126 |
+
self._current_task["schema_sql"],
|
| 127 |
+
action.query,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Grade it (returns dict with value, syntax_ok, result_match_pct, etc.)
|
| 131 |
+
reward_dict = compute_reward(self._current_task, action.query, run_result)
|
| 132 |
+
reward_value = reward_dict["value"]
|
| 133 |
+
|
| 134 |
+
# Track the best reward this episode
|
| 135 |
+
self._best_reward = max(self._best_reward, reward_value)
|
| 136 |
+
|
| 137 |
+
# Episode ends on perfect score or max steps
|
| 138 |
+
max_steps = self._current_task.get("max_steps", 8)
|
| 139 |
+
done = (reward_value >= 0.99) or (self._state.step_count >= max_steps)
|
| 140 |
+
|
| 141 |
+
return SQLDebugObservation(
|
| 142 |
+
task_id=self._current_task["task_id"],
|
| 143 |
+
schema_sql=self._current_task["schema_sql"],
|
| 144 |
+
current_query=action.query,
|
| 145 |
+
error_message=run_result["error"] or "",
|
| 146 |
+
query_result=run_result["rows"][:10],
|
| 147 |
+
execution_plan=run_result["plan"],
|
| 148 |
+
step_count=self._state.step_count,
|
| 149 |
+
target_description=self._current_task["target_description"],
|
| 150 |
+
reward_so_far=self._best_reward,
|
| 151 |
+
available_tasks=list(self._all_tasks.keys()),
|
| 152 |
+
done=done,
|
| 153 |
+
reward=reward_value,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# ── state ─────────────────────────────────────────────────────────────────
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def state(self) -> State:
|
| 160 |
+
return self._state
|
tasks/__init__.py
ADDED
|
File without changes
|
tasks/task_easy.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK = {
|
| 2 |
+
"task_id": "syntax_fix_001",
|
| 3 |
+
"difficulty": "easy",
|
| 4 |
+
"max_steps": 5,
|
| 5 |
+
|
| 6 |
+
# This creates the database the agent works with
|
| 7 |
+
"schema_sql": """
|
| 8 |
+
CREATE TABLE orders (
|
| 9 |
+
id INTEGER, customer TEXT, amount REAL, order_date TEXT
|
| 10 |
+
);
|
| 11 |
+
INSERT INTO orders VALUES (1, 'Alice', 520.0, '2024-01-15');
|
| 12 |
+
INSERT INTO orders VALUES (2, 'Bob', 90.0, '2024-01-16');
|
| 13 |
+
INSERT INTO orders VALUES (3, 'Carol', 800.0, '2024-01-17');
|
| 14 |
+
INSERT INTO orders VALUES (4, 'Dan', 150.0, '2024-01-18');
|
| 15 |
+
""",
|
| 16 |
+
|
| 17 |
+
# This is the broken query the agent must fix
|
| 18 |
+
"broken_query": "SELEC * FORM orders WERE amount > 500",
|
| 19 |
+
|
| 20 |
+
# Plain English: what should the fixed query do?
|
| 21 |
+
"target_description": "Return all orders where amount is greater than 500",
|
| 22 |
+
|
| 23 |
+
# What the correct answer looks like — used by grader to check
|
| 24 |
+
"expected_rows": [
|
| 25 |
+
{"id": 1, "customer": "Alice", "amount": 520.0, "order_date": "2024-01-15"},
|
| 26 |
+
{"id": 3, "customer": "Carol", "amount": 800.0, "order_date": "2024-01-17"},
|
| 27 |
+
],
|
| 28 |
+
|
| 29 |
+
# For easy task, plan quality doesn't matter
|
| 30 |
+
"check_plan": False,
|
| 31 |
+
}
|
tasks/task_hard.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tasks/task_hard.py
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
def generate_schema(n_rows=5000, seed=42):
|
| 5 |
+
"""Generates schema + INSERT statements for n_rows transactions."""
|
| 6 |
+
rng = random.Random(seed)
|
| 7 |
+
statuses = ['completed', 'pending', 'failed']
|
| 8 |
+
inserts = []
|
| 9 |
+
for i in range(1, n_rows + 1):
|
| 10 |
+
user_id = rng.randint(1, 100)
|
| 11 |
+
amount = round(rng.uniform(10, 1000), 2)
|
| 12 |
+
status = rng.choice(statuses)
|
| 13 |
+
inserts.append(f"INSERT INTO transactions VALUES ({i}, {user_id}, {amount}, 'completed');")
|
| 14 |
+
return (
|
| 15 |
+
"CREATE TABLE transactions (id INTEGER, user_id INTEGER, amount REAL, ts TEXT, status TEXT);\n"
|
| 16 |
+
+ "\n".join(inserts[:200]) # Keep it fast for demo (200 rows)
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
TASK = {
|
| 20 |
+
"task_id": "optimize_001",
|
| 21 |
+
"difficulty": "hard",
|
| 22 |
+
"max_steps": 10,
|
| 23 |
+
|
| 24 |
+
"schema_sql": generate_schema(200), # Use 200 rows for speed in hackathon
|
| 25 |
+
|
| 26 |
+
# Slow: correlated subquery — runs inner SELECT once per outer row
|
| 27 |
+
"broken_query": """
|
| 28 |
+
SELECT *
|
| 29 |
+
FROM transactions t1
|
| 30 |
+
WHERE amount > (
|
| 31 |
+
SELECT AVG(amount)
|
| 32 |
+
FROM transactions t2
|
| 33 |
+
WHERE t2.user_id = t1.user_id
|
| 34 |
+
)
|
| 35 |
+
AND t1.status = 'completed'
|
| 36 |
+
""",
|
| 37 |
+
|
| 38 |
+
"target_description": (
|
| 39 |
+
"Return all completed transactions where the amount exceeds that user's average. "
|
| 40 |
+
"Optimize it — avoid correlated subqueries. Use a CTE or subquery with GROUP BY."
|
| 41 |
+
),
|
| 42 |
+
|
| 43 |
+
# For hard task we grade differently — no fixed expected_rows
|
| 44 |
+
"expected_rows": None,
|
| 45 |
+
|
| 46 |
+
# We check that the query plan is efficient (no per-row correlated scans)
|
| 47 |
+
"check_plan": True,
|
| 48 |
+
|
| 49 |
+
# Keywords we look for in the agent's solution
|
| 50 |
+
"good_patterns": ["WITH", "GROUP BY", "AVG("],
|
| 51 |
+
}
|
tasks/task_medium.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK = {
|
| 2 |
+
"task_id": "logic_fix_001",
|
| 3 |
+
"difficulty": "medium",
|
| 4 |
+
"max_steps": 8,
|
| 5 |
+
|
| 6 |
+
"schema_sql": """
|
| 7 |
+
CREATE TABLE employees (id INTEGER, name TEXT, dept_id INTEGER, salary REAL);
|
| 8 |
+
CREATE TABLE departments (id INTEGER, dept_name TEXT, budget REAL);
|
| 9 |
+
|
| 10 |
+
INSERT INTO departments VALUES (1, 'Engineering', 500000);
|
| 11 |
+
INSERT INTO departments VALUES (2, 'Sales', 300000);
|
| 12 |
+
|
| 13 |
+
INSERT INTO employees VALUES (1, 'Alice', 1, 95000);
|
| 14 |
+
INSERT INTO employees VALUES (2, 'Bob', 2, 60000);
|
| 15 |
+
INSERT INTO employees VALUES (3, 'Carol', 1, 85000);
|
| 16 |
+
INSERT INTO employees VALUES (4, 'Dan', 99, 55000); -- dept 99 doesn't exist!
|
| 17 |
+
""",
|
| 18 |
+
|
| 19 |
+
# Bug: LEFT JOIN means Dan (no dept) appears in results. Should be INNER JOIN.
|
| 20 |
+
"broken_query": """
|
| 21 |
+
SELECT e.name, d.dept_name
|
| 22 |
+
FROM employees e
|
| 23 |
+
LEFT JOIN departments d ON e.dept_id = d.id
|
| 24 |
+
WHERE d.budget > 400000
|
| 25 |
+
""",
|
| 26 |
+
|
| 27 |
+
"target_description": (
|
| 28 |
+
"Return names of employees in departments with budget > 400000. "
|
| 29 |
+
"Do NOT include employees whose department doesn't exist."
|
| 30 |
+
),
|
| 31 |
+
|
| 32 |
+
"expected_rows": [
|
| 33 |
+
{"name": "Alice", "dept_name": "Engineering"},
|
| 34 |
+
{"name": "Carol", "dept_name": "Engineering"},
|
| 35 |
+
],
|
| 36 |
+
|
| 37 |
+
"check_plan": False,
|
| 38 |
+
}
|
test.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# test_websocket.py
|
| 2 |
+
from client import SQLDebugEnv
|
| 3 |
+
|
| 4 |
+
def test():
|
| 5 |
+
# Use WebSocket URL
|
| 6 |
+
env = SQLDebugEnv(base_url="ws://localhost:8000")
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
for task_id in ["syntax_fix_002", "logic_fix_002", "optimize_002", "pipeline_audit_001"]:
|
| 10 |
+
print(f"\n{'='*60}")
|
| 11 |
+
print(f"Testing: {task_id}")
|
| 12 |
+
|
| 13 |
+
# Connect and reset
|
| 14 |
+
result = env.reset(task_id=task_id)
|
| 15 |
+
obs = result.observation
|
| 16 |
+
|
| 17 |
+
print(f"✓ task_id: {obs.task_id}")
|
| 18 |
+
print(f"✓ description: {obs.target_description[:50]}...")
|
| 19 |
+
print(f"✓ query: {obs.current_query[:60]}...")
|
| 20 |
+
|
| 21 |
+
# Try one step
|
| 22 |
+
from models import SQLDebugAction
|
| 23 |
+
result = env.step(SQLDebugAction(query="SELECT 1"))
|
| 24 |
+
print(f"✓ step reward: {result.reward}")
|
| 25 |
+
|
| 26 |
+
finally:
|
| 27 |
+
env.close()
|
| 28 |
+
|
| 29 |
+
test()
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|