jaivardhan2409 commited on
Commit
e4c32ce
·
verified ·
1 Parent(s): 5abdb9a

Upload folder using huggingface_hub

Browse files
Files changed (17) hide show
  1. Dockerfile +8 -0
  2. README.md +48 -6
  3. __init__.py +1 -0
  4. baseline.py +71 -0
  5. client.py +56 -0
  6. env/__init__.py +0 -0
  7. env/environment.py +106 -0
  8. env/models.py +20 -0
  9. env/reward.py +27 -0
  10. env/tasks.py +137 -0
  11. models.py +20 -0
  12. openenv.yaml +14 -0
  13. pyproject.toml +22 -0
  14. requirements.txt +6 -0
  15. server/__init__.py +1 -0
  16. server/app.py +72 -0
  17. uv.lock +0 -0
Dockerfile ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+ WORKDIR /app
3
+ COPY requirements.txt .
4
+ RUN pip install --no-cache-dir -r requirements.txt
5
+ COPY . .
6
+ EXPOSE 7860
7
+ ENV ENABLE_WEB_INTERFACE=true
8
+ CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,52 @@
1
  ---
2
- title: Sql Query Optimizer
3
- emoji: 👀
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: docker
7
- pinned: false
 
8
  ---
 
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SQL Query Optimizer
3
+ emoji: 🦀
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
+ app_port: 7860
8
+ base_path: /web
9
  ---
10
+ # SQL Query Optimizer OpenEnv
11
 
