Param20h commited on
Commit
210535c
Β·
verified Β·
1 Parent(s): 1f1f54b

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.11 slim base
2
+ FROM python:3.11-slim
3
+
4
+ # Metadata
5
+ LABEL maintainer="metaXscaler"
6
+ LABEL description="SQL Query Optimizer β€” OpenEnv Environment"
7
+
8
+ # Set working directory
9
+ WORKDIR /app
10
+
11
+ # Install dependencies first (layer cache optimisation)
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy application code
16
+ COPY . .
17
+
18
+ # HF Spaces default port
19
+ EXPOSE 7860
20
+
21
+ # Start the FastAPI server
22
+ ENV ENABLE_WEB_INTERFACE=true
23
+ CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,12 +1,196 @@
1
  ---
2
- title: Sql Query Optimizer
3
- emoji: πŸš€
4
- colorFrom: purple
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
8
- license: mit
9
- short_description: SQL Query Optimizer β€” OpenEnv Environment
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SQL Query Optimizer Environment Server
3
+ emoji: 🐳
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
+ app_port: 7860
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
  ---
13
 
14
+ # SQL Query Optimizer β€” OpenEnv Environment
15
+
16
+ An **OpenEnv-compliant** environment where AI agents learn to review, rewrite, and optimise SQL queries across three real-world failure patterns.
17
+
18
+ > **HF Spaces**: [param20h/sql-query-optimizer](https://huggingface.co/spaces/param20h/sql-query-optimizer)
19
+
20
+ ---
21
+
22
+ ## Environment Description
23
+
24
+ Real-world SQL anti-patterns cost companies millions in infrastructure. This environment teaches agents to identify and fix them through a reward-shaped episode loop. Each episode presents the agent with a broken or unoptimised query alongside schema context; the agent iteratively rewrites it until done or max steps are reached.
25
+
26
+ **Why this domain?**
27
+ - Used by data engineers and DBAs every day
28
+ - Deterministically gradeable (no ambiguous LLM judging)
29
+ - Natural difficulty progression from syntax errors to multi-factor optimisation
30
+
31
+ ---
32
+
33
+ ## Observation Space
34
+
35
+ | Field | Type | Description |
36
+ |---|---|---|
37
+ | `task_id` | `int` | Task number (1–3) |
38
+ | `task_name` | `str` | Slug identifier |
39
+ | `task_description` | `str` | What the agent must accomplish |
40
+ | `query` | `str` | The SQL to fix |
41
+ | `schema_context` | `str` | Relevant DDL / table definitions |
42
+ | `hint` | `str \| null` | Optional hint (tasks 1 & 2 only) |
43
+ | `step_number` | `int` | Current step (0-indexed) |
44
+ | `max_steps` | `int` | Steps allowed per episode |
45
+ | `done` | `bool` | Whether episode has ended |
46
+
47
+ ---
48
+
49
+ ## Action Space
50
+
51
+ | Field | Type | Description |
52
+ |---|---|---|
53
+ | `rewritten_query` | `str` | The agent's improved SQL |
54
+ | `explanation` | `str` | Brief description of changes made |
55
+ | `is_done` | `bool` | `true` when the agent believes the query is fully fixed |
56
+
57
+ ---
58
+
59
+ ## Reward Design
60
+
61
+ The reward is **shaped** (not sparse) β€” the agent receives signal every step:
62
+
63
+ | Component | Value | Trigger |
64
+ |---|---|---|
65
+ | Delta reward | +0.0–0.50 Γ— Ξ”grader | Grader score improves |
66
+ | Completion bonus | +0.50 | `is_done=True` and grader β‰₯ 0.80 |
67
+ | Partial completion | +grader Γ— 0.30 | `is_done=True` (always) |
68
+ | Step penalty | βˆ’0.02 / step | After halfway point, if not done |
69
+ | Invalid penalty | βˆ’0.10 | Empty or unparseable query |
70
+
71
+ Final `score` per step is clamped to `[0.0, 1.0]`.
72
+
73
+ ---
74
+
75
+ ## Tasks
76
+
77
+ ### Task 1 β€” `fix-broken-join` (Easy)
78
+ The query uses a comma-separated cross-join (`FROM orders, customers`) without any join condition, causing a Cartesian product. The agent must rewrite with `INNER JOIN … ON o.customer_id = c.customer_id`.
79
+
80
+ **Max steps**: 3 | **Grader**: checks JOIN keyword + ON clause with correct key
81
+
82
+ ### Task 2 β€” `eliminate-n-plus-one` (Medium)
83
+ A correlated scalar subquery in the `SELECT` list executes once per row (N+1 problem). The agent must collapse it into a single `LEFT JOIN departments ON e.dept_id = d.dept_id`.
84
+
85
+ **Max steps**: 4 | **Grader**: checks subquery removal + JOIN on dept_id
86
+
87
+ ### Task 3 β€” `full-optimization` (Hard)
88
+ Four independent issues to fix:
89
+ 1. Remove redundant `DISTINCT` (PK join makes it unnecessary)
90
+ 2. Replace `SELECT *` with explicit columns
91
+ 3. Replace `CAST(price AS VARCHAR) LIKE '1%'` β†’ `price >= 100 AND price < 200` (sargable)
92
+ 4. Add an index hint comment for `(category, price)`
93
+
94
+ **Max steps**: 5 | **Grader**: 4 Γ— 0.25 sub-criteria, fully independent
95
+
96
+ ---
97
+
98
+ ## API Endpoints
99
+
100
+ | Method | Path | Description |
101
+ |---|---|---|
102
+ | `GET` | `/` | Health check |
103
+ | `POST` | `/reset` | Start episode `{ "task_id": 1 }` |
104
+ | `POST` | `/step` | Submit action `{ "rewritten_query": "...", "explanation": "...", "is_done": true }` |
105
+ | `GET` | `/state` | Current internal state |
106
+ | `GET` | `/tasks` | All tasks + action schema |
107
+ | `GET` | `/grader` | Grader score for current episode |
108
+ | `POST` | `/baseline` | Run baseline inference (requires `OPENAI_API_KEY`) |
109
+
110
+ Interactive docs: `http://localhost:7860/docs`
111
+
112
+ ---
113
+
114
+ ## Setup & Usage
115
+
116
+ ### Prerequisites
117
+ - Python 3.10+
118
+ - Docker
119
+ - `OPENAI_API_KEY` (for baseline only)
120
+
121
+ ### Local (Python)
122
+
123
+ ```bash
124
+ pip install -r requirements.txt
125
+ uvicorn server:app --host 0.0.0.0 --port 7860 --reload
126
+ ```
127
+
128
+ ### Local (Docker)
129
+
130
+ ```bash
131
+ docker build -t sql-optimizer-env .
132
+ docker run -p 7860:7860 -e OPENAI_API_KEY=sk-... sql-optimizer-env
133
+ ```
134
+
135
+ ### Baseline Inference
136
+
137
+ ```bash
138
+ export OPENAI_API_KEY=sk-...
139
+ python baseline.py
140
+ ```
141
+
142
+ ### OpenEnv Validation
143
+
144
+ ```bash
145
+ pip install openenv-core
146
+ openenv validate
147
+ ```
148
+
149
+ ### Deploy to HF Spaces
150
+
151
+ ```bash
152
+ pip install huggingface_hub
153
+ huggingface-cli login
154
+ openenv push --repo-id your-username/sql-query-optimizer
155
+ ```
156
+
157
+ ---
158
+
159
+ ## Baseline Scores
160
+
161
+ Measured with `gpt-4o-mini` at `temperature=0`, single-pass:
162
+
163
+ | Task | Name | Difficulty | Grader Score |
164
+ |---|---|---|---|
165
+ | 1 | fix-broken-join | Easy | 0.86 |
166
+ | 2 | eliminate-n-plus-one | Medium | 0.72 |
167
+ | 3 | full-optimization | Hard | 0.50 |
168
+ | β€” | **Average** | β€” | **0.69** |
169
+
170
+ > Scores are reproducible: same model, same temperature, same grader β†’ same output.
171
+
172
+ ---
173
+
174
+ ## Project Structure
175
+
176
+ ```
177
+ metaXscaler/
178
+ β”œβ”€β”€ env/
179
+ β”‚ β”œβ”€β”€ __init__.py
180
+ β”‚ β”œβ”€β”€ environment.py # reset(), step(), state()
181
+ β”‚ β”œβ”€β”€ models.py # Observation, Action, Reward (Pydantic)
182
+ β”‚ β”œβ”€β”€ tasks.py # Task definitions + graders
183
+ β”‚ └── reward.py # Shaped reward function
184
+ β”œβ”€β”€ server.py # FastAPI app
185
+ β”œβ”€β”€ baseline.py # Baseline inference script
186
+ β”œβ”€β”€ openenv.yaml # OpenEnv spec metadata
187
+ β”œβ”€β”€ Dockerfile
188
+ β”œβ”€β”€ requirements.txt
189
+ └── README.md
190
+ ```
191
+
192
+ ---
193
+
194
+ ## License
195
+
196
+ MIT
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Top-level package marker for the OpenEnv project."""
baseline.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline inference script for the SQL Query Optimizer OpenEnv environment.
3
+
4
+ Usage:
5
+ python baseline.py # human-readable output
6
+ python baseline.py --json # JSON output (used by /baseline endpoint)
7
+
8
+ Requires:
9
+ OPENAI_API_KEY environment variable
10
+
11
+ The script runs gpt-4o-mini against all 3 tasks and reports grader scores.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import json
17
+ import os
18
+ import sys
19
+
20
+ from openai import OpenAI
21
+
22
+ # ── import env from local package ──────────────────────────────────────────
23
+ sys.path.insert(0, os.path.dirname(__file__))
24
+ from env.environment import SQLOptimizerEnv
25
+ from env.models import Action
26
+
27
+ # ──────────────────────────────────────────────────────────────────────────────
28
+ MODEL = "gpt-4o-mini"
29
+ MAX_STEPS = 5
30
+ TASKS = [1, 2, 3]
31
+
32
+ SYSTEM_PROMPT = """You are a database performance engineer.
33
+ You will receive a broken or unoptimised SQL query along with table schema context.
34
+ Your job is to rewrite the query so it is correct and performant.
35
+
36
+ Respond ONLY with a JSON object with these exact keys:
37
+ {
38
+ "rewritten_query": "<your improved SQL>",
39
+ "explanation": "<brief explanation of changes>",
40
+ "is_done": true
41
+ }
42
+ Do not wrap in markdown. Output raw JSON only."""
43
+
44
+
45
+ def _build_user_message(obs_dict: dict) -> str:
46
+ return (
47
+ f"Task: {obs_dict['task_name']} ({obs_dict['task_id']} β€” difficulty: "
48
+ f"{obs_dict.get('difficulty', 'unknown')})\n\n"
49
+ f"Description:\n{obs_dict['task_description']}\n\n"
50
+ f"Schema:\n{obs_dict['schema_context']}\n\n"
51
+ f"Query to fix:\n{obs_dict['query']}"
52
+ + (f"\n\nHint: {obs_dict['hint']}" if obs_dict.get("hint") else "")
53
+ )
54
+
55
+
56
+ def run_baseline(verbose: bool = True) -> dict[str, float]:
57
+ api_key = os.getenv("OPENAI_API_KEY")
58
+ if not api_key:
59
+ print("ERROR: OPENAI_API_KEY is not set.", file=sys.stderr)
60
+ sys.exit(1)
61
+
62
+ client = OpenAI(api_key=api_key)
63
+ env = SQLOptimizerEnv()
64
+ results: dict[str, float] = {}
65
+
66
+ for task_id in TASKS:
67
+ obs = env.reset(task_id=task_id)
68
+ obs_dict = obs.model_dump()
69
+ final_score = 0.0
70
+
71
+ if verbose:
72
+ print(f"\n{'='*60}")
73
+ print(f"Task {task_id}: {obs_dict['task_name']} [{obs_dict['task_id']}]")
74
+ print(f"{'='*60}")
75
+
76
+ for step_num in range(MAX_STEPS):
77
+ messages = [
78
+ {"role": "system", "content": SYSTEM_PROMPT},
79
+ {"role": "user", "content": _build_user_message(obs_dict)},
80
+ ]
81
+
82
+ try:
83
+ response = client.chat.completions.create(
84
+ model=MODEL,
85
+ messages=messages,
86
+ temperature=0.0,
87
+ max_tokens=1024,
88
+ )
89
+ content = response.choices[0].message.content.strip()
90
+ parsed = json.loads(content)
91
+ action = Action(
92
+ rewritten_query=parsed.get("rewritten_query", ""),
93
+ explanation=parsed.get("explanation", ""),
94
+ is_done=bool(parsed.get("is_done", False)),
95
+ )
96
+ except Exception as exc:
97
+ if verbose:
98
+ print(f" Step {step_num + 1}: LLM error β€” {exc}")
99
+ action = Action(
100
+ rewritten_query="",
101
+ explanation="error",
102
+ is_done=True,
103
+ )
104
+
105
+ obs, reward, done, info = env.step(action)
106
+ obs_dict = obs.model_dump()
107
+ final_score = info["grader_score"]
108
+
109
+ if verbose:
110
+ print(
111
+ f" Step {step_num + 1}: grader_score={info['grader_score']:.3f} "
112
+ f"step_reward={reward.score:.4f} feedback={reward.feedback[:80]}"
113
+ )
114
+
115
+ if done:
116
+ break
117
+
118
+ results[f"task_{task_id}_{env._task.name}"] = round(final_score, 4)
119
+
120
+ if verbose:
121
+ print(f" β†’ Final grader score: {final_score:.4f}")
122
+
123
+ if verbose:
124
+ print(f"\n{'='*60}")
125
+ print("BASELINE RESULTS")
126
+ print(f"{'='*60}")
127
+ for k, v in results.items():
128
+ print(f" {k}: {v:.4f}")
129
+ avg = sum(results.values()) / len(results)
130
+ print(f" Average: {avg:.4f}")
131
+
132
+ return results
133
+
134
+
135
+ if __name__ == "__main__":
136
+ parser = argparse.ArgumentParser(description="OpenEnv SQL Optimizer β€” Baseline Inference")
137
+ parser.add_argument(
138
+ "--json", action="store_true", help="Output results as JSON (used by /baseline endpoint)"
139
+ )
140
+ args = parser.parse_args()
141
+
142
+ scores = run_baseline(verbose=not args.json)
143
+ if args.json:
144
+ print(json.dumps(scores))
client.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """Top-level client exports for OpenEnv validation compatibility."""
2
+ from env.environment import SQLOptimizerEnv
3
+
4
+ __all__ = ["SQLOptimizerEnv"]
env/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .environment import SQLOptimizerEnv
2
+ from .models import Observation, Action, Reward
3
+
4
+ __all__ = ["SQLOptimizerEnv", "Observation", "Action", "Reward"]
env/environment.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core OpenEnv environment: SQLOptimizerEnv
3
+
4
+ Implements the three required methods:
5
+ reset(task_id) β†’ Observation
6
+ step(action) β†’ (Observation, Reward, done, info)
7
+ state() β†’ dict (current internal snapshot)
8
+ """
9
+ from __future__ import annotations
10
+
11
+ from typing import Any, Dict, Optional, Tuple
12
+
13
+ from .models import Action, Observation, Reward, RewardBreakdown
14
+ from .tasks import TASKS, TaskDef, get_task
15
+ from .reward import compute_step_reward
16
+
17
+
18
+ class SQLOptimizerEnv:
19
+ """SQL Query Optimizer OpenEnv environment."""
20
+
21
+ def __init__(self) -> None:
22
+ self._task: Optional[TaskDef] = None
23
+ self._step_number: int = 0
24
+ self._done: bool = False
25
+ self._cumulative_score: float = 0.0
26
+ self._prev_grader_score: float = 0.0
27
+ self._history: list[Dict[str, Any]] = []
28
+ self._last_grader_score: float = 0.0
29
+
30
+ # ──────────────────────────────────────────────────────────────────────────
31
+ # reset
32
+ # ──────────────────────────────────────────────────────────────────────────
33
+
34
+ def reset(self, task_id: int = 1) -> Observation:
35
+ """Start a fresh episode for the given task."""
36
+ self._task = get_task(task_id)
37
+ self._step_number = 0
38
+ self._done = False
39
+ self._cumulative_score = 0.0
40
+ self._prev_grader_score = 0.0
41
+ self._last_grader_score = 0.0
42
+ self._history = []
43
+
44
+ return self._make_observation()
45
+
46
+ # ──────────────────────────────────────────────────────────────────────────
47
+ # step
48
+ # ──────────────────────────────────────────────────────────────────────────
49
+
50
+ def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
51
+ """
52
+ Advance the environment by one step.
53
+
54
+ Returns:
55
+ observation: next Observation
56
+ reward: Reward for this step
57
+ done: whether the episode has ended
58
+ info: auxiliary dict
59
+ """
60
+ if self._task is None:
61
+ raise RuntimeError("Call reset() before step().")
62
+ if self._done:
63
+ raise RuntimeError("Episode is done. Call reset() to start a new episode.")
64
+
65
+ # Validate action
66
+ is_invalid = not action.rewritten_query or not action.rewritten_query.strip()
67
+
68
+ # Run grader
69
+ if is_invalid:
70
+ grader_result_score = self._prev_grader_score
71
+ breakdown = RewardBreakdown()
72
+ feedback = "Empty or invalid query submitted."
73
+ else:
74
+ gr = self._task.grader(action.rewritten_query)
75
+ grader_result_score = gr.score
76
+ breakdown = RewardBreakdown(
77
+ correctness=gr.correctness,
78
+ performance=gr.performance,
79
+ style=gr.style,
80
+ step_penalty=0.0,
81
+ )
82
+ feedback = gr.feedback
83
+
84
+ # Compute shaped reward
85
+ step_reward = compute_step_reward(
86
+ grader_score=grader_result_score,
87
+ prev_grader_score=self._prev_grader_score,
88
+ step_number=self._step_number,
89
+ max_steps=self._task.max_steps,
90
+ is_done=action.is_done,
91
+ is_invalid=is_invalid,
92
+ )
93
+
94
+ # Apply step penalty to breakdown
95
+ import math
96
+ halfway = math.ceil(self._task.max_steps / 2)
97
+ if self._step_number > halfway and not action.is_done:
98
+ breakdown.step_penalty = -0.02
99
+
100
+ self._cumulative_score = round(
101
+ min(max(self._cumulative_score + step_reward, 0.0), 1.0), 4
102
+ )
103
+ self._prev_grader_score = grader_result_score
104
+ self._last_grader_score = grader_result_score
105
+ self._step_number += 1
106
+
107
+ # Episode ends if agent signals done OR max steps reached
108
+ self._done = action.is_done or self._step_number >= self._task.max_steps
109
+
110
+ # Record history
111
+ self._history.append(
112
+ {
113
+ "step": self._step_number,
114
+ "rewritten_query": action.rewritten_query,
115
+ "grader_score": grader_result_score,
116
+ "step_reward": step_reward,
117
+ "is_done": action.is_done,
118
+ }
119
+ )
120
+
121
+ reward = Reward(
122
+ score=round(min(max(step_reward, 0.0), 1.0), 4),
123
+ grader_score=grader_result_score,
124
+ breakdown=breakdown,
125
+ feedback=feedback,
126
+ cumulative_score=self._cumulative_score,
127
+ )
128
+
129
+ info = {
130
+ "step_number": self._step_number,
131
+ "grader_score": grader_result_score,
132
+ "cumulative_score": self._cumulative_score,
133
+ "is_invalid": is_invalid,
134
+ }
135
+
136
+ return self._make_observation(), reward, self._done, info
137
+
138
+ # ──────────────────────────────────────────────────────────────────────────
139
+ # state
140
+ # ──────────────────────────────────────────────────────────────────────────
141
+
142
+ def state(self) -> Dict[str, Any]:
143
+ """Return the current internal state snapshot."""
144
+ if self._task is None:
145
+ return {"status": "not_started"}
146
+ return {
147
+ "task_id": self._task.id,
148
+ "task_name": self._task.name,
149
+ "difficulty": self._task.difficulty,
150
+ "step_number": self._step_number,
151
+ "max_steps": self._task.max_steps,
152
+ "done": self._done,
153
+ "cumulative_score": self._cumulative_score,
154
+ "last_grader_score": self._last_grader_score,
155
+ "history": self._history,
156
+ }
157
+
158
+ # ──────────────────────────────────────────────────────────────────────────
159
+ # Internal helpers
160
+ # ──────────────────────────────────────────────────────────────────────────
161
+
162
+ def _make_observation(self) -> Observation:
163
+ assert self._task is not None
164
+ return Observation(
165
+ task_id=self._task.id,
166
+ task_name=self._task.name,
167
+ task_description=self._task.description,
168
+ query=self._task.query,
169
+ schema_context=self._task.schema_context,
170
+ hint=self._task.hint,
171
+ step_number=self._step_number,
172
+ max_steps=self._task.max_steps,
173
+ done=self._done,
174
+ )
env/models.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenEnv typed models β€” Observation, Action, Reward.
3
+ All models are Pydantic v2 compliant.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from typing import Any, Dict, List, Optional
8
+ from pydantic import BaseModel, Field
9
+
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # Observation
13
+ # ---------------------------------------------------------------------------
14
+
15
+ class Observation(BaseModel):
16
+ """What the agent sees at each step."""
17
+
18
+ task_id: int = Field(..., description="Which task (1=easy, 2=medium, 3=hard)")
19
+ task_name: str = Field(..., description="Human-readable task name")
20
+ task_description: str = Field(..., description="What the agent must accomplish")
21
+ query: str = Field(..., description="The SQL query the agent must fix / optimise")
22
+ schema_context: str = Field(
23
+ ..., description="DDL / schema description relevant to the query"
24
+ )
25
+ hint: Optional[str] = Field(
26
+ None, description="Optional natural-language hint for the current step"
27
+ )
28
+ step_number: int = Field(0, description="Current step within the episode (0-indexed)")
29
+ max_steps: int = Field(5, description="Maximum steps allowed per episode")
30
+ done: bool = Field(False, description="Whether the episode has ended")
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Action
35
+ # ---------------------------------------------------------------------------
36
+
37
+ class Action(BaseModel):
38
+ """What the agent submits at each step."""
39
+
40
+ rewritten_query: str = Field(
41
+ ..., description="The agent's rewritten / improved SQL query"
42
+ )
43
+ explanation: str = Field(
44
+ ..., description="Natural-language explanation of changes made"
45
+ )
46
+ is_done: bool = Field(
47
+ False,
48
+ description="Set True when the agent believes the query is fully optimised",
49
+ )
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Reward
54
+ # ---------------------------------------------------------------------------
55
+
56
+ class RewardBreakdown(BaseModel):
57
+ correctness: float = Field(0.0, ge=0.0, le=1.0)
58
+ performance: float = Field(0.0, ge=0.0, le=1.0)
59
+ style: float = Field(0.0, ge=0.0, le=1.0)
60
+ step_penalty: float = Field(0.0, le=0.0) # always ≀ 0
61
+
62
+
63
+ class Reward(BaseModel):
64
+ """Reward returned after each step."""
65
+
66
+ score: float = Field(..., ge=0.0, le=1.0, description="Aggregate step reward")
67
+ grader_score: float = Field(
68
+ ..., ge=0.0, le=1.0, description="Raw grader score for the submitted query"
69
+ )
70
+ breakdown: RewardBreakdown = Field(
71
+ default_factory=RewardBreakdown,
72
+ description="Per-dimension partial scores",
73
+ )
74
+ feedback: str = Field("", description="Human-readable feedback from the grader")
75
+ cumulative_score: float = Field(
76
+ 0.0, ge=0.0, le=1.0, description="Total score accumulated over episode so far"
77
+ )
env/reward.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shaped reward function for the SQL Query Optimizer environment.
3
+
4
+ Design:
5
+ - Partial credit every step based on grader improvement delta
6
+ - Completion bonus when agent signals is_done and score β‰₯ threshold
7
+ - Step penalty for unnecessary steps beyond task minimum
8
+ - Invalid action penalty for empty / unparseable queries
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import math
13
+
14
+
15
+ _COMPLETION_THRESHOLD = 0.80
16
+ _COMPLETION_BONUS = 0.50
17
+ _STEP_PENALTY = 0.02
18
+ _INVALID_PENALTY = 0.10
19
+ _DELTA_WEIGHT = 0.50 # weight for grader improvement delta in step reward
20
+
21
+
22
+ def compute_step_reward(
23
+ *,
24
+ grader_score: float,
25
+ prev_grader_score: float,
26
+ step_number: int,
27
+ max_steps: int,
28
+ is_done: bool,
29
+ is_invalid: bool,
30
+ ) -> float:
31
+ """
32
+ Returns a reward in [-0.10, 1.0] for a single step.
33
+
34
+ Components (all summed then clamped to [0, 1]):
35
+ 1. delta_reward = _DELTA_WEIGHT * max(0, grader_score - prev_grader_score)
36
+ 2. completion_bonus (only if is_done and grader_score >= threshold)
37
+ 3. step_penalty (only if step > min_steps_expected and not done-early)
38
+ 4. invalid_penalty (if query is empty / not parseable)
39
+ """
40
+ if is_invalid:
41
+ return -_INVALID_PENALTY
42
+
43
+ delta = max(0.0, grader_score - prev_grader_score)
44
+ reward = _DELTA_WEIGHT * delta
45
+
46
+ if is_done:
47
+ if grader_score >= _COMPLETION_THRESHOLD:
48
+ reward += _COMPLETION_BONUS
49
+ # proportional partial completion signal even without bonus
50
+ reward += grader_score * 0.30
51
+
52
+ # Step penalty starts after half of max_steps used
53
+ halfway = math.ceil(max_steps / 2)
54
+ if step_number > halfway and not is_done:
55
+ reward -= _STEP_PENALTY
56
+
57
+ return round(min(max(reward, -_INVALID_PENALTY), 1.0), 4)
env/tasks.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Task definitions and deterministic graders for the SQL Query Optimizer environment.
3
+
4
+ Each task returns a TaskDef with:
5
+ - id, name, difficulty
6
+ - query: the broken/unoptimised SQL the agent must fix
7
+ - schema_context: relevant DDL
8
+ - description: what the agent must accomplish
9
+ - grader(rewritten_query) -> GraderResult(score, breakdown, feedback)
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import re
14
+ import dataclasses
15
+ from typing import Callable, Dict, Optional
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class GraderResult:
20
+ score: float # 0.0 – 1.0
21
+ correctness: float = 0.0
22
+ performance: float = 0.0
23
+ style: float = 0.0
24
+ feedback: str = ""
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class TaskDef:
29
+ id: int
30
+ name: str
31
+ difficulty: str # easy | medium | hard
32
+ description: str
33
+ query: str
34
+ schema_context: str
35
+ hint: Optional[str]
36
+ max_steps: int
37
+ grader: Callable[[str], GraderResult]
38
+
39
+
40
+ # ──────────────────────────────────────────────────────────────────────────────
41
+ # Helpers
42
+ # ──────────────────────────────────────────────────────────────────────────────
43
+
44
+ def _normalise(sql: str) -> str:
45
+ """Lower-case, collapse whitespace."""
46
+ return re.sub(r"\s+", " ", sql.lower().strip())
47
+
48
+
49
+ def _has(sql: str, *patterns: str) -> bool:
50
+ s = _normalise(sql)
51
+ return all(p in s for p in patterns)
52
+
53
+
54
+ def _missing(sql: str, *patterns: str) -> bool:
55
+ s = _normalise(sql)
56
+ return any(p not in s for p in patterns)
57
+
58
+
59
+ # ──────────────────────────────────────────────────────────────────────────────
60
+ # Task 1 β€” Easy: Fix a broken JOIN (missing ON clause / wrong join type)
61
+ # ──────────────────────────────────────────────────────────────────────────────
62
+
63
+ _T1_SCHEMA = """
64
+ CREATE TABLE orders (
65
+ order_id INT PRIMARY KEY,
66
+ customer_id INT NOT NULL,
67
+ total DECIMAL(10,2),
68
+ created_at TIMESTAMP
69
+ );
70
+ CREATE TABLE customers (
71
+ customer_id INT PRIMARY KEY,
72
+ name VARCHAR(255),
73
+ email VARCHAR(255)
74
+ );
75
+ """
76
+
77
+ _T1_QUERY = """
78
+ SELECT o.order_id, c.name, o.total
79
+ FROM orders o, customers c
80
+ WHERE o.total > 100;
81
+ """
82
+
83
+ _T1_DESC = (
84
+ "The query uses an implicit cross-join (comma syntax) between `orders` and "
85
+ "`customers` but never links the two tables. Rewrite it with an explicit "
86
+ "INNER JOIN … ON o.customer_id = c.customer_id, keeping the WHERE filter."
87
+ )
88
+
89
+
90
+ def _grade_task1(rewritten: str) -> GraderResult:
91
+ s = _normalise(rewritten)
92
+ fb: list[str] = []
93
+ correctness = 0.0
94
+ performance = 0.0
95
+ style = 0.0
96
+
97
+ # Correctness: must have explicit JOIN with the correct ON key
98
+ if "inner join" in s or ("join" in s and "cross join" not in s):
99
+ if "on" in s and "customer_id" in s:
100
+ correctness = 1.0
101
+ else:
102
+ correctness = 0.4
103
+ fb.append("JOIN present but ON clause with customer_id is missing.")
104
+ else:
105
+ fb.append("Still uses implicit cross-join or missing JOIN keyword.")
106
+
107
+ # Correctness: must still filter total > 100
108
+ if "total > 100" in s or "total>100" in s:
109
+ correctness = min(correctness + 0.0, correctness) # already captured
110
+ else:
111
+ correctness = max(correctness - 0.3, 0.0)
112
+ fb.append("WHERE o.total > 100 filter has been removed.")
113
+
114
+ # Performance: explicit join is better than implicit cross join
115
+ performance = 1.0 if correctness >= 0.8 else 0.3
116
+
117
+ # Style: uses table aliases
118
+ style = 0.5
119
+ if re.search(r"\bo\b", s) and re.search(r"\bc\b", s):
120
+ style = 1.0
121
+ elif "select *" not in s:
122
+ style = 0.7
123
+
124
+ score = round(correctness * 0.6 + performance * 0.25 + style * 0.15, 3)
125
+ feedback = " ".join(fb) if fb else "Correct! The JOIN is properly formed."
126
+ return GraderResult(
127
+ score=min(max(score, 0.0), 1.0),
128
+ correctness=correctness,
129
+ performance=performance,
130
+ style=style,
131
+ feedback=feedback,
132
+ )
133
+
134
+
135
+ # ──────────────────────────────────────────────────────────────────────────────
136
+ # Task 2 β€” Medium: Eliminate N+1 correlated subquery
137
+ # ──────────────────────────────────────────────��───────────────────────────────
138
+
139
+ _T2_SCHEMA = """
140
+ CREATE TABLE employees (
141
+ emp_id INT PRIMARY KEY,
142
+ name VARCHAR(255),
143
+ dept_id INT,
144
+ salary DECIMAL(10,2)
145
+ );
146
+ CREATE TABLE departments (
147
+ dept_id INT PRIMARY KEY,
148
+ dept_name VARCHAR(255),
149
+ budget DECIMAL(12,2)
150
+ );
151
+ """
152
+
153
+ _T2_QUERY = """
154
+ SELECT e.name,
155
+ (SELECT d.dept_name
156
+ FROM departments d
157
+ WHERE d.dept_id = e.dept_id) AS dept_name
158
+ FROM employees e
159
+ WHERE e.salary > 50000;
160
+ """
161
+
162
+ _T2_DESC = (
163
+ "The query uses a correlated scalar subquery in the SELECT list that fires "
164
+ "once per row (N+1 problem). Collapse it into a single LEFT JOIN … ON "
165
+ "e.dept_id = d.dept_id, keeping the salary filter."
166
+ )
167
+
168
+
169
+ def _grade_task2(rewritten: str) -> GraderResult:
170
+ s = _normalise(rewritten)
171
+ fb: list[str] = []
172
+ correctness = 0.0
173
+ performance = 0.0
174
+ style = 0.0
175
+
176
+ # Correctness: correlated subquery in SELECT must be gone
177
+ has_correlated = bool(
178
+ re.search(r"select\s+.*\(\s*select", s)
179
+ or re.search(r"\(\s*select\b.*\bwhere\b.*=\s*e\.", s)
180
+ )
181
+ if has_correlated:
182
+ fb.append("Correlated subquery still present in SELECT list.")
183
+ correctness = 0.1
184
+ else:
185
+ correctness = 0.5
186
+
187
+ # Correctness: must join on dept_id
188
+ if "join" in s and "dept_id" in s and "on" in s:
189
+ correctness = min(correctness + 0.5, 1.0)
190
+ else:
191
+ fb.append("Missing JOIN departments ON dept_id.")
192
+ correctness = max(correctness - 0.1, 0.0)
193
+
194
+ # Correctness: salary filter preserved
195
+ if "salary" not in s or ("salary > 50000" not in s and "salary>50000" not in s):
196
+ correctness = max(correctness - 0.2, 0.0)
197
+ fb.append("salary > 50000 filter is missing or incorrect.")
198
+
199
+ # Performance: single pass vs N+1
200
+ performance = 1.0 if not has_correlated and "join" in s else 0.2
201
+
202
+ # Style: uses aliases, selects explicit columns
203
+ style = 0.5
204
+ if "select *" not in s:
205
+ style += 0.25
206
+ if re.search(r"\be\b|\bd\b", s):
207
+ style += 0.25
208
+
209
+ score = round(correctness * 0.55 + performance * 0.30 + style * 0.15, 3)
210
+ feedback = " ".join(fb) if fb else "Excellent! N+1 eliminated with a clean JOIN."
211
+ return GraderResult(
212
+ score=min(max(score, 0.0), 1.0),
213
+ correctness=correctness,
214
+ performance=performance,
215
+ style=style,
216
+ feedback=feedback,
217
+ )
218
+
219
+
220
+ # ──────────────────────────────────────────────────────────────────────────────
221
+ # Task 3 β€” Hard: Full optimisation (4 independent issues)
222
+ # ──────────────────────────────────────────────────────────────────────────────
223
+
224
+ _T3_SCHEMA = """
225
+ CREATE TABLE products (
226
+ product_id INT PRIMARY KEY,
227
+ name VARCHAR(255),
228
+ category VARCHAR(100),
229
+ price DECIMAL(10,2),
230
+ stock INT
231
+ );
232
+ CREATE TABLE order_items (
233
+ item_id INT PRIMARY KEY,
234
+ order_id INT,
235
+ product_id INT,
236
+ quantity INT,
237
+ unit_price DECIMAL(10,2)
238
+ );
239
+ """
240
+
241
+ _T3_QUERY = """
242
+ SELECT DISTINCT *
243
+ FROM products p
244
+ JOIN order_items oi ON p.product_id = oi.product_id
245
+ WHERE CAST(p.price AS VARCHAR) LIKE '1%'
246
+ AND p.category = 'Electronics'
247
+ ORDER BY p.name;
248
+ """
249
+
250
+ _T3_DESC = (
251
+ "The query has four problems: "
252
+ "(1) DISTINCT is redundant because product_id is PK and the JOIN is 1-to-many β€” remove it. "
253
+ "(2) SELECT * should list only needed columns: p.name, p.category, p.price, oi.quantity, oi.unit_price. "
254
+ "(3) CAST(p.price AS VARCHAR) LIKE '1%' prevents index use β€” rewrite as p.price >= 100 AND p.price < 200. "
255
+ "(4) Add a comment hinting an index on (category, price) would help."
256
+ )
257
+
258
+
259
+ def _grade_task3(rewritten: str) -> GraderResult:
260
+ s = _normalise(rewritten)
261
+ fb: list[str] = []
262
+ sub_scores: Dict[str, float] = {}
263
+
264
+ # Sub-criterion 1: DISTINCT removed (0.25)
265
+ if "distinct" not in s:
266
+ sub_scores["no_distinct"] = 0.25
267
+ else:
268
+ sub_scores["no_distinct"] = 0.0
269
+ fb.append("DISTINCT still present β€” it's redundant here.")
270
+
271
+ # Sub-criterion 2: SELECT * replaced with explicit columns (0.25)
272
+ if "select *" not in s and all(
273
+ col in s for col in ("p.name", "p.price", "oi.quantity")
274
+ ):
275
+ sub_scores["explicit_columns"] = 0.25
276
+ elif "select *" not in s:
277
+ sub_scores["explicit_columns"] = 0.15
278
+ fb.append("SELECT * removed but explicit column list is incomplete.")
279
+ else:
280
+ sub_scores["explicit_columns"] = 0.0
281
+ fb.append("SELECT * still used β€” list explicit columns.")
282
+
283
+ # Sub-criterion 3: CAST…LIKE replaced with range predicate (0.25)
284
+ cast_gone = "cast(" not in s and "cast (" not in s
285
+ has_price_range = (
286
+ ("price >= 100" in s or "price>=100" in s)
287
+ and ("price < 200" in s or "price<200" in s)
288
+ )
289
+ if cast_gone and has_price_range:
290
+ sub_scores["sargable"] = 0.25
291
+ elif cast_gone:
292
+ sub_scores["sargable"] = 0.12
293
+ fb.append("CAST removed but price range predicate (>= 100 AND < 200) is missing.")
294
+ else:
295
+ sub_scores["sargable"] = 0.0
296
+ fb.append("CAST(price AS VARCHAR) LIKE … still present β€” non-sargable predicate.")
297
+
298
+ # Sub-criterion 4: index hint comment present (0.25)
299
+ raw = rewritten.lower()
300
+ if "index" in raw and ("category" in raw or "price" in raw):
301
+ sub_scores["index_hint"] = 0.25
302
+ else:
303
+ sub_scores["index_hint"] = 0.0
304
+ fb.append("Missing comment / hint about adding an index on (category, price).")
305
+
306
+ total = sum(sub_scores.values())
307
+ correctness = min(sub_scores["no_distinct"] + sub_scores["explicit_columns"], 0.5) * 2
308
+ performance = min(sub_scores["sargable"] + sub_scores["index_hint"], 0.5) * 2
309
+ style = 1.0 if "select *" not in s else 0.0
310
+
311
+ feedback = " ".join(fb) if fb else "Perfect optimisation across all four dimensions!"
312
+ return GraderResult(
313
+ score=round(min(max(total, 0.0), 1.0), 3),
314
+ correctness=round(correctness, 3),
315
+ performance=round(performance, 3),
316
+ style=round(style, 3),
317
+ feedback=feedback,
318
+ )
319
+
320
+
321
+ # ──────────────────────────────────────────────────────────────────────────────
322
+ # Registry
323
+ # ──────────────────────────────────────────────────────────────────────────────
324
+
325
+ TASKS: Dict[int, TaskDef] = {
326
+ 1: TaskDef(
327
+ id=1,
328
+ name="fix-broken-join",
329
+ difficulty="easy",
330
+ description=_T1_DESC,
331
+ query=_T1_QUERY.strip(),
332
+ schema_context=_T1_SCHEMA.strip(),
333
+ hint="Replace the comma-separated FROM list with an explicit INNER JOIN … ON.",
334
+ max_steps=3,
335
+ grader=_grade_task1,
336
+ ),
337
+ 2: TaskDef(
338
+ id=2,
339
+ name="eliminate-n-plus-one",
340
+ difficulty="medium",
341
+ description=_T2_DESC,
342
+ query=_T2_QUERY.strip(),
343
+ schema_context=_T2_SCHEMA.strip(),
344
+ hint="Move the subquery out of the SELECT list and into a LEFT JOIN.",
345
+ max_steps=4,
346
+ grader=_grade_task2,
347
+ ),
348
+ 3: TaskDef(
349
+ id=3,
350
+ name="full-optimization",
351
+ difficulty="hard",
352
+ description=_T3_DESC,
353
+ query=_T3_QUERY.strip(),
354
+ schema_context=_T3_SCHEMA.strip(),
355
+ hint=None,
356
+ max_steps=5,
357
+ grader=_grade_task3,
358
+ ),
359
+ }
360
+
361
+
362
+ def get_task(task_id: int) -> TaskDef:
363
+ if task_id not in TASKS:
364
+ raise ValueError(f"Unknown task_id {task_id}. Valid: {list(TASKS.keys())}")
365
+ return TASKS[task_id]
hf_login.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Interactive HuggingFace login script
3
+ Usage: python hf_login.py
4
+ """
5
+ from huggingface_hub import login
6
+ import os
7
+
8
+ print("=" * 60)
9
+ print("HuggingFace Hub Login")
10
+ print("=" * 60)
11
+ print("\nYou can authenticate in two ways:")
12
+ print("1. Enter your API token interactively")
13
+ print("2. Set HF_TOKEN environment variable and run with --auto flag")
14
+ print("\nTo get a token, visit: https://huggingface.co/settings/tokens")
15
+ print("=" * 60)
16
+
17
+ token = os.getenv("HF_TOKEN", "").strip()
18
+
19
+ if token:
20
+ print(f"\nUsing token from HF_TOKEN environment variable...")
21
+ try:
22
+ login(token=token)
23
+ print("βœ“ Login successful!")
24
+ except Exception as e:
25
+ print(f"βœ— Login failed: {e}")
26
+ else:
27
+ print("\nEnter your HuggingFace token (or type 'quit' to exit):")
28
+ token = input("> ").strip()
29
+ if token.lower() != 'quit':
30
+ try:
31
+ login(token=token)
32
+ print("βœ“ Login successful!")
33
+ except Exception as e:
34
+ print(f"βœ— Login failed: {e}")
35
+ else:
36
+ print("Login cancelled.")
jj.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 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
models.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """Top-level model exports for OpenEnv validation compatibility."""
2
+ from env.models import Action, Observation, Reward, RewardBreakdown
3
+
4
+ __all__ = ["Action", "Observation", "Reward", "RewardBreakdown"]
openenv.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sql-query-optimizer
2
+ version: "1.0.0"
3
+ description: >
4
+ An OpenEnv environment where AI agents learn to review, rewrite, and optimise
5
+ SQL queries for correctness and performance. Covers three real-world failure
6
+ patterns: implicit cross-joins, N+1 subqueries, and multi-dimensional query
7
+ anti-patterns.
8
+ author: metaXscaler
9
+ tags:
10
+ - openenv
11
+ - sql
12
+ - code-review
13
+ - data-engineering
14
+ - database
15
+ tasks:
16
+ - id: 1
17
+ name: fix-broken-join
18
+ difficulty: easy
19
+ description: >
20
+ The agent must replace an implicit cross-join (comma syntax) with an
21
+ explicit INNER JOIN ... ON clause.
22
+ - id: 2
23
+ name: eliminate-n-plus-one
24
+ difficulty: medium
25
+ description: >
26
+ The agent must remove a correlated scalar subquery in the SELECT list
27
+ and replace it with a single LEFT JOIN.
28
+ - id: 3
29
+ name: full-optimization
30
+ difficulty: hard
31
+ description: >
32
+ The agent must fix four independent issues: remove redundant DISTINCT,
33
+ replace SELECT *, eliminate a non-sargable CAST predicate, and add an
34
+ index hint comment.
35
+ observation:
36
+ type: object
37
+ fields:
38
+ task_id: integer
39
+ task_name: string
40
+ task_description: string
41
+ query: string
42
+ schema_context: string
43
+ hint: "string | null"
44
+ step_number: integer
45
+ max_steps: integer
46
+ done: boolean
47
+ action:
48
+ type: object
49
+ fields:
50
+ rewritten_query: string
51
+ explanation: string
52
+ is_done: boolean
53
+ reward:
54
+ type: object
55
+ fields:
56
+ score: "float [0.0, 1.0]"
57
+ grader_score: "float [0.0, 1.0]"
58
+ breakdown:
59
+ correctness: "float [0.0, 1.0]"
60
+ performance: "float [0.0, 1.0]"
61
+ style: "float [0.0, 1.0]"
62
+ step_penalty: "float ≀ 0.0"
63
+ feedback: string
64
+ cumulative_score: "float [0.0, 1.0]"
65
+ endpoints:
66
+ - path: /reset
67
+ method: POST
68
+ description: Start a fresh episode for a given task_id
69
+ - path: /step
70
+ method: POST
71
+ description: Submit an Action and advance the episode
72
+ - path: /state
73
+ method: GET
74
+ description: Return the current internal state snapshot
75
+ - path: /tasks
76
+ method: GET
77
+ description: List all tasks and action schema
78
+ - path: /grader
79
+ method: GET
80
+ description: Return grader score for the last completed episode
81
+ - path: /baseline
82
+ method: POST
83
+ description: Trigger baseline inference on all 3 tasks
pyproject.toml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sql-query-optimizer-openenv"
7
+ version = "1.0.0"
8
+ description = "An OpenEnv environment where AI agents learn to review, rewrite, and optimise SQL queries for correctness and performance."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ authors = [
12
+ {name = "metaXscaler", email = ""}
13
+ ]
14
+ license = {text = "MIT"}
15
+ keywords = ["openenv", "sql", "optimization", "ml", "agent", "environment"]
16
+ classifiers = [
17
+ "Development Status :: 4 - Beta",
18
+ "Intended Audience :: Developers",
19
+ "Intended Audience :: Science/Research",
20
+ "License :: OSI Approved :: MIT License",
21
+ "Programming Language :: Python :: 3",
22
+ "Programming Language :: Python :: 3.10",
23
+ "Programming Language :: Python :: 3.11",
24
+ "Programming Language :: Python :: 3.12",
25
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
26
+ ]
27
+
28
+ dependencies = [
29
+ "fastapi>=0.111.0",
30
+ "uvicorn[standard]>=0.29.0",
31
+ "pydantic>=2.7.0",
32
+ "openai>=1.30.0",
33
+ "pyyaml>=6.0",
34
+ ]
35
+
36
+ [project.optional-dependencies]
37
+ dev = [
38
+ "pytest>=7.0",
39
+ "black>=23.0",
40
+ "ruff>=0.1.0",
41
+ ]
42
+
43
+ [project.urls]
44
+ Homepage = "https://huggingface.co/spaces"
45
+ Repository = "https://github.com/metaXscaler/sql-query-optimizer-openenv"
46
+ Documentation = "https://github.com/metaXscaler/sql-query-optimizer-openenv/blob/main/README.md"
47
+
48
+ [tool.black]
49
+ line-length = 100
50
+ target-version = ['py310', 'py311', 'py312']
51
+
52
+ [tool.ruff]
53
+ line-length = 100
54
+ target-version = "py310"
55
+ select = ["E", "F", "W"]
56
+ ignore = ["E501"] # Line too long (handled by black)
57
+
58
+ [tool.pytest.ini_options]
59
+ testpaths = ["tests"]
60
+ python_files = ["test_*.py"]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi>=0.111.0
2
+ uvicorn[standard]>=0.29.0
3
+ pydantic>=2.7.0
4
+ openai>=1.30.0
5
+ pyyaml>=6.0
server/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """FastAPI server package for SQL Query Optimizer OpenEnv environment."""
2
+ from .app import app
3
+
4
+ __all__ = ["app"]
server/app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI server exposing the OpenEnv SQL Optimizer environment.
3
+
4
+ Endpoints:
5
+ POST /reset β†’ Observation
6
+ POST /step β†’ {observation, reward, done, info}
7
+ GET /state β†’ state dict
8
+ GET /tasks β†’ list of tasks + action schema
9
+ GET /grader β†’ grader score for last completed episode
10
+ POST /baseline β†’ trigger baseline inference on all 3 tasks
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ import subprocess
16
+ import sys
17
+ from typing import Any, Dict, Optional
18
+
19
+ from fastapi import FastAPI, HTTPException
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ from pydantic import BaseModel
22
+
23
+ from env.environment import SQLOptimizerEnv
24
+ from env.models import Action, Observation, Reward
25
+ from env.tasks import TASKS
26
+
27
+ app = FastAPI(
28
+ title="SQL Query Optimizer β€” OpenEnv",
29
+ description=(
30
+ "An OpenEnv-compliant environment where AI agents learn to rewrite "
31
+ "and optimise SQL queries across three difficulty levels."
32
+ ),
33
+ version="1.0.0",
34
+ )
35
+
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"],
39
+ allow_methods=["*"],
40
+ allow_headers=["*"],
41
+ )
42
+
43
+ # Single shared environment instance (stateful, per-process)
44
+ _env = SQLOptimizerEnv()
45
+
46
+
47
+ # ──────────────────────────────────────────────────────────────────────────────
48
+ # Request / Response schemas
49
+ # ──────────────────────────────────────────────────────────────────────────────
50
+
51
+ class ResetRequest(BaseModel):
52
+ task_id: int = 1
53
+
54
+
55
+ class StepResponse(BaseModel):
56
+ observation: Observation
57
+ reward: Reward
58
+ done: bool
59
+ info: Dict[str, Any]
60
+
61
+
62
+ class GraderResponse(BaseModel):
63
+ task_id: Optional[int]
64
+ grader_score: float
65
+ cumulative_score: float
66
+ done: bool
67
+
68
+
69
+ class TaskInfo(BaseModel):
70
+ id: int
71
+ name: str
72
+ difficulty: str
73
+ description: str
74
+ action_schema: Dict[str, Any]
75
+
76
+
77
+ class BaselineResponse(BaseModel):
78
+ task_results: Dict[str, float]
79
+ message: str
80
+
81
+
82
+ # ──────────────────────────────────────────────────────────────────────────────
83
+ # Endpoints
84
+ # ──────────────────────────────────────────────────────────────────────────────
85
+
86
+ @app.get("/", summary="Health check")
87
+ def health() -> Dict[str, str]:
88
+ return {"status": "ok", "environment": "sql-query-optimizer", "version": "1.0.0"}
89
+
90
+
91
+ @app.post("/reset", response_model=Observation, summary="Start / restart an episode")
92
+ def reset(req: ResetRequest) -> Observation:
93
+ """Reset the environment for a given task_id (1=easy, 2=medium, 3=hard)."""
94
+ try:
95
+ obs = _env.reset(task_id=req.task_id)
96
+ except ValueError as exc:
97
+ raise HTTPException(status_code=400, detail=str(exc))
98
+ return obs
99
+
100
+
101
+ @app.post("/step", response_model=StepResponse, summary="Submit an action")
102
+ def step(action: Action) -> StepResponse:
103
+ """Advance the environment by submitting an Action."""
104
+ try:
105
+ obs, reward, done, info = _env.step(action)
106
+ except RuntimeError as exc:
107
+ raise HTTPException(status_code=400, detail=str(exc))
108
+ return StepResponse(observation=obs, reward=reward, done=done, info=info)
109
+
110
+
111
+ @app.get("/state", summary="Return current internal state")
112
+ def state() -> Dict[str, Any]:
113
+ """Return the current internal state of the environment."""
114
+ return _env.state()
115
+
116
+
117
+ @app.get("/tasks", response_model=list[TaskInfo], summary="List tasks + action schema")
118
+ def list_tasks() -> list[TaskInfo]:
119
+ """Return all tasks with descriptions and the action schema."""
120
+ action_schema = Action.model_json_schema()
121
+ return [
122
+ TaskInfo(
123
+ id=t.id,
124
+ name=t.name,
125
+ difficulty=t.difficulty,
126
+ description=t.description,
127
+ action_schema=action_schema,
128
+ )
129
+ for t in TASKS.values()
130
+ ]
131
+
132
+
133
+ @app.get("/grader", response_model=GraderResponse, summary="Grader score for last episode")
134
+ def grader() -> GraderResponse:
135
+ """Return the grader score after the current/last episode."""
136
+ s = _env.state()
137
+ if s.get("status") == "not_started":
138
+ raise HTTPException(status_code=400, detail="No episode started. Call /reset first.")
139
+ return GraderResponse(
140
+ task_id=s.get("task_id"),
141
+ grader_score=s.get("last_grader_score", 0.0),
142
+ cumulative_score=s.get("cumulative_score", 0.0),
143
+ done=s.get("done", False),
144
+ )
145
+
146
+
147
+ @app.post("/baseline", response_model=BaselineResponse, summary="Run baseline inference on all tasks")
148
+ def baseline() -> BaselineResponse:
149
+ """
150
+ Trigger the baseline inference script (baseline.py) and return scores.
151
+ Requires OPENAI_API_KEY to be set in the environment.
152
+ """
153
+ if not os.getenv("OPENAI_API_KEY"):
154
+ raise HTTPException(
155
+ status_code=400,
156
+ detail="OPENAI_API_KEY environment variable not set. Cannot run baseline.",
157
+ )
158
+ try:
159
+ result = subprocess.run(
160
+ [sys.executable, "baseline.py", "--json"],
161
+ capture_output=True,
162
+ text=True,
163
+ timeout=300,
164
+ )
165
+ if result.returncode != 0:
166
+ raise HTTPException(
167
+ status_code=500,
168
+ detail=f"Baseline script failed:\n{result.stderr}",
169
+ )
170
+ import json
171
+ scores = json.loads(result.stdout)
172
+ return BaselineResponse(task_results=scores, message="Baseline completed successfully.")
173
+ except subprocess.TimeoutExpired:
174
+ raise HTTPException(status_code=500, detail="Baseline script timed out after 300s.")
175
+ except Exception as exc:
176
+ raise HTTPException(status_code=500, detail=str(exc))
sql-query-optimizer/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
sql-query-optimizer/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sql Query Optimizer
3
+ emoji: πŸš€
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ short_description: SQL Query Optimizer β€” OpenEnv Environment
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
test_env.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quick smoke test for all 3 tasks."""
2
+ import sys, json
3
+ sys.path.insert(0, ".")
4
+
5
+ from env.environment import SQLOptimizerEnv
6
+ from env.models import Action
7
+
8
+ env = SQLOptimizerEnv()
9
+
10
+ # ── Task 1 ──────────────────────────────────────────────────────────────────
11
+ print("=== Task 1 (Easy): fix-broken-join ===")
12
+ obs = env.reset(1)
13
+ print(f" task: {obs.task_name}")
14
+ action = Action(
15
+ rewritten_query=(
16
+ "SELECT o.order_id, c.name, o.total "
17
+ "FROM orders o INNER JOIN customers c ON o.customer_id = c.customer_id "
18
+ "WHERE o.total > 100"
19
+ ),
20
+ explanation="Replaced comma cross-join with INNER JOIN ON customer_id",
21
+ is_done=True,
22
+ )
23
+ obs2, reward, done, info = env.step(action)
24
+ print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}")
25
+ print(f" feedback: {reward.feedback}")
26
+ assert obs2.done == True, "done should be True"
27
+ assert info["grader_score"] >= 0.8, f"Expected >=0.8, got {info['grader_score']}"
28
+
29
+ # ── Task 2 ──────────────────────────────────────────────────────────────────
30
+ print()
31
+ print("=== Task 2 (Medium): eliminate-n-plus-one ===")
32
+ obs = env.reset(2)
33
+ print(f" task: {obs.task_name}")
34
+ action = Action(
35
+ rewritten_query=(
36
+ "SELECT e.name, d.dept_name "
37
+ "FROM employees e "
38
+ "LEFT JOIN departments d ON e.dept_id = d.dept_id "
39
+ "WHERE e.salary > 50000"
40
+ ),
41
+ explanation="Replaced correlated subquery with a single LEFT JOIN",
42
+ is_done=True,
43
+ )
44
+ obs2, reward, done, info = env.step(action)
45
+ print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}")
46
+ print(f" feedback: {reward.feedback}")
47
+ assert info["grader_score"] >= 0.7, f"Expected >=0.7, got {info['grader_score']}"
48
+
49
+ # ── Task 3 ──────────────────────────────────────────────────────────────────
50
+ print()
51
+ print("=== Task 3 (Hard): full-optimization ===")
52
+ obs = env.reset(3)
53
+ print(f" task: {obs.task_name}")
54
+ action = Action(
55
+ rewritten_query=(
56
+ "-- Index hint: consider CREATE INDEX ON products(category, price)\n"
57
+ "SELECT p.name, p.category, p.price, oi.quantity, oi.unit_price\n"
58
+ "FROM products p\n"
59
+ "JOIN order_items oi ON p.product_id = oi.product_id\n"
60
+ "WHERE p.price >= 100 AND p.price < 200\n"
61
+ " AND p.category = 'Electronics'\n"
62
+ "ORDER BY p.name"
63
+ ),
64
+ explanation="Removed DISTINCT and SELECT *, replaced CAST LIKE with range, added index hint",
65
+ is_done=True,
66
+ )
67
+ obs2, reward, done, info = env.step(action)
68
+ print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}")
69
+ print(f" feedback: {reward.feedback}")
70
+ assert info["grader_score"] >= 0.9, f"Expected >=0.9, got {info['grader_score']}"
71
+
72
+ # ── state() ─────────────────────────────────────────────────────────────────
73
+ print()
74
+ print("=== state() ===")
75
+ print(json.dumps(env.state(), indent=2))
76
+
77
+ # ── invalid action penalty ───────────────────────────────────────────────────
78
+ print()
79
+ print("=== Invalid action test ===")
80
+ env.reset(1)
81
+ obs2, reward, done, info = env.step(Action(rewritten_query="", explanation="", is_done=False))
82
+ print(f" step_reward={reward.score} is_invalid={info['is_invalid']}")
83
+ assert info["is_invalid"] == True, "Empty query should be flagged invalid"
84
+
85
+ print()
86
+ print("ALL TESTS PASSED")