Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- README.md +17 -3
- baseline.py +8 -144
- inference.py +197 -0
- pyproject.toml +4 -0
- server/app.py +35 -11
- uv.lock +0 -0
README.md
CHANGED
|
@@ -116,7 +116,9 @@ Interactive docs: `http://localhost:7860/docs`
|
|
| 116 |
### Prerequisites
|
| 117 |
- Python 3.10+
|
| 118 |
- Docker
|
| 119 |
-
- `
|
|
|
|
|
|
|
| 120 |
|
| 121 |
### Local (Python)
|
| 122 |
|
|
@@ -135,8 +137,10 @@ docker run -p 7860:7860 -e OPENAI_API_KEY=sk-... sql-optimizer-env
|
|
| 135 |
### Baseline Inference
|
| 136 |
|
| 137 |
```bash
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
| 140 |
```
|
| 141 |
|
| 142 |
### OpenEnv Validation
|
|
@@ -154,6 +158,16 @@ huggingface-cli login
|
|
| 154 |
openenv push --repo-id your-username/sql-query-optimizer
|
| 155 |
```
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
---
|
| 158 |
|
| 159 |
## Baseline Scores
|
|
|
|
| 116 |
### Prerequisites
|
| 117 |
- Python 3.10+
|
| 118 |
- Docker
|
| 119 |
+
- `API_BASE_URL` (OpenAI-compatible endpoint for inference)
|
| 120 |
+
- `MODEL_NAME` (model identifier for inference)
|
| 121 |
+
- `HF_TOKEN` (API key / bearer token for inference)
|
| 122 |
|
| 123 |
### Local (Python)
|
| 124 |
|
|
|
|
| 137 |
### Baseline Inference
|
| 138 |
|
| 139 |
```bash
|
| 140 |
+
$env:API_BASE_URL="https://api.openai.com/v1"
|
| 141 |
+
$env:MODEL_NAME="gpt-4o-mini"
|
| 142 |
+
$env:HF_TOKEN="hf_or_openai_api_key_here"
|
| 143 |
+
python inference.py
|
| 144 |
```
|
| 145 |
|
| 146 |
### OpenEnv Validation
|
|
|
|
| 158 |
openenv push --repo-id your-username/sql-query-optimizer
|
| 159 |
```
|
| 160 |
|
| 161 |
+
### Environment Configuration
|
| 162 |
+
|
| 163 |
+
Define these variables before running inference or `/baseline`:
|
| 164 |
+
|
| 165 |
+
```powershell
|
| 166 |
+
$env:API_BASE_URL = "https://api.openai.com/v1"
|
| 167 |
+
$env:MODEL_NAME = "gpt-4o-mini"
|
| 168 |
+
$env:HF_TOKEN = "your_api_key"
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
---
|
| 172 |
|
| 173 |
## Baseline Scores
|
baseline.py
CHANGED
|
@@ -1,144 +1,8 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
OPENAI_API_KEY environment variable
|
| 10 |
-
|
| 11 |
-
The script runs gpt-4o-mini against all 3 tasks and reports grader scores.
|
| 12 |
-
"""
|
| 13 |
-
from __future__ import annotations
|
| 14 |
-
|
| 15 |
-
import argparse
|
| 16 |
-
import json
|
| 17 |
-
import os
|
| 18 |
-
import sys
|
| 19 |
-
|
| 20 |
-
from openai import OpenAI
|
| 21 |
-
|
| 22 |
-
# ββ import env from local package ββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
-
sys.path.insert(0, os.path.dirname(__file__))
|
| 24 |
-
from env.environment import SQLOptimizerEnv
|
| 25 |
-
from env.models import Action
|
| 26 |
-
|
| 27 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
-
MODEL = "gpt-4o-mini"
|
| 29 |
-
MAX_STEPS = 5
|
| 30 |
-
TASKS = [1, 2, 3]
|
| 31 |
-
|
| 32 |
-
SYSTEM_PROMPT = """You are a database performance engineer.
|
| 33 |
-
You will receive a broken or unoptimised SQL query along with table schema context.
|
| 34 |
-
Your job is to rewrite the query so it is correct and performant.
|
| 35 |
-
|
| 36 |
-
Respond ONLY with a JSON object with these exact keys:
|
| 37 |
-
{
|
| 38 |
-
"rewritten_query": "<your improved SQL>",
|
| 39 |
-
"explanation": "<brief explanation of changes>",
|
| 40 |
-
"is_done": true
|
| 41 |
-
}
|
| 42 |
-
Do not wrap in markdown. Output raw JSON only."""
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _build_user_message(obs_dict: dict) -> str:
|
| 46 |
-
return (
|
| 47 |
-
f"Task: {obs_dict['task_name']} ({obs_dict['task_id']} β difficulty: "
|
| 48 |
-
f"{obs_dict.get('difficulty', 'unknown')})\n\n"
|
| 49 |
-
f"Description:\n{obs_dict['task_description']}\n\n"
|
| 50 |
-
f"Schema:\n{obs_dict['schema_context']}\n\n"
|
| 51 |
-
f"Query to fix:\n{obs_dict['query']}"
|
| 52 |
-
+ (f"\n\nHint: {obs_dict['hint']}" if obs_dict.get("hint") else "")
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def run_baseline(verbose: bool = True) -> dict[str, float]:
|
| 57 |
-
api_key = os.getenv("OPENAI_API_KEY")
|
| 58 |
-
if not api_key:
|
| 59 |
-
print("ERROR: OPENAI_API_KEY is not set.", file=sys.stderr)
|
| 60 |
-
sys.exit(1)
|
| 61 |
-
|
| 62 |
-
client = OpenAI(api_key=api_key)
|
| 63 |
-
env = SQLOptimizerEnv()
|
| 64 |
-
results: dict[str, float] = {}
|
| 65 |
-
|
| 66 |
-
for task_id in TASKS:
|
| 67 |
-
obs = env.reset(task_id=task_id)
|
| 68 |
-
obs_dict = obs.model_dump()
|
| 69 |
-
final_score = 0.0
|
| 70 |
-
|
| 71 |
-
if verbose:
|
| 72 |
-
print(f"\n{'='*60}")
|
| 73 |
-
print(f"Task {task_id}: {obs_dict['task_name']} [{obs_dict['task_id']}]")
|
| 74 |
-
print(f"{'='*60}")
|
| 75 |
-
|
| 76 |
-
for step_num in range(MAX_STEPS):
|
| 77 |
-
messages = [
|
| 78 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 79 |
-
{"role": "user", "content": _build_user_message(obs_dict)},
|
| 80 |
-
]
|
| 81 |
-
|
| 82 |
-
try:
|
| 83 |
-
response = client.chat.completions.create(
|
| 84 |
-
model=MODEL,
|
| 85 |
-
messages=messages,
|
| 86 |
-
temperature=0.0,
|
| 87 |
-
max_tokens=1024,
|
| 88 |
-
)
|
| 89 |
-
content = response.choices[0].message.content.strip()
|
| 90 |
-
parsed = json.loads(content)
|
| 91 |
-
action = Action(
|
| 92 |
-
rewritten_query=parsed.get("rewritten_query", ""),
|
| 93 |
-
explanation=parsed.get("explanation", ""),
|
| 94 |
-
is_done=bool(parsed.get("is_done", False)),
|
| 95 |
-
)
|
| 96 |
-
except Exception as exc:
|
| 97 |
-
if verbose:
|
| 98 |
-
print(f" Step {step_num + 1}: LLM error β {exc}")
|
| 99 |
-
action = Action(
|
| 100 |
-
rewritten_query="",
|
| 101 |
-
explanation="error",
|
| 102 |
-
is_done=True,
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
obs, reward, done, info = env.step(action)
|
| 106 |
-
obs_dict = obs.model_dump()
|
| 107 |
-
final_score = info["grader_score"]
|
| 108 |
-
|
| 109 |
-
if verbose:
|
| 110 |
-
print(
|
| 111 |
-
f" Step {step_num + 1}: grader_score={info['grader_score']:.3f} "
|
| 112 |
-
f"step_reward={reward.score:.4f} feedback={reward.feedback[:80]}"
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
if done:
|
| 116 |
-
break
|
| 117 |
-
|
| 118 |
-
results[f"task_{task_id}_{env._task.name}"] = round(final_score, 4)
|
| 119 |
-
|
| 120 |
-
if verbose:
|
| 121 |
-
print(f" β Final grader score: {final_score:.4f}")
|
| 122 |
-
|
| 123 |
-
if verbose:
|
| 124 |
-
print(f"\n{'='*60}")
|
| 125 |
-
print("BASELINE RESULTS")
|
| 126 |
-
print(f"{'='*60}")
|
| 127 |
-
for k, v in results.items():
|
| 128 |
-
print(f" {k}: {v:.4f}")
|
| 129 |
-
avg = sum(results.values()) / len(results)
|
| 130 |
-
print(f" Average: {avg:.4f}")
|
| 131 |
-
|
| 132 |
-
return results
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
if __name__ == "__main__":
|
| 136 |
-
parser = argparse.ArgumentParser(description="OpenEnv SQL Optimizer β Baseline Inference")
|
| 137 |
-
parser.add_argument(
|
| 138 |
-
"--json", action="store_true", help="Output results as JSON (used by /baseline endpoint)"
|
| 139 |
-
)
|
| 140 |
-
args = parser.parse_args()
|
| 141 |
-
|
| 142 |
-
scores = run_baseline(verbose=not args.json)
|
| 143 |
-
if args.json:
|
| 144 |
-
print(json.dumps(scores))
|
|
|
|
| 1 |
+
"""Compatibility wrapper for the required inference.py entrypoint."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from inference import run_inference
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
if __name__ == "__main__":
|
| 8 |
+
run_inference()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenAI-based inference runner for the SQL Query Optimizer OpenEnv environment.
|
| 2 |
+
|
| 3 |
+
Environment variables:
|
| 4 |
+
API_BASE_URL: OpenAI-compatible API endpoint
|
| 5 |
+
MODEL_NAME: model identifier to use for inference
|
| 6 |
+
HF_TOKEN: API key / bearer token for the LLM provider
|
| 7 |
+
|
| 8 |
+
The script emits structured stdout logs in three sections only:
|
| 9 |
+
[START] ...
|
| 10 |
+
[STEP] ...
|
| 11 |
+
[END] ...
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
from typing import Any, Dict
|
| 20 |
+
|
| 21 |
+
from openai import OpenAI
|
| 22 |
+
|
| 23 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 24 |
+
|
| 25 |
+
from env.environment import SQLOptimizerEnv
|
| 26 |
+
from env.models import Action
|
| 27 |
+
|
| 28 |
+
DEFAULT_MAX_STEPS = 5
|
| 29 |
+
TASK_IDS = (1, 2, 3)
|
| 30 |
+
|
| 31 |
+
SYSTEM_PROMPT = """You are a database performance engineer.
|
| 32 |
+
You will receive a broken or unoptimised SQL query along with table schema context.
|
| 33 |
+
Your job is to rewrite the query so it is correct and performant.
|
| 34 |
+
|
| 35 |
+
Respond ONLY with a JSON object with these exact keys:
|
| 36 |
+
{
|
| 37 |
+
"rewritten_query": "<your improved SQL>",
|
| 38 |
+
"explanation": "<brief explanation of changes>",
|
| 39 |
+
"is_done": true
|
| 40 |
+
}
|
| 41 |
+
Do not wrap in markdown. Output raw JSON only."""
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _load_runtime_config() -> Dict[str, str]:
|
| 45 |
+
api_base_url = os.getenv("API_BASE_URL", "").strip()
|
| 46 |
+
model_name = os.getenv("MODEL_NAME", "").strip()
|
| 47 |
+
hf_token = os.getenv("HF_TOKEN", "").strip()
|
| 48 |
+
|
| 49 |
+
missing = [
|
| 50 |
+
name
|
| 51 |
+
for name, value in (
|
| 52 |
+
("API_BASE_URL", api_base_url),
|
| 53 |
+
("MODEL_NAME", model_name),
|
| 54 |
+
("HF_TOKEN", hf_token),
|
| 55 |
+
)
|
| 56 |
+
if not value
|
| 57 |
+
]
|
| 58 |
+
if missing:
|
| 59 |
+
raise RuntimeError(f"Missing required environment variables: {', '.join(missing)}")
|
| 60 |
+
|
| 61 |
+
return {
|
| 62 |
+
"API_BASE_URL": api_base_url,
|
| 63 |
+
"MODEL_NAME": model_name,
|
| 64 |
+
"HF_TOKEN": hf_token,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _build_user_message(obs_dict: dict) -> str:
|
| 69 |
+
message = (
|
| 70 |
+
f"Task: {obs_dict['task_name']} ({obs_dict['task_id']} β difficulty: "
|
| 71 |
+
f"{obs_dict.get('difficulty', 'unknown')})\n\n"
|
| 72 |
+
f"Description:\n{obs_dict['task_description']}\n\n"
|
| 73 |
+
f"Schema:\n{obs_dict['schema_context']}\n\n"
|
| 74 |
+
f"Query to fix:\n{obs_dict['query']}"
|
| 75 |
+
)
|
| 76 |
+
if obs_dict.get("hint"):
|
| 77 |
+
message += f"\n\nHint: {obs_dict['hint']}"
|
| 78 |
+
return message
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _log(prefix: str, payload: Dict[str, Any]) -> None:
|
| 82 |
+
print(f"{prefix} {json.dumps(payload, ensure_ascii=True, separators=(',', ':'))}")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _parse_json_action(text: str) -> Action:
|
| 86 |
+
parsed = json.loads(text)
|
| 87 |
+
return Action(
|
| 88 |
+
rewritten_query=parsed.get("rewritten_query", ""),
|
| 89 |
+
explanation=parsed.get("explanation", ""),
|
| 90 |
+
is_done=bool(parsed.get("is_done", False)),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def run_inference() -> Dict[str, float]:
|
| 95 |
+
config = _load_runtime_config()
|
| 96 |
+
client = OpenAI(api_key=config["HF_TOKEN"], base_url=config["API_BASE_URL"])
|
| 97 |
+
env = SQLOptimizerEnv()
|
| 98 |
+
|
| 99 |
+
_log(
|
| 100 |
+
"[START]",
|
| 101 |
+
OrderedDict(
|
| 102 |
+
[
|
| 103 |
+
("script", "inference.py"),
|
| 104 |
+
("api_base_url", config["API_BASE_URL"]),
|
| 105 |
+
("model_name", config["MODEL_NAME"]),
|
| 106 |
+
("tasks", list(TASK_IDS)),
|
| 107 |
+
]
|
| 108 |
+
),
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
results: Dict[str, float] = {}
|
| 112 |
+
total_score = 0.0
|
| 113 |
+
|
| 114 |
+
for task_id in TASK_IDS:
|
| 115 |
+
observation = env.reset(task_id=task_id)
|
| 116 |
+
obs_dict = observation.model_dump()
|
| 117 |
+
final_grader_score = 0.0
|
| 118 |
+
step_count = 0
|
| 119 |
+
|
| 120 |
+
for step_number in range(DEFAULT_MAX_STEPS):
|
| 121 |
+
messages = [
|
| 122 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 123 |
+
{"role": "user", "content": _build_user_message(obs_dict)},
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
response = client.chat.completions.create(
|
| 128 |
+
model=config["MODEL_NAME"],
|
| 129 |
+
messages=messages,
|
| 130 |
+
temperature=0.0,
|
| 131 |
+
max_tokens=1024,
|
| 132 |
+
)
|
| 133 |
+
content = (response.choices[0].message.content or "").strip()
|
| 134 |
+
action = _parse_json_action(content)
|
| 135 |
+
llm_status = "ok"
|
| 136 |
+
except Exception as exc:
|
| 137 |
+
action = Action(rewritten_query="", explanation=f"error: {exc}", is_done=True)
|
| 138 |
+
llm_status = "error"
|
| 139 |
+
|
| 140 |
+
observation, reward, done, info = env.step(action)
|
| 141 |
+
obs_dict = observation.model_dump()
|
| 142 |
+
final_grader_score = float(info.get("grader_score", 0.0))
|
| 143 |
+
step_count = step_number + 1
|
| 144 |
+
|
| 145 |
+
_log(
|
| 146 |
+
"[STEP]",
|
| 147 |
+
OrderedDict(
|
| 148 |
+
[
|
| 149 |
+
("task_id", task_id),
|
| 150 |
+
("task_name", obs_dict["task_name"]),
|
| 151 |
+
("step", step_count),
|
| 152 |
+
("grader_score", round(final_grader_score, 4)),
|
| 153 |
+
("reward_score", round(float(reward.score), 4)),
|
| 154 |
+
("done", bool(done)),
|
| 155 |
+
("llm_status", llm_status),
|
| 156 |
+
]
|
| 157 |
+
),
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if done:
|
| 161 |
+
break
|
| 162 |
+
|
| 163 |
+
task_key = f"task_{task_id}_{env._task.name}"
|
| 164 |
+
results[task_key] = round(final_grader_score, 4)
|
| 165 |
+
total_score += final_grader_score
|
| 166 |
+
|
| 167 |
+
average_score = round(total_score / len(TASK_IDS), 4)
|
| 168 |
+
|
| 169 |
+
_log(
|
| 170 |
+
"[END]",
|
| 171 |
+
OrderedDict(
|
| 172 |
+
[
|
| 173 |
+
("task_results", results),
|
| 174 |
+
("average_score", average_score),
|
| 175 |
+
("status", "success"),
|
| 176 |
+
]
|
| 177 |
+
),
|
| 178 |
+
)
|
| 179 |
+
return results
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
try:
|
| 184 |
+
run_inference()
|
| 185 |
+
except Exception as exc:
|
| 186 |
+
_log(
|
| 187 |
+
"[END]",
|
| 188 |
+
OrderedDict(
|
| 189 |
+
[
|
| 190 |
+
("task_results", {}),
|
| 191 |
+
("average_score", 0.0),
|
| 192 |
+
("status", "error"),
|
| 193 |
+
("error", str(exc)),
|
| 194 |
+
]
|
| 195 |
+
),
|
| 196 |
+
)
|
| 197 |
+
sys.exit(1)
|
pyproject.toml
CHANGED
|
@@ -30,9 +30,13 @@ dependencies = [
|
|
| 30 |
"uvicorn[standard]>=0.29.0",
|
| 31 |
"pydantic>=2.7.0",
|
| 32 |
"openai>=1.30.0",
|
|
|
|
| 33 |
"pyyaml>=6.0",
|
| 34 |
]
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
[project.optional-dependencies]
|
| 37 |
dev = [
|
| 38 |
"pytest>=7.0",
|
|
|
|
| 30 |
"uvicorn[standard]>=0.29.0",
|
| 31 |
"pydantic>=2.7.0",
|
| 32 |
"openai>=1.30.0",
|
| 33 |
+
"openenv-core>=0.2.0",
|
| 34 |
"pyyaml>=6.0",
|
| 35 |
]
|
| 36 |
|
| 37 |
+
[project.scripts]
|
| 38 |
+
server = "server.app:main"
|
| 39 |
+
|
| 40 |
[project.optional-dependencies]
|
| 41 |
dev = [
|
| 42 |
"pytest>=7.0",
|
server/app.py
CHANGED
|
@@ -19,6 +19,7 @@ from typing import Any, Dict, Optional
|
|
| 19 |
from fastapi import FastAPI, HTTPException
|
| 20 |
from fastapi.middleware.cors import CORSMiddleware
|
| 21 |
from pydantic import BaseModel
|
|
|
|
| 22 |
|
| 23 |
from env.environment import SQLOptimizerEnv
|
| 24 |
from env.models import Action, Observation, Reward
|
|
@@ -79,6 +80,17 @@ class BaselineResponse(BaseModel):
|
|
| 79 |
message: str
|
| 80 |
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 83 |
# Endpoints
|
| 84 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -157,30 +169,42 @@ def grader() -> GraderResponse:
|
|
| 157 |
@app.post("/baseline", response_model=BaselineResponse, summary="Run baseline inference on all tasks")
|
| 158 |
def baseline() -> BaselineResponse:
|
| 159 |
"""
|
| 160 |
-
Trigger the baseline inference script (
|
| 161 |
-
Requires
|
| 162 |
"""
|
| 163 |
-
|
|
|
|
|
|
|
| 164 |
raise HTTPException(
|
| 165 |
status_code=400,
|
| 166 |
-
detail="
|
| 167 |
)
|
| 168 |
try:
|
| 169 |
result = subprocess.run(
|
| 170 |
-
[sys.executable, "
|
| 171 |
capture_output=True,
|
| 172 |
text=True,
|
| 173 |
-
timeout=
|
| 174 |
)
|
| 175 |
if result.returncode != 0:
|
| 176 |
raise HTTPException(
|
| 177 |
status_code=500,
|
| 178 |
-
detail=f"
|
| 179 |
)
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
| 183 |
except subprocess.TimeoutExpired:
|
| 184 |
-
raise HTTPException(status_code=500, detail="
|
| 185 |
except Exception as exc:
|
| 186 |
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
from fastapi import FastAPI, HTTPException
|
| 20 |
from fastapi.middleware.cors import CORSMiddleware
|
| 21 |
from pydantic import BaseModel
|
| 22 |
+
import uvicorn
|
| 23 |
|
| 24 |
from env.environment import SQLOptimizerEnv
|
| 25 |
from env.models import Action, Observation, Reward
|
|
|
|
| 80 |
message: str
|
| 81 |
|
| 82 |
|
| 83 |
+
def _parse_end_payload(stdout: str) -> Dict[str, Any]:
|
| 84 |
+
for line in reversed(stdout.splitlines()):
|
| 85 |
+
if not line.startswith("[END] "):
|
| 86 |
+
continue
|
| 87 |
+
payload_text = line[len("[END] ") :].strip()
|
| 88 |
+
import json
|
| 89 |
+
|
| 90 |
+
return json.loads(payload_text)
|
| 91 |
+
raise ValueError("Could not find [END] payload in inference output")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 95 |
# Endpoints
|
| 96 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 169 |
@app.post("/baseline", response_model=BaselineResponse, summary="Run baseline inference on all tasks")
|
| 170 |
def baseline() -> BaselineResponse:
|
| 171 |
"""
|
| 172 |
+
Trigger the baseline inference script (inference.py) and return scores.
|
| 173 |
+
Requires API_BASE_URL, MODEL_NAME, and HF_TOKEN to be set in the environment.
|
| 174 |
"""
|
| 175 |
+
required_vars = ["API_BASE_URL", "MODEL_NAME", "HF_TOKEN"]
|
| 176 |
+
missing = [name for name in required_vars if not os.getenv(name)]
|
| 177 |
+
if missing:
|
| 178 |
raise HTTPException(
|
| 179 |
status_code=400,
|
| 180 |
+
detail=f"Missing required environment variables: {', '.join(missing)}",
|
| 181 |
)
|
| 182 |
try:
|
| 183 |
result = subprocess.run(
|
| 184 |
+
[sys.executable, "inference.py"],
|
| 185 |
capture_output=True,
|
| 186 |
text=True,
|
| 187 |
+
timeout=1200,
|
| 188 |
)
|
| 189 |
if result.returncode != 0:
|
| 190 |
raise HTTPException(
|
| 191 |
status_code=500,
|
| 192 |
+
detail=f"Inference script failed:\n{result.stderr}",
|
| 193 |
)
|
| 194 |
+
payload = _parse_end_payload(result.stdout)
|
| 195 |
+
return BaselineResponse(
|
| 196 |
+
task_results=payload.get("task_results", {}),
|
| 197 |
+
message="Baseline completed successfully.",
|
| 198 |
+
)
|
| 199 |
except subprocess.TimeoutExpired:
|
| 200 |
+
raise HTTPException(status_code=500, detail="Inference script timed out after 1200s.")
|
| 201 |
except Exception as exc:
|
| 202 |
raise HTTPException(status_code=500, detail=str(exc))
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def main() -> None:
|
| 206 |
+
uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
if __name__ == "__main__":
|
| 210 |
+
main()
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|