Param20h commited on
Commit
2541228
Β·
verified Β·
1 Parent(s): 35c8316

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. README.md +17 -3
  2. baseline.py +8 -144
  3. inference.py +197 -0
  4. pyproject.toml +4 -0
  5. server/app.py +35 -11
  6. uv.lock +0 -0
README.md CHANGED
@@ -116,7 +116,9 @@ Interactive docs: `http://localhost:7860/docs`
116
  ### Prerequisites
117
  - Python 3.10+
118
  - Docker
119
- - `OPENAI_API_KEY` (for baseline only)
 
 
120
 
121
  ### Local (Python)
122
 
@@ -135,8 +137,10 @@ docker run -p 7860:7860 -e OPENAI_API_KEY=sk-... sql-optimizer-env
135
  ### Baseline Inference
136
 
137
  ```bash
138
- export OPENAI_API_KEY=sk-...
139
- python baseline.py
 
 
140
  ```
141
 
142
  ### OpenEnv Validation
@@ -154,6 +158,16 @@ huggingface-cli login
154
  openenv push --repo-id your-username/sql-query-optimizer
155
  ```
156
 
 
 
 
 
 
 
 
 
 
 
157
  ---
158
 
159
  ## Baseline Scores
 
116
  ### Prerequisites
117
  - Python 3.10+
118
  - Docker
119
+ - `API_BASE_URL` (OpenAI-compatible endpoint for inference)
120
+ - `MODEL_NAME` (model identifier for inference)
121
+ - `HF_TOKEN` (API key / bearer token for inference)
122
 
123
  ### Local (Python)
124
 
 
137
  ### Baseline Inference
138
 
139
  ```bash
140
+ $env:API_BASE_URL="https://api.openai.com/v1"
141
+ $env:MODEL_NAME="gpt-4o-mini"
142
+ $env:HF_TOKEN="hf_or_openai_api_key_here"
143
+ python inference.py
144
  ```
145
 
146
  ### OpenEnv Validation
 
158
  openenv push --repo-id your-username/sql-query-optimizer
159
  ```
160
 
161
+ ### Environment Configuration
162
+
163
+ Define these variables before running inference or `/baseline`:
164
+
165
+ ```powershell
166
+ $env:API_BASE_URL = "https://api.openai.com/v1"
167
+ $env:MODEL_NAME = "gpt-4o-mini"
168
+ $env:HF_TOKEN = "your_api_key"
169
+ ```
170
+
171
  ---
172
 
173
  ## Baseline Scores
