Rithwik Ravi commited on
Commit ·
de07414
1
Parent(s): c861e8f
Upload full OpenEnv project structure
Browse files- Dockerfile +25 -0
- __init__.py +1 -0
- __pycache__/env.cpython-314.pyc +0 -0
- __pycache__/inference.cpython-314.pyc +0 -0
- __pycache__/models.cpython-314.pyc +0 -0
- __pycache__/tasks.cpython-314.pyc +0 -0
- client.py +40 -0
- env.py +1 -0
- inference.py +144 -0
- models.py +8 -0
- openenv.yaml +20 -0
- pyproject.toml +24 -0
- requirements.txt +5 -0
- server/__init__.py +2 -0
- server/__pycache__/__init__.cpython-314.pyc +0 -0
- server/__pycache__/app.cpython-314.pyc +0 -0
- server/__pycache__/models.cpython-314.pyc +0 -0
- server/__pycache__/tasks.cpython-314.pyc +0 -0
- server/app.py +178 -0
- server/models.py +29 -0
- server/tasks.py +196 -0
- tmp_env.db +0 -0
- uv.lock +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
# Creates a non-root user with an explicit UID required by Hugging Face Spaces
|
| 4 |
+
RUN useradd -m -u 1000 user
|
| 5 |
+
USER user
|
| 6 |
+
|
| 7 |
+
# Set environment variables
|
| 8 |
+
ENV HOME=/home/user \
|
| 9 |
+
PATH=/home/user/.local/bin:$PATH
|
| 10 |
+
|
| 11 |
+
# Create the working directory
|
| 12 |
+
WORKDIR $HOME/app
|
| 13 |
+
|
| 14 |
+
# Install dependencies first for Docker caching
|
| 15 |
+
COPY --chown=user:user requirements.txt .
|
| 16 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Copy the rest of the files
|
| 19 |
+
COPY --chown=user:user . .
|
| 20 |
+
|
| 21 |
+
# Expose the standard Hugging Face Space port
|
| 22 |
+
EXPOSE 7860
|
| 23 |
+
|
| 24 |
+
# Start the environment API server
|
| 25 |
+
CMD ["uvicorn", "env:app", "--host", "0.0.0.0", "--port", "7860"]
|
__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from server import env_instance
|
__pycache__/env.cpython-314.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
__pycache__/inference.cpython-314.pyc
ADDED
|
Binary file (7.41 kB). View file
|
|
|
__pycache__/models.cpython-314.pyc
ADDED
|
Binary file (2.72 kB). View file
|
|
|
__pycache__/tasks.cpython-314.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
client.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
from typing import Dict, Any, Optional
|
| 3 |
+
|
| 4 |
+
from server.models import Action, StepResult, ResetResult
|
| 5 |
+
|
| 6 |
+
class EnvironmentalClient:
|
| 7 |
+
"""
|
| 8 |
+
Official OpenEnv client wrapper to interface with the containerized environment endpoints.
|
| 9 |
+
Can be used by baseline inference scripts or remote agent evaluations.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, base_url: str = "http://127.0.0.1:7860"):
|
| 12 |
+
self.base_url = base_url.rstrip('/')
|
| 13 |
+
|
| 14 |
+
def reset(self, task_id: int = 1) -> ResetResult:
|
| 15 |
+
"""
|
| 16 |
+
Calls the /reset endpoint to begin a new episode for the specified task.
|
| 17 |
+
"""
|
| 18 |
+
response = requests.post(f"{self.base_url}/reset", json={"task_id": task_id})
|
| 19 |
+
response.raise_for_status()
|
| 20 |
+
return ResetResult(**response.json())
|
| 21 |
+
|
| 22 |
+
def step(self, action: Action) -> StepResult:
|
| 23 |
+
"""
|
| 24 |
+
Submits a step (action) to the environment and returns the updated state/reward.
|
| 25 |
+
"""
|
| 26 |
+
# Pydantic dict serialization for requests
|
| 27 |
+
response = requests.post(f"{self.base_url}/step", json=action.dict())
|
| 28 |
+
response.raise_for_status()
|
| 29 |
+
return StepResult(**response.json())
|
| 30 |
+
|
| 31 |
+
def state(self) -> Dict[str, Any]:
|
| 32 |
+
"""
|
| 33 |
+
Extracts the unstructured metadata state of the running environment.
|
| 34 |
+
"""
|
| 35 |
+
response = requests.get(f"{self.base_url}/state")
|
| 36 |
+
response.raise_for_status()
|
| 37 |
+
return response.json()
|
| 38 |
+
|
| 39 |
+
# Provide a default singleton for ease of use
|
| 40 |
+
client = EnvironmentalClient()
|
env.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from server.app import app, env_instance
|
inference.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from openai import OpenAI
|
| 4 |
+
from server.models import Action, BrowserGymAction # using our local Action model alias
|
| 5 |
+
from server.app import env_instance as env
|
| 6 |
+
|
| 7 |
+
# Environment Configuration
|
| 8 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 9 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
|
| 10 |
+
API_KEY = os.environ.get("OPENAI_API_KEY") or os.environ.get("HF_TOKEN", "")
|
| 11 |
+
|
| 12 |
+
MAX_STEPS = 15
|
| 13 |
+
TEMPERATURE = 0.2
|
| 14 |
+
MAX_TOKENS = 512
|
| 15 |
+
|
| 16 |
+
SYSTEM_PROMPT = """
|
| 17 |
+
You are an expert Data Engineer interacting with a simulated SQLite database.
|
| 18 |
+
You will be given a task goal, the current database schema, and the most recent step's SQL output or error.
|
| 19 |
+
Your goal is to complete the task by executing SQL commands.
|
| 20 |
+
|
| 21 |
+
CRITICAL RULES:
|
| 22 |
+
1. You may only execute ONE SQL statement at a time. Do not chain statements with semicolons.
|
| 23 |
+
2. If you need to review data, use short SELECT queries.
|
| 24 |
+
3. If your previous action resulted in an SQL error, fix the error and try again.
|
| 25 |
+
4. If you need multiple steps to achieve the goal (e.g. create tables, then insert data), execute them one by one.
|
| 26 |
+
5. You MUST output ONLY a valid JSON object matching this schema:
|
| 27 |
+
{
|
| 28 |
+
"action_str": "YOUR SQL QUERY HERE"
|
| 29 |
+
}
|
| 30 |
+
Do not wrap your response in markdown code blocks. Just valid JSON.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def build_user_prompt(step: int, observation, history: list) -> str:
|
| 34 |
+
prompt = f"--- Step {step} ---\n"
|
| 35 |
+
prompt += f"Goal: {observation.goal}\n\n"
|
| 36 |
+
if observation.schema_dump:
|
| 37 |
+
prompt += f"Current DB Schema:\n{observation.schema_dump}\n\n"
|
| 38 |
+
|
| 39 |
+
prompt += f"Last Result (or Error):\n{observation.result}\n\n"
|
| 40 |
+
|
| 41 |
+
if history:
|
| 42 |
+
prompt += "Action History (Last 3 steps):\n"
|
| 43 |
+
for h in history[-3:]:
|
| 44 |
+
prompt += h + "\n"
|
| 45 |
+
|
| 46 |
+
prompt += "\nProvide the JSON with your next `action_str`:"
|
| 47 |
+
return prompt
|
| 48 |
+
|
| 49 |
+
def parse_model_action(response_text: str) -> str:
|
| 50 |
+
# Try to parse JSON
|
| 51 |
+
text = response_text.strip()
|
| 52 |
+
if text.startswith("```json"): text = text[7:]
|
| 53 |
+
if text.startswith("```"): text = text[3:]
|
| 54 |
+
if text.endswith("```"): text = text[:-3]
|
| 55 |
+
text = text.strip()
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
data = json.loads(text)
|
| 59 |
+
return data.get("action_str", "SELECT 1;")
|
| 60 |
+
except json.JSONDecodeError:
|
| 61 |
+
# Fallback if model doesn't follow json format correctly
|
| 62 |
+
return text
|
| 63 |
+
|
| 64 |
+
def run_task(task_id: int):
|
| 65 |
+
print(f"\n{'='*50}\nStarting Task {task_id}\n{'='*50}")
|
| 66 |
+
|
| 67 |
+
client = OpenAI(
|
| 68 |
+
base_url=API_BASE_URL,
|
| 69 |
+
api_key=API_KEY
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
history = []
|
| 73 |
+
|
| 74 |
+
# Using the local env object wrapper
|
| 75 |
+
result = env.reset(task_id=task_id)
|
| 76 |
+
observation = result.observation
|
| 77 |
+
print(f"Episode goal: {observation.goal}\n")
|
| 78 |
+
|
| 79 |
+
for step in range(1, MAX_STEPS + 1):
|
| 80 |
+
# We handle done from the step result, but for initial step we check just in case
|
| 81 |
+
user_prompt = build_user_prompt(step, observation, history)
|
| 82 |
+
|
| 83 |
+
# print("PROMPT:", user_prompt)
|
| 84 |
+
|
| 85 |
+
messages = [
|
| 86 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 87 |
+
{"role": "user", "content": user_prompt},
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
completion = client.chat.completions.create(
|
| 92 |
+
model=MODEL_NAME,
|
| 93 |
+
messages=messages,
|
| 94 |
+
temperature=TEMPERATURE,
|
| 95 |
+
max_tokens=MAX_TOKENS,
|
| 96 |
+
stream=False,
|
| 97 |
+
response_format={"type": "json_object"} # enforce json output
|
| 98 |
+
)
|
| 99 |
+
response_text = completion.choices[0].message.content or ""
|
| 100 |
+
action_str = parse_model_action(response_text)
|
| 101 |
+
except Exception as exc:
|
| 102 |
+
failure_msg = f"Model request failed ({exc}). Using fallback action."
|
| 103 |
+
print(failure_msg)
|
| 104 |
+
action_str = "SELECT 1;"
|
| 105 |
+
|
| 106 |
+
print(f"Step {step}: model suggested -> {action_str[:100]}...")
|
| 107 |
+
|
| 108 |
+
# Step the environment
|
| 109 |
+
step_result = env.step(BrowserGymAction(action_str=action_str))
|
| 110 |
+
observation = step_result.observation
|
| 111 |
+
reward = step_result.reward
|
| 112 |
+
done = step_result.done
|
| 113 |
+
|
| 114 |
+
error_flag = " ERROR" if observation.last_action_error else ""
|
| 115 |
+
history_line = f"Step {step}: {action_str[:50]}... -> reward {reward:+.2f}{error_flag}"
|
| 116 |
+
history.append(history_line)
|
| 117 |
+
|
| 118 |
+
print(f" Reward: {reward:+.2f} | Done: {done} | Last action error: {observation.last_action_error}")
|
| 119 |
+
|
| 120 |
+
if done:
|
| 121 |
+
final_score = step_result.info.get("current_score", 0.0)
|
| 122 |
+
print(f"\nEpisode complete! Final Score: {final_score}/1.0")
|
| 123 |
+
break
|
| 124 |
+
else:
|
| 125 |
+
final_score = env.state().get("current_score", 0.0)
|
| 126 |
+
print(f"\nReached max steps ({MAX_STEPS}). Final Score: {final_score}/1.0")
|
| 127 |
+
|
| 128 |
+
return final_score
|
| 129 |
+
|
| 130 |
+
def main():
|
| 131 |
+
print("Testing OpenEnv Data Engineer Inference Baseline")
|
| 132 |
+
|
| 133 |
+
if not API_KEY:
|
| 134 |
+
print("Warning: API_KEY/HF_TOKEN not set. Will likely fail unless local LLM.")
|
| 135 |
+
|
| 136 |
+
scores = {}
|
| 137 |
+
for task_id in [1, 2, 3]:
|
| 138 |
+
score = run_task(task_id)
|
| 139 |
+
scores[f"Task_{task_id}"] = score
|
| 140 |
+
|
| 141 |
+
print(f"\n{'*'*50}\nEVALUATION COMPLETE\n{scores}\n{'*'*50}")
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from server.models import (
|
| 2 |
+
Action,
|
| 3 |
+
BrowserGymAction,
|
| 4 |
+
Observation,
|
| 5 |
+
Reward,
|
| 6 |
+
StepResult,
|
| 7 |
+
ResetResult
|
| 8 |
+
)
|
openenv.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: OpenEnv-SQL-Data-Engineer
|
| 2 |
+
version: 1.0.0
|
| 3 |
+
description: >
|
| 4 |
+
A real-world OpenEnv environment simulating a Database Administrator/Data Engineer.
|
| 5 |
+
The agent must interact with a live SQLite mock database using SQL commands to perform data extraction, data cleaning, and schema normalization.
|
| 6 |
+
entrypoint: env:app
|
| 7 |
+
models:
|
| 8 |
+
action: models:Action
|
| 9 |
+
observation: models:Observation
|
| 10 |
+
reward: models:Reward
|
| 11 |
+
tags:
|
| 12 |
+
- database
|
| 13 |
+
- sql
|
| 14 |
+
- data-engineering
|
| 15 |
+
- structured-data
|
| 16 |
+
- real-world
|
| 17 |
+
port: 7860
|
| 18 |
+
dockerfile: Dockerfile
|
| 19 |
+
environment_vars:
|
| 20 |
+
- HF_TOKEN
|
pyproject.toml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "sql-data-engineer-env"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "A real-world SQL data engineering environment for agent evaluation."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
dependencies = [
|
| 12 |
+
"openenv-core",
|
| 13 |
+
"fastapi",
|
| 14 |
+
"uvicorn",
|
| 15 |
+
"pydantic",
|
| 16 |
+
"openai",
|
| 17 |
+
"pandas"
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
[tool.setuptools]
|
| 21 |
+
packages = ["."]
|
| 22 |
+
|
| 23 |
+
[project.scripts]
|
| 24 |
+
server = "server.app:main"
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
pydantic
|
| 4 |
+
openai
|
| 5 |
+
requests
|
server/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .app import env_instance
|
| 2 |
+
from .models import Observation, Action
|
server/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (271 Bytes). View file
|
|
|
server/__pycache__/app.cpython-314.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
server/__pycache__/models.cpython-314.pyc
ADDED
|
Binary file (2.78 kB). View file
|
|
|
server/__pycache__/tasks.cpython-314.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
server/app.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import os
|
| 3 |
+
from typing import Dict, Any, Optional
|
| 4 |
+
from fastapi import FastAPI, HTTPException
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
import uvicorn
|
| 7 |
+
|
| 8 |
+
from .models import Action, Observation, Reward, StepResult, ResetResult
|
| 9 |
+
from .tasks import TASKS
|
| 10 |
+
|
| 11 |
+
app = FastAPI(title="OpenEnv SQL Data Engineer")
|
| 12 |
+
|
| 13 |
+
class SQLEnvironment:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.conn: Optional[sqlite3.Connection] = None
|
| 16 |
+
self.task_id = 1
|
| 17 |
+
self.step_count = 0
|
| 18 |
+
self.current_score = 0.0
|
| 19 |
+
self.db_path = "tmp_env.db"
|
| 20 |
+
|
| 21 |
+
def get_schema_dump(self) -> str:
|
| 22 |
+
if not self.conn:
|
| 23 |
+
return ""
|
| 24 |
+
try:
|
| 25 |
+
c = self.conn.cursor()
|
| 26 |
+
c.execute("SELECT type, name, sql FROM sqlite_master WHERE type='table' OR type='view'")
|
| 27 |
+
rows = c.fetchall()
|
| 28 |
+
dump = []
|
| 29 |
+
for t, name, sql in rows:
|
| 30 |
+
if name.startswith('sqlite_'): continue
|
| 31 |
+
dump.append(f"[{t.upper()}] {name}:\n {sql}")
|
| 32 |
+
|
| 33 |
+
return "\n".join(dump) if dump else "Database is empty."
|
| 34 |
+
except Exception as e:
|
| 35 |
+
return f"Error extracting schema: {e}"
|
| 36 |
+
|
| 37 |
+
def reset(self, task_id: int = 1) -> ResetResult:
|
| 38 |
+
if self.conn:
|
| 39 |
+
self.conn.close()
|
| 40 |
+
|
| 41 |
+
# Clean up existing temp db
|
| 42 |
+
if os.path.exists(self.db_path):
|
| 43 |
+
os.remove(self.db_path)
|
| 44 |
+
|
| 45 |
+
self.task_id = task_id
|
| 46 |
+
if self.task_id not in TASKS:
|
| 47 |
+
raise ValueError(f"Task ID {self.task_id} not found.")
|
| 48 |
+
|
| 49 |
+
self.conn = sqlite3.connect(self.db_path)
|
| 50 |
+
self.step_count = 0
|
| 51 |
+
self.current_score = 0.0
|
| 52 |
+
|
| 53 |
+
# Initialize standard SQLite settings
|
| 54 |
+
self.conn.execute("PRAGMA foreign_keys = ON")
|
| 55 |
+
|
| 56 |
+
# Setup specific task data
|
| 57 |
+
task = TASKS[self.task_id]
|
| 58 |
+
task.setup_db(self.conn)
|
| 59 |
+
self.current_score = task.grade(self.conn) # Base score
|
| 60 |
+
|
| 61 |
+
goal_text = task.get_goal()
|
| 62 |
+
# Add basic info about actual task goal
|
| 63 |
+
instructions = f"Task Goal: {goal_text}\n"
|
| 64 |
+
|
| 65 |
+
obs = Observation(
|
| 66 |
+
goal=instructions,
|
| 67 |
+
result="Environment initialized. Schema ready.",
|
| 68 |
+
step=self.step_count,
|
| 69 |
+
last_action_error=False,
|
| 70 |
+
schema_dump=self.get_schema_dump()
|
| 71 |
+
)
|
| 72 |
+
return ResetResult(observation=obs, info={"task_id": self.task_id, "initial_score": self.current_score})
|
| 73 |
+
|
| 74 |
+
def step(self, action: Action) -> StepResult:
|
| 75 |
+
if not self.conn:
|
| 76 |
+
raise ValueError("Environment not initialized. Call reset() first.")
|
| 77 |
+
|
| 78 |
+
self.step_count += 1
|
| 79 |
+
last_action_error = False
|
| 80 |
+
query_result = ""
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
c = self.conn.cursor()
|
| 84 |
+
query = action.action_str.strip()
|
| 85 |
+
# Basic mitigation of forbidden queries just in case (though we're in mock)
|
| 86 |
+
if query.upper().startswith("DROP TABLE sqlite_"):
|
| 87 |
+
raise Exception("Cannot modify system tables.")
|
| 88 |
+
|
| 89 |
+
c.execute(query)
|
| 90 |
+
|
| 91 |
+
if query.upper().startswith("SELECT") or query.upper().startswith("PRAGMA"):
|
| 92 |
+
rows = c.fetchmany(10) # limit output size for LLM observation
|
| 93 |
+
col_names = [description[0] for description in c.description] if c.description else []
|
| 94 |
+
# Format tabular output
|
| 95 |
+
result_str = " | ".join(col_names) + "\n"
|
| 96 |
+
result_str += "-" * len(result_str) + "\n"
|
| 97 |
+
for r in rows:
|
| 98 |
+
result_str += " | ".join([str(val) for val in r]) + "\n"
|
| 99 |
+
if len(rows) == 10:
|
| 100 |
+
result_str += "... (output truncated)"
|
| 101 |
+
query_result = result_str
|
| 102 |
+
else:
|
| 103 |
+
self.conn.commit()
|
| 104 |
+
query_result = f"Command executed successfully. Rowcount: {c.rowcount}"
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
last_action_error = True
|
| 108 |
+
query_result = f"SQL Error: {str(e)}"
|
| 109 |
+
|
| 110 |
+
# Run grader
|
| 111 |
+
task = TASKS[self.task_id]
|
| 112 |
+
new_score = task.grade(self.conn)
|
| 113 |
+
|
| 114 |
+
# Reward is dense: change in score + small penalty for errors
|
| 115 |
+
reward_value = new_score - self.current_score
|
| 116 |
+
|
| 117 |
+
if last_action_error:
|
| 118 |
+
# minor penalty for syntax errors
|
| 119 |
+
reward_value -= 0.05
|
| 120 |
+
|
| 121 |
+
self.current_score = new_score
|
| 122 |
+
|
| 123 |
+
# Episode terminates when score is 1.0
|
| 124 |
+
done = (self.current_score >= 1.0) or (self.step_count > 30)
|
| 125 |
+
|
| 126 |
+
obs = Observation(
|
| 127 |
+
goal=task.get_goal(),
|
| 128 |
+
result=query_result,
|
| 129 |
+
step=self.step_count,
|
| 130 |
+
last_action_error=last_action_error,
|
| 131 |
+
schema_dump=self.get_schema_dump() if not last_action_error else None # only dump if no error to save tokens
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return StepResult(
|
| 135 |
+
observation=obs,
|
| 136 |
+
reward=reward_value,
|
| 137 |
+
done=done,
|
| 138 |
+
info={"current_score": self.current_score}
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def state(self) -> Any:
|
| 142 |
+
# Return state as unstructured dict per standard API
|
| 143 |
+
return {
|
| 144 |
+
"task_id": self.task_id,
|
| 145 |
+
"step": self.step_count,
|
| 146 |
+
"current_score": self.current_score,
|
| 147 |
+
"schema_dump": self.get_schema_dump()
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
# Global instance
|
| 151 |
+
env_instance = SQLEnvironment()
|
| 152 |
+
|
| 153 |
+
class ResetRequest(BaseModel):
|
| 154 |
+
task_id: int = 1
|
| 155 |
+
|
| 156 |
+
@app.post("/reset", response_model=ResetResult)
|
| 157 |
+
def reset(req: ResetRequest):
|
| 158 |
+
try:
|
| 159 |
+
return env_instance.reset(task_id=req.task_id)
|
| 160 |
+
except Exception as e:
|
| 161 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 162 |
+
|
| 163 |
+
@app.post("/step", response_model=StepResult)
|
| 164 |
+
def step(action: Action):
|
| 165 |
+
try:
|
| 166 |
+
return env_instance.step(action)
|
| 167 |
+
except Exception as e:
|
| 168 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 169 |
+
|
| 170 |
+
@app.get("/state")
|
| 171 |
+
def state():
|
| 172 |
+
return env_instance.state()
|
| 173 |
+
|
| 174 |
+
def main():
|
| 175 |
+
uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False)
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
main()
|
server/models.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import Optional, Any, Dict, List
|
| 3 |
+
|
| 4 |
+
class Action(BaseModel):
|
| 5 |
+
action_str: str
|
| 6 |
+
|
| 7 |
+
# Alias for compatibility with the sample inference script
|
| 8 |
+
BrowserGymAction = Action
|
| 9 |
+
|
| 10 |
+
class Observation(BaseModel):
|
| 11 |
+
goal: str
|
| 12 |
+
result: str
|
| 13 |
+
step: int
|
| 14 |
+
last_action_error: bool
|
| 15 |
+
schema_dump: Optional[str] = None
|
| 16 |
+
|
| 17 |
+
class Reward(BaseModel):
|
| 18 |
+
value: float
|
| 19 |
+
reason: Optional[str] = None
|
| 20 |
+
|
| 21 |
+
class StepResult(BaseModel):
|
| 22 |
+
observation: Observation
|
| 23 |
+
reward: float
|
| 24 |
+
done: bool
|
| 25 |
+
info: Dict[str, Any]
|
| 26 |
+
|
| 27 |
+
class ResetResult(BaseModel):
|
| 28 |
+
observation: Observation
|
| 29 |
+
info: Dict[str, Any]
|
server/tasks.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
class Task:
|
| 5 |
+
def __init__(self, task_id: int):
|
| 6 |
+
self.task_id = task_id
|
| 7 |
+
|
| 8 |
+
def setup_db(self, conn: sqlite3.Connection):
|
| 9 |
+
raise NotImplementedError
|
| 10 |
+
|
| 11 |
+
def get_goal(self) -> str:
|
| 12 |
+
raise NotImplementedError
|
| 13 |
+
|
| 14 |
+
def grade(self, conn: sqlite3.Connection) -> float:
|
| 15 |
+
raise NotImplementedError
|
| 16 |
+
|
| 17 |
+
class EasyTask(Task):
|
| 18 |
+
def __init__(self):
|
| 19 |
+
super().__init__(1)
|
| 20 |
+
|
| 21 |
+
def setup_db(self, conn: sqlite3.Connection):
|
| 22 |
+
c = conn.cursor()
|
| 23 |
+
c.execute("CREATE TABLE customers (id INTEGER PRIMARY KEY, name TEXT, total_spent REAL)")
|
| 24 |
+
c.executemany("INSERT INTO customers (name, total_spent) VALUES (?, ?)", [
|
| 25 |
+
("Alice", 500.0),
|
| 26 |
+
("Bob", 1200.0),
|
| 27 |
+
("Charlie", 50.0),
|
| 28 |
+
("Diana", 3000.0),
|
| 29 |
+
("Eve", 1000.01) # over 1000
|
| 30 |
+
])
|
| 31 |
+
conn.commit()
|
| 32 |
+
|
| 33 |
+
def get_goal(self) -> str:
|
| 34 |
+
return "Create a view named 'high_value_customers' containing all customers who have a 'total_spent' greater than 1000.0. The view should contain the exact same columns as the customers table."
|
| 35 |
+
|
| 36 |
+
def grade(self, conn: sqlite3.Connection) -> float:
|
| 37 |
+
c = conn.cursor()
|
| 38 |
+
try:
|
| 39 |
+
# Check if view exists
|
| 40 |
+
c.execute("SELECT name FROM sqlite_master WHERE type='view' AND name='high_value_customers'")
|
| 41 |
+
if not c.fetchone():
|
| 42 |
+
return 0.0
|
| 43 |
+
|
| 44 |
+
# Check rows
|
| 45 |
+
c.execute("SELECT name, total_spent FROM high_value_customers ORDER BY name")
|
| 46 |
+
rows = c.fetchall()
|
| 47 |
+
|
| 48 |
+
if len(rows) != 3:
|
| 49 |
+
return 0.5 # partially correct, exists but wrong rows
|
| 50 |
+
|
| 51 |
+
expected = [("Bob", 1200.0), ("Diana", 3000.0), ("Eve", 1000.01)]
|
| 52 |
+
if rows == expected:
|
| 53 |
+
return 1.0
|
| 54 |
+
return 0.5
|
| 55 |
+
except Exception:
|
| 56 |
+
return 0.0
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MediumTask(Task):
|
| 60 |
+
def __init__(self):
|
| 61 |
+
super().__init__(2)
|
| 62 |
+
|
| 63 |
+
def setup_db(self, conn: sqlite3.Connection):
|
| 64 |
+
c = conn.cursor()
|
| 65 |
+
c.execute("CREATE TABLE products (id INTEGER PRIMARY KEY, name TEXT, category TEXT, price TEXT)")
|
| 66 |
+
c.executemany("INSERT INTO products (name, category, price) VALUES (?, ?, ?)", [
|
| 67 |
+
("Laptop", "Electronics", "$999.99"),
|
| 68 |
+
("Mouse", "electronics", "25.50 USD"),
|
| 69 |
+
("Desk", "FURNITURE", "150.0"),
|
| 70 |
+
("Chair", "furniture", "$85.00"),
|
| 71 |
+
("Headphones", "ELEC", "€45.00") # We'll just ask them to remove letters/symbols
|
| 72 |
+
])
|
| 73 |
+
conn.commit()
|
| 74 |
+
|
| 75 |
+
def get_goal(self) -> str:
|
| 76 |
+
return (
|
| 77 |
+
"The 'products' table is messy. "
|
| 78 |
+
"1. Standardize the 'category' column to be fully UPPERCASE. (Hint: treat 'ELEC' as 'ELECTRONICS'). "
|
| 79 |
+
"2. Create a new column 'price_usd' of type REAL. Extract the numeric value from the 'price' string and populate 'price_usd'. "
|
| 80 |
+
"Do not drop any original columns."
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def grade(self, conn: sqlite3.Connection) -> float:
|
| 84 |
+
score = 0.0
|
| 85 |
+
c = conn.cursor()
|
| 86 |
+
try:
|
| 87 |
+
# Check column exists
|
| 88 |
+
c.execute("PRAGMA table_info(products)")
|
| 89 |
+
columns = [row[1] for row in c.fetchall()]
|
| 90 |
+
if 'price_usd' in columns:
|
| 91 |
+
score += 0.3
|
| 92 |
+
|
| 93 |
+
# Check data accuracy for price
|
| 94 |
+
c.execute("SELECT price_usd FROM products ORDER BY id")
|
| 95 |
+
prices = [row[0] for row in c.fetchall()]
|
| 96 |
+
expected_prices = [999.99, 25.50, 150.0, 85.0, 45.0]
|
| 97 |
+
|
| 98 |
+
# allow small float diffs
|
| 99 |
+
correct_prices = sum(1 for p, e in zip(prices, expected_prices) if p is not None and abs(p - e) < 0.01)
|
| 100 |
+
score += (correct_prices / 5.0) * 0.4 # up to 0.4 for correct prices
|
| 101 |
+
|
| 102 |
+
# Check category uppercase
|
| 103 |
+
c.execute("SELECT category FROM products ORDER BY id")
|
| 104 |
+
categories = [row[0] for row in c.fetchall()]
|
| 105 |
+
expected_cats = ["ELECTRONICS", "ELECTRONICS", "FURNITURE", "FURNITURE", "ELECTRONICS"]
|
| 106 |
+
|
| 107 |
+
correct_cats = sum(1 for c, e in zip(categories, expected_cats) if c == e)
|
| 108 |
+
score += (correct_cats / 5.0) * 0.3 # up to 0.3 for correct categories
|
| 109 |
+
|
| 110 |
+
return min(1.0, score)
|
| 111 |
+
except Exception:
|
| 112 |
+
return score
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class HardTask(Task):
|
| 116 |
+
def __init__(self):
|
| 117 |
+
super().__init__(3)
|
| 118 |
+
|
| 119 |
+
def setup_db(self, conn: sqlite3.Connection):
|
| 120 |
+
c = conn.cursor()
|
| 121 |
+
c.execute("""
|
| 122 |
+
CREATE TABLE hospital_records (
|
| 123 |
+
patient_name TEXT,
|
| 124 |
+
patient_dob TEXT,
|
| 125 |
+
doctor_name TEXT,
|
| 126 |
+
doctor_specialty TEXT,
|
| 127 |
+
appointment_date TEXT,
|
| 128 |
+
diagnosis TEXT
|
| 129 |
+
)
|
| 130 |
+
""")
|
| 131 |
+
records = [
|
| 132 |
+
("John Doe", "1980-01-01", "Dr. Smith", "Cardiology", "2023-10-01", "Hypertension"),
|
| 133 |
+
("Jane Roe", "1992-05-15", "Dr. Jones", "Neurology", "2023-10-02", "Migraine"),
|
| 134 |
+
("John Doe", "1980-01-01", "Dr. Smith", "Cardiology", "2023-11-01", "Follow-up"),
|
| 135 |
+
("Bob Guy", "1975-11-20", "Dr. Smith", "Cardiology", "2023-10-05", "Checkup")
|
| 136 |
+
]
|
| 137 |
+
c.executemany("INSERT INTO hospital_records VALUES (?, ?, ?, ?, ?, ?)", records)
|
| 138 |
+
conn.commit()
|
| 139 |
+
|
| 140 |
+
def get_goal(self) -> str:
|
| 141 |
+
return (
|
| 142 |
+
"Normalize the flat 'hospital_records' table into 3 tables: "
|
| 143 |
+
"'patients' (id INTEGER PRIMARY KEY, name TEXT, dob TEXT), "
|
| 144 |
+
"'doctors' (id INTEGER PRIMARY KEY, name TEXT, specialty TEXT), and "
|
| 145 |
+
"'appointments' (id INTEGER PRIMARY KEY, patient_id INTEGER, doctor_id INTEGER, date TEXT, diagnosis TEXT). "
|
| 146 |
+
"Migrate all data from 'hospital_records' correctly without duplication. "
|
| 147 |
+
"Ensure foreign keys are correctly pointing to the new IDs."
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def grade(self, conn: sqlite3.Connection) -> float:
|
| 151 |
+
score = 0.0
|
| 152 |
+
c = conn.cursor()
|
| 153 |
+
try:
|
| 154 |
+
# Check tables exist
|
| 155 |
+
c.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
| 156 |
+
tables = [row[0] for row in c.fetchall()]
|
| 157 |
+
|
| 158 |
+
if 'patients' in tables: score += 0.1
|
| 159 |
+
if 'doctors' in tables: score += 0.1
|
| 160 |
+
if 'appointments' in tables: score += 0.2
|
| 161 |
+
|
| 162 |
+
if score < 0.4:
|
| 163 |
+
return score
|
| 164 |
+
|
| 165 |
+
# Check data counts (3 unique patients, 2 unique doctors, 4 appointments)
|
| 166 |
+
c.execute("SELECT COUNT(*) FROM patients")
|
| 167 |
+
if c.fetchone()[0] == 3: score += 0.1
|
| 168 |
+
|
| 169 |
+
c.execute("SELECT COUNT(*) FROM doctors")
|
| 170 |
+
if c.fetchone()[0] == 2: score += 0.1
|
| 171 |
+
|
| 172 |
+
c.execute("SELECT COUNT(*) FROM appointments")
|
| 173 |
+
if c.fetchone()[0] == 4: score += 0.1
|
| 174 |
+
|
| 175 |
+
# Check referential integrity (can we reconstruct the original view?)
|
| 176 |
+
query = """
|
| 177 |
+
SELECT p.name, p.dob, d.name, d.specialty, a.date, a.diagnosis
|
| 178 |
+
FROM appointments a
|
| 179 |
+
JOIN patients p ON a.patient_id = p.id
|
| 180 |
+
JOIN doctors d ON a.doctor_id = d.id
|
| 181 |
+
ORDER BY p.name, a.date
|
| 182 |
+
"""
|
| 183 |
+
c.execute(query)
|
| 184 |
+
reconstructed = c.fetchall()
|
| 185 |
+
if len(reconstructed) == 4:
|
| 186 |
+
score += 0.3
|
| 187 |
+
|
| 188 |
+
return min(1.0, score)
|
| 189 |
+
except Exception:
|
| 190 |
+
return score
|
| 191 |
+
|
| 192 |
+
TASKS = {
|
| 193 |
+
1: EasyTask(),
|
| 194 |
+
2: MediumTask(),
|
| 195 |
+
3: HardTask()
|
| 196 |
+
}
|
tmp_env.db
ADDED
|
Binary file (16.4 kB). View file
|
|
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|