snigenigmatic commited on
Commit
0683cf4
·
verified ·
1 Parent(s): 51862df

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
2
+ FROM ${BASE_IMAGE} AS builder
3
+
4
+ WORKDIR /app
5
+
6
+ RUN apt-get update && apt-get install -y --no-install-recommends git curl \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ COPY . /app/env
10
+ WORKDIR /app/env
11
+
12
+ RUN if ! command -v uv >/dev/null 2>&1; then \
13
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
14
+ mv /root/.local/bin/uv /usr/local/bin/uv; \
15
+ fi
16
+
17
+ RUN --mount=type=cache,target=/root/.cache/uv \
18
+ if [ -f uv.lock ]; then uv sync --frozen --no-install-project --no-editable; \
19
+ else uv sync --no-install-project --no-editable; fi
20
+
21
+ RUN --mount=type=cache,target=/root/.cache/uv \
22
+ if [ -f uv.lock ]; then uv sync --frozen --no-editable; \
23
+ else uv sync --no-editable; fi
24
+
25
+ # Final image
26
+ FROM ${BASE_IMAGE}
27
+
28
+ WORKDIR /app
29
+
30
+ COPY --from=builder /app/env/.venv /app/.venv
31
+ COPY --from=builder /app/env /app/env
32
+
33
+ ENV PATH="/app/.venv/bin:$PATH"
34
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
35
+
36
+ EXPOSE 8000
37
+
38
+ HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
39
+ CMD curl -f http://localhost:8000/health || exit 1
40
+
41
+ ENV ENABLE_WEB_INTERFACE=true
42
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
README.md CHANGED
@@ -1,10 +1,128 @@
1
- ---
2
- title: Sql Tutor Env
3
- emoji: 🏃
4
- colorFrom: green
5
- colorTo: yellow
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SQL Tutor Env
3
+ colorFrom: blue
4
+ colorTo: indigo
5
+ sdk: docker
6
+ pinned: false
7
+ tags:
8
+ - openenv
9
+ - openenv-main
10
+ - rl-environment
11
+ base_path: /web
12
+ ---
13
+
14
+ # SQL Tutor Environment
15
+
16
+ An **OpenEnv** reinforcement learning environment that trains LLM agents to identify and fix bugs in SQL queries.
17
+
18
+ Built for the [Meta x Hugging Face x PyTorch India Hackathon 2026](https://www.scaler.com/school-of-technology/meta-pytorch-hackathon).
19
+
20
+ ---
21
+
22
+ ## Task
23
+
24
+ At each episode the agent receives:
25
+ - A **broken SQL query** with a deliberate bug
26
+ - The **database schema** (tables and columns)
27
+ - A **task description** of what the correct query should return
28
+
29
+ The agent must either:
30
+ 1. **Submit a fix** (`submit_fix`) - provide a corrected SQL query
31
+ 2. **Request a hint** (`request_hint`) - get a progressive hint (with a small reward penalty)
32
+
33
+ The episode ends when the agent submits a correct query or exhausts its 5 allowed actions.
34
+
35
+ ---
36
+
37
+ ## Reward Structure
38
+
39
+ | Outcome | Reward |
40
+ |---|---|
41
+ | Correct fix, no hints, first try | **+1.0** |
42
+ | Correct fix with hints / retries | **+0.1 to +0.95** (scaled down) |
43
+ | SQL syntax error | **-0.1** |
44
+ | Wrong query (valid SQL, wrong result) | **-0.05** |
45
+ | Requesting a hint | **-0.1** |
46
+ | Max steps reached without solving | **0** |
47
+
48
+ ---
49
+
50
+ ## Challenge Types (5 built-in)
51
+
52
+ | ID | Bug Type |
53
+ |---|---|
54
+ | `wrong_aggregate` | Missing `SUM()` + `GROUP BY` |
55
+ | `wrong_join` | `INNER JOIN` should be `LEFT JOIN` |
56
+ | `off_by_one_filter` | Wrong comparison operator in `WHERE` |
57
+ | `missing_having` | `WHERE` used instead of `HAVING` for aggregate filter |
58
+ | `wrong_order_limit` | `ASC` should be `DESC` for top-N query |
59
+
60
+ ---
61
+
62
+ ## Quick Start
63
+
64
+ ```python
65
+ from openenv.core import EnvClient
66
+ from sql_tutor_env.client import SQLTutorEnv
67
+ from sql_tutor_env.models import SQLAction
68
+
69
+ # Connect to the running HF Space
70
+ env = SQLTutorEnv(base_url="https://your-username-sql-tutor-env.hf.space")
71
+
72
+ # Start an episode
73
+ obs, state = env.reset()
74
+ print(f"Task: {obs.task_description}")
75
+ print(f"Broken query: {obs.broken_query}")
76
+
77
+ # Submit a fix
78
+ result = env.step(SQLAction(
79
+ action_type="submit_fix",
80
+ sql_query="SELECT customer_id, SUM(amount) AS total_amount FROM orders WHERE status = 'completed' GROUP BY customer_id ORDER BY customer_id;"
81
+ ))
82
+ print(f"Correct: {result.observation.is_correct}, Reward: {result.reward}")
83
+ ```
84
+
85
+ ---
86
+
87
+ ## Integration with TRL / GRPOTrainer
88
+
89
+ ```python
90
+ from trl import GRPOTrainer, GRPOConfig
91
+ from sql_tutor_env.client import SQLTutorEnv
92
+ from sql_tutor_env.models import SQLAction
93
+
94
+ def rollout_func(prompts, env):
95
+ obs, _ = env.reset()
96
+ # ... build prompt from obs, call model, parse SQL, step env
97
+ pass
98
+
99
+ env = SQLTutorEnv(base_url="https://your-space.hf.space")
100
+ trainer = GRPOTrainer(
101
+ model=model,
102
+ config=GRPOConfig(...),
103
+ rollout_func=rollout_func,
104
+ env=env,
105
+ )
106
+ trainer.train()
107
+ ```
108
+
109
+ ---
110
+
111
+ ## Project Structure
112
+
113
+ ```
114
+ sql_tutor_env/
115
+ |-- __init__.py
116
+ |-- models.py # SQLAction, SQLObservation, SQLState
117
+ |-- client.py # SQLTutorEnv (EnvClient subclass)
118
+ |-- openenv.yaml
119
+ |-- pyproject.toml
120
+ |-- README.md
121
+ `-- server/
122
+ |-- __init__.py
123
+ |-- app.py # FastAPI app via create_app()
124
+ |-- sql_environment.py # Core reset/step/state logic
125
+ |-- challenges.py # Bank of SQL bug challenges
126
+ |-- requirements.txt
127
+ `-- Dockerfile
128
+ ```
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # sql_tutor_env
client.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from openenv.core import EnvClient
4
+ from openenv.core.client_types import StepResult
5
+ from models import SQLAction, SQLObservation, SQLState
6
+
7
+
8
+ class SQLTutorEnv(EnvClient[SQLAction, SQLObservation, SQLState]):
9
+ """
10
+ Client for the SQL Tutor environment.
11
+
12
+ Usage:
13
+ # Connect to a running HF Space
14
+ env = SQLTutorEnv(base_url="https://your-space.hf.space")
15
+
16
+ # Or load locally from Hub
17
+ env = SQLTutorEnv.from_hub("your-username/sql-tutor-env")
18
+
19
+ obs, state = env.reset()
20
+ result = env.step(SQLAction(action_type="submit_fix", sql_query="SELECT ..."))
21
+ """
22
+
23
+ def __init__(self, base_url: str, **kwargs):
24
+ super().__init__(base_url=base_url, **kwargs)
25
+
26
+ def _step_payload(self, action: SQLAction) -> Dict:
27
+ return {
28
+ "action_type": action.action_type,
29
+ "sql_query": action.sql_query,
30
+ }
31
+
32
+ def _parse_result(self, payload: Dict) -> StepResult[SQLObservation]:
33
+ obs_data = payload.get("observation", {})
34
+ observation = SQLObservation(
35
+ broken_query=obs_data.get("broken_query", ""),
36
+ schema_description=obs_data.get("schema_description", ""),
37
+ task_description=obs_data.get("task_description", ""),
38
+ execution_result=obs_data.get("execution_result", ""),
39
+ is_correct=obs_data.get("is_correct", False),
40
+ hint=obs_data.get("hint"),
41
+ steps_taken=obs_data.get("steps_taken", 0),
42
+ max_steps=obs_data.get("max_steps", 5),
43
+ hints_used=obs_data.get("hints_used", 0),
44
+ )
45
+ return StepResult(
46
+ observation=observation,
47
+ reward=payload.get("reward", 0.0),
48
+ done=payload.get("done", False),
49
+ )
50
+
51
+ def _parse_state(self, payload: Dict) -> SQLState:
52
+ return SQLState(
53
+ challenge_id=payload.get("challenge_id", ""),
54
+ broken_query=payload.get("broken_query", ""),
55
+ correct_query=payload.get("correct_query", ""),
56
+ schema_sql=payload.get("schema_sql", ""),
57
+ schema_description=payload.get("schema_description", ""),
58
+ task_description=payload.get("task_description", ""),
59
+ hints=payload.get("hints", []),
60
+ steps_taken=payload.get("steps_taken", 0),
61
+ max_steps=payload.get("max_steps", 5),
62
+ hints_used=payload.get("hints_used", 0),
63
+ is_resolved=payload.get("is_resolved", False),
64
+ cumulative_reward=payload.get("cumulative_reward", 0.0),
65
+ )
models.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ from pydantic import Field
3
+ from openenv.core.env_server.types import Action, Observation, State
4
+
5
+
6
+ class SQLAction(Action):
7
+ action_type: str = Field(
8
+ ...,
9
+ description="Either 'submit_fix' to submit a corrected SQL query, or 'request_hint' to get a hint.",
10
+ )
11
+ sql_query: Optional[str] = Field(
12
+ default=None,
13
+ description="The corrected SQL query. Required when action_type is 'submit_fix'.",
14
+ )
15
+
16
+
17
+ class SQLObservation(Observation):
18
+ # The broken query the agent needs to fix
19
+ broken_query: str = Field(..., description="The SQL query that contains a bug.")
20
+ schema_description: str = Field(..., description="Description of the database schema.")
21
+ task_description: str = Field(..., description="What the correct query should return.")
22
+
23
+ # Feedback from the last action
24
+ execution_result: str = Field(
25
+ default="", description="Output or error from executing the submitted query."
26
+ )
27
+ is_correct: bool = Field(
28
+ default=False, description="Whether the last submitted query was correct."
29
+ )
30
+ hint: Optional[str] = Field(
31
+ default=None, description="A hint, if one was requested."
32
+ )
33
+
34
+ # Progress info
35
+ steps_taken: int = Field(default=0, description="Number of actions taken so far.")
36
+ max_steps: int = Field(default=5, description="Maximum allowed actions.")
37
+ hints_used: int = Field(default=0, description="Number of hints used so far.")
38
+
39
+
40
+ class SQLState(State):
41
+ challenge_id: str = ""
42
+ broken_query: str = ""
43
+ correct_query: str = ""
44
+ schema_sql: str = ""
45
+ schema_description: str = ""
46
+ task_description: str = ""
47
+ hints: List[str] = Field(default_factory=list)
48
+
49
+ steps_taken: int = 0
50
+ max_steps: int = 5
51
+ hints_used: int = 0
52
+ is_resolved: bool = False
53
+ cumulative_reward: float = 0.0
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: sql_tutor_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
pyproject.toml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "sql-tutor-env"
7
+ version = "0.1.0"
8
+ description = "An OpenEnv RL environment for training LLM agents to fix broken SQL queries."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ dependencies = [
12
+ "openenv-core>=0.2.2",
13
+ "fastapi>=0.104.0",
14
+ "uvicorn>=0.24.0",
15
+ ]
16
+
17
+ [project.optional-dependencies]
18
+ dev = ["pytest", "httpx"]
19
+
20
+ [tool.hatch.build.targets.wheel]
21
+ only-include = [
22
+ "__init__.py",
23
+ "client.py",
24
+ "models.py",
25
+ "server",
26
+ ]
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # server
server/app.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openenv.core.env_server.http_server import create_app
2
+ from models import SQLAction, SQLObservation
3
+ from server.sql_environment import SQLTutorEnvironment
4
+
5
+ app = create_app(
6
+ SQLTutorEnvironment,
7
+ SQLAction,
8
+ SQLObservation,
9
+ env_name="sql_tutor_env",
10
+ max_concurrent_envs=10,
11
+ )
server/challenges.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A bank of SQL challenges. Each challenge has:
3
+ - schema_sql: DDL + seed data to create an in-memory SQLite DB
4
+ - schema_description: human-readable description of tables
5
+ - task_description: what the correct query must return
6
+ - broken_query: a query with a deliberate bug
7
+ - correct_query: the canonical correct query
8
+ - hints: progressive hints (easiest to most specific)
9
+ """
10
+
11
+ CHALLENGES = [
12
+ {
13
+ "id": "wrong_aggregate",
14
+ "schema_sql": """
15
+ CREATE TABLE orders (
16
+ id INTEGER PRIMARY KEY,
17
+ customer_id INTEGER,
18
+ amount REAL,
19
+ status TEXT
20
+ );
21
+ INSERT INTO orders VALUES
22
+ (1, 1, 120.50, 'completed'),
23
+ (2, 1, 45.00, 'completed'),
24
+ (3, 2, 200.00, 'pending'),
25
+ (4, 2, 80.00, 'completed'),
26
+ (5, 3, 60.00, 'completed');
27
+ """,
28
+ "schema_description": (
29
+ "Table `orders`: id (int), customer_id (int), amount (real), status (text)."
30
+ ),
31
+ "task_description": (
32
+ "Find each customer's total spending on completed orders. "
33
+ "Return customer_id and total_amount, ordered by customer_id."
34
+ ),
35
+ "broken_query": (
36
+ "SELECT customer_id, amount AS total_amount "
37
+ "FROM orders "
38
+ "WHERE status = 'completed' "
39
+ "ORDER BY customer_id;"
40
+ ),
41
+ "correct_query": (
42
+ "SELECT customer_id, SUM(amount) AS total_amount "
43
+ "FROM orders "
44
+ "WHERE status = 'completed' "
45
+ "GROUP BY customer_id "
46
+ "ORDER BY customer_id;"
47
+ ),
48
+ "hints": [
49
+ "The query is missing an aggregation - each customer can have multiple orders.",
50
+ "You need to use SUM() to add up amounts, and GROUP BY to group per customer.",
51
+ "Add `SUM(amount) AS total_amount` and `GROUP BY customer_id` to your query.",
52
+ ],
53
+ },
54
+ {
55
+ "id": "wrong_join",
56
+ "schema_sql": """
57
+ CREATE TABLE employees (
58
+ id INTEGER PRIMARY KEY,
59
+ name TEXT,
60
+ department_id INTEGER
61
+ );
62
+ CREATE TABLE departments (
63
+ id INTEGER PRIMARY KEY,
64
+ name TEXT
65
+ );
66
+ INSERT INTO departments VALUES (1, 'Engineering'), (2, 'Marketing'), (3, 'HR');
67
+ INSERT INTO employees VALUES
68
+ (1, 'Alice', 1),
69
+ (2, 'Bob', 2),
70
+ (3, 'Carol', 1),
71
+ (4, 'Dave', NULL);
72
+ """,
73
+ "schema_description": (
74
+ "Table `employees`: id, name, department_id. "
75
+ "Table `departments`: id, name."
76
+ ),
77
+ "task_description": (
78
+ "List all employees and their department name. "
79
+ "Include employees with no department (show NULL for their department). "
80
+ "Return employee name and department name."
81
+ ),
82
+ "broken_query": (
83
+ "SELECT e.name, d.name AS department "
84
+ "FROM employees e "
85
+ "INNER JOIN departments d ON e.department_id = d.id;"
86
+ ),
87
+ "correct_query": (
88
+ "SELECT e.name, d.name AS department "
89
+ "FROM employees e "
90
+ "LEFT JOIN departments d ON e.department_id = d.id;"
91
+ ),
92
+ "hints": [
93
+ "The result is missing some employees. Check whether all rows from `employees` appear.",
94
+ "INNER JOIN only returns rows with a match in both tables. Dave has no department.",
95
+ "Change INNER JOIN to LEFT JOIN so employees without a department are still included.",
96
+ ],
97
+ },
98
+ {
99
+ "id": "off_by_one_filter",
100
+ "schema_sql": """
101
+ CREATE TABLE products (
102
+ id INTEGER PRIMARY KEY,
103
+ name TEXT,
104
+ price REAL,
105
+ stock INTEGER
106
+ );
107
+ INSERT INTO products VALUES
108
+ (1, 'Widget A', 9.99, 50),
109
+ (2, 'Widget B', 14.99, 0),
110
+ (3, 'Gadget X', 49.99, 10),
111
+ (4, 'Gadget Y', 99.99, 0),
112
+ (5, 'Gizmo Z', 4.99, 200);
113
+ """,
114
+ "schema_description": (
115
+ "Table `products`: id, name, price (real), stock (int)."
116
+ ),
117
+ "task_description": (
118
+ "Find all products that are currently out of stock (stock = 0). "
119
+ "Return their name and price."
120
+ ),
121
+ "broken_query": (
122
+ "SELECT name, price FROM products WHERE stock < 0;"
123
+ ),
124
+ "correct_query": (
125
+ "SELECT name, price FROM products WHERE stock = 0;"
126
+ ),
127
+ "hints": [
128
+ "The query returns no rows, but some products have zero stock.",
129
+ "Check the WHERE condition - you want products where stock equals zero, not less than zero.",
130
+ "Change `stock < 0` to `stock = 0`.",
131
+ ],
132
+ },
133
+ {
134
+ "id": "missing_having",
135
+ "schema_sql": """
136
+ CREATE TABLE sales (
137
+ id INTEGER PRIMARY KEY,
138
+ region TEXT,
139
+ salesperson TEXT,
140
+ revenue REAL
141
+ );
142
+ INSERT INTO sales VALUES
143
+ (1, 'North', 'Alice', 5000),
144
+ (2, 'North', 'Bob', 3000),
145
+ (3, 'South', 'Carol', 8000),
146
+ (4, 'South', 'Dave', 2000),
147
+ (5, 'East', 'Eve', 1500),
148
+ (6, 'North', 'Alice', 4000),
149
+ (7, 'South', 'Carol', 7000);
150
+ """,
151
+ "schema_description": (
152
+ "Table `sales`: id, region (text), salesperson (text), revenue (real)."
153
+ ),
154
+ "task_description": (
155
+ "Find regions whose total revenue exceeds 10000. "
156
+ "Return region and total_revenue."
157
+ ),
158
+ "broken_query": (
159
+ "SELECT region, SUM(revenue) AS total_revenue "
160
+ "FROM sales "
161
+ "WHERE SUM(revenue) > 10000 "
162
+ "GROUP BY region;"
163
+ ),
164
+ "correct_query": (
165
+ "SELECT region, SUM(revenue) AS total_revenue "
166
+ "FROM sales "
167
+ "GROUP BY region "
168
+ "HAVING SUM(revenue) > 10000;"
169
+ ),
170
+ "hints": [
171
+ "You cannot use aggregate functions like SUM() inside a WHERE clause.",
172
+ "To filter on aggregated values, use HAVING instead of WHERE.",
173
+ "Move the condition to a HAVING clause after GROUP BY: `HAVING SUM(revenue) > 10000`.",
174
+ ],
175
+ },
176
+ {
177
+ "id": "wrong_order_limit",
178
+ "schema_sql": """
179
+ CREATE TABLE students (
180
+ id INTEGER PRIMARY KEY,
181
+ name TEXT,
182
+ score INTEGER
183
+ );
184
+ INSERT INTO students VALUES
185
+ (1, 'Alice', 88),
186
+ (2, 'Bob', 72),
187
+ (3, 'Carol', 95),
188
+ (4, 'Dave', 60),
189
+ (5, 'Eve', 91);
190
+ """,
191
+ "schema_description": (
192
+ "Table `students`: id, name, score (int)."
193
+ ),
194
+ "task_description": (
195
+ "Find the top 3 students by score. Return name and score, highest score first."
196
+ ),
197
+ "broken_query": (
198
+ "SELECT name, score FROM students ORDER BY score ASC LIMIT 3;"
199
+ ),
200
+ "correct_query": (
201
+ "SELECT name, score FROM students ORDER BY score DESC LIMIT 3;"
202
+ ),
203
+ "hints": [
204
+ "The query returns the lowest scores, not the highest.",
205
+ "Check the ORDER BY direction - ASC sorts smallest first.",
206
+ "Change `ORDER BY score ASC` to `ORDER BY score DESC` to get the top scores.",
207
+ ],
208
+ },
209
+ ]
server/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ openenv-core>=0.2.2
2
+ fastapi>=0.104.0
3
+ uvicorn>=0.24.0
server/sql_environment.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import random
3
+ from typing import Tuple
4
+
5
+ from openenv.core.env_server.types import EnvBase
6
+ from models import SQLAction, SQLObservation, SQLState
7
+ from server.challenges import CHALLENGES
8
+
9
+
10
+ def _run_query(schema_sql: str, query: str) -> Tuple[bool, str]:
11
+ """
12
+ Execute query against an in-memory SQLite DB seeded with schema_sql.
13
+ Returns (success: bool, result_string: str).
14
+ """
15
+ try:
16
+ conn = sqlite3.connect(":memory:")
17
+ conn.executescript(schema_sql)
18
+ cursor = conn.execute(query)
19
+ rows = cursor.fetchall()
20
+ col_names = [desc[0] for desc in cursor.description] if cursor.description else []
21
+ conn.close()
22
+
23
+ if not rows:
24
+ return True, "(no rows returned)"
25
+
26
+ # Format as a simple text table
27
+ header = " | ".join(col_names)
28
+ sep = "-" * len(header)
29
+ row_lines = [" | ".join(str(v) for v in row) for row in rows]
30
+ return True, "\n".join([header, sep] + row_lines)
31
+
32
+ except Exception as e:
33
+ return False, f"ERROR: {e}"
34
+
35
+
36
+ def _results_match(schema_sql: str, query_a: str, query_b: str) -> bool:
37
+ """Check whether two queries return identical result sets."""
38
+ try:
39
+ conn = sqlite3.connect(":memory:")
40
+ conn.executescript(schema_sql)
41
+
42
+ rows_a = set(conn.execute(query_a).fetchall())
43
+ rows_b = set(conn.execute(query_b).fetchall())
44
+ conn.close()
45
+ return rows_a == rows_b
46
+ except Exception:
47
+ return False
48
+
49
+
50
+ class SQLTutorEnvironment(EnvBase[SQLAction, SQLObservation, SQLState]):
51
+
52
+ def reset(self) -> Tuple[SQLObservation, SQLState]:
53
+ challenge = random.choice(CHALLENGES)
54
+
55
+ state = SQLState(
56
+ challenge_id=challenge["id"],
57
+ broken_query=challenge["broken_query"],
58
+ correct_query=challenge["correct_query"],
59
+ schema_sql=challenge["schema_sql"],
60
+ schema_description=challenge["schema_description"],
61
+ task_description=challenge["task_description"],
62
+ hints=challenge["hints"],
63
+ steps_taken=0,
64
+ max_steps=5,
65
+ hints_used=0,
66
+ is_resolved=False,
67
+ cumulative_reward=0.0,
68
+ )
69
+ self._state = state
70
+
71
+ # Show the agent the broken query output so it understands what's wrong
72
+ _, broken_result = _run_query(state.schema_sql, state.broken_query)
73
+
74
+ observation = SQLObservation(
75
+ broken_query=state.broken_query,
76
+ schema_description=state.schema_description,
77
+ task_description=state.task_description,
78
+ execution_result=f"Current (broken) query output:\n{broken_result}",
79
+ is_correct=False,
80
+ hint=None,
81
+ steps_taken=0,
82
+ max_steps=state.max_steps,
83
+ hints_used=0,
84
+ )
85
+ return observation, state
86
+
87
+ def step(self, action: SQLAction) -> Tuple[SQLObservation, float, bool, SQLState]:
88
+ state = self._state
89
+ state.steps_taken += 1
90
+ reward = 0.0
91
+ done = False
92
+ hint = None
93
+
94
+ if action.action_type == "request_hint":
95
+ hint_index = min(state.hints_used, len(state.hints) - 1)
96
+ hint = state.hints[hint_index]
97
+ state.hints_used += 1
98
+ reward = -0.1 # small penalty for using a hint
99
+ execution_result = f"Current (broken) query output shown for reference."
100
+ _, execution_result = _run_query(state.schema_sql, state.broken_query)
101
+ execution_result = f"(Hint requested - no query executed)\nBroken query output:\n{execution_result}"
102
+ is_correct = False
103
+
104
+ elif action.action_type == "submit_fix":
105
+ if not action.sql_query:
106
+ execution_result = "ERROR: You chose 'submit_fix' but provided no sql_query."
107
+ is_correct = False
108
+ reward = -0.05
109
+ else:
110
+ success, execution_result = _run_query(state.schema_sql, action.sql_query)
111
+
112
+ if not success:
113
+ is_correct = False
114
+ reward = -0.1
115
+ else:
116
+ is_correct = _results_match(
117
+ state.schema_sql, action.sql_query, state.correct_query
118
+ )
119
+ if is_correct:
120
+ # Reward decreases with hints used and steps taken
121
+ base_reward = 1.0
122
+ hint_penalty = 0.15 * state.hints_used
123
+ step_penalty = 0.05 * max(0, state.steps_taken - 1)
124
+ reward = max(0.1, base_reward - hint_penalty - step_penalty)
125
+ state.is_resolved = True
126
+ done = True
127
+ else:
128
+ reward = -0.05
129
+ else:
130
+ execution_result = f"ERROR: Unknown action_type '{action.action_type}'. Use 'submit_fix' or 'request_hint'."
131
+ is_correct = False
132
+ reward = -0.05
133
+
134
+ # End episode if max steps reached
135
+ if state.steps_taken >= state.max_steps and not done:
136
+ done = True
137
+
138
+ state.cumulative_reward += reward
139
+
140
+ observation = SQLObservation(
141
+ broken_query=state.broken_query,
142
+ schema_description=state.schema_description,
143
+ task_description=state.task_description,
144
+ execution_result=execution_result,
145
+ is_correct=is_correct,
146
+ hint=hint,
147
+ steps_taken=state.steps_taken,
148
+ max_steps=state.max_steps,
149
+ hints_used=state.hints_used,
150
+ )
151
+
152
+ return observation, reward, done, state
153
+
154
+ def get_state(self) -> SQLState:
155
+ return self._state
uv.lock ADDED
The diff for this file is too large to render. See raw diff