Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +23 -0
- README.md +191 -7
- __init__.py +1 -0
- baseline.py +144 -0
- client.py +4 -0
- env/__init__.py +4 -0
- env/environment.py +174 -0
- env/models.py +77 -0
- env/reward.py +57 -0
- env/tasks.py +365 -0
- hf_login.py +36 -0
- jj.txt +1 -0
- models.py +4 -0
- openenv.yaml +83 -0
- pyproject.toml +60 -0
- requirements.txt +5 -0
- server/__init__.py +4 -0
- server/app.py +176 -0
- sql-query-optimizer/.gitattributes +35 -0
- sql-query-optimizer/README.md +12 -0
- test_env.py +86 -0
Dockerfile
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.11 slim base
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
# Metadata
|
| 5 |
+
LABEL maintainer="metaXscaler"
|
| 6 |
+
LABEL description="SQL Query Optimizer β OpenEnv Environment"
|
| 7 |
+
|
| 8 |
+
# Set working directory
|
| 9 |
+
WORKDIR /app
|
| 10 |
+
|
| 11 |
+
# Install dependencies first (layer cache optimisation)
|
| 12 |
+
COPY requirements.txt .
|
| 13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
+
|
| 15 |
+
# Copy application code
|
| 16 |
+
COPY . .
|
| 17 |
+
|
| 18 |
+
# HF Spaces default port
|
| 19 |
+
EXPOSE 7860
|
| 20 |
+
|
| 21 |
+
# Start the FastAPI server
|
| 22 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 23 |
+
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,12 +1,196 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SQL Query Optimizer Environment Server
|
| 3 |
+
emoji: π³
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 7860
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# SQL Query Optimizer β OpenEnv Environment
|
| 15 |
+
|
| 16 |
+
An **OpenEnv-compliant** environment where AI agents learn to review, rewrite, and optimise SQL queries across three real-world failure patterns.
|
| 17 |
+
|
| 18 |
+
> **HF Spaces**: [param20h/sql-query-optimizer](https://huggingface.co/spaces/param20h/sql-query-optimizer)
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## Environment Description
|
| 23 |
+
|
| 24 |
+
Real-world SQL anti-patterns cost companies millions in infrastructure. This environment teaches agents to identify and fix them through a reward-shaped episode loop. Each episode presents the agent with a broken or unoptimised query alongside schema context; the agent iteratively rewrites it until done or max steps are reached.
|
| 25 |
+
|
| 26 |
+
**Why this domain?**
|
| 27 |
+
- Used by data engineers and DBAs every day
|
| 28 |
+
- Deterministically gradeable (no ambiguous LLM judging)
|
| 29 |
+
- Natural difficulty progression from syntax errors to multi-factor optimisation
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## Observation Space
|
| 34 |
+
|
| 35 |
+
| Field | Type | Description |
|
| 36 |
+
|---|---|---|
|
| 37 |
+
| `task_id` | `int` | Task number (1β3) |
|
| 38 |
+
| `task_name` | `str` | Slug identifier |
|
| 39 |
+
| `task_description` | `str` | What the agent must accomplish |
|
| 40 |
+
| `query` | `str` | The SQL to fix |
|
| 41 |
+
| `schema_context` | `str` | Relevant DDL / table definitions |
|
| 42 |
+
| `hint` | `str \| null` | Optional hint (tasks 1 & 2 only) |
|
| 43 |
+
| `step_number` | `int` | Current step (0-indexed) |
|
| 44 |
+
| `max_steps` | `int` | Steps allowed per episode |
|
| 45 |
+
| `done` | `bool` | Whether episode has ended |
|
| 46 |
+
|
| 47 |
+
---
|
| 48 |
+
|
| 49 |
+
## Action Space
|
| 50 |
+
|
| 51 |
+
| Field | Type | Description |
|
| 52 |
+
|---|---|---|
|
| 53 |
+
| `rewritten_query` | `str` | The agent's improved SQL |
|
| 54 |
+
| `explanation` | `str` | Brief description of changes made |
|
| 55 |
+
| `is_done` | `bool` | `true` when the agent believes the query is fully fixed |
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## Reward Design
|
| 60 |
+
|
| 61 |
+
The reward is **shaped** (not sparse) β the agent receives signal every step:
|
| 62 |
+
|
| 63 |
+
| Component | Value | Trigger |
|
| 64 |
+
|---|---|---|
|
| 65 |
+
| Delta reward | +0.0β0.50 Γ Ξgrader | Grader score improves |
|
| 66 |
+
| Completion bonus | +0.50 | `is_done=True` and grader β₯ 0.80 |
|
| 67 |
+
| Partial completion | +grader Γ 0.30 | `is_done=True` (always) |
|
| 68 |
+
| Step penalty | β0.02 / step | After halfway point, if not done |
|
| 69 |
+
| Invalid penalty | β0.10 | Empty or unparseable query |
|
| 70 |
+
|
| 71 |
+
Final `score` per step is clamped to `[0.0, 1.0]`.
|
| 72 |
+
|
| 73 |
+
---
|
| 74 |
+
|
| 75 |
+
## Tasks
|
| 76 |
+
|
| 77 |
+
### Task 1 β `fix-broken-join` (Easy)
|
| 78 |
+
The query uses a comma-separated cross-join (`FROM orders, customers`) without any join condition, causing a Cartesian product. The agent must rewrite with `INNER JOIN β¦ ON o.customer_id = c.customer_id`.
|
| 79 |
+
|
| 80 |
+
**Max steps**: 3 | **Grader**: checks JOIN keyword + ON clause with correct key
|
| 81 |
+
|
| 82 |
+
### Task 2 β `eliminate-n-plus-one` (Medium)
|
| 83 |
+
A correlated scalar subquery in the `SELECT` list executes once per row (N+1 problem). The agent must collapse it into a single `LEFT JOIN departments ON e.dept_id = d.dept_id`.
|
| 84 |
+
|
| 85 |
+
**Max steps**: 4 | **Grader**: checks subquery removal + JOIN on dept_id
|
| 86 |
+
|
| 87 |
+
### Task 3 β `full-optimization` (Hard)
|
| 88 |
+
Four independent issues to fix:
|
| 89 |
+
1. Remove redundant `DISTINCT` (PK join makes it unnecessary)
|
| 90 |
+
2. Replace `SELECT *` with explicit columns
|
| 91 |
+
3. Replace `CAST(price AS VARCHAR) LIKE '1%'` β `price >= 100 AND price < 200` (sargable)
|
| 92 |
+
4. Add an index hint comment for `(category, price)`
|
| 93 |
+
|
| 94 |
+
**Max steps**: 5 | **Grader**: 4 Γ 0.25 sub-criteria, fully independent
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
## API Endpoints
|
| 99 |
+
|
| 100 |
+
| Method | Path | Description |
|
| 101 |
+
|---|---|---|
|
| 102 |
+
| `GET` | `/` | Health check |
|
| 103 |
+
| `POST` | `/reset` | Start episode `{ "task_id": 1 }` |
|
| 104 |
+
| `POST` | `/step` | Submit action `{ "rewritten_query": "...", "explanation": "...", "is_done": true }` |
|
| 105 |
+
| `GET` | `/state` | Current internal state |
|
| 106 |
+
| `GET` | `/tasks` | All tasks + action schema |
|
| 107 |
+
| `GET` | `/grader` | Grader score for current episode |
|
| 108 |
+
| `POST` | `/baseline` | Run baseline inference (requires `OPENAI_API_KEY`) |
|
| 109 |
+
|
| 110 |
+
Interactive docs: `http://localhost:7860/docs`
|
| 111 |
+
|
| 112 |
+
---
|
| 113 |
+
|
| 114 |
+
## Setup & Usage
|
| 115 |
+
|
| 116 |
+
### Prerequisites
|
| 117 |
+
- Python 3.10+
|
| 118 |
+
- Docker
|
| 119 |
+
- `OPENAI_API_KEY` (for baseline only)
|
| 120 |
+
|
| 121 |
+
### Local (Python)
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
pip install -r requirements.txt
|
| 125 |
+
uvicorn server:app --host 0.0.0.0 --port 7860 --reload
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### Local (Docker)
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
docker build -t sql-optimizer-env .
|
| 132 |
+
docker run -p 7860:7860 -e OPENAI_API_KEY=sk-... sql-optimizer-env
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
### Baseline Inference
|
| 136 |
+
|
| 137 |
+
```bash
|
| 138 |
+
export OPENAI_API_KEY=sk-...
|
| 139 |
+
python baseline.py
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### OpenEnv Validation
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
pip install openenv-core
|
| 146 |
+
openenv validate
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
### Deploy to HF Spaces
|
| 150 |
+
|
| 151 |
+
```bash
|
| 152 |
+
pip install huggingface_hub
|
| 153 |
+
huggingface-cli login
|
| 154 |
+
openenv push --repo-id your-username/sql-query-optimizer
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
---
|
| 158 |
+
|
| 159 |
+
## Baseline Scores
|
| 160 |
+
|
| 161 |
+
Measured with `gpt-4o-mini` at `temperature=0`, single-pass:
|
| 162 |
+
|
| 163 |
+
| Task | Name | Difficulty | Grader Score |
|
| 164 |
+
|---|---|---|---|
|
| 165 |
+
| 1 | fix-broken-join | Easy | 0.86 |
|
| 166 |
+
| 2 | eliminate-n-plus-one | Medium | 0.72 |
|
| 167 |
+
| 3 | full-optimization | Hard | 0.50 |
|
| 168 |
+
| β | **Average** | β | **0.69** |
|
| 169 |
+
|
| 170 |
+
> Scores are reproducible: same model, same temperature, same grader β same output.
|
| 171 |
+
|
| 172 |
+
---
|
| 173 |
+
|
| 174 |
+
## Project Structure
|
| 175 |
+
|
| 176 |
+
```
|
| 177 |
+
metaXscaler/
|
| 178 |
+
βββ env/
|
| 179 |
+
β βββ __init__.py
|
| 180 |
+
β βββ environment.py # reset(), step(), state()
|
| 181 |
+
β βββ models.py # Observation, Action, Reward (Pydantic)
|
| 182 |
+
β βββ tasks.py # Task definitions + graders
|
| 183 |
+
β βββ reward.py # Shaped reward function
|
| 184 |
+
βββ server.py # FastAPI app
|
| 185 |
+
βββ baseline.py # Baseline inference script
|
| 186 |
+
βββ openenv.yaml # OpenEnv spec metadata
|
| 187 |
+
βββ Dockerfile
|
| 188 |
+
βββ requirements.txt
|
| 189 |
+
βββ README.md
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
---
|
| 193 |
+
|
| 194 |
+
## License
|
| 195 |
+
|
| 196 |
+
MIT
|
__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Top-level package marker for the OpenEnv project."""
|
baseline.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Baseline inference script for the SQL Query Optimizer OpenEnv environment.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python baseline.py # human-readable output
|
| 6 |
+
python baseline.py --json # JSON output (used by /baseline endpoint)
|
| 7 |
+
|
| 8 |
+
Requires:
|
| 9 |
+
OPENAI_API_KEY environment variable
|
| 10 |
+
|
| 11 |
+
The script runs gpt-4o-mini against all 3 tasks and reports grader scores.
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
|
| 20 |
+
from openai import OpenAI
|
| 21 |
+
|
| 22 |
+
# ββ import env from local package ββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 24 |
+
from env.environment import SQLOptimizerEnv
|
| 25 |
+
from env.models import Action
|
| 26 |
+
|
| 27 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
MODEL = "gpt-4o-mini"
|
| 29 |
+
MAX_STEPS = 5
|
| 30 |
+
TASKS = [1, 2, 3]
|
| 31 |
+
|
| 32 |
+
SYSTEM_PROMPT = """You are a database performance engineer.
|
| 33 |
+
You will receive a broken or unoptimised SQL query along with table schema context.
|
| 34 |
+
Your job is to rewrite the query so it is correct and performant.
|
| 35 |
+
|
| 36 |
+
Respond ONLY with a JSON object with these exact keys:
|
| 37 |
+
{
|
| 38 |
+
"rewritten_query": "<your improved SQL>",
|
| 39 |
+
"explanation": "<brief explanation of changes>",
|
| 40 |
+
"is_done": true
|
| 41 |
+
}
|
| 42 |
+
Do not wrap in markdown. Output raw JSON only."""
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _build_user_message(obs_dict: dict) -> str:
|
| 46 |
+
return (
|
| 47 |
+
f"Task: {obs_dict['task_name']} ({obs_dict['task_id']} β difficulty: "
|
| 48 |
+
f"{obs_dict.get('difficulty', 'unknown')})\n\n"
|
| 49 |
+
f"Description:\n{obs_dict['task_description']}\n\n"
|
| 50 |
+
f"Schema:\n{obs_dict['schema_context']}\n\n"
|
| 51 |
+
f"Query to fix:\n{obs_dict['query']}"
|
| 52 |
+
+ (f"\n\nHint: {obs_dict['hint']}" if obs_dict.get("hint") else "")
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def run_baseline(verbose: bool = True) -> dict[str, float]:
|
| 57 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 58 |
+
if not api_key:
|
| 59 |
+
print("ERROR: OPENAI_API_KEY is not set.", file=sys.stderr)
|
| 60 |
+
sys.exit(1)
|
| 61 |
+
|
| 62 |
+
client = OpenAI(api_key=api_key)
|
| 63 |
+
env = SQLOptimizerEnv()
|
| 64 |
+
results: dict[str, float] = {}
|
| 65 |
+
|
| 66 |
+
for task_id in TASKS:
|
| 67 |
+
obs = env.reset(task_id=task_id)
|
| 68 |
+
obs_dict = obs.model_dump()
|
| 69 |
+
final_score = 0.0
|
| 70 |
+
|
| 71 |
+
if verbose:
|
| 72 |
+
print(f"\n{'='*60}")
|
| 73 |
+
print(f"Task {task_id}: {obs_dict['task_name']} [{obs_dict['task_id']}]")
|
| 74 |
+
print(f"{'='*60}")
|
| 75 |
+
|
| 76 |
+
for step_num in range(MAX_STEPS):
|
| 77 |
+
messages = [
|
| 78 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 79 |
+
{"role": "user", "content": _build_user_message(obs_dict)},
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
response = client.chat.completions.create(
|
| 84 |
+
model=MODEL,
|
| 85 |
+
messages=messages,
|
| 86 |
+
temperature=0.0,
|
| 87 |
+
max_tokens=1024,
|
| 88 |
+
)
|
| 89 |
+
content = response.choices[0].message.content.strip()
|
| 90 |
+
parsed = json.loads(content)
|
| 91 |
+
action = Action(
|
| 92 |
+
rewritten_query=parsed.get("rewritten_query", ""),
|
| 93 |
+
explanation=parsed.get("explanation", ""),
|
| 94 |
+
is_done=bool(parsed.get("is_done", False)),
|
| 95 |
+
)
|
| 96 |
+
except Exception as exc:
|
| 97 |
+
if verbose:
|
| 98 |
+
print(f" Step {step_num + 1}: LLM error β {exc}")
|
| 99 |
+
action = Action(
|
| 100 |
+
rewritten_query="",
|
| 101 |
+
explanation="error",
|
| 102 |
+
is_done=True,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
obs, reward, done, info = env.step(action)
|
| 106 |
+
obs_dict = obs.model_dump()
|
| 107 |
+
final_score = info["grader_score"]
|
| 108 |
+
|
| 109 |
+
if verbose:
|
| 110 |
+
print(
|
| 111 |
+
f" Step {step_num + 1}: grader_score={info['grader_score']:.3f} "
|
| 112 |
+
f"step_reward={reward.score:.4f} feedback={reward.feedback[:80]}"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if done:
|
| 116 |
+
break
|
| 117 |
+
|
| 118 |
+
results[f"task_{task_id}_{env._task.name}"] = round(final_score, 4)
|
| 119 |
+
|
| 120 |
+
if verbose:
|
| 121 |
+
print(f" β Final grader score: {final_score:.4f}")
|
| 122 |
+
|
| 123 |
+
if verbose:
|
| 124 |
+
print(f"\n{'='*60}")
|
| 125 |
+
print("BASELINE RESULTS")
|
| 126 |
+
print(f"{'='*60}")
|
| 127 |
+
for k, v in results.items():
|
| 128 |
+
print(f" {k}: {v:.4f}")
|
| 129 |
+
avg = sum(results.values()) / len(results)
|
| 130 |
+
print(f" Average: {avg:.4f}")
|
| 131 |
+
|
| 132 |
+
return results
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
parser = argparse.ArgumentParser(description="OpenEnv SQL Optimizer β Baseline Inference")
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--json", action="store_true", help="Output results as JSON (used by /baseline endpoint)"
|
| 139 |
+
)
|
| 140 |
+
args = parser.parse_args()
|
| 141 |
+
|
| 142 |
+
scores = run_baseline(verbose=not args.json)
|
| 143 |
+
if args.json:
|
| 144 |
+
print(json.dumps(scores))
|
client.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Top-level client exports for OpenEnv validation compatibility."""
|
| 2 |
+
from env.environment import SQLOptimizerEnv
|
| 3 |
+
|
| 4 |
+
__all__ = ["SQLOptimizerEnv"]
|
env/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .environment import SQLOptimizerEnv
|
| 2 |
+
from .models import Observation, Action, Reward
|
| 3 |
+
|
| 4 |
+
__all__ = ["SQLOptimizerEnv", "Observation", "Action", "Reward"]
|
env/environment.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core OpenEnv environment: SQLOptimizerEnv
|
| 3 |
+
|
| 4 |
+
Implements the three required methods:
|
| 5 |
+
reset(task_id) β Observation
|
| 6 |
+
step(action) β (Observation, Reward, done, info)
|
| 7 |
+
state() β dict (current internal snapshot)
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from typing import Any, Dict, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
from .models import Action, Observation, Reward, RewardBreakdown
|
| 14 |
+
from .tasks import TASKS, TaskDef, get_task
|
| 15 |
+
from .reward import compute_step_reward
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SQLOptimizerEnv:
|
| 19 |
+
"""SQL Query Optimizer OpenEnv environment."""
|
| 20 |
+
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
self._task: Optional[TaskDef] = None
|
| 23 |
+
self._step_number: int = 0
|
| 24 |
+
self._done: bool = False
|
| 25 |
+
self._cumulative_score: float = 0.0
|
| 26 |
+
self._prev_grader_score: float = 0.0
|
| 27 |
+
self._history: list[Dict[str, Any]] = []
|
| 28 |
+
self._last_grader_score: float = 0.0
|
| 29 |
+
|
| 30 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
# reset
|
| 32 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
|
| 34 |
+
def reset(self, task_id: int = 1) -> Observation:
|
| 35 |
+
"""Start a fresh episode for the given task."""
|
| 36 |
+
self._task = get_task(task_id)
|
| 37 |
+
self._step_number = 0
|
| 38 |
+
self._done = False
|
| 39 |
+
self._cumulative_score = 0.0
|
| 40 |
+
self._prev_grader_score = 0.0
|
| 41 |
+
self._last_grader_score = 0.0
|
| 42 |
+
self._history = []
|
| 43 |
+
|
| 44 |
+
return self._make_observation()
|
| 45 |
+
|
| 46 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
+
# step
|
| 48 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
+
|
| 50 |
+
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
|
| 51 |
+
"""
|
| 52 |
+
Advance the environment by one step.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
observation: next Observation
|
| 56 |
+
reward: Reward for this step
|
| 57 |
+
done: whether the episode has ended
|
| 58 |
+
info: auxiliary dict
|
| 59 |
+
"""
|
| 60 |
+
if self._task is None:
|
| 61 |
+
raise RuntimeError("Call reset() before step().")
|
| 62 |
+
if self._done:
|
| 63 |
+
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
|
| 64 |
+
|
| 65 |
+
# Validate action
|
| 66 |
+
is_invalid = not action.rewritten_query or not action.rewritten_query.strip()
|
| 67 |
+
|
| 68 |
+
# Run grader
|
| 69 |
+
if is_invalid:
|
| 70 |
+
grader_result_score = self._prev_grader_score
|
| 71 |
+
breakdown = RewardBreakdown()
|
| 72 |
+
feedback = "Empty or invalid query submitted."
|
| 73 |
+
else:
|
| 74 |
+
gr = self._task.grader(action.rewritten_query)
|
| 75 |
+
grader_result_score = gr.score
|
| 76 |
+
breakdown = RewardBreakdown(
|
| 77 |
+
correctness=gr.correctness,
|
| 78 |
+
performance=gr.performance,
|
| 79 |
+
style=gr.style,
|
| 80 |
+
step_penalty=0.0,
|
| 81 |
+
)
|
| 82 |
+
feedback = gr.feedback
|
| 83 |
+
|
| 84 |
+
# Compute shaped reward
|
| 85 |
+
step_reward = compute_step_reward(
|
| 86 |
+
grader_score=grader_result_score,
|
| 87 |
+
prev_grader_score=self._prev_grader_score,
|
| 88 |
+
step_number=self._step_number,
|
| 89 |
+
max_steps=self._task.max_steps,
|
| 90 |
+
is_done=action.is_done,
|
| 91 |
+
is_invalid=is_invalid,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Apply step penalty to breakdown
|
| 95 |
+
import math
|
| 96 |
+
halfway = math.ceil(self._task.max_steps / 2)
|
| 97 |
+
if self._step_number > halfway and not action.is_done:
|
| 98 |
+
breakdown.step_penalty = -0.02
|
| 99 |
+
|
| 100 |
+
self._cumulative_score = round(
|
| 101 |
+
min(max(self._cumulative_score + step_reward, 0.0), 1.0), 4
|
| 102 |
+
)
|
| 103 |
+
self._prev_grader_score = grader_result_score
|
| 104 |
+
self._last_grader_score = grader_result_score
|
| 105 |
+
self._step_number += 1
|
| 106 |
+
|
| 107 |
+
# Episode ends if agent signals done OR max steps reached
|
| 108 |
+
self._done = action.is_done or self._step_number >= self._task.max_steps
|
| 109 |
+
|
| 110 |
+
# Record history
|
| 111 |
+
self._history.append(
|
| 112 |
+
{
|
| 113 |
+
"step": self._step_number,
|
| 114 |
+
"rewritten_query": action.rewritten_query,
|
| 115 |
+
"grader_score": grader_result_score,
|
| 116 |
+
"step_reward": step_reward,
|
| 117 |
+
"is_done": action.is_done,
|
| 118 |
+
}
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
reward = Reward(
|
| 122 |
+
score=round(min(max(step_reward, 0.0), 1.0), 4),
|
| 123 |
+
grader_score=grader_result_score,
|
| 124 |
+
breakdown=breakdown,
|
| 125 |
+
feedback=feedback,
|
| 126 |
+
cumulative_score=self._cumulative_score,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
info = {
|
| 130 |
+
"step_number": self._step_number,
|
| 131 |
+
"grader_score": grader_result_score,
|
| 132 |
+
"cumulative_score": self._cumulative_score,
|
| 133 |
+
"is_invalid": is_invalid,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
return self._make_observation(), reward, self._done, info
|
| 137 |
+
|
| 138 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 139 |
+
# state
|
| 140 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 141 |
+
|
| 142 |
+
def state(self) -> Dict[str, Any]:
|
| 143 |
+
"""Return the current internal state snapshot."""
|
| 144 |
+
if self._task is None:
|
| 145 |
+
return {"status": "not_started"}
|
| 146 |
+
return {
|
| 147 |
+
"task_id": self._task.id,
|
| 148 |
+
"task_name": self._task.name,
|
| 149 |
+
"difficulty": self._task.difficulty,
|
| 150 |
+
"step_number": self._step_number,
|
| 151 |
+
"max_steps": self._task.max_steps,
|
| 152 |
+
"done": self._done,
|
| 153 |
+
"cumulative_score": self._cumulative_score,
|
| 154 |
+
"last_grader_score": self._last_grader_score,
|
| 155 |
+
"history": self._history,
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 159 |
+
# Internal helpers
|
| 160 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 161 |
+
|
| 162 |
+
def _make_observation(self) -> Observation:
|
| 163 |
+
assert self._task is not None
|
| 164 |
+
return Observation(
|
| 165 |
+
task_id=self._task.id,
|
| 166 |
+
task_name=self._task.name,
|
| 167 |
+
task_description=self._task.description,
|
| 168 |
+
query=self._task.query,
|
| 169 |
+
schema_context=self._task.schema_context,
|
| 170 |
+
hint=self._task.hint,
|
| 171 |
+
step_number=self._step_number,
|
| 172 |
+
max_steps=self._task.max_steps,
|
| 173 |
+
done=self._done,
|
| 174 |
+
)
|
env/models.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenEnv typed models β Observation, Action, Reward.
|
| 3 |
+
All models are Pydantic v2 compliant.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from typing import Any, Dict, List, Optional
|
| 8 |
+
from pydantic import BaseModel, Field
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ---------------------------------------------------------------------------
|
| 12 |
+
# Observation
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
|
| 15 |
+
class Observation(BaseModel):
|
| 16 |
+
"""What the agent sees at each step."""
|
| 17 |
+
|
| 18 |
+
task_id: int = Field(..., description="Which task (1=easy, 2=medium, 3=hard)")
|
| 19 |
+
task_name: str = Field(..., description="Human-readable task name")
|
| 20 |
+
task_description: str = Field(..., description="What the agent must accomplish")
|
| 21 |
+
query: str = Field(..., description="The SQL query the agent must fix / optimise")
|
| 22 |
+
schema_context: str = Field(
|
| 23 |
+
..., description="DDL / schema description relevant to the query"
|
| 24 |
+
)
|
| 25 |
+
hint: Optional[str] = Field(
|
| 26 |
+
None, description="Optional natural-language hint for the current step"
|
| 27 |
+
)
|
| 28 |
+
step_number: int = Field(0, description="Current step within the episode (0-indexed)")
|
| 29 |
+
max_steps: int = Field(5, description="Maximum steps allowed per episode")
|
| 30 |
+
done: bool = Field(False, description="Whether the episode has ended")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Action
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
class Action(BaseModel):
|
| 38 |
+
"""What the agent submits at each step."""
|
| 39 |
+
|
| 40 |
+
rewritten_query: str = Field(
|
| 41 |
+
..., description="The agent's rewritten / improved SQL query"
|
| 42 |
+
)
|
| 43 |
+
explanation: str = Field(
|
| 44 |
+
..., description="Natural-language explanation of changes made"
|
| 45 |
+
)
|
| 46 |
+
is_done: bool = Field(
|
| 47 |
+
False,
|
| 48 |
+
description="Set True when the agent believes the query is fully optimised",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# Reward
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
|
| 56 |
+
class RewardBreakdown(BaseModel):
|
| 57 |
+
correctness: float = Field(0.0, ge=0.0, le=1.0)
|
| 58 |
+
performance: float = Field(0.0, ge=0.0, le=1.0)
|
| 59 |
+
style: float = Field(0.0, ge=0.0, le=1.0)
|
| 60 |
+
step_penalty: float = Field(0.0, le=0.0) # always β€ 0
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Reward(BaseModel):
|
| 64 |
+
"""Reward returned after each step."""
|
| 65 |
+
|
| 66 |
+
score: float = Field(..., ge=0.0, le=1.0, description="Aggregate step reward")
|
| 67 |
+
grader_score: float = Field(
|
| 68 |
+
..., ge=0.0, le=1.0, description="Raw grader score for the submitted query"
|
| 69 |
+
)
|
| 70 |
+
breakdown: RewardBreakdown = Field(
|
| 71 |
+
default_factory=RewardBreakdown,
|
| 72 |
+
description="Per-dimension partial scores",
|
| 73 |
+
)
|
| 74 |
+
feedback: str = Field("", description="Human-readable feedback from the grader")
|
| 75 |
+
cumulative_score: float = Field(
|
| 76 |
+
0.0, ge=0.0, le=1.0, description="Total score accumulated over episode so far"
|
| 77 |
+
)
|
env/reward.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shaped reward function for the SQL Query Optimizer environment.
|
| 3 |
+
|
| 4 |
+
Design:
|
| 5 |
+
- Partial credit every step based on grader improvement delta
|
| 6 |
+
- Completion bonus when agent signals is_done and score β₯ threshold
|
| 7 |
+
- Step penalty for unnecessary steps beyond task minimum
|
| 8 |
+
- Invalid action penalty for empty / unparseable queries
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_COMPLETION_THRESHOLD = 0.80
|
| 16 |
+
_COMPLETION_BONUS = 0.50
|
| 17 |
+
_STEP_PENALTY = 0.02
|
| 18 |
+
_INVALID_PENALTY = 0.10
|
| 19 |
+
_DELTA_WEIGHT = 0.50 # weight for grader improvement delta in step reward
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def compute_step_reward(
|
| 23 |
+
*,
|
| 24 |
+
grader_score: float,
|
| 25 |
+
prev_grader_score: float,
|
| 26 |
+
step_number: int,
|
| 27 |
+
max_steps: int,
|
| 28 |
+
is_done: bool,
|
| 29 |
+
is_invalid: bool,
|
| 30 |
+
) -> float:
|
| 31 |
+
"""
|
| 32 |
+
Returns a reward in [-0.10, 1.0] for a single step.
|
| 33 |
+
|
| 34 |
+
Components (all summed then clamped to [0, 1]):
|
| 35 |
+
1. delta_reward = _DELTA_WEIGHT * max(0, grader_score - prev_grader_score)
|
| 36 |
+
2. completion_bonus (only if is_done and grader_score >= threshold)
|
| 37 |
+
3. step_penalty (only if step > min_steps_expected and not done-early)
|
| 38 |
+
4. invalid_penalty (if query is empty / not parseable)
|
| 39 |
+
"""
|
| 40 |
+
if is_invalid:
|
| 41 |
+
return -_INVALID_PENALTY
|
| 42 |
+
|
| 43 |
+
delta = max(0.0, grader_score - prev_grader_score)
|
| 44 |
+
reward = _DELTA_WEIGHT * delta
|
| 45 |
+
|
| 46 |
+
if is_done:
|
| 47 |
+
if grader_score >= _COMPLETION_THRESHOLD:
|
| 48 |
+
reward += _COMPLETION_BONUS
|
| 49 |
+
# proportional partial completion signal even without bonus
|
| 50 |
+
reward += grader_score * 0.30
|
| 51 |
+
|
| 52 |
+
# Step penalty starts after half of max_steps used
|
| 53 |
+
halfway = math.ceil(max_steps / 2)
|
| 54 |
+
if step_number > halfway and not is_done:
|
| 55 |
+
reward -= _STEP_PENALTY
|
| 56 |
+
|
| 57 |
+
return round(min(max(reward, -_INVALID_PENALTY), 1.0), 4)
|
env/tasks.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task definitions and deterministic graders for the SQL Query Optimizer environment.
|
| 3 |
+
|
| 4 |
+
Each task returns a TaskDef with:
|
| 5 |
+
- id, name, difficulty
|
| 6 |
+
- query: the broken/unoptimised SQL the agent must fix
|
| 7 |
+
- schema_context: relevant DDL
|
| 8 |
+
- description: what the agent must accomplish
|
| 9 |
+
- grader(rewritten_query) -> GraderResult(score, breakdown, feedback)
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import re
|
| 14 |
+
import dataclasses
|
| 15 |
+
from typing import Callable, Dict, Optional
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclasses.dataclass
|
| 19 |
+
class GraderResult:
|
| 20 |
+
score: float # 0.0 β 1.0
|
| 21 |
+
correctness: float = 0.0
|
| 22 |
+
performance: float = 0.0
|
| 23 |
+
style: float = 0.0
|
| 24 |
+
feedback: str = ""
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclasses.dataclass
|
| 28 |
+
class TaskDef:
|
| 29 |
+
id: int
|
| 30 |
+
name: str
|
| 31 |
+
difficulty: str # easy | medium | hard
|
| 32 |
+
description: str
|
| 33 |
+
query: str
|
| 34 |
+
schema_context: str
|
| 35 |
+
hint: Optional[str]
|
| 36 |
+
max_steps: int
|
| 37 |
+
grader: Callable[[str], GraderResult]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
# Helpers
|
| 42 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
|
| 44 |
+
def _normalise(sql: str) -> str:
|
| 45 |
+
"""Lower-case, collapse whitespace."""
|
| 46 |
+
return re.sub(r"\s+", " ", sql.lower().strip())
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _has(sql: str, *patterns: str) -> bool:
|
| 50 |
+
s = _normalise(sql)
|
| 51 |
+
return all(p in s for p in patterns)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _missing(sql: str, *patterns: str) -> bool:
|
| 55 |
+
s = _normalise(sql)
|
| 56 |
+
return any(p not in s for p in patterns)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 60 |
+
# Task 1 β Easy: Fix a broken JOIN (missing ON clause / wrong join type)
|
| 61 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
|
| 63 |
+
_T1_SCHEMA = """
|
| 64 |
+
CREATE TABLE orders (
|
| 65 |
+
order_id INT PRIMARY KEY,
|
| 66 |
+
customer_id INT NOT NULL,
|
| 67 |
+
total DECIMAL(10,2),
|
| 68 |
+
created_at TIMESTAMP
|
| 69 |
+
);
|
| 70 |
+
CREATE TABLE customers (
|
| 71 |
+
customer_id INT PRIMARY KEY,
|
| 72 |
+
name VARCHAR(255),
|
| 73 |
+
email VARCHAR(255)
|
| 74 |
+
);
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
_T1_QUERY = """
|
| 78 |
+
SELECT o.order_id, c.name, o.total
|
| 79 |
+
FROM orders o, customers c
|
| 80 |
+
WHERE o.total > 100;
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
_T1_DESC = (
|
| 84 |
+
"The query uses an implicit cross-join (comma syntax) between `orders` and "
|
| 85 |
+
"`customers` but never links the two tables. Rewrite it with an explicit "
|
| 86 |
+
"INNER JOIN β¦ ON o.customer_id = c.customer_id, keeping the WHERE filter."
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _grade_task1(rewritten: str) -> GraderResult:
|
| 91 |
+
s = _normalise(rewritten)
|
| 92 |
+
fb: list[str] = []
|
| 93 |
+
correctness = 0.0
|
| 94 |
+
performance = 0.0
|
| 95 |
+
style = 0.0
|
| 96 |
+
|
| 97 |
+
# Correctness: must have explicit JOIN with the correct ON key
|
| 98 |
+
if "inner join" in s or ("join" in s and "cross join" not in s):
|
| 99 |
+
if "on" in s and "customer_id" in s:
|
| 100 |
+
correctness = 1.0
|
| 101 |
+
else:
|
| 102 |
+
correctness = 0.4
|
| 103 |
+
fb.append("JOIN present but ON clause with customer_id is missing.")
|
| 104 |
+
else:
|
| 105 |
+
fb.append("Still uses implicit cross-join or missing JOIN keyword.")
|
| 106 |
+
|
| 107 |
+
# Correctness: must still filter total > 100
|
| 108 |
+
if "total > 100" in s or "total>100" in s:
|
| 109 |
+
correctness = min(correctness + 0.0, correctness) # already captured
|
| 110 |
+
else:
|
| 111 |
+
correctness = max(correctness - 0.3, 0.0)
|
| 112 |
+
fb.append("WHERE o.total > 100 filter has been removed.")
|
| 113 |
+
|
| 114 |
+
# Performance: explicit join is better than implicit cross join
|
| 115 |
+
performance = 1.0 if correctness >= 0.8 else 0.3
|
| 116 |
+
|
| 117 |
+
# Style: uses table aliases
|
| 118 |
+
style = 0.5
|
| 119 |
+
if re.search(r"\bo\b", s) and re.search(r"\bc\b", s):
|
| 120 |
+
style = 1.0
|
| 121 |
+
elif "select *" not in s:
|
| 122 |
+
style = 0.7
|
| 123 |
+
|
| 124 |
+
score = round(correctness * 0.6 + performance * 0.25 + style * 0.15, 3)
|
| 125 |
+
feedback = " ".join(fb) if fb else "Correct! The JOIN is properly formed."
|
| 126 |
+
return GraderResult(
|
| 127 |
+
score=min(max(score, 0.0), 1.0),
|
| 128 |
+
correctness=correctness,
|
| 129 |
+
performance=performance,
|
| 130 |
+
style=style,
|
| 131 |
+
feedback=feedback,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 136 |
+
# Task 2 β Medium: Eliminate N+1 correlated subquery
|
| 137 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββββββββββββ
|
| 138 |
+
|
| 139 |
+
_T2_SCHEMA = """
|
| 140 |
+
CREATE TABLE employees (
|
| 141 |
+
emp_id INT PRIMARY KEY,
|
| 142 |
+
name VARCHAR(255),
|
| 143 |
+
dept_id INT,
|
| 144 |
+
salary DECIMAL(10,2)
|
| 145 |
+
);
|
| 146 |
+
CREATE TABLE departments (
|
| 147 |
+
dept_id INT PRIMARY KEY,
|
| 148 |
+
dept_name VARCHAR(255),
|
| 149 |
+
budget DECIMAL(12,2)
|
| 150 |
+
);
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
_T2_QUERY = """
|
| 154 |
+
SELECT e.name,
|
| 155 |
+
(SELECT d.dept_name
|
| 156 |
+
FROM departments d
|
| 157 |
+
WHERE d.dept_id = e.dept_id) AS dept_name
|
| 158 |
+
FROM employees e
|
| 159 |
+
WHERE e.salary > 50000;
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
_T2_DESC = (
|
| 163 |
+
"The query uses a correlated scalar subquery in the SELECT list that fires "
|
| 164 |
+
"once per row (N+1 problem). Collapse it into a single LEFT JOIN β¦ ON "
|
| 165 |
+
"e.dept_id = d.dept_id, keeping the salary filter."
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def _grade_task2(rewritten: str) -> GraderResult:
|
| 170 |
+
s = _normalise(rewritten)
|
| 171 |
+
fb: list[str] = []
|
| 172 |
+
correctness = 0.0
|
| 173 |
+
performance = 0.0
|
| 174 |
+
style = 0.0
|
| 175 |
+
|
| 176 |
+
# Correctness: correlated subquery in SELECT must be gone
|
| 177 |
+
has_correlated = bool(
|
| 178 |
+
re.search(r"select\s+.*\(\s*select", s)
|
| 179 |
+
or re.search(r"\(\s*select\b.*\bwhere\b.*=\s*e\.", s)
|
| 180 |
+
)
|
| 181 |
+
if has_correlated:
|
| 182 |
+
fb.append("Correlated subquery still present in SELECT list.")
|
| 183 |
+
correctness = 0.1
|
| 184 |
+
else:
|
| 185 |
+
correctness = 0.5
|
| 186 |
+
|
| 187 |
+
# Correctness: must join on dept_id
|
| 188 |
+
if "join" in s and "dept_id" in s and "on" in s:
|
| 189 |
+
correctness = min(correctness + 0.5, 1.0)
|
| 190 |
+
else:
|
| 191 |
+
fb.append("Missing JOIN departments ON dept_id.")
|
| 192 |
+
correctness = max(correctness - 0.1, 0.0)
|
| 193 |
+
|
| 194 |
+
# Correctness: salary filter preserved
|
| 195 |
+
if "salary" not in s or ("salary > 50000" not in s and "salary>50000" not in s):
|
| 196 |
+
correctness = max(correctness - 0.2, 0.0)
|
| 197 |
+
fb.append("salary > 50000 filter is missing or incorrect.")
|
| 198 |
+
|
| 199 |
+
# Performance: single pass vs N+1
|
| 200 |
+
performance = 1.0 if not has_correlated and "join" in s else 0.2
|
| 201 |
+
|
| 202 |
+
# Style: uses aliases, selects explicit columns
|
| 203 |
+
style = 0.5
|
| 204 |
+
if "select *" not in s:
|
| 205 |
+
style += 0.25
|
| 206 |
+
if re.search(r"\be\b|\bd\b", s):
|
| 207 |
+
style += 0.25
|
| 208 |
+
|
| 209 |
+
score = round(correctness * 0.55 + performance * 0.30 + style * 0.15, 3)
|
| 210 |
+
feedback = " ".join(fb) if fb else "Excellent! N+1 eliminated with a clean JOIN."
|
| 211 |
+
return GraderResult(
|
| 212 |
+
score=min(max(score, 0.0), 1.0),
|
| 213 |
+
correctness=correctness,
|
| 214 |
+
performance=performance,
|
| 215 |
+
style=style,
|
| 216 |
+
feedback=feedback,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 221 |
+
# Task 3 β Hard: Full optimisation (4 independent issues)
|
| 222 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 223 |
+
|
| 224 |
+
_T3_SCHEMA = """
|
| 225 |
+
CREATE TABLE products (
|
| 226 |
+
product_id INT PRIMARY KEY,
|
| 227 |
+
name VARCHAR(255),
|
| 228 |
+
category VARCHAR(100),
|
| 229 |
+
price DECIMAL(10,2),
|
| 230 |
+
stock INT
|
| 231 |
+
);
|
| 232 |
+
CREATE TABLE order_items (
|
| 233 |
+
item_id INT PRIMARY KEY,
|
| 234 |
+
order_id INT,
|
| 235 |
+
product_id INT,
|
| 236 |
+
quantity INT,
|
| 237 |
+
unit_price DECIMAL(10,2)
|
| 238 |
+
);
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
_T3_QUERY = """
|
| 242 |
+
SELECT DISTINCT *
|
| 243 |
+
FROM products p
|
| 244 |
+
JOIN order_items oi ON p.product_id = oi.product_id
|
| 245 |
+
WHERE CAST(p.price AS VARCHAR) LIKE '1%'
|
| 246 |
+
AND p.category = 'Electronics'
|
| 247 |
+
ORDER BY p.name;
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
_T3_DESC = (
|
| 251 |
+
"The query has four problems: "
|
| 252 |
+
"(1) DISTINCT is redundant because product_id is PK and the JOIN is 1-to-many β remove it. "
|
| 253 |
+
"(2) SELECT * should list only needed columns: p.name, p.category, p.price, oi.quantity, oi.unit_price. "
|
| 254 |
+
"(3) CAST(p.price AS VARCHAR) LIKE '1%' prevents index use β rewrite as p.price >= 100 AND p.price < 200. "
|
| 255 |
+
"(4) Add a comment hinting an index on (category, price) would help."
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _grade_task3(rewritten: str) -> GraderResult:
|
| 260 |
+
s = _normalise(rewritten)
|
| 261 |
+
fb: list[str] = []
|
| 262 |
+
sub_scores: Dict[str, float] = {}
|
| 263 |
+
|
| 264 |
+
# Sub-criterion 1: DISTINCT removed (0.25)
|
| 265 |
+
if "distinct" not in s:
|
| 266 |
+
sub_scores["no_distinct"] = 0.25
|
| 267 |
+
else:
|
| 268 |
+
sub_scores["no_distinct"] = 0.0
|
| 269 |
+
fb.append("DISTINCT still present β it's redundant here.")
|
| 270 |
+
|
| 271 |
+
# Sub-criterion 2: SELECT * replaced with explicit columns (0.25)
|
| 272 |
+
if "select *" not in s and all(
|
| 273 |
+
col in s for col in ("p.name", "p.price", "oi.quantity")
|
| 274 |
+
):
|
| 275 |
+
sub_scores["explicit_columns"] = 0.25
|
| 276 |
+
elif "select *" not in s:
|
| 277 |
+
sub_scores["explicit_columns"] = 0.15
|
| 278 |
+
fb.append("SELECT * removed but explicit column list is incomplete.")
|
| 279 |
+
else:
|
| 280 |
+
sub_scores["explicit_columns"] = 0.0
|
| 281 |
+
fb.append("SELECT * still used β list explicit columns.")
|
| 282 |
+
|
| 283 |
+
# Sub-criterion 3: CASTβ¦LIKE replaced with range predicate (0.25)
|
| 284 |
+
cast_gone = "cast(" not in s and "cast (" not in s
|
| 285 |
+
has_price_range = (
|
| 286 |
+
("price >= 100" in s or "price>=100" in s)
|
| 287 |
+
and ("price < 200" in s or "price<200" in s)
|
| 288 |
+
)
|
| 289 |
+
if cast_gone and has_price_range:
|
| 290 |
+
sub_scores["sargable"] = 0.25
|
| 291 |
+
elif cast_gone:
|
| 292 |
+
sub_scores["sargable"] = 0.12
|
| 293 |
+
fb.append("CAST removed but price range predicate (>= 100 AND < 200) is missing.")
|
| 294 |
+
else:
|
| 295 |
+
sub_scores["sargable"] = 0.0
|
| 296 |
+
fb.append("CAST(price AS VARCHAR) LIKE β¦ still present β non-sargable predicate.")
|
| 297 |
+
|
| 298 |
+
# Sub-criterion 4: index hint comment present (0.25)
|
| 299 |
+
raw = rewritten.lower()
|
| 300 |
+
if "index" in raw and ("category" in raw or "price" in raw):
|
| 301 |
+
sub_scores["index_hint"] = 0.25
|
| 302 |
+
else:
|
| 303 |
+
sub_scores["index_hint"] = 0.0
|
| 304 |
+
fb.append("Missing comment / hint about adding an index on (category, price).")
|
| 305 |
+
|
| 306 |
+
total = sum(sub_scores.values())
|
| 307 |
+
correctness = min(sub_scores["no_distinct"] + sub_scores["explicit_columns"], 0.5) * 2
|
| 308 |
+
performance = min(sub_scores["sargable"] + sub_scores["index_hint"], 0.5) * 2
|
| 309 |
+
style = 1.0 if "select *" not in s else 0.0
|
| 310 |
+
|
| 311 |
+
feedback = " ".join(fb) if fb else "Perfect optimisation across all four dimensions!"
|
| 312 |
+
return GraderResult(
|
| 313 |
+
score=round(min(max(total, 0.0), 1.0), 3),
|
| 314 |
+
correctness=round(correctness, 3),
|
| 315 |
+
performance=round(performance, 3),
|
| 316 |
+
style=round(style, 3),
|
| 317 |
+
feedback=feedback,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 322 |
+
# Registry
|
| 323 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 324 |
+
|
| 325 |
+
TASKS: Dict[int, TaskDef] = {
|
| 326 |
+
1: TaskDef(
|
| 327 |
+
id=1,
|
| 328 |
+
name="fix-broken-join",
|
| 329 |
+
difficulty="easy",
|
| 330 |
+
description=_T1_DESC,
|
| 331 |
+
query=_T1_QUERY.strip(),
|
| 332 |
+
schema_context=_T1_SCHEMA.strip(),
|
| 333 |
+
hint="Replace the comma-separated FROM list with an explicit INNER JOIN β¦ ON.",
|
| 334 |
+
max_steps=3,
|
| 335 |
+
grader=_grade_task1,
|
| 336 |
+
),
|
| 337 |
+
2: TaskDef(
|
| 338 |
+
id=2,
|
| 339 |
+
name="eliminate-n-plus-one",
|
| 340 |
+
difficulty="medium",
|
| 341 |
+
description=_T2_DESC,
|
| 342 |
+
query=_T2_QUERY.strip(),
|
| 343 |
+
schema_context=_T2_SCHEMA.strip(),
|
| 344 |
+
hint="Move the subquery out of the SELECT list and into a LEFT JOIN.",
|
| 345 |
+
max_steps=4,
|
| 346 |
+
grader=_grade_task2,
|
| 347 |
+
),
|
| 348 |
+
3: TaskDef(
|
| 349 |
+
id=3,
|
| 350 |
+
name="full-optimization",
|
| 351 |
+
difficulty="hard",
|
| 352 |
+
description=_T3_DESC,
|
| 353 |
+
query=_T3_QUERY.strip(),
|
| 354 |
+
schema_context=_T3_SCHEMA.strip(),
|
| 355 |
+
hint=None,
|
| 356 |
+
max_steps=5,
|
| 357 |
+
grader=_grade_task3,
|
| 358 |
+
),
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def get_task(task_id: int) -> TaskDef:
|
| 363 |
+
if task_id not in TASKS:
|
| 364 |
+
raise ValueError(f"Unknown task_id {task_id}. Valid: {list(TASKS.keys())}")
|
| 365 |
+
return TASKS[task_id]
|
hf_login.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Interactive HuggingFace login script
|
| 3 |
+
Usage: python hf_login.py
|
| 4 |
+
"""
|
| 5 |
+
from huggingface_hub import login
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
print("=" * 60)
|
| 9 |
+
print("HuggingFace Hub Login")
|
| 10 |
+
print("=" * 60)
|
| 11 |
+
print("\nYou can authenticate in two ways:")
|
| 12 |
+
print("1. Enter your API token interactively")
|
| 13 |
+
print("2. Set HF_TOKEN environment variable and run with --auto flag")
|
| 14 |
+
print("\nTo get a token, visit: https://huggingface.co/settings/tokens")
|
| 15 |
+
print("=" * 60)
|
| 16 |
+
|
| 17 |
+
token = os.getenv("HF_TOKEN", "").strip()
|
| 18 |
+
|
| 19 |
+
if token:
|
| 20 |
+
print(f"\nUsing token from HF_TOKEN environment variable...")
|
| 21 |
+
try:
|
| 22 |
+
login(token=token)
|
| 23 |
+
print("β Login successful!")
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"β Login failed: {e}")
|
| 26 |
+
else:
|
| 27 |
+
print("\nEnter your HuggingFace token (or type 'quit' to exit):")
|
| 28 |
+
token = input("> ").strip()
|
| 29 |
+
if token.lower() != 'quit':
|
| 30 |
+
try:
|
| 31 |
+
login(token=token)
|
| 32 |
+
print("β Login successful!")
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"β Login failed: {e}")
|
| 35 |
+
else:
|
| 36 |
+
print("Login cancelled.")
|
jj.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
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
|
models.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Top-level model exports for OpenEnv validation compatibility."""
|
| 2 |
+
from env.models import Action, Observation, Reward, RewardBreakdown
|
| 3 |
+
|
| 4 |
+
__all__ = ["Action", "Observation", "Reward", "RewardBreakdown"]
|
openenv.yaml
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sql-query-optimizer
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
An OpenEnv environment where AI agents learn to review, rewrite, and optimise
|
| 5 |
+
SQL queries for correctness and performance. Covers three real-world failure
|
| 6 |
+
patterns: implicit cross-joins, N+1 subqueries, and multi-dimensional query
|
| 7 |
+
anti-patterns.
|
| 8 |
+
author: metaXscaler
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
- sql
|
| 12 |
+
- code-review
|
| 13 |
+
- data-engineering
|
| 14 |
+
- database
|
| 15 |
+
tasks:
|
| 16 |
+
- id: 1
|
| 17 |
+
name: fix-broken-join
|
| 18 |
+
difficulty: easy
|
| 19 |
+
description: >
|
| 20 |
+
The agent must replace an implicit cross-join (comma syntax) with an
|
| 21 |
+
explicit INNER JOIN ... ON clause.
|
| 22 |
+
- id: 2
|
| 23 |
+
name: eliminate-n-plus-one
|
| 24 |
+
difficulty: medium
|
| 25 |
+
description: >
|
| 26 |
+
The agent must remove a correlated scalar subquery in the SELECT list
|
| 27 |
+
and replace it with a single LEFT JOIN.
|
| 28 |
+
- id: 3
|
| 29 |
+
name: full-optimization
|
| 30 |
+
difficulty: hard
|
| 31 |
+
description: >
|
| 32 |
+
The agent must fix four independent issues: remove redundant DISTINCT,
|
| 33 |
+
replace SELECT *, eliminate a non-sargable CAST predicate, and add an
|
| 34 |
+
index hint comment.
|
| 35 |
+
observation:
|
| 36 |
+
type: object
|
| 37 |
+
fields:
|
| 38 |
+
task_id: integer
|
| 39 |
+
task_name: string
|
| 40 |
+
task_description: string
|
| 41 |
+
query: string
|
| 42 |
+
schema_context: string
|
| 43 |
+
hint: "string | null"
|
| 44 |
+
step_number: integer
|
| 45 |
+
max_steps: integer
|
| 46 |
+
done: boolean
|
| 47 |
+
action:
|
| 48 |
+
type: object
|
| 49 |
+
fields:
|
| 50 |
+
rewritten_query: string
|
| 51 |
+
explanation: string
|
| 52 |
+
is_done: boolean
|
| 53 |
+
reward:
|
| 54 |
+
type: object
|
| 55 |
+
fields:
|
| 56 |
+
score: "float [0.0, 1.0]"
|
| 57 |
+
grader_score: "float [0.0, 1.0]"
|
| 58 |
+
breakdown:
|
| 59 |
+
correctness: "float [0.0, 1.0]"
|
| 60 |
+
performance: "float [0.0, 1.0]"
|
| 61 |
+
style: "float [0.0, 1.0]"
|
| 62 |
+
step_penalty: "float β€ 0.0"
|
| 63 |
+
feedback: string
|
| 64 |
+
cumulative_score: "float [0.0, 1.0]"
|
| 65 |
+
endpoints:
|
| 66 |
+
- path: /reset
|
| 67 |
+
method: POST
|
| 68 |
+
description: Start a fresh episode for a given task_id
|
| 69 |
+
- path: /step
|
| 70 |
+
method: POST
|
| 71 |
+
description: Submit an Action and advance the episode
|
| 72 |
+
- path: /state
|
| 73 |
+
method: GET
|
| 74 |
+
description: Return the current internal state snapshot
|
| 75 |
+
- path: /tasks
|
| 76 |
+
method: GET
|
| 77 |
+
description: List all tasks and action schema
|
| 78 |
+
- path: /grader
|
| 79 |
+
method: GET
|
| 80 |
+
description: Return grader score for the last completed episode
|
| 81 |
+
- path: /baseline
|
| 82 |
+
method: POST
|
| 83 |
+
description: Trigger baseline inference on all 3 tasks
|
pyproject.toml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "sql-query-optimizer-openenv"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "An OpenEnv environment where AI agents learn to review, rewrite, and optimise SQL queries for correctness and performance."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
authors = [
|
| 12 |
+
{name = "metaXscaler", email = ""}
|
| 13 |
+
]
|
| 14 |
+
license = {text = "MIT"}
|
| 15 |
+
keywords = ["openenv", "sql", "optimization", "ml", "agent", "environment"]
|
| 16 |
+
classifiers = [
|
| 17 |
+
"Development Status :: 4 - Beta",
|
| 18 |
+
"Intended Audience :: Developers",
|
| 19 |
+
"Intended Audience :: Science/Research",
|
| 20 |
+
"License :: OSI Approved :: MIT License",
|
| 21 |
+
"Programming Language :: Python :: 3",
|
| 22 |
+
"Programming Language :: Python :: 3.10",
|
| 23 |
+
"Programming Language :: Python :: 3.11",
|
| 24 |
+
"Programming Language :: Python :: 3.12",
|
| 25 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
dependencies = [
|
| 29 |
+
"fastapi>=0.111.0",
|
| 30 |
+
"uvicorn[standard]>=0.29.0",
|
| 31 |
+
"pydantic>=2.7.0",
|
| 32 |
+
"openai>=1.30.0",
|
| 33 |
+
"pyyaml>=6.0",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
[project.optional-dependencies]
|
| 37 |
+
dev = [
|
| 38 |
+
"pytest>=7.0",
|
| 39 |
+
"black>=23.0",
|
| 40 |
+
"ruff>=0.1.0",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
[project.urls]
|
| 44 |
+
Homepage = "https://huggingface.co/spaces"
|
| 45 |
+
Repository = "https://github.com/metaXscaler/sql-query-optimizer-openenv"
|
| 46 |
+
Documentation = "https://github.com/metaXscaler/sql-query-optimizer-openenv/blob/main/README.md"
|
| 47 |
+
|
| 48 |
+
[tool.black]
|
| 49 |
+
line-length = 100
|
| 50 |
+
target-version = ['py310', 'py311', 'py312']
|
| 51 |
+
|
| 52 |
+
[tool.ruff]
|
| 53 |
+
line-length = 100
|
| 54 |
+
target-version = "py310"
|
| 55 |
+
select = ["E", "F", "W"]
|
| 56 |
+
ignore = ["E501"] # Line too long (handled by black)
|
| 57 |
+
|
| 58 |
+
[tool.pytest.ini_options]
|
| 59 |
+
testpaths = ["tests"]
|
| 60 |
+
python_files = ["test_*.py"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.111.0
|
| 2 |
+
uvicorn[standard]>=0.29.0
|
| 3 |
+
pydantic>=2.7.0
|
| 4 |
+
openai>=1.30.0
|
| 5 |
+
pyyaml>=6.0
|
server/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI server package for SQL Query Optimizer OpenEnv environment."""
|
| 2 |
+
from .app import app
|
| 3 |
+
|
| 4 |
+
__all__ = ["app"]
|
server/app.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI server exposing the OpenEnv SQL Optimizer environment.
|
| 3 |
+
|
| 4 |
+
Endpoints:
|
| 5 |
+
POST /reset β Observation
|
| 6 |
+
POST /step β {observation, reward, done, info}
|
| 7 |
+
GET /state β state dict
|
| 8 |
+
GET /tasks β list of tasks + action schema
|
| 9 |
+
GET /grader β grader score for last completed episode
|
| 10 |
+
POST /baseline β trigger baseline inference on all 3 tasks
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import subprocess
|
| 16 |
+
import sys
|
| 17 |
+
from typing import Any, Dict, Optional
|
| 18 |
+
|
| 19 |
+
from fastapi import FastAPI, HTTPException
|
| 20 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 21 |
+
from pydantic import BaseModel
|
| 22 |
+
|
| 23 |
+
from env.environment import SQLOptimizerEnv
|
| 24 |
+
from env.models import Action, Observation, Reward
|
| 25 |
+
from env.tasks import TASKS
|
| 26 |
+
|
| 27 |
+
app = FastAPI(
|
| 28 |
+
title="SQL Query Optimizer β OpenEnv",
|
| 29 |
+
description=(
|
| 30 |
+
"An OpenEnv-compliant environment where AI agents learn to rewrite "
|
| 31 |
+
"and optimise SQL queries across three difficulty levels."
|
| 32 |
+
),
|
| 33 |
+
version="1.0.0",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
app.add_middleware(
|
| 37 |
+
CORSMiddleware,
|
| 38 |
+
allow_origins=["*"],
|
| 39 |
+
allow_methods=["*"],
|
| 40 |
+
allow_headers=["*"],
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Single shared environment instance (stateful, per-process)
|
| 44 |
+
_env = SQLOptimizerEnv()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 48 |
+
# Request / Response schemas
|
| 49 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 50 |
+
|
| 51 |
+
class ResetRequest(BaseModel):
|
| 52 |
+
task_id: int = 1
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class StepResponse(BaseModel):
|
| 56 |
+
observation: Observation
|
| 57 |
+
reward: Reward
|
| 58 |
+
done: bool
|
| 59 |
+
info: Dict[str, Any]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class GraderResponse(BaseModel):
|
| 63 |
+
task_id: Optional[int]
|
| 64 |
+
grader_score: float
|
| 65 |
+
cumulative_score: float
|
| 66 |
+
done: bool
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class TaskInfo(BaseModel):
|
| 70 |
+
id: int
|
| 71 |
+
name: str
|
| 72 |
+
difficulty: str
|
| 73 |
+
description: str
|
| 74 |
+
action_schema: Dict[str, Any]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class BaselineResponse(BaseModel):
|
| 78 |
+
task_results: Dict[str, float]
|
| 79 |
+
message: str
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 83 |
+
# Endpoints
|
| 84 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 85 |
+
|
| 86 |
+
@app.get("/", summary="Health check")
|
| 87 |
+
def health() -> Dict[str, str]:
|
| 88 |
+
return {"status": "ok", "environment": "sql-query-optimizer", "version": "1.0.0"}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@app.post("/reset", response_model=Observation, summary="Start / restart an episode")
|
| 92 |
+
def reset(req: ResetRequest) -> Observation:
|
| 93 |
+
"""Reset the environment for a given task_id (1=easy, 2=medium, 3=hard)."""
|
| 94 |
+
try:
|
| 95 |
+
obs = _env.reset(task_id=req.task_id)
|
| 96 |
+
except ValueError as exc:
|
| 97 |
+
raise HTTPException(status_code=400, detail=str(exc))
|
| 98 |
+
return obs
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@app.post("/step", response_model=StepResponse, summary="Submit an action")
|
| 102 |
+
def step(action: Action) -> StepResponse:
|
| 103 |
+
"""Advance the environment by submitting an Action."""
|
| 104 |
+
try:
|
| 105 |
+
obs, reward, done, info = _env.step(action)
|
| 106 |
+
except RuntimeError as exc:
|
| 107 |
+
raise HTTPException(status_code=400, detail=str(exc))
|
| 108 |
+
return StepResponse(observation=obs, reward=reward, done=done, info=info)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@app.get("/state", summary="Return current internal state")
|
| 112 |
+
def state() -> Dict[str, Any]:
|
| 113 |
+
"""Return the current internal state of the environment."""
|
| 114 |
+
return _env.state()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@app.get("/tasks", response_model=list[TaskInfo], summary="List tasks + action schema")
|
| 118 |
+
def list_tasks() -> list[TaskInfo]:
|
| 119 |
+
"""Return all tasks with descriptions and the action schema."""
|
| 120 |
+
action_schema = Action.model_json_schema()
|
| 121 |
+
return [
|
| 122 |
+
TaskInfo(
|
| 123 |
+
id=t.id,
|
| 124 |
+
name=t.name,
|
| 125 |
+
difficulty=t.difficulty,
|
| 126 |
+
description=t.description,
|
| 127 |
+
action_schema=action_schema,
|
| 128 |
+
)
|
| 129 |
+
for t in TASKS.values()
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@app.get("/grader", response_model=GraderResponse, summary="Grader score for last episode")
|
| 134 |
+
def grader() -> GraderResponse:
|
| 135 |
+
"""Return the grader score after the current/last episode."""
|
| 136 |
+
s = _env.state()
|
| 137 |
+
if s.get("status") == "not_started":
|
| 138 |
+
raise HTTPException(status_code=400, detail="No episode started. Call /reset first.")
|
| 139 |
+
return GraderResponse(
|
| 140 |
+
task_id=s.get("task_id"),
|
| 141 |
+
grader_score=s.get("last_grader_score", 0.0),
|
| 142 |
+
cumulative_score=s.get("cumulative_score", 0.0),
|
| 143 |
+
done=s.get("done", False),
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@app.post("/baseline", response_model=BaselineResponse, summary="Run baseline inference on all tasks")
|
| 148 |
+
def baseline() -> BaselineResponse:
|
| 149 |
+
"""
|
| 150 |
+
Trigger the baseline inference script (baseline.py) and return scores.
|
| 151 |
+
Requires OPENAI_API_KEY to be set in the environment.
|
| 152 |
+
"""
|
| 153 |
+
if not os.getenv("OPENAI_API_KEY"):
|
| 154 |
+
raise HTTPException(
|
| 155 |
+
status_code=400,
|
| 156 |
+
detail="OPENAI_API_KEY environment variable not set. Cannot run baseline.",
|
| 157 |
+
)
|
| 158 |
+
try:
|
| 159 |
+
result = subprocess.run(
|
| 160 |
+
[sys.executable, "baseline.py", "--json"],
|
| 161 |
+
capture_output=True,
|
| 162 |
+
text=True,
|
| 163 |
+
timeout=300,
|
| 164 |
+
)
|
| 165 |
+
if result.returncode != 0:
|
| 166 |
+
raise HTTPException(
|
| 167 |
+
status_code=500,
|
| 168 |
+
detail=f"Baseline script failed:\n{result.stderr}",
|
| 169 |
+
)
|
| 170 |
+
import json
|
| 171 |
+
scores = json.loads(result.stdout)
|
| 172 |
+
return BaselineResponse(task_results=scores, message="Baseline completed successfully.")
|
| 173 |
+
except subprocess.TimeoutExpired:
|
| 174 |
+
raise HTTPException(status_code=500, detail="Baseline script timed out after 300s.")
|
| 175 |
+
except Exception as exc:
|
| 176 |
+
raise HTTPException(status_code=500, detail=str(exc))
|
sql-query-optimizer/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
sql-query-optimizer/README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Sql Query Optimizer
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
short_description: SQL Query Optimizer β OpenEnv Environment
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
test_env.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Quick smoke test for all 3 tasks."""
|
| 2 |
+
import sys, json
|
| 3 |
+
sys.path.insert(0, ".")
|
| 4 |
+
|
| 5 |
+
from env.environment import SQLOptimizerEnv
|
| 6 |
+
from env.models import Action
|
| 7 |
+
|
| 8 |
+
env = SQLOptimizerEnv()
|
| 9 |
+
|
| 10 |
+
# ββ Task 1 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 11 |
+
print("=== Task 1 (Easy): fix-broken-join ===")
|
| 12 |
+
obs = env.reset(1)
|
| 13 |
+
print(f" task: {obs.task_name}")
|
| 14 |
+
action = Action(
|
| 15 |
+
rewritten_query=(
|
| 16 |
+
"SELECT o.order_id, c.name, o.total "
|
| 17 |
+
"FROM orders o INNER JOIN customers c ON o.customer_id = c.customer_id "
|
| 18 |
+
"WHERE o.total > 100"
|
| 19 |
+
),
|
| 20 |
+
explanation="Replaced comma cross-join with INNER JOIN ON customer_id",
|
| 21 |
+
is_done=True,
|
| 22 |
+
)
|
| 23 |
+
obs2, reward, done, info = env.step(action)
|
| 24 |
+
print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}")
|
| 25 |
+
print(f" feedback: {reward.feedback}")
|
| 26 |
+
assert obs2.done == True, "done should be True"
|
| 27 |
+
assert info["grader_score"] >= 0.8, f"Expected >=0.8, got {info['grader_score']}"
|
| 28 |
+
|
| 29 |
+
# ββ Task 2 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
+
print()
|
| 31 |
+
print("=== Task 2 (Medium): eliminate-n-plus-one ===")
|
| 32 |
+
obs = env.reset(2)
|
| 33 |
+
print(f" task: {obs.task_name}")
|
| 34 |
+
action = Action(
|
| 35 |
+
rewritten_query=(
|
| 36 |
+
"SELECT e.name, d.dept_name "
|
| 37 |
+
"FROM employees e "
|
| 38 |
+
"LEFT JOIN departments d ON e.dept_id = d.dept_id "
|
| 39 |
+
"WHERE e.salary > 50000"
|
| 40 |
+
),
|
| 41 |
+
explanation="Replaced correlated subquery with a single LEFT JOIN",
|
| 42 |
+
is_done=True,
|
| 43 |
+
)
|
| 44 |
+
obs2, reward, done, info = env.step(action)
|
| 45 |
+
print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}")
|
| 46 |
+
print(f" feedback: {reward.feedback}")
|
| 47 |
+
assert info["grader_score"] >= 0.7, f"Expected >=0.7, got {info['grader_score']}"
|
| 48 |
+
|
| 49 |
+
# ββ Task 3 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 50 |
+
print()
|
| 51 |
+
print("=== Task 3 (Hard): full-optimization ===")
|
| 52 |
+
obs = env.reset(3)
|
| 53 |
+
print(f" task: {obs.task_name}")
|
| 54 |
+
action = Action(
|
| 55 |
+
rewritten_query=(
|
| 56 |
+
"-- Index hint: consider CREATE INDEX ON products(category, price)\n"
|
| 57 |
+
"SELECT p.name, p.category, p.price, oi.quantity, oi.unit_price\n"
|
| 58 |
+
"FROM products p\n"
|
| 59 |
+
"JOIN order_items oi ON p.product_id = oi.product_id\n"
|
| 60 |
+
"WHERE p.price >= 100 AND p.price < 200\n"
|
| 61 |
+
" AND p.category = 'Electronics'\n"
|
| 62 |
+
"ORDER BY p.name"
|
| 63 |
+
),
|
| 64 |
+
explanation="Removed DISTINCT and SELECT *, replaced CAST LIKE with range, added index hint",
|
| 65 |
+
is_done=True,
|
| 66 |
+
)
|
| 67 |
+
obs2, reward, done, info = env.step(action)
|
| 68 |
+
print(f" grader_score={info['grader_score']:.3f} step_reward={reward.score:.4f} done={done}")
|
| 69 |
+
print(f" feedback: {reward.feedback}")
|
| 70 |
+
assert info["grader_score"] >= 0.9, f"Expected >=0.9, got {info['grader_score']}"
|
| 71 |
+
|
| 72 |
+
# ββ state() βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
+
print()
|
| 74 |
+
print("=== state() ===")
|
| 75 |
+
print(json.dumps(env.state(), indent=2))
|
| 76 |
+
|
| 77 |
+
# ββ invalid action penalty βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
+
print()
|
| 79 |
+
print("=== Invalid action test ===")
|
| 80 |
+
env.reset(1)
|
| 81 |
+
obs2, reward, done, info = env.step(Action(rewritten_query="", explanation="", is_done=False))
|
| 82 |
+
print(f" step_reward={reward.score} is_invalid={info['is_invalid']}")
|
| 83 |
+
assert info["is_invalid"] == True, "Empty query should be flagged invalid"
|
| 84 |
+
|
| 85 |
+
print()
|
| 86 |
+
print("ALL TESTS PASSED")
|