baseline.py CHANGED
@@ -1,144 +1,8 @@
1
- """
2
- Baseline inference script for the SQL Query Optimizer OpenEnv environment.
3
-
4
- Usage:
5
- python baseline.py # human-readable output
6
- python baseline.py --json # JSON output (used by /baseline endpoint)
7
-
8
- Requires:
9
- OPENAI_API_KEY environment variable
10
-
11
- The script runs gpt-4o-mini against all 3 tasks and reports grader scores.
12
- """
13
- from __future__ import annotations
14
-
15
- import argparse
16
- import json
17
- import os
18
- import sys
19
-
20
- from openai import OpenAI
21
-
22
- # ── import env from local package ──────────────────────────────────────────
23
- sys.path.insert(0, os.path.dirname(__file__))
24
- from env.environment import SQLOptimizerEnv
25
- from env.models import Action
26
-
27
- # ──────────────────────────────────────────────────────────────────────────────
28
- MODEL = "gpt-4o-mini"
29
- MAX_STEPS = 5
30
- TASKS = [1, 2, 3]
31
-
32
- SYSTEM_PROMPT = """You are a database performance engineer.
33
- You will receive a broken or unoptimised SQL query along with table schema context.
34
- Your job is to rewrite the query so it is correct and performant.
35
-
36
- Respond ONLY with a JSON object with these exact keys:
37
- {
38
- "rewritten_query": "<your improved SQL>",
39
- "explanation": "<brief explanation of changes>",
40
- "is_done": true
41
- }
42
- Do not wrap in markdown. Output raw JSON only."""
43
-
44
-
45
- def _build_user_message(obs_dict: dict) -> str:
46
- return (
47
- f"Task: {obs_dict['task_name']} ({obs_dict['task_id']} β€” difficulty: "
48
- f"{obs_dict.get('difficulty', 'unknown')})\n\n"
49
- f"Description:\n{obs_dict['task_description']}\n\n"
50
- f"Schema:\n{obs_dict['schema_context']}\n\n"
51
- f"Query to fix:\n{obs_dict['query']}"
52
- + (f"\n\nHint: {obs_dict['hint']}" if obs_dict.get("hint") else "")
53
- )
54
-
55
-
56
- def run_baseline(verbose: bool = True) -> dict[str, float]:
57
- api_key = os.getenv("OPENAI_API_KEY")
58
- if not api_key:
59
- print("ERROR: OPENAI_API_KEY is not set.", file=sys.stderr)
60
- sys.exit(1)
61
-
62
- client = OpenAI(api_key=api_key)
63
- env = SQLOptimizerEnv()
64
- results: dict[str, float] = {}
65
-
66
- for task_id in TASKS:
67
- obs = env.reset(task_id=task_id)
68
- obs_dict = obs.model_dump()
69
- final_score = 0.0
70
-
71
- if verbose:
72
- print(f"\n{'='*60}")
73
- print(f"Task {task_id}: {obs_dict['task_name']} [{obs_dict['task_id']}]")
74
- print(f"{'='*60}")
75
-
76
- for step_num in range(MAX_STEPS):
77
- messages = [
78
- {"role": "system", "content": SYSTEM_PROMPT},
79
- {"role": "user", "content": _build_user_message(obs_dict)},
80
- ]
81
-
82
- try:
83
- response = client.chat.completions.create(
84
- model=MODEL,
85
- messages=messages,
86
- temperature=0.0,
87
- max_tokens=1024,
88
- )
89
- content = response.choices[0].message.content.strip()
90
- parsed = json.loads(content)
91
- action = Action(
92
- rewritten_query=parsed.get("rewritten_query", ""),
93
- explanation=parsed.get("explanation", ""),
94
- is_done=bool(parsed.get("is_done", False)),
95
- )
96
- except Exception as exc:
97
- if verbose:
98
- print(f" Step {step_num + 1}: LLM error β€” {exc}")
99
- action = Action(
100
- rewritten_query="",
101
- explanation="error",
102
- is_done=True,
103
- )
104
-
105
- obs, reward, done, info = env.step(action)
106
- obs_dict = obs.model_dump()
107
- final_score = info["grader_score"]
108
-
109
- if verbose:
110
- print(
111
- f" Step {step_num + 1}: grader_score={info['grader_score']:.3f} "
112
- f"step_reward={reward.score:.4f} feedback={reward.feedback[:80]}"
113
- )
114
-
115
- if done:
116
- break
117
-
118
- results[f"task_{task_id}_{env._task.name}"] = round(final_score, 4)
119
-
120
- if verbose:
121
- print(f" β†’ Final grader score: {final_score:.4f}")
122
-
123
- if verbose:
124
- print(f"\n{'='*60}")
125
- print("BASELINE RESULTS")
126
- print(f"{'='*60}")
127
- for k, v in results.items():
128
- print(f" {k}: {v:.4f}")
129
- avg = sum(results.values()) / len(results)
130
- print(f" Average: {avg:.4f}")
131
-
132
- return results
133
-
134
-
135
- if __name__ == "__main__":
136
- parser = argparse.ArgumentParser(description="OpenEnv SQL Optimizer β€” Baseline Inference")
137
- parser.add_argument(
138
- "--json", action="store_true", help="Output results as JSON (used by /baseline endpoint)"
139
- )
140
- args = parser.parse_args()
141
-
142
- scores = run_baseline(verbose=not args.json)
143
- if args.json:
144
- print(json.dumps(scores))
 
1
+ """Compatibility wrapper for the required inference.py entrypoint."""
2
+ from __future__ import annotations
3
+
4
+ from inference import run_inference
5
+
6
+
7
+ if __name__ == "__main__":
8
+ run_inference()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenAI-based inference runner for the SQL Query Optimizer OpenEnv environment.
2
+
3
+ Environment variables:
4
+ API_BASE_URL: OpenAI-compatible API endpoint
5
+ MODEL_NAME: model identifier to use for inference
6
+ HF_TOKEN: API key / bearer token for the LLM provider
7
+
8
+ The script emits structured stdout logs in three sections only:
9
+ [START] ...
10
+ [STEP] ...
11
+ [END] ...
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import os
17
+ import sys
18
+ from collections import OrderedDict
19
+ from typing import Any, Dict
20
+
21
+ from openai import OpenAI
22
+
23
+ sys.path.insert(0, os.path.dirname(__file__))
24
+
25
+ from env.environment import SQLOptimizerEnv
26
+ from env.models import Action
27
+
28
+ DEFAULT_MAX_STEPS = 5
29
+ TASK_IDS = (1, 2, 3)
30
+
31
+ SYSTEM_PROMPT = """You are a database performance engineer.
32
+ You will receive a broken or unoptimised SQL query along with table schema context.
33
+ Your job is to rewrite the query so it is correct and performant.
34
+
35
+ Respond ONLY with a JSON object with these exact keys:
36
+ {
37
+ "rewritten_query": "<your improved SQL>",
38
+ "explanation": "<brief explanation of changes>",
39
+ "is_done": true
40
+ }
41
+ Do not wrap in markdown. Output raw JSON only."""
42
+
43
+
44
+ def _load_runtime_config() -> Dict[str, str]:
45
+ api_base_url = os.getenv("API_BASE_URL", "").strip()
46
+ model_name = os.getenv("MODEL_NAME", "").strip()
47
+ hf_token = os.getenv("HF_TOKEN", "").strip()
48
+
49
+ missing = [
50
+ name
51
+ for name, value in (
52
+ ("API_BASE_URL", api_base_url),
53
+ ("MODEL_NAME", model_name),
54
+ ("HF_TOKEN", hf_token),
55
+ )
56
+ if not value
57
+ ]
58
+ if missing:
59
+ raise RuntimeError(f"Missing required environment variables: {', '.join(missing)}")
60
+
61
+ return {
62
+ "API_BASE_URL": api_base_url,
63
+ "MODEL_NAME": model_name,
64
+ "HF_TOKEN": hf_token,
65
+ }
66
+
67
+
68
+ def _build_user_message(obs_dict: dict) -> str:
69
+ message = (
70
+ f"Task: {obs_dict['task_name']} ({obs_dict['task_id']} β€” difficulty: "
71
+ f"{obs_dict.get('difficulty', 'unknown')})\n\n"
72
+ f"Description:\n{obs_dict['task_description']}\n\n"
73
+ f"Schema:\n{obs_dict['schema_context']}\n\n"
74
+ f"Query to fix:\n{obs_dict['query']}"
75
+ )
76
+ if obs_dict.get("hint"):
77
+ message += f"\n\nHint: {obs_dict['hint']}"
78
+ return message
79
+
80
+
81
+ def _log(prefix: str, payload: Dict[str, Any]) -> None:
82
+ print(f"{prefix} {json.dumps(payload, ensure_ascii=True, separators=(',', ':'))}")
83
+
84
+
85
+ def _parse_json_action(text: str) -> Action:
86
+ parsed = json.loads(text)
87
+ return Action(
88
+ rewritten_query=parsed.get("rewritten_query", ""),
89
+ explanation=parsed.get("explanation", ""),
90
+ is_done=bool(parsed.get("is_done", False)),
91
+ )
92
+
93
+
94
+ def run_inference() -> Dict[str, float]:
95
+ config = _load_runtime_config()
96
+ client = OpenAI(api_key=config["HF_TOKEN"], base_url=config["API_BASE_URL"])
97
+ env = SQLOptimizerEnv()
98
+
99
+ _log(
100
+ "[START]",
101
+ OrderedDict(
102
+ [
103
+ ("script", "inference.py"),
104
+ ("api_base_url", config["API_BASE_URL"]),
105
+ ("model_name", config["MODEL_NAME"]),
106
+ ("tasks", list(TASK_IDS)),
107
+ ]
108
+ ),
109
+ )
110
+
111
+ results: Dict[str, float] = {}
112
+ total_score = 0.0
113
+
114
+ for task_id in TASK_IDS:
115
+ observation = env.reset(task_id=task_id)
116
+ obs_dict = observation.model_dump()
117
+ final_grader_score = 0.0
118
+ step_count = 0
119
+
120
+ for step_number in range(DEFAULT_MAX_STEPS):
121
+ messages = [
122
+ {"role": "system", "content": SYSTEM_PROMPT},
123
+ {"role": "user", "content": _build_user_message(obs_dict)},
124
+ ]
125
+
126
+ try:
127
+ response = client.chat.completions.create(
128
+ model=config["MODEL_NAME"],
129
+ messages=messages,
130
+ temperature=0.0,
131
+ max_tokens=1024,
132
+ )
133
+ content = (response.choices[0].message.content or "").strip()
134
+ action = _parse_json_action(content)
135
+ llm_status = "ok"
136
+ except Exception as exc:
137
+ action = Action(rewritten_query="", explanation=f"error: {exc}", is_done=True)
138
+ llm_status = "error"
139
+
140
+ observation, reward, done, info = env.step(action)
141
+ obs_dict = observation.model_dump()
142
+ final_grader_score = float(info.get("grader_score", 0.0))
143
+ step_count = step_number + 1
144
+
145
+ _log(
146
+ "[STEP]",
147
+ OrderedDict(
148
+ [
149
+ ("task_id", task_id),
150
+ ("task_name", obs_dict["task_name"]),
151
+ ("step", step_count),
152
+ ("grader_score", round(final_grader_score, 4)),
153
+ ("reward_score", round(float(reward.score), 4)),
154
+ ("done", bool(done)),
155
+ ("llm_status", llm_status),
156
+ ]
157
+ ),
158
+ )
159
+
160
+ if done:
161
+ break
162
+
163
+ task_key = f"task_{task_id}_{env._task.name}"
164
+ results[task_key] = round(final_grader_score, 4)
165
+ total_score += final_grader_score
166
+
167
+ average_score = round(total_score / len(TASK_IDS), 4)
168
+
169
+ _log(
170
+ "[END]",
171
+ OrderedDict(
172
+ [
173
+ ("task_results", results),
174
+ ("average_score", average_score),
175
+ ("status", "success"),
176
+ ]
177
+ ),
178
+ )
179
+ return results
180
+
181
+
182
+ if __name__ == "__main__":
183
+ try:
184
+ run_inference()
185
+ except Exception as exc:
186
+ _log(
187
+ "[END]",
188
+ OrderedDict(
189
+ [
190
+ ("task_results", {}),
191
+ ("average_score", 0.0),
192
+ ("status", "error"),
193
+ ("error", str(exc)),
194
+ ]
195
+ ),
196
+ )
197
+ sys.exit(1)
pyproject.toml CHANGED
@@ -30,9 +30,13 @@ dependencies = [
30
  "uvicorn[standard]>=0.29.0",
31
  "pydantic>=2.7.0",
32
  "openai>=1.30.0",
 
33
  "pyyaml>=6.0",
34
  ]
35
 
 
 
 
36
  [project.optional-dependencies]
37
  dev = [
38
  "pytest>=7.0",
 
30
  "uvicorn[standard]>=0.29.0",
31
  "pydantic>=2.7.0",
32
  "openai>=1.30.0",
33
+ "openenv-core>=0.2.0",
34
  "pyyaml>=6.0",
35
  ]
36
 
37
+ [project.scripts]
38
+ server = "server.app:main"
39
+
40
  [project.optional-dependencies]
41
  dev = [
42
  "pytest>=7.0",
server/app.py CHANGED
@@ -19,6 +19,7 @@ from typing import Any, Dict, Optional
19
  from fastapi import FastAPI, HTTPException
20
  from fastapi.middleware.cors import CORSMiddleware
21
  from pydantic import BaseModel
 
22
 
23
  from env.environment import SQLOptimizerEnv
24
  from env.models import Action, Observation, Reward
@@ -79,6 +80,17 @@ class BaselineResponse(BaseModel):
79
  message: str
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
82
  # ──────────────────────────────────────────────────────────────────────────────
83
  # Endpoints
84
  # ──────────────────────────────────────────────────────────────────────────────
@@ -157,30 +169,42 @@ def grader() -> GraderResponse:
157
  @app.post("/baseline", response_model=BaselineResponse, summary="Run baseline inference on all tasks")
158
  def baseline() -> BaselineResponse:
159
  """
160
- Trigger the baseline inference script (baseline.py) and return scores.
161
- Requires OPENAI_API_KEY to be set in the environment.
162
  """
163
- if not os.getenv("OPENAI_API_KEY"):
 
 
164
  raise HTTPException(
165
  status_code=400,
166
- detail="OPENAI_API_KEY environment variable not set. Cannot run baseline.",
167
  )
168
  try:
169
  result = subprocess.run(
170
- [sys.executable, "baseline.py", "--json"],
171
  capture_output=True,
172
  text=True,
173
- timeout=300,
174
  )
175
  if result.returncode != 0:
176
  raise HTTPException(
177
  status_code=500,
178
- detail=f"Baseline script failed:\n{result.stderr}",
179
  )
180
- import json
181
- scores = json.loads(result.stdout)
182
- return BaselineResponse(task_results=scores, message="Baseline completed successfully.")
 
 
183
  except subprocess.TimeoutExpired:
184
- raise HTTPException(status_code=500, detail="Baseline script timed out after 300s.")
185
  except Exception as exc:
186
  raise HTTPException(status_code=500, detail=str(exc))
 
 
 
 
 
 
 
 
 
19
  from fastapi import FastAPI, HTTPException
20
  from fastapi.middleware.cors import CORSMiddleware
21
  from pydantic import BaseModel
22
+ import uvicorn
23
 
24
  from env.environment import SQLOptimizerEnv
25
  from env.models import Action, Observation, Reward
 
80
  message: str
81
 
82
 
83
+ def _parse_end_payload(stdout: str) -> Dict[str, Any]:
84
+ for line in reversed(stdout.splitlines()):
85
+ if not line.startswith("[END] "):
86
+ continue
87
+ payload_text = line[len("[END] ") :].strip()
88
+ import json
89
+
90
+ return json.loads(payload_text)
91
+ raise ValueError("Could not find [END] payload in inference output")
92
+
93
+
94
  # ──────────────────────────────────────────────────────────────────────────────
95
  # Endpoints
96
  # ──────────────────────────────────────────────────────────────────────────────
 
169
  @app.post("/baseline", response_model=BaselineResponse, summary="Run baseline inference on all tasks")
170
  def baseline() -> BaselineResponse:
171
  """
172
+ Trigger the baseline inference script (inference.py) and return scores.
173
+ Requires API_BASE_URL, MODEL_NAME, and HF_TOKEN to be set in the environment.
174
  """
175
+ required_vars = ["API_BASE_URL", "MODEL_NAME", "HF_TOKEN"]
176
+ missing = [name for name in required_vars if not os.getenv(name)]
177
+ if missing:
178
  raise HTTPException(
179
  status_code=400,
180
+ detail=f"Missing required environment variables: {', '.join(missing)}",
181
  )
182
  try:
183
  result = subprocess.run(
184
+ [sys.executable, "inference.py"],
185
  capture_output=True,
186
  text=True,
187
+ timeout=1200,
188
  )
189
  if result.returncode != 0:
190
  raise HTTPException(
191
  status_code=500,
192
+ detail=f"Inference script failed:\n{result.stderr}",
193
  )
194
+ payload = _parse_end_payload(result.stdout)
195
+ return BaselineResponse(
196
+ task_results=payload.get("task_results", {}),
197
+ message="Baseline completed successfully.",
198
+ )
199
  except subprocess.TimeoutExpired:
200
+ raise HTTPException(status_code=500, detail="Inference script timed out after 1200s.")
201
  except Exception as exc:
202
  raise HTTPException(status_code=500, detail=str(exc))
203
+
204
+
205
+ def main() -> None:
206
+ uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
207
+
208
+
209
+ if __name__ == "__main__":
210
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff