Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- Dockerfile +42 -0
- README.md +128 -10
- __init__.py +1 -0
- client.py +65 -0
- models.py +53 -0
- openenv.yaml +6 -0
- pyproject.toml +26 -0
- server/__init__.py +1 -0
- server/app.py +11 -0
- server/challenges.py +209 -0
- server/requirements.txt +3 -0
- server/sql_environment.py +155 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 2 |
+
FROM ${BASE_IMAGE} AS builder
|
| 3 |
+
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends git curl \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
COPY . /app/env
|
| 10 |
+
WORKDIR /app/env
|
| 11 |
+
|
| 12 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 13 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 14 |
+
mv /root/.local/bin/uv /usr/local/bin/uv; \
|
| 15 |
+
fi
|
| 16 |
+
|
| 17 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 18 |
+
if [ -f uv.lock ]; then uv sync --frozen --no-install-project --no-editable; \
|
| 19 |
+
else uv sync --no-install-project --no-editable; fi
|
| 20 |
+
|
| 21 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 22 |
+
if [ -f uv.lock ]; then uv sync --frozen --no-editable; \
|
| 23 |
+
else uv sync --no-editable; fi
|
| 24 |
+
|
| 25 |
+
# Final image
|
| 26 |
+
FROM ${BASE_IMAGE}
|
| 27 |
+
|
| 28 |
+
WORKDIR /app
|
| 29 |
+
|
| 30 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 31 |
+
COPY --from=builder /app/env /app/env
|
| 32 |
+
|
| 33 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 34 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 35 |
+
|
| 36 |
+
EXPOSE 8000
|
| 37 |
+
|
| 38 |
+
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
|
| 39 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 40 |
+
|
| 41 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 42 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,128 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SQL Tutor Env
|
| 3 |
+
colorFrom: blue
|
| 4 |
+
colorTo: indigo
|
| 5 |
+
sdk: docker
|
| 6 |
+
pinned: false
|
| 7 |
+
tags:
|
| 8 |
+
- openenv
|
| 9 |
+
- openenv-main
|
| 10 |
+
- rl-environment
|
| 11 |
+
base_path: /web
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# SQL Tutor Environment
|
| 15 |
+
|
| 16 |
+
An **OpenEnv** reinforcement learning environment that trains LLM agents to identify and fix bugs in SQL queries.
|
| 17 |
+
|
| 18 |
+
Built for the [Meta x Hugging Face x PyTorch India Hackathon 2026](https://www.scaler.com/school-of-technology/meta-pytorch-hackathon).
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## Task
|
| 23 |
+
|
| 24 |
+
At each episode the agent receives:
|
| 25 |
+
- A **broken SQL query** with a deliberate bug
|
| 26 |
+
- The **database schema** (tables and columns)
|
| 27 |
+
- A **task description** of what the correct query should return
|
| 28 |
+
|
| 29 |
+
The agent must either:
|
| 30 |
+
1. **Submit a fix** (`submit_fix`) - provide a corrected SQL query
|
| 31 |
+
2. **Request a hint** (`request_hint`) - get a progressive hint (with a small reward penalty)
|
| 32 |
+
|
| 33 |
+
The episode ends when the agent submits a correct query or exhausts its 5 allowed actions.
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## Reward Structure
|
| 38 |
+
|
| 39 |
+
| Outcome | Reward |
|
| 40 |
+
|---|---|
|
| 41 |
+
| Correct fix, no hints, first try | **+1.0** |
|
| 42 |
+
| Correct fix with hints / retries | **+0.1 to +0.95** (scaled down) |
|
| 43 |
+
| SQL syntax error | **-0.1** |
|
| 44 |
+
| Wrong query (valid SQL, wrong result) | **-0.05** |
|
| 45 |
+
| Requesting a hint | **-0.1** |
|
| 46 |
+
| Max steps reached without solving | **0** |
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## Challenge Types (5 built-in)
|
| 51 |
+
|
| 52 |
+
| ID | Bug Type |
|
| 53 |
+
|---|---|
|
| 54 |
+
| `wrong_aggregate` | Missing `SUM()` + `GROUP BY` |
|
| 55 |
+
| `wrong_join` | `INNER JOIN` should be `LEFT JOIN` |
|
| 56 |
+
| `off_by_one_filter` | Wrong comparison operator in `WHERE` |
|
| 57 |
+
| `missing_having` | `WHERE` used instead of `HAVING` for aggregate filter |
|
| 58 |
+
| `wrong_order_limit` | `ASC` should be `DESC` for top-N query |
|
| 59 |
+
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
## Quick Start
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
from openenv.core import EnvClient
|
| 66 |
+
from sql_tutor_env.client import SQLTutorEnv
|
| 67 |
+
from sql_tutor_env.models import SQLAction
|
| 68 |
+
|
| 69 |
+
# Connect to the running HF Space
|
| 70 |
+
env = SQLTutorEnv(base_url="https://your-username-sql-tutor-env.hf.space")
|
| 71 |
+
|
| 72 |
+
# Start an episode
|
| 73 |
+
obs, state = env.reset()
|
| 74 |
+
print(f"Task: {obs.task_description}")
|
| 75 |
+
print(f"Broken query: {obs.broken_query}")
|
| 76 |
+
|
| 77 |
+
# Submit a fix
|
| 78 |
+
result = env.step(SQLAction(
|
| 79 |
+
action_type="submit_fix",
|
| 80 |
+
sql_query="SELECT customer_id, SUM(amount) AS total_amount FROM orders WHERE status = 'completed' GROUP BY customer_id ORDER BY customer_id;"
|
| 81 |
+
))
|
| 82 |
+
print(f"Correct: {result.observation.is_correct}, Reward: {result.reward}")
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
---
|
| 86 |
+
|
| 87 |
+
## Integration with TRL / GRPOTrainer
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
from trl import GRPOTrainer, GRPOConfig
|
| 91 |
+
from sql_tutor_env.client import SQLTutorEnv
|
| 92 |
+
from sql_tutor_env.models import SQLAction
|
| 93 |
+
|
| 94 |
+
def rollout_func(prompts, env):
|
| 95 |
+
obs, _ = env.reset()
|
| 96 |
+
# ... build prompt from obs, call model, parse SQL, step env
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
env = SQLTutorEnv(base_url="https://your-space.hf.space")
|
| 100 |
+
trainer = GRPOTrainer(
|
| 101 |
+
model=model,
|
| 102 |
+
config=GRPOConfig(...),
|
| 103 |
+
rollout_func=rollout_func,
|
| 104 |
+
env=env,
|
| 105 |
+
)
|
| 106 |
+
trainer.train()
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
---
|
| 110 |
+
|
| 111 |
+
## Project Structure
|
| 112 |
+
|
| 113 |
+
```
|
| 114 |
+
sql_tutor_env/
|
| 115 |
+
|-- __init__.py
|
| 116 |
+
|-- models.py # SQLAction, SQLObservation, SQLState
|
| 117 |
+
|-- client.py # SQLTutorEnv (EnvClient subclass)
|
| 118 |
+
|-- openenv.yaml
|
| 119 |
+
|-- pyproject.toml
|
| 120 |
+
|-- README.md
|
| 121 |
+
`-- server/
|
| 122 |
+
|-- __init__.py
|
| 123 |
+
|-- app.py # FastAPI app via create_app()
|
| 124 |
+
|-- sql_environment.py # Core reset/step/state logic
|
| 125 |
+
|-- challenges.py # Bank of SQL bug challenges
|
| 126 |
+
|-- requirements.txt
|
| 127 |
+
`-- Dockerfile
|
| 128 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# sql_tutor_env
|
client.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
|
| 3 |
+
from openenv.core import EnvClient
|
| 4 |
+
from openenv.core.client_types import StepResult
|
| 5 |
+
from models import SQLAction, SQLObservation, SQLState
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SQLTutorEnv(EnvClient[SQLAction, SQLObservation, SQLState]):
|
| 9 |
+
"""
|
| 10 |
+
Client for the SQL Tutor environment.
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
# Connect to a running HF Space
|
| 14 |
+
env = SQLTutorEnv(base_url="https://your-space.hf.space")
|
| 15 |
+
|
| 16 |
+
# Or load locally from Hub
|
| 17 |
+
env = SQLTutorEnv.from_hub("your-username/sql-tutor-env")
|
| 18 |
+
|
| 19 |
+
obs, state = env.reset()
|
| 20 |
+
result = env.step(SQLAction(action_type="submit_fix", sql_query="SELECT ..."))
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, base_url: str, **kwargs):
|
| 24 |
+
super().__init__(base_url=base_url, **kwargs)
|
| 25 |
+
|
| 26 |
+
def _step_payload(self, action: SQLAction) -> Dict:
|
| 27 |
+
return {
|
| 28 |
+
"action_type": action.action_type,
|
| 29 |
+
"sql_query": action.sql_query,
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
def _parse_result(self, payload: Dict) -> StepResult[SQLObservation]:
|
| 33 |
+
obs_data = payload.get("observation", {})
|
| 34 |
+
observation = SQLObservation(
|
| 35 |
+
broken_query=obs_data.get("broken_query", ""),
|
| 36 |
+
schema_description=obs_data.get("schema_description", ""),
|
| 37 |
+
task_description=obs_data.get("task_description", ""),
|
| 38 |
+
execution_result=obs_data.get("execution_result", ""),
|
| 39 |
+
is_correct=obs_data.get("is_correct", False),
|
| 40 |
+
hint=obs_data.get("hint"),
|
| 41 |
+
steps_taken=obs_data.get("steps_taken", 0),
|
| 42 |
+
max_steps=obs_data.get("max_steps", 5),
|
| 43 |
+
hints_used=obs_data.get("hints_used", 0),
|
| 44 |
+
)
|
| 45 |
+
return StepResult(
|
| 46 |
+
observation=observation,
|
| 47 |
+
reward=payload.get("reward", 0.0),
|
| 48 |
+
done=payload.get("done", False),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def _parse_state(self, payload: Dict) -> SQLState:
|
| 52 |
+
return SQLState(
|
| 53 |
+
challenge_id=payload.get("challenge_id", ""),
|
| 54 |
+
broken_query=payload.get("broken_query", ""),
|
| 55 |
+
correct_query=payload.get("correct_query", ""),
|
| 56 |
+
schema_sql=payload.get("schema_sql", ""),
|
| 57 |
+
schema_description=payload.get("schema_description", ""),
|
| 58 |
+
task_description=payload.get("task_description", ""),
|
| 59 |
+
hints=payload.get("hints", []),
|
| 60 |
+
steps_taken=payload.get("steps_taken", 0),
|
| 61 |
+
max_steps=payload.get("max_steps", 5),
|
| 62 |
+
hints_used=payload.get("hints_used", 0),
|
| 63 |
+
is_resolved=payload.get("is_resolved", False),
|
| 64 |
+
cumulative_reward=payload.get("cumulative_reward", 0.0),
|
| 65 |
+
)
|
models.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, List
|
| 2 |
+
from pydantic import Field
|
| 3 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SQLAction(Action):
|
| 7 |
+
action_type: str = Field(
|
| 8 |
+
...,
|
| 9 |
+
description="Either 'submit_fix' to submit a corrected SQL query, or 'request_hint' to get a hint.",
|
| 10 |
+
)
|
| 11 |
+
sql_query: Optional[str] = Field(
|
| 12 |
+
default=None,
|
| 13 |
+
description="The corrected SQL query. Required when action_type is 'submit_fix'.",
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SQLObservation(Observation):
|
| 18 |
+
# The broken query the agent needs to fix
|
| 19 |
+
broken_query: str = Field(..., description="The SQL query that contains a bug.")
|
| 20 |
+
schema_description: str = Field(..., description="Description of the database schema.")
|
| 21 |
+
task_description: str = Field(..., description="What the correct query should return.")
|
| 22 |
+
|
| 23 |
+
# Feedback from the last action
|
| 24 |
+
execution_result: str = Field(
|
| 25 |
+
default="", description="Output or error from executing the submitted query."
|
| 26 |
+
)
|
| 27 |
+
is_correct: bool = Field(
|
| 28 |
+
default=False, description="Whether the last submitted query was correct."
|
| 29 |
+
)
|
| 30 |
+
hint: Optional[str] = Field(
|
| 31 |
+
default=None, description="A hint, if one was requested."
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Progress info
|
| 35 |
+
steps_taken: int = Field(default=0, description="Number of actions taken so far.")
|
| 36 |
+
max_steps: int = Field(default=5, description="Maximum allowed actions.")
|
| 37 |
+
hints_used: int = Field(default=0, description="Number of hints used so far.")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SQLState(State):
|
| 41 |
+
challenge_id: str = ""
|
| 42 |
+
broken_query: str = ""
|
| 43 |
+
correct_query: str = ""
|
| 44 |
+
schema_sql: str = ""
|
| 45 |
+
schema_description: str = ""
|
| 46 |
+
task_description: str = ""
|
| 47 |
+
hints: List[str] = Field(default_factory=list)
|
| 48 |
+
|
| 49 |
+
steps_taken: int = 0
|
| 50 |
+
max_steps: int = 5
|
| 51 |
+
hints_used: int = 0
|
| 52 |
+
is_resolved: bool = False
|
| 53 |
+
cumulative_reward: float = 0.0
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: sql_tutor_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 8000
|
pyproject.toml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "sql-tutor-env"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "An OpenEnv RL environment for training LLM agents to fix broken SQL queries."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
dependencies = [
|
| 12 |
+
"openenv-core>=0.2.2",
|
| 13 |
+
"fastapi>=0.104.0",
|
| 14 |
+
"uvicorn>=0.24.0",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
[project.optional-dependencies]
|
| 18 |
+
dev = ["pytest", "httpx"]
|
| 19 |
+
|
| 20 |
+
[tool.hatch.build.targets.wheel]
|
| 21 |
+
only-include = [
|
| 22 |
+
"__init__.py",
|
| 23 |
+
"client.py",
|
| 24 |
+
"models.py",
|
| 25 |
+
"server",
|
| 26 |
+
]
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# server
|
server/app.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openenv.core.env_server.http_server import create_app
|
| 2 |
+
from models import SQLAction, SQLObservation
|
| 3 |
+
from server.sql_environment import SQLTutorEnvironment
|
| 4 |
+
|
| 5 |
+
app = create_app(
|
| 6 |
+
SQLTutorEnvironment,
|
| 7 |
+
SQLAction,
|
| 8 |
+
SQLObservation,
|
| 9 |
+
env_name="sql_tutor_env",
|
| 10 |
+
max_concurrent_envs=10,
|
| 11 |
+
)
|
server/challenges.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A bank of SQL challenges. Each challenge has:
|
| 3 |
+
- schema_sql: DDL + seed data to create an in-memory SQLite DB
|
| 4 |
+
- schema_description: human-readable description of tables
|
| 5 |
+
- task_description: what the correct query must return
|
| 6 |
+
- broken_query: a query with a deliberate bug
|
| 7 |
+
- correct_query: the canonical correct query
|
| 8 |
+
- hints: progressive hints (easiest to most specific)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
CHALLENGES = [
|
| 12 |
+
{
|
| 13 |
+
"id": "wrong_aggregate",
|
| 14 |
+
"schema_sql": """
|
| 15 |
+
CREATE TABLE orders (
|
| 16 |
+
id INTEGER PRIMARY KEY,
|
| 17 |
+
customer_id INTEGER,
|
| 18 |
+
amount REAL,
|
| 19 |
+
status TEXT
|
| 20 |
+
);
|
| 21 |
+
INSERT INTO orders VALUES
|
| 22 |
+
(1, 1, 120.50, 'completed'),
|
| 23 |
+
(2, 1, 45.00, 'completed'),
|
| 24 |
+
(3, 2, 200.00, 'pending'),
|
| 25 |
+
(4, 2, 80.00, 'completed'),
|
| 26 |
+
(5, 3, 60.00, 'completed');
|
| 27 |
+
""",
|
| 28 |
+
"schema_description": (
|
| 29 |
+
"Table `orders`: id (int), customer_id (int), amount (real), status (text)."
|
| 30 |
+
),
|
| 31 |
+
"task_description": (
|
| 32 |
+
"Find each customer's total spending on completed orders. "
|
| 33 |
+
"Return customer_id and total_amount, ordered by customer_id."
|
| 34 |
+
),
|
| 35 |
+
"broken_query": (
|
| 36 |
+
"SELECT customer_id, amount AS total_amount "
|
| 37 |
+
"FROM orders "
|
| 38 |
+
"WHERE status = 'completed' "
|
| 39 |
+
"ORDER BY customer_id;"
|
| 40 |
+
),
|
| 41 |
+
"correct_query": (
|
| 42 |
+
"SELECT customer_id, SUM(amount) AS total_amount "
|
| 43 |
+
"FROM orders "
|
| 44 |
+
"WHERE status = 'completed' "
|
| 45 |
+
"GROUP BY customer_id "
|
| 46 |
+
"ORDER BY customer_id;"
|
| 47 |
+
),
|
| 48 |
+
"hints": [
|
| 49 |
+
"The query is missing an aggregation - each customer can have multiple orders.",
|
| 50 |
+
"You need to use SUM() to add up amounts, and GROUP BY to group per customer.",
|
| 51 |
+
"Add `SUM(amount) AS total_amount` and `GROUP BY customer_id` to your query.",
|
| 52 |
+
],
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"id": "wrong_join",
|
| 56 |
+
"schema_sql": """
|
| 57 |
+
CREATE TABLE employees (
|
| 58 |
+
id INTEGER PRIMARY KEY,
|
| 59 |
+
name TEXT,
|
| 60 |
+
department_id INTEGER
|
| 61 |
+
);
|
| 62 |
+
CREATE TABLE departments (
|
| 63 |
+
id INTEGER PRIMARY KEY,
|
| 64 |
+
name TEXT
|
| 65 |
+
);
|
| 66 |
+
INSERT INTO departments VALUES (1, 'Engineering'), (2, 'Marketing'), (3, 'HR');
|
| 67 |
+
INSERT INTO employees VALUES
|
| 68 |
+
(1, 'Alice', 1),
|
| 69 |
+
(2, 'Bob', 2),
|
| 70 |
+
(3, 'Carol', 1),
|
| 71 |
+
(4, 'Dave', NULL);
|
| 72 |
+
""",
|
| 73 |
+
"schema_description": (
|
| 74 |
+
"Table `employees`: id, name, department_id. "
|
| 75 |
+
"Table `departments`: id, name."
|
| 76 |
+
),
|
| 77 |
+
"task_description": (
|
| 78 |
+
"List all employees and their department name. "
|
| 79 |
+
"Include employees with no department (show NULL for their department). "
|
| 80 |
+
"Return employee name and department name."
|
| 81 |
+
),
|
| 82 |
+
"broken_query": (
|
| 83 |
+
"SELECT e.name, d.name AS department "
|
| 84 |
+
"FROM employees e "
|
| 85 |
+
"INNER JOIN departments d ON e.department_id = d.id;"
|
| 86 |
+
),
|
| 87 |
+
"correct_query": (
|
| 88 |
+
"SELECT e.name, d.name AS department "
|
| 89 |
+
"FROM employees e "
|
| 90 |
+
"LEFT JOIN departments d ON e.department_id = d.id;"
|
| 91 |
+
),
|
| 92 |
+
"hints": [
|
| 93 |
+
"The result is missing some employees. Check whether all rows from `employees` appear.",
|
| 94 |
+
"INNER JOIN only returns rows with a match in both tables. Dave has no department.",
|
| 95 |
+
"Change INNER JOIN to LEFT JOIN so employees without a department are still included.",
|
| 96 |
+
],
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"id": "off_by_one_filter",
|
| 100 |
+
"schema_sql": """
|
| 101 |
+
CREATE TABLE products (
|
| 102 |
+
id INTEGER PRIMARY KEY,
|
| 103 |
+
name TEXT,
|
| 104 |
+
price REAL,
|
| 105 |
+
stock INTEGER
|
| 106 |
+
);
|
| 107 |
+
INSERT INTO products VALUES
|
| 108 |
+
(1, 'Widget A', 9.99, 50),
|
| 109 |
+
(2, 'Widget B', 14.99, 0),
|
| 110 |
+
(3, 'Gadget X', 49.99, 10),
|
| 111 |
+
(4, 'Gadget Y', 99.99, 0),
|
| 112 |
+
(5, 'Gizmo Z', 4.99, 200);
|
| 113 |
+
""",
|
| 114 |
+
"schema_description": (
|
| 115 |
+
"Table `products`: id, name, price (real), stock (int)."
|
| 116 |
+
),
|
| 117 |
+
"task_description": (
|
| 118 |
+
"Find all products that are currently out of stock (stock = 0). "
|
| 119 |
+
"Return their name and price."
|
| 120 |
+
),
|
| 121 |
+
"broken_query": (
|
| 122 |
+
"SELECT name, price FROM products WHERE stock < 0;"
|
| 123 |
+
),
|
| 124 |
+
"correct_query": (
|
| 125 |
+
"SELECT name, price FROM products WHERE stock = 0;"
|
| 126 |
+
),
|
| 127 |
+
"hints": [
|
| 128 |
+
"The query returns no rows, but some products have zero stock.",
|
| 129 |
+
"Check the WHERE condition - you want products where stock equals zero, not less than zero.",
|
| 130 |
+
"Change `stock < 0` to `stock = 0`.",
|
| 131 |
+
],
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"id": "missing_having",
|
| 135 |
+
"schema_sql": """
|
| 136 |
+
CREATE TABLE sales (
|
| 137 |
+
id INTEGER PRIMARY KEY,
|
| 138 |
+
region TEXT,
|
| 139 |
+
salesperson TEXT,
|
| 140 |
+
revenue REAL
|
| 141 |
+
);
|
| 142 |
+
INSERT INTO sales VALUES
|
| 143 |
+
(1, 'North', 'Alice', 5000),
|
| 144 |
+
(2, 'North', 'Bob', 3000),
|
| 145 |
+
(3, 'South', 'Carol', 8000),
|
| 146 |
+
(4, 'South', 'Dave', 2000),
|
| 147 |
+
(5, 'East', 'Eve', 1500),
|
| 148 |
+
(6, 'North', 'Alice', 4000),
|
| 149 |
+
(7, 'South', 'Carol', 7000);
|
| 150 |
+
""",
|
| 151 |
+
"schema_description": (
|
| 152 |
+
"Table `sales`: id, region (text), salesperson (text), revenue (real)."
|
| 153 |
+
),
|
| 154 |
+
"task_description": (
|
| 155 |
+
"Find regions whose total revenue exceeds 10000. "
|
| 156 |
+
"Return region and total_revenue."
|
| 157 |
+
),
|
| 158 |
+
"broken_query": (
|
| 159 |
+
"SELECT region, SUM(revenue) AS total_revenue "
|
| 160 |
+
"FROM sales "
|
| 161 |
+
"WHERE SUM(revenue) > 10000 "
|
| 162 |
+
"GROUP BY region;"
|
| 163 |
+
),
|
| 164 |
+
"correct_query": (
|
| 165 |
+
"SELECT region, SUM(revenue) AS total_revenue "
|
| 166 |
+
"FROM sales "
|
| 167 |
+
"GROUP BY region "
|
| 168 |
+
"HAVING SUM(revenue) > 10000;"
|
| 169 |
+
),
|
| 170 |
+
"hints": [
|
| 171 |
+
"You cannot use aggregate functions like SUM() inside a WHERE clause.",
|
| 172 |
+
"To filter on aggregated values, use HAVING instead of WHERE.",
|
| 173 |
+
"Move the condition to a HAVING clause after GROUP BY: `HAVING SUM(revenue) > 10000`.",
|
| 174 |
+
],
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
"id": "wrong_order_limit",
|
| 178 |
+
"schema_sql": """
|
| 179 |
+
CREATE TABLE students (
|
| 180 |
+
id INTEGER PRIMARY KEY,
|
| 181 |
+
name TEXT,
|
| 182 |
+
score INTEGER
|
| 183 |
+
);
|
| 184 |
+
INSERT INTO students VALUES
|
| 185 |
+
(1, 'Alice', 88),
|
| 186 |
+
(2, 'Bob', 72),
|
| 187 |
+
(3, 'Carol', 95),
|
| 188 |
+
(4, 'Dave', 60),
|
| 189 |
+
(5, 'Eve', 91);
|
| 190 |
+
""",
|
| 191 |
+
"schema_description": (
|
| 192 |
+
"Table `students`: id, name, score (int)."
|
| 193 |
+
),
|
| 194 |
+
"task_description": (
|
| 195 |
+
"Find the top 3 students by score. Return name and score, highest score first."
|
| 196 |
+
),
|
| 197 |
+
"broken_query": (
|
| 198 |
+
"SELECT name, score FROM students ORDER BY score ASC LIMIT 3;"
|
| 199 |
+
),
|
| 200 |
+
"correct_query": (
|
| 201 |
+
"SELECT name, score FROM students ORDER BY score DESC LIMIT 3;"
|
| 202 |
+
),
|
| 203 |
+
"hints": [
|
| 204 |
+
"The query returns the lowest scores, not the highest.",
|
| 205 |
+
"Check the ORDER BY direction - ASC sorts smallest first.",
|
| 206 |
+
"Change `ORDER BY score ASC` to `ORDER BY score DESC` to get the top scores.",
|
| 207 |
+
],
|
| 208 |
+
},
|
| 209 |
+
]
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core>=0.2.2
|
| 2 |
+
fastapi>=0.104.0
|
| 3 |
+
uvicorn>=0.24.0
|
server/sql_environment.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import random
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
from openenv.core.env_server.types import EnvBase
|
| 6 |
+
from models import SQLAction, SQLObservation, SQLState
|
| 7 |
+
from server.challenges import CHALLENGES
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _run_query(schema_sql: str, query: str) -> Tuple[bool, str]:
|
| 11 |
+
"""
|
| 12 |
+
Execute query against an in-memory SQLite DB seeded with schema_sql.
|
| 13 |
+
Returns (success: bool, result_string: str).
|
| 14 |
+
"""
|
| 15 |
+
try:
|
| 16 |
+
conn = sqlite3.connect(":memory:")
|
| 17 |
+
conn.executescript(schema_sql)
|
| 18 |
+
cursor = conn.execute(query)
|
| 19 |
+
rows = cursor.fetchall()
|
| 20 |
+
col_names = [desc[0] for desc in cursor.description] if cursor.description else []
|
| 21 |
+
conn.close()
|
| 22 |
+
|
| 23 |
+
if not rows:
|
| 24 |
+
return True, "(no rows returned)"
|
| 25 |
+
|
| 26 |
+
# Format as a simple text table
|
| 27 |
+
header = " | ".join(col_names)
|
| 28 |
+
sep = "-" * len(header)
|
| 29 |
+
row_lines = [" | ".join(str(v) for v in row) for row in rows]
|
| 30 |
+
return True, "\n".join([header, sep] + row_lines)
|
| 31 |
+
|
| 32 |
+
except Exception as e:
|
| 33 |
+
return False, f"ERROR: {e}"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _results_match(schema_sql: str, query_a: str, query_b: str) -> bool:
|
| 37 |
+
"""Check whether two queries return identical result sets."""
|
| 38 |
+
try:
|
| 39 |
+
conn = sqlite3.connect(":memory:")
|
| 40 |
+
conn.executescript(schema_sql)
|
| 41 |
+
|
| 42 |
+
rows_a = set(conn.execute(query_a).fetchall())
|
| 43 |
+
rows_b = set(conn.execute(query_b).fetchall())
|
| 44 |
+
conn.close()
|
| 45 |
+
return rows_a == rows_b
|
| 46 |
+
except Exception:
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class SQLTutorEnvironment(EnvBase[SQLAction, SQLObservation, SQLState]):
|
| 51 |
+
|
| 52 |
+
def reset(self) -> Tuple[SQLObservation, SQLState]:
|
| 53 |
+
challenge = random.choice(CHALLENGES)
|
| 54 |
+
|
| 55 |
+
state = SQLState(
|
| 56 |
+
challenge_id=challenge["id"],
|
| 57 |
+
broken_query=challenge["broken_query"],
|
| 58 |
+
correct_query=challenge["correct_query"],
|
| 59 |
+
schema_sql=challenge["schema_sql"],
|
| 60 |
+
schema_description=challenge["schema_description"],
|
| 61 |
+
task_description=challenge["task_description"],
|
| 62 |
+
hints=challenge["hints"],
|
| 63 |
+
steps_taken=0,
|
| 64 |
+
max_steps=5,
|
| 65 |
+
hints_used=0,
|
| 66 |
+
is_resolved=False,
|
| 67 |
+
cumulative_reward=0.0,
|
| 68 |
+
)
|
| 69 |
+
self._state = state
|
| 70 |
+
|
| 71 |
+
# Show the agent the broken query output so it understands what's wrong
|
| 72 |
+
_, broken_result = _run_query(state.schema_sql, state.broken_query)
|
| 73 |
+
|
| 74 |
+
observation = SQLObservation(
|
| 75 |
+
broken_query=state.broken_query,
|
| 76 |
+
schema_description=state.schema_description,
|
| 77 |
+
task_description=state.task_description,
|
| 78 |
+
execution_result=f"Current (broken) query output:\n{broken_result}",
|
| 79 |
+
is_correct=False,
|
| 80 |
+
hint=None,
|
| 81 |
+
steps_taken=0,
|
| 82 |
+
max_steps=state.max_steps,
|
| 83 |
+
hints_used=0,
|
| 84 |
+
)
|
| 85 |
+
return observation, state
|
| 86 |
+
|
| 87 |
+
def step(self, action: SQLAction) -> Tuple[SQLObservation, float, bool, SQLState]:
|
| 88 |
+
state = self._state
|
| 89 |
+
state.steps_taken += 1
|
| 90 |
+
reward = 0.0
|
| 91 |
+
done = False
|
| 92 |
+
hint = None
|
| 93 |
+
|
| 94 |
+
if action.action_type == "request_hint":
|
| 95 |
+
hint_index = min(state.hints_used, len(state.hints) - 1)
|
| 96 |
+
hint = state.hints[hint_index]
|
| 97 |
+
state.hints_used += 1
|
| 98 |
+
reward = -0.1 # small penalty for using a hint
|
| 99 |
+
execution_result = f"Current (broken) query output shown for reference."
|
| 100 |
+
_, execution_result = _run_query(state.schema_sql, state.broken_query)
|
| 101 |
+
execution_result = f"(Hint requested - no query executed)\nBroken query output:\n{execution_result}"
|
| 102 |
+
is_correct = False
|
| 103 |
+
|
| 104 |
+
elif action.action_type == "submit_fix":
|
| 105 |
+
if not action.sql_query:
|
| 106 |
+
execution_result = "ERROR: You chose 'submit_fix' but provided no sql_query."
|
| 107 |
+
is_correct = False
|
| 108 |
+
reward = -0.05
|
| 109 |
+
else:
|
| 110 |
+
success, execution_result = _run_query(state.schema_sql, action.sql_query)
|
| 111 |
+
|
| 112 |
+
if not success:
|
| 113 |
+
is_correct = False
|
| 114 |
+
reward = -0.1
|
| 115 |
+
else:
|
| 116 |
+
is_correct = _results_match(
|
| 117 |
+
state.schema_sql, action.sql_query, state.correct_query
|
| 118 |
+
)
|
| 119 |
+
if is_correct:
|
| 120 |
+
# Reward decreases with hints used and steps taken
|
| 121 |
+
base_reward = 1.0
|
| 122 |
+
hint_penalty = 0.15 * state.hints_used
|
| 123 |
+
step_penalty = 0.05 * max(0, state.steps_taken - 1)
|
| 124 |
+
reward = max(0.1, base_reward - hint_penalty - step_penalty)
|
| 125 |
+
state.is_resolved = True
|
| 126 |
+
done = True
|
| 127 |
+
else:
|
| 128 |
+
reward = -0.05
|
| 129 |
+
else:
|
| 130 |
+
execution_result = f"ERROR: Unknown action_type '{action.action_type}'. Use 'submit_fix' or 'request_hint'."
|
| 131 |
+
is_correct = False
|
| 132 |
+
reward = -0.05
|
| 133 |
+
|
| 134 |
+
# End episode if max steps reached
|
| 135 |
+
if state.steps_taken >= state.max_steps and not done:
|
| 136 |
+
done = True
|
| 137 |
+
|
| 138 |
+
state.cumulative_reward += reward
|
| 139 |
+
|
| 140 |
+
observation = SQLObservation(
|
| 141 |
+
broken_query=state.broken_query,
|
| 142 |
+
schema_description=state.schema_description,
|
| 143 |
+
task_description=state.task_description,
|
| 144 |
+
execution_result=execution_result,
|
| 145 |
+
is_correct=is_correct,
|
| 146 |
+
hint=hint,
|
| 147 |
+
steps_taken=state.steps_taken,
|
| 148 |
+
max_steps=state.max_steps,
|
| 149 |
+
hints_used=state.hints_used,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return observation, reward, done, state
|
| 153 |
+
|
| 154 |
+
def get_state(self) -> SQLState:
|
| 155 |
+
return self._state
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|