12
+ ## Environment Description & Motivation
13
+ This domain models a genuine, high-value task performed by data engineers and DBAs every day: reviewing and optimizing slow SQL queries. Instead of a toy environment, this is a real-world task where the agent must rewrite SQL queries to be syntactically correct and performant.
14
+
15
+ ## Action Space
16
+ - `rewritten_query` (str): The optimized SQL query.
17
+ - `explanation` (str): A brief explanation of the changes made and why they improve the query.
18
+ - `is_done` (bool): Set to true if finished to submit query for final scoring.
19
+
20
+ ## Observation Space
21
+ - `task_id` (int): The ID of the task to perform.
22
+ - `query` (str): The SQL query to review and optimize.
23
+ - `schema_context` (str): Database schema context (CREATE statements).
24
+ - `hint` (str): Optional natural-language hints.
25
+ - `step_number` (int): Current step in the episode.
26
+ - `max_steps` (int): Maximum allowed steps.
27
+
28
+ ## Tasks
29
+ 1. **fix-broken-join (Easy)**: Identify and repair a query with an issue such as a missing ON clause.
30
+ 2. **eliminate-n-plus-one (Medium)**: Remove correlated subqueries and replace them with properly structured JOINs.
31
+ 3. **full-optimization (Hard)**: Remove redundant DISTINCT clauses, avoid SELECT *, use index hints, and fix implicit type casts in a more complex query.
32
+
33
+ ## Setup & Testing
34
+ ```bash
35
+ # Verify using openenv
36
+ openenv validate
37
+
38
+ # Local testing
39
+ uvicorn server:app --host 0.0.0.0 --port 7860
40
+
41
+ # Docker build
42
+ docker build -t sql-optimizer-env .
43
+ docker run -p 7860:7860 sql-optimizer-env
44
+ ```
45
+
46
+ ## Baseline Evaluation
47
+ A provided `baseline.py` script replicates inference.
48
+ Usage:
49
+ ```bash
50
+ export OPENAI_API_KEY=sk-...
51
+ python baseline.py
52
+ ```
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Root package
baseline.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from openai import OpenAI
3
+ from env.environment import SQLEnv
4
+ from env.models import Action
5
+
6
+ def run_task(env: SQLEnv, task_id: int) -> float:
7
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
8
+ obs = env.reset(task_id)
9
+
10
+ messages = [
11
+ {"role": "system", "content": "You are an expert SQL DBA. You rewrite SQL queries to be correct, optimized, and performant."}
12
+ ]
13
+
14
+ prompt = f"""
15
+ Task # {obs.task_id}
16
+ Original Query: {obs.query}
17
+ Database Schema Context: {obs.schema_context}
18
+ Hint: {obs.hint}
19
+
20
+ Please provide the optimized query. Output ONLY the raw SQL query, no markdown formatting, no explanation.
21
+ """
22
+
23
+ messages.append({"role": "user", "content": prompt.strip()})
24
+
25
+ try:
26
+ response = client.chat.completions.create(
27
+ model="gpt-3.5-turbo",
28
+ messages=messages,
29
+ temperature=0.0
30
+ )
31
+ rewritten_query = response.choices[0].message.content.strip()
32
+ if rewritten_query.startswith("```sql"):
33
+ rewritten_query = rewritten_query[6:]
34
+ if rewritten_query.endswith("```"):
35
+ rewritten_query = rewritten_query[:-3]
36
+ rewritten_query = rewritten_query.strip()
37
+ except Exception as e:
38
+ print(f"Error calling OpenAI API: {e}")
39
+ rewritten_query = obs.query
40
+
41
+ action = Action(
42
+ rewritten_query=rewritten_query,
43
+ explanation="Baseline inference using LLM",
44
+ is_done=True
45
+ )
46
+
47
+ _, reward, done, info = env.step(action)
48
+ return env.final_grader_score
49
+
50
+ def run_all_tasks():
51
+ if not os.environ.get("OPENAI_API_KEY"):
52
+ raise ValueError("OPENAI_API_KEY environment variable is required.")
53
+
54
+ env = SQLEnv()
55
+ scores = {}
56
+ for task_id in [1, 2, 3]:
57
+ print(f"Running baseline for Task {task_id}...")
58
+ score = run_task(env, task_id)
59
+ scores[task_id] = score
60
+ print(f"Task {task_id} Grader Score: {score}")
61
+
62
+ return scores
63
+
64
+ if __name__ == "__main__":
65
+ try:
66
+ scores = run_all_tasks()
67
+ print("\nBaseline Evaluation Results:")
68
+ for t, s in scores.items():
69
+ print(f"Task {t}: {s}/1.0")
70
+ except Exception as e:
71
+ print(f"Baseline Evaluation Failed: {e}")
client.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """SQL Query Optimizer Client."""
8
+
9
+ from typing import Dict
10
+
11
+ from openenv.core import EnvClient
12
+ from openenv.core.client_types import StepResult
13
+ from openenv.core.env_server.types import State
14
+
15
+ from models import Action, Observation
16
+
17
+
18
+ class SQLEnvClient(
19
+ EnvClient[Action, Observation, State]
20
+ ):
21
+ """
22
+ Client for the SQL Query Optimizer Environment.
23
+
24
+ This client maintains a persistent WebSocket connection to the environment server.
25
+ """
26
+
27
+ def _step_payload(self, action: Action) -> Dict:
28
+ """
29
+ Convert Action to JSON payload for step message.
30
+ """
31
+ return action.model_dump()
32
+
33
+ def _parse_result(self, payload: Dict) -> StepResult[Observation]:
34
+ """
35
+ Parse server response into StepResult[Observation].
36
+ """
37
+ obs_data = payload.get("observation", {})
38
+ observation = Observation(**obs_data)
39
+
40
+ # Get reward payload properly (whether it's a dict or primitive)
41
+ reward_data = payload.get("reward")
42
+
43
+ return StepResult(
44
+ observation=observation,
45
+ reward=reward_data,
46
+ done=payload.get("done", False),
47
+ )
48
+
49
+ def _parse_state(self, payload: Dict) -> State:
50
+ """
51
+ Parse server response into State object.
52
+ """
53
+ return State(
54
+ episode_id=payload.get("episode_id"),
55
+ step_count=payload.get("step_count", 0),
56
+ )
env/__init__.py ADDED
File without changes
env/environment.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Dict, Any, List
2
+ from .models import Observation, Action, Reward
3
+ from .tasks import TASKS, grade_action, get_task
4
+ from .reward import compute_reward
5
+
6
+ class SQLEnv:
7
+ def __init__(self):
8
+ self.current_task_id = None
9
+ self.task = None
10
+ self.step_number = 0
11
+ self.max_steps = 0
12
+ self.history = []
13
+ self.cumulative_score = 0.0
14
+ self.previous_grader_score = 0.0
15
+ self.final_grader_score = 0.0
16
+
17
+ def reset(self, task_id: int) -> Observation:
18
+ task = get_task(task_id)
19
+ if not task:
20
+ raise ValueError(f"Task {task_id} not found.")
21
+
22
+ self.current_task_id = task_id
23
+ self.task = task
24
+ self.step_number = 1
25
+ self.max_steps = task["max_steps"]
26
+ self.history = []
27
+ self.cumulative_score = 0.0
28
+ self.previous_grader_score = 0.0
29
+ self.final_grader_score = 0.0
30
+
31
+ obs = Observation(
32
+ task_id=self.current_task_id,
33
+ query=self.task["initial_query"],
34
+ schema_context=self.task["schema_context"],
35
+ hint=self.task["hint"],
36
+ step_number=self.step_number,
37
+ max_steps=self.max_steps
38
+ )
39
+ self.history.append({"step": 0, "type": "reset", "observation": obs.model_dump()})
40
+ return obs
41
+
42
+ def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
43
+ if not self.task:
44
+ raise RuntimeError("Environment not initialized. Call reset() first.")
45
+
46
+ grader_score, breakdown, feedback = grade_action(self.current_task_id, action.rewritten_query)
47
+ action_valid = len(action.rewritten_query.strip()) > 0
48
+
49
+ done = action.is_done or self.step_number >= self.max_steps
50
+
51
+ step_reward = compute_reward(
52
+ grader_score=grader_score,
53
+ previous_score=self.previous_grader_score,
54
+ step_number=self.step_number,
55
+ max_steps=self.max_steps,
56
+ is_done=done,
57
+ action_valid=action_valid
58
+ )
59
+
60
+ self.cumulative_score += step_reward
61
+ self.previous_grader_score = grader_score
62
+
63
+ reward = Reward(
64
+ score=step_reward,
65
+ breakdown=breakdown,
66
+ feedback=feedback
67
+ )
68
+
69
+ obs = Observation(
70
+ task_id=self.current_task_id,
71
+ query=action.rewritten_query,
72
+ schema_context=self.task["schema_context"],
73
+ hint=self.task["hint"],
74
+ step_number=self.step_number + 1,
75
+ max_steps=self.max_steps
76
+ )
77
+
78
+ info = {
79
+ "cumulative_score": self.cumulative_score,
80
+ "grader_score": grader_score
81
+ }
82
+
83
+ if done:
84
+ self.final_grader_score = grader_score
85
+
86
+ self.history.append({
87
+ "step": self.step_number,
88
+ "type": "step",
89
+ "action": action.model_dump(),
90
+ "reward": reward.model_dump(),
91
+ "done": done,
92
+ "info": info
93
+ })
94
+
95
+ self.step_number += 1
96
+ return obs, reward, done, info
97
+
98
+ def state(self) -> Dict[str, Any]:
99
+ return {
100
+ "current_task_id": self.current_task_id,
101
+ "step_number": self.step_number,
102
+ "max_steps": self.max_steps,
103
+ "cumulative_score": self.cumulative_score,
104
+ "final_grader_score": self.final_grader_score,
105
+ "history": self.history
106
+ }
env/models.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict
2
+ from pydantic import BaseModel, Field
3
+
4
+ class Observation(BaseModel):
5
+ task_id: int = Field(description="The ID of the task to perform.")
6
+ query: str = Field(description="The SQL query to review and optimize.")
7
+ schema_context: str = Field(description="The database schema context for the query, such as CREATE TABLE statements.")
8
+ hint: Optional[str] = Field(default=None, description="An optional natural-language hint or description of the problem.")
9
+ step_number: int = Field(description="The current step number in the episode (1-indexed).")
10
+ max_steps: int = Field(description="The maximum allowed steps for this task.")
11
+
12
+ class Action(BaseModel):
13
+ rewritten_query: str = Field(description="The rewritten, optimized SQL query.")
14
+ explanation: str = Field(description="A brief explanation of the changes made and why they improve the query.")
15
+ is_done: bool = Field(description="Set to true if you are finished and want to submit the query for final scoring.")
16
+
17
+ class Reward(BaseModel):
18
+ score: float = Field(description="The overall score for the episode (0.0 to 1.0).")
19
+ breakdown: Dict[str, float] = Field(default_factory=dict, description="A breakdown of the score by sub-criteria.")
20
+ feedback: str = Field(description="Specific feedback on the rewritten query or action taken.")
env/reward.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ def compute_reward(grader_score: float, previous_score: float, step_number: int, max_steps: int, is_done: bool, action_valid: bool) -> float:
4
+ """
5
+ Computes a shaped reward based on the problem statement requirements:
6
+ - Partial credit per step: +0.0-0.5 for incremental improvement
7
+ - Completion bonus: +0.5 if grader score >= 0.8 when is_done=True
8
+ - Step penalty: -0.02 per unnecessary step (> task's min required steps)
9
+ - Invalid action penalty: -0.1 for empty/unparseable queries
10
+ - Total clamped to [0.0, 1.0]
11
+ """
12
+ reward = 0.0
13
+
14
+ if not action_valid:
15
+ return -0.1
16
+
17
+ improvement = max(0.0, grader_score - previous_score)
18
+ # Give partial credit up to 0.5 based on improvement
19
+ reward += improvement * 0.5
20
+
21
+ if is_done and grader_score >= 0.8:
22
+ reward += 0.5
23
+
24
+ if step_number > max_steps:
25
+ reward -= 0.02 * (step_number - max_steps)
26
+
27
+ return max(0.0, min(1.0, reward))
env/tasks.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlglot
2
+ from sqlglot import exp
3
+ from typing import Dict, Any, Tuple
4
+
5
+ TASKS = {
6
+ 1: {
7
+ "name": "fix-broken-join",
8
+ "difficulty": "easy",
9
+ "schema_context": "CREATE TABLE users (id INT, name VARCHAR); CREATE TABLE orders (id INT, user_id INT, amount DECIMAL);",
10
+ "hint": "The query is trying to join users and orders, but it is missing the ON clause, creating a cross join.",
11
+ "initial_query": "SELECT users.name, orders.amount FROM users JOIN orders;",
12
+ "max_steps": 3,
13
+ },
14
+ 2: {
15
+ "name": "eliminate-n-plus-one",
16
+ "difficulty": "medium",
17
+ "schema_context": "CREATE TABLE employees (id INT, dept_id INT, name VARCHAR); CREATE TABLE departments (id INT, name VARCHAR);",
18
+ "hint": "The query uses a correlated subquery in the WHERE clause. Rewrite it using a JOIN to improve performance.",
19
+ "initial_query": "SELECT e.name FROM employees e WHERE e.dept_id IN (SELECT d.id FROM departments d WHERE d.name = 'Engineering');",
20
+ "max_steps": 4,
21
+ },
22
+ 3: {
23
+ "name": "full-optimization",
24
+ "difficulty": "hard",
25
+ "schema_context": "CREATE TABLE sales (id INT, product_id INT, sale_date DATE, amount DECIMAL); CREATE INDEX idx_sales_date ON sales(sale_date);",
26
+ "hint": "Optimize the query: remove redundant DISTINCT, avoid SELECT *, use index hint if applicable, and fix implicit type casts.",
27
+ "initial_query": "SELECT DISTINCT * FROM sales s WHERE CAST(s.sale_date AS VARCHAR) = '2023-01-01';",
28
+ "max_steps": 5,
29
+ }
30
+ }
31
+
32
+ def grade_task_1(rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
33
+ try:
34
+ parsed = sqlglot.parse_one(rewritten_query, read="postgres")
35
+ except Exception as e:
36
+ return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}"
37
+
38
+ score = 0.0
39
+ feedback = []
40
+ breakdown = {}
41
+
42
+ joins = list(parsed.find_all(exp.Join))
43
+ if not joins:
44
+ return 0.0, {"missing_join": 1.0}, "No JOIN found in the query."
45
+
46
+ join = joins[0]
47
+ if join.args.get("on"):
48
+ score += 1.0
49
+ breakdown["has_on_clause"] = 1.0
50
+ feedback.append("Successfully added the ON clause.")
51
+ else:
52
+ breakdown["has_on_clause"] = 0.0
53
+ feedback.append("The JOIN is still missing an ON clause.")
54
+
55
+ return score, breakdown, " ".join(feedback)
56
+
57
+ def grade_task_2(rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
58
+ try:
59
+ parsed = sqlglot.parse_one(rewritten_query, read="postgres")
60
+ except Exception as e:
61
+ return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}"
62
+
63
+ score = 0.0
64
+ breakdown = {}
65
+ feedback = []
66
+
67
+ subqueries = list(parsed.find_all(exp.Subquery))
68
+ if not subqueries and not list(parsed.find_all(exp.In)):
69
+ score += 0.5
70
+ breakdown["removed_correlated_subquery"] = 0.5
71
+ feedback.append("Removed correlated subquery.")
72
+ else:
73
+ breakdown["removed_correlated_subquery"] = 0.0
74
+ feedback.append("Correlated subquery still present.")
75
+
76
+ joins = list(parsed.find_all(exp.Join))
77
+ if joins:
78
+ score += 0.5
79
+ breakdown["added_join"] = 0.5
80
+ feedback.append("Added JOIN successfully.")
81
+ else:
82
+ breakdown["added_join"] = 0.0
83
+ feedback.append("Missing JOIN.")
84
+
85
+ return score, breakdown, " ".join(feedback)
86
+
87
+ def grade_task_3(rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
88
+ try:
89
+ parsed = sqlglot.parse_one(rewritten_query, read="postgres")
90
+ except Exception as e:
91
+ return 0.0, {"parse_error": 1.0}, f"Query could not be parsed: {e}"
92
+
93
+ score = 0.0
94
+ breakdown = {"no_distinct": 0.0, "no_select_star": 0.0, "fixed_cast": 0.0, "has_index_hint": 0.0}
95
+ feedback = []
96
+
97
+ if not parsed.args.get("distinct"):
98
+ score += 0.25
99
+ breakdown["no_distinct"] = 0.25
100
+ feedback.append("Removed redundant DISTINCT.")
101
+
102
+ stars = list(parsed.find_all(exp.Star))
103
+ if not stars:
104
+ score += 0.25
105
+ breakdown["no_select_star"] = 0.25
106
+ feedback.append("Replaced SELECT * with explicit columns.")
107
+
108
+ casts = list(parsed.find_all(exp.Cast))
109
+ cast_on_date = False
110
+ for c in casts:
111
+ this = c.args.get("this")
112
+ if isinstance(this, exp.Column) and this.name.lower() == "sale_date":
113
+ cast_on_date = True
114
+
115
+ if not cast_on_date:
116
+ score += 0.25
117
+ breakdown["fixed_cast"] = 0.25
118
+ feedback.append("Fixed implicit type cast on sale_date.")
119
+
120
+ if "INDEX" in rewritten_query.upper():
121
+ score += 0.25
122
+ breakdown["has_index_hint"] = 0.25
123
+ feedback.append("Added index hint.")
124
+
125
+ return score, breakdown, " ".join(feedback)
126
+
127
+ def grade_action(task_id: int, rewritten_query: str) -> Tuple[float, Dict[str, float], str]:
128
+ if task_id == 1:
129
+ return grade_task_1(rewritten_query)
130
+ elif task_id == 2:
131
+ return grade_task_2(rewritten_query)
132
+ elif task_id == 3:
133
+ return grade_task_3(rewritten_query)
134
+ return 0.0, {}, "Unknown task."
135
+
136
+ def get_task(task_id: int) -> Dict[str, Any]:
137
+ return TASKS.get(task_id)
models.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict
2
+ from pydantic import BaseModel, Field
3
+
4
+ class Observation(BaseModel):
5
+ task_id: int = Field(description="The ID of the task to perform.")
6
+ query: str = Field(description="The SQL query to review and optimize.")
7
+ schema_context: str = Field(description="The database schema context for the query, such as CREATE TABLE statements.")
8
+ hint: Optional[str] = Field(default=None, description="An optional natural-language hint or description of the problem.")
9
+ step_number: int = Field(description="The current step number in the episode (1-indexed).")
10
+ max_steps: int = Field(description="The maximum allowed steps for this task.")
11
+
12
+ class Action(BaseModel):
13
+ rewritten_query: str = Field(description="The rewritten, optimized SQL query.")
14
+ explanation: str = Field(description="A brief explanation of the changes made and why they improve the query.")
15
+ is_done: bool = Field(description="Set to true if you are finished and want to submit the query for final scoring.")
16
+
17
+ class Reward(BaseModel):
18
+ score: float = Field(description="The overall score for the episode (0.0 to 1.0).")
19
+ breakdown: Dict[str, float] = Field(default_factory=dict, description="A breakdown of the score by sub-criteria.")
20
+ feedback: str = Field(description="Specific feedback on the rewritten query or action taken.")
openenv.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sql-query-optimizer
2
+ version: "1.0.0"
3
+ description: "AI agent reviews and rewrites SQL queries for correctness and performance."
4
+ tags: [openenv, sql, code-review, data-engineering]
5
+ tasks:
6
+ - id: 1
7
+ name: fix-broken-join
8
+ difficulty: easy
9
+ - id: 2
10
+ name: eliminate-n-plus-one
11
+ difficulty: medium
12
+ - id: 3
13
+ name: full-optimization
14
+ difficulty: hard
pyproject.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sql-query-optimizer"
7
+ version = "1.0.0"
8
+ description = "AI agent reviews and rewrites SQL queries for correctness and performance."
9
+ dependencies = [
10
+ "fastapi>=0.111.0",
11
+ "uvicorn>=0.30.1",
12
+ "pydantic>=2.7.4",
13
+ "openai>=1.35.3",
14
+ "sqlglot>=25.5.0",
15
+ "openenv-core>=0.2.0"
16
+ ]
17
+
18
+ [project.scripts]
19
+ server = "server.app:main"
20
+
21
+ [tool.setuptools]
22
+ packages = ["server", "env"]
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi>=0.111.0
2
+ uvicorn>=0.30.1
3
+ pydantic>=2.7.4
4
+ openai>=1.35.3
5
+ sqlglot>=25.5.0
6
+ openenv-core>=0.1.0
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Server package
server/app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import Dict, Any, List
4
+ import asyncio
5
+
6
+ from env.environment import SQLEnv
7
+ from env.models import Observation, Action, Reward
8
+ from env.tasks import TASKS
9
+
10
+ app = FastAPI(title="SQL Query Optimizer OpenEnv")
11
+ env = SQLEnv()
12
+
13
+ class ResetRequest(BaseModel):
14
+ task_id: int
15
+
16
+ @app.post("/reset", response_model=Observation)
17
+ async def reset(req: ResetRequest):
18
+ try:
19
+ return env.reset(req.task_id)
20
+ except ValueError as e:
21
+ raise HTTPException(status_code=400, detail=str(e))
22
+
23
+ @app.post("/step")
24
+ async def step(action: Action):
25
+ try:
26
+ obs, reward, done, info = env.step(action)
27
+ return {
28
+ "observation": obs.model_dump(),
29
+ "reward": reward.model_dump(),
30
+ "done": done,
31
+ "info": info
32
+ }
33
+ except RuntimeError as e:
34
+ raise HTTPException(status_code=400, detail=str(e))
35
+
36
+ @app.get("/state")
37
+ async def state():
38
+ return env.state()
39
+
40
+ @app.get("/tasks")
41
+ async def get_tasks():
42
+ action_schema = Action.model_json_schema()
43
+ task_list = [{"id": k, **v} for k, v in TASKS.items()]
44
+ return {
45
+ "tasks": task_list,
46
+ "action_schema": action_schema
47
+ }
48
+
49
+ @app.get("/grader")
50
+ async def grader():
51
+ if not env.task:
52
+ raise HTTPException(status_code=400, detail="Environment not initialized.")
53
+ return {"grader_score": env.final_grader_score}
54
+
55
+ class BaselineResponse(BaseModel):
56
+ scores: Dict[int, float]
57
+
58
+ @app.post("/baseline", response_model=BaselineResponse)
59
+ async def run_baseline():
60
+ import baseline
61
+ try:
62
+ scores = baseline.run_all_tasks()
63
+ return BaselineResponse(scores=scores)
64
+ except Exception as e:
65
+ raise HTTPException(status_code=500, detail=str(e))
66
+
67
+ def main(host: str = "0.0.0.0", port: int = 8000):
68
+ import uvicorn
69
+ uvicorn.run(app, host=host, port=port)
70
+
71
+ if __name__ == '__main__':
72
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff