Rithwik Ravi commited on
Commit
de07414
·
1 Parent(s): c861e8f

Upload full OpenEnv project structure

Browse files
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Creates a non-root user with an explicit UID required by Hugging Face Spaces
4
+ RUN useradd -m -u 1000 user
5
+ USER user
6
+
7
+ # Set environment variables
8
+ ENV HOME=/home/user \
9
+ PATH=/home/user/.local/bin:$PATH
10
+
11
+ # Create the working directory
12
+ WORKDIR $HOME/app
13
+
14
+ # Install dependencies first for Docker caching
15
+ COPY --chown=user:user requirements.txt .
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy the rest of the files
19
+ COPY --chown=user:user . .
20
+
21
+ # Expose the standard Hugging Face Space port
22
+ EXPOSE 7860
23
+
24
+ # Start the environment API server
25
+ CMD ["uvicorn", "env:app", "--host", "0.0.0.0", "--port", "7860"]
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from server import env_instance
__pycache__/env.cpython-314.pyc ADDED
Binary file (10 kB). View file
 
__pycache__/inference.cpython-314.pyc ADDED
Binary file (7.41 kB). View file
 
__pycache__/models.cpython-314.pyc ADDED
Binary file (2.72 kB). View file
 
__pycache__/tasks.cpython-314.pyc ADDED
Binary file (12.6 kB). View file
 
client.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from typing import Dict, Any, Optional
3
+
4
+ from server.models import Action, StepResult, ResetResult
5
+
6
+ class EnvironmentalClient:
7
+ """
8
+ Official OpenEnv client wrapper to interface with the containerized environment endpoints.
9
+ Can be used by baseline inference scripts or remote agent evaluations.
10
+ """
11
+ def __init__(self, base_url: str = "http://127.0.0.1:7860"):
12
+ self.base_url = base_url.rstrip('/')
13
+
14
+ def reset(self, task_id: int = 1) -> ResetResult:
15
+ """
16
+ Calls the /reset endpoint to begin a new episode for the specified task.
17
+ """
18
+ response = requests.post(f"{self.base_url}/reset", json={"task_id": task_id})
19
+ response.raise_for_status()
20
+ return ResetResult(**response.json())
21
+
22
+ def step(self, action: Action) -> StepResult:
23
+ """
24
+ Submits a step (action) to the environment and returns the updated state/reward.
25
+ """
26
+ # Pydantic dict serialization for requests
27
+ response = requests.post(f"{self.base_url}/step", json=action.dict())
28
+ response.raise_for_status()
29
+ return StepResult(**response.json())
30
+
31
+ def state(self) -> Dict[str, Any]:
32
+ """
33
+ Extracts the unstructured metadata state of the running environment.
34
+ """
35
+ response = requests.get(f"{self.base_url}/state")
36
+ response.raise_for_status()
37
+ return response.json()
38
+
39
+ # Provide a default singleton for ease of use
40
+ client = EnvironmentalClient()
env.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from server.app import app, env_instance
inference.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from openai import OpenAI
4
+ from server.models import Action, BrowserGymAction # using our local Action model alias
5
+ from server.app import env_instance as env
6
+
7
+ # Environment Configuration
8
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
9
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
10
+ API_KEY = os.environ.get("OPENAI_API_KEY") or os.environ.get("HF_TOKEN", "")
11
+
12
+ MAX_STEPS = 15
13
+ TEMPERATURE = 0.2
14
+ MAX_TOKENS = 512
15
+
16
+ SYSTEM_PROMPT = """
17
+ You are an expert Data Engineer interacting with a simulated SQLite database.
18
+ You will be given a task goal, the current database schema, and the most recent step's SQL output or error.
19
+ Your goal is to complete the task by executing SQL commands.
20
+
21
+ CRITICAL RULES:
22
+ 1. You may only execute ONE SQL statement at a time. Do not chain statements with semicolons.
23
+ 2. If you need to review data, use short SELECT queries.
24
+ 3. If your previous action resulted in an SQL error, fix the error and try again.
25
+ 4. If you need multiple steps to achieve the goal (e.g. create tables, then insert data), execute them one by one.
26
+ 5. You MUST output ONLY a valid JSON object matching this schema:
27
+ {
28
+ "action_str": "YOUR SQL QUERY HERE"
29
+ }
30
+ Do not wrap your response in markdown code blocks. Just valid JSON.
31
+ """
32
+
33
+ def build_user_prompt(step: int, observation, history: list) -> str:
34
+ prompt = f"--- Step {step} ---\n"
35
+ prompt += f"Goal: {observation.goal}\n\n"
36
+ if observation.schema_dump:
37
+ prompt += f"Current DB Schema:\n{observation.schema_dump}\n\n"
38
+
39
+ prompt += f"Last Result (or Error):\n{observation.result}\n\n"
40
+
41
+ if history:
42
+ prompt += "Action History (Last 3 steps):\n"
43
+ for h in history[-3:]:
44
+ prompt += h + "\n"
45
+
46
+ prompt += "\nProvide the JSON with your next `action_str`:"
47
+ return prompt
48
+
49
+ def parse_model_action(response_text: str) -> str:
50
+ # Try to parse JSON
51
+ text = response_text.strip()
52
+ if text.startswith("```json"): text = text[7:]
53
+ if text.startswith("```"): text = text[3:]
54
+ if text.endswith("```"): text = text[:-3]
55
+ text = text.strip()
56
+
57
+ try:
58
+ data = json.loads(text)
59
+ return data.get("action_str", "SELECT 1;")
60
+ except json.JSONDecodeError:
61
+ # Fallback if model doesn't follow json format correctly
62
+ return text
63
+
64
+ def run_task(task_id: int):
65
+ print(f"\n{'='*50}\nStarting Task {task_id}\n{'='*50}")
66
+
67
+ client = OpenAI(
68
+ base_url=API_BASE_URL,
69
+ api_key=API_KEY
70
+ )
71
+
72
+ history = []
73
+
74
+ # Using the local env object wrapper
75
+ result = env.reset(task_id=task_id)
76
+ observation = result.observation
77
+ print(f"Episode goal: {observation.goal}\n")
78
+
79
+ for step in range(1, MAX_STEPS + 1):
80
+ # We handle done from the step result, but for initial step we check just in case
81
+ user_prompt = build_user_prompt(step, observation, history)
82
+
83
+ # print("PROMPT:", user_prompt)
84
+
85
+ messages = [
86
+ {"role": "system", "content": SYSTEM_PROMPT},
87
+ {"role": "user", "content": user_prompt},
88
+ ]
89
+
90
+ try:
91
+ completion = client.chat.completions.create(
92
+ model=MODEL_NAME,
93
+ messages=messages,
94
+ temperature=TEMPERATURE,
95
+ max_tokens=MAX_TOKENS,
96
+ stream=False,
97
+ response_format={"type": "json_object"} # enforce json output
98
+ )
99
+ response_text = completion.choices[0].message.content or ""
100
+ action_str = parse_model_action(response_text)
101
+ except Exception as exc:
102
+ failure_msg = f"Model request failed ({exc}). Using fallback action."
103
+ print(failure_msg)
104
+ action_str = "SELECT 1;"
105
+
106
+ print(f"Step {step}: model suggested -> {action_str[:100]}...")
107
+
108
+ # Step the environment
109
+ step_result = env.step(BrowserGymAction(action_str=action_str))
110
+ observation = step_result.observation
111
+ reward = step_result.reward
112
+ done = step_result.done
113
+
114
+ error_flag = " ERROR" if observation.last_action_error else ""
115
+ history_line = f"Step {step}: {action_str[:50]}... -> reward {reward:+.2f}{error_flag}"
116
+ history.append(history_line)
117
+
118
+ print(f" Reward: {reward:+.2f} | Done: {done} | Last action error: {observation.last_action_error}")
119
+
120
+ if done:
121
+ final_score = step_result.info.get("current_score", 0.0)
122
+ print(f"\nEpisode complete! Final Score: {final_score}/1.0")
123
+ break
124
+ else:
125
+ final_score = env.state().get("current_score", 0.0)
126
+ print(f"\nReached max steps ({MAX_STEPS}). Final Score: {final_score}/1.0")
127
+
128
+ return final_score
129
+
130
+ def main():
131
+ print("Testing OpenEnv Data Engineer Inference Baseline")
132
+
133
+ if not API_KEY:
134
+ print("Warning: API_KEY/HF_TOKEN not set. Will likely fail unless local LLM.")
135
+
136
+ scores = {}
137
+ for task_id in [1, 2, 3]:
138
+ score = run_task(task_id)
139
+ scores[f"Task_{task_id}"] = score
140
+
141
+ print(f"\n{'*'*50}\nEVALUATION COMPLETE\n{scores}\n{'*'*50}")
142
+
143
+ if __name__ == "__main__":
144
+ main()
models.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from server.models import (
2
+ Action,
3
+ BrowserGymAction,
4
+ Observation,
5
+ Reward,
6
+ StepResult,
7
+ ResetResult
8
+ )
openenv.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: OpenEnv-SQL-Data-Engineer
2
+ version: 1.0.0
3
+ description: >
4
+ A real-world OpenEnv environment simulating a Database Administrator/Data Engineer.
5
+ The agent must interact with a live SQLite mock database using SQL commands to perform data extraction, data cleaning, and schema normalization.
6
+ entrypoint: env:app
7
+ models:
8
+ action: models:Action
9
+ observation: models:Observation
10
+ reward: models:Reward
11
+ tags:
12
+ - database
13
+ - sql
14
+ - data-engineering
15
+ - structured-data
16
+ - real-world
17
+ port: 7860
18
+ dockerfile: Dockerfile
19
+ environment_vars:
20
+ - HF_TOKEN
pyproject.toml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sql-data-engineer-env"
7
+ version = "0.1.0"
8
+ description = "A real-world SQL data engineering environment for agent evaluation."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ dependencies = [
12
+ "openenv-core",
13
+ "fastapi",
14
+ "uvicorn",
15
+ "pydantic",
16
+ "openai",
17
+ "pandas"
18
+ ]
19
+
20
+ [tool.setuptools]
21
+ packages = ["."]
22
+
23
+ [project.scripts]
24
+ server = "server.app:main"
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pydantic
4
+ openai
5
+ requests
server/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .app import env_instance
2
+ from .models import Observation, Action
server/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (271 Bytes). View file
 
server/__pycache__/app.cpython-314.pyc ADDED
Binary file (10.2 kB). View file
 
server/__pycache__/models.cpython-314.pyc ADDED
Binary file (2.78 kB). View file
 
server/__pycache__/tasks.cpython-314.pyc ADDED
Binary file (12.6 kB). View file
 
server/app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import os
3
+ from typing import Dict, Any, Optional
4
+ from fastapi import FastAPI, HTTPException
5
+ from pydantic import BaseModel
6
+ import uvicorn
7
+
8
+ from .models import Action, Observation, Reward, StepResult, ResetResult
9
+ from .tasks import TASKS
10
+
11
+ app = FastAPI(title="OpenEnv SQL Data Engineer")
12
+
13
+ class SQLEnvironment:
14
+ def __init__(self):
15
+ self.conn: Optional[sqlite3.Connection] = None
16
+ self.task_id = 1
17
+ self.step_count = 0
18
+ self.current_score = 0.0
19
+ self.db_path = "tmp_env.db"
20
+
21
+ def get_schema_dump(self) -> str:
22
+ if not self.conn:
23
+ return ""
24
+ try:
25
+ c = self.conn.cursor()
26
+ c.execute("SELECT type, name, sql FROM sqlite_master WHERE type='table' OR type='view'")
27
+ rows = c.fetchall()
28
+ dump = []
29
+ for t, name, sql in rows:
30
+ if name.startswith('sqlite_'): continue
31
+ dump.append(f"[{t.upper()}] {name}:\n {sql}")
32
+
33
+ return "\n".join(dump) if dump else "Database is empty."
34
+ except Exception as e:
35
+ return f"Error extracting schema: {e}"
36
+
37
+ def reset(self, task_id: int = 1) -> ResetResult:
38
+ if self.conn:
39
+ self.conn.close()
40
+
41
+ # Clean up existing temp db
42
+ if os.path.exists(self.db_path):
43
+ os.remove(self.db_path)
44
+
45
+ self.task_id = task_id
46
+ if self.task_id not in TASKS:
47
+ raise ValueError(f"Task ID {self.task_id} not found.")
48
+
49
+ self.conn = sqlite3.connect(self.db_path)
50
+ self.step_count = 0
51
+ self.current_score = 0.0
52
+
53
+ # Initialize standard SQLite settings
54
+ self.conn.execute("PRAGMA foreign_keys = ON")
55
+
56
+ # Setup specific task data
57
+ task = TASKS[self.task_id]
58
+ task.setup_db(self.conn)
59
+ self.current_score = task.grade(self.conn) # Base score
60
+
61
+ goal_text = task.get_goal()
62
+ # Add basic info about actual task goal
63
+ instructions = f"Task Goal: {goal_text}\n"
64
+
65
+ obs = Observation(
66
+ goal=instructions,
67
+ result="Environment initialized. Schema ready.",
68
+ step=self.step_count,
69
+ last_action_error=False,
70
+ schema_dump=self.get_schema_dump()
71
+ )
72
+ return ResetResult(observation=obs, info={"task_id": self.task_id, "initial_score": self.current_score})
73
+
74
+ def step(self, action: Action) -> StepResult:
75
+ if not self.conn:
76
+ raise ValueError("Environment not initialized. Call reset() first.")
77
+
78
+ self.step_count += 1
79
+ last_action_error = False
80
+ query_result = ""
81
+
82
+ try:
83
+ c = self.conn.cursor()
84
+ query = action.action_str.strip()
85
+ # Basic mitigation of forbidden queries just in case (though we're in mock)
86
+ if query.upper().startswith("DROP TABLE sqlite_"):
87
+ raise Exception("Cannot modify system tables.")
88
+
89
+ c.execute(query)
90
+
91
+ if query.upper().startswith("SELECT") or query.upper().startswith("PRAGMA"):
92
+ rows = c.fetchmany(10) # limit output size for LLM observation
93
+ col_names = [description[0] for description in c.description] if c.description else []
94
+ # Format tabular output
95
+ result_str = " | ".join(col_names) + "\n"
96
+ result_str += "-" * len(result_str) + "\n"
97
+ for r in rows:
98
+ result_str += " | ".join([str(val) for val in r]) + "\n"
99
+ if len(rows) == 10:
100
+ result_str += "... (output truncated)"
101
+ query_result = result_str
102
+ else:
103
+ self.conn.commit()
104
+ query_result = f"Command executed successfully. Rowcount: {c.rowcount}"
105
+
106
+ except Exception as e:
107
+ last_action_error = True
108
+ query_result = f"SQL Error: {str(e)}"
109
+
110
+ # Run grader
111
+ task = TASKS[self.task_id]
112
+ new_score = task.grade(self.conn)
113
+
114
+ # Reward is dense: change in score + small penalty for errors
115
+ reward_value = new_score - self.current_score
116
+
117
+ if last_action_error:
118
+ # minor penalty for syntax errors
119
+ reward_value -= 0.05
120
+
121
+ self.current_score = new_score
122
+
123
+ # Episode terminates when score is 1.0
124
+ done = (self.current_score >= 1.0) or (self.step_count > 30)
125
+
126
+ obs = Observation(
127
+ goal=task.get_goal(),
128
+ result=query_result,
129
+ step=self.step_count,
130
+ last_action_error=last_action_error,
131
+ schema_dump=self.get_schema_dump() if not last_action_error else None # only dump if no error to save tokens
132
+ )
133
+
134
+ return StepResult(
135
+ observation=obs,
136
+ reward=reward_value,
137
+ done=done,
138
+ info={"current_score": self.current_score}
139
+ )
140
+
141
+ def state(self) -> Any:
142
+ # Return state as unstructured dict per standard API
143
+ return {
144
+ "task_id": self.task_id,
145
+ "step": self.step_count,
146
+ "current_score": self.current_score,
147
+ "schema_dump": self.get_schema_dump()
148
+ }
149
+
150
+ # Global instance
151
+ env_instance = SQLEnvironment()
152
+
153
+ class ResetRequest(BaseModel):
154
+ task_id: int = 1
155
+
156
+ @app.post("/reset", response_model=ResetResult)
157
+ def reset(req: ResetRequest):
158
+ try:
159
+ return env_instance.reset(task_id=req.task_id)
160
+ except Exception as e:
161
+ raise HTTPException(status_code=400, detail=str(e))
162
+
163
+ @app.post("/step", response_model=StepResult)
164
+ def step(action: Action):
165
+ try:
166
+ return env_instance.step(action)
167
+ except Exception as e:
168
+ raise HTTPException(status_code=400, detail=str(e))
169
+
170
+ @app.get("/state")
171
+ def state():
172
+ return env_instance.state()
173
+
174
+ def main():
175
+ uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False)
176
+
177
+ if __name__ == "__main__":
178
+ main()
server/models.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Optional, Any, Dict, List
3
+
4
+ class Action(BaseModel):
5
+ action_str: str
6
+
7
+ # Alias for compatibility with the sample inference script
8
+ BrowserGymAction = Action
9
+
10
+ class Observation(BaseModel):
11
+ goal: str
12
+ result: str
13
+ step: int
14
+ last_action_error: bool
15
+ schema_dump: Optional[str] = None
16
+
17
+ class Reward(BaseModel):
18
+ value: float
19
+ reason: Optional[str] = None
20
+
21
+ class StepResult(BaseModel):
22
+ observation: Observation
23
+ reward: float
24
+ done: bool
25
+ info: Dict[str, Any]
26
+
27
+ class ResetResult(BaseModel):
28
+ observation: Observation
29
+ info: Dict[str, Any]
server/tasks.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import re
3
+
4
+ class Task:
5
+ def __init__(self, task_id: int):
6
+ self.task_id = task_id
7
+
8
+ def setup_db(self, conn: sqlite3.Connection):
9
+ raise NotImplementedError
10
+
11
+ def get_goal(self) -> str:
12
+ raise NotImplementedError
13
+
14
+ def grade(self, conn: sqlite3.Connection) -> float:
15
+ raise NotImplementedError
16
+
17
+ class EasyTask(Task):
18
+ def __init__(self):
19
+ super().__init__(1)
20
+
21
+ def setup_db(self, conn: sqlite3.Connection):
22
+ c = conn.cursor()
23
+ c.execute("CREATE TABLE customers (id INTEGER PRIMARY KEY, name TEXT, total_spent REAL)")
24
+ c.executemany("INSERT INTO customers (name, total_spent) VALUES (?, ?)", [
25
+ ("Alice", 500.0),
26
+ ("Bob", 1200.0),
27
+ ("Charlie", 50.0),
28
+ ("Diana", 3000.0),
29
+ ("Eve", 1000.01) # over 1000
30
+ ])
31
+ conn.commit()
32
+
33
+ def get_goal(self) -> str:
34
+ return "Create a view named 'high_value_customers' containing all customers who have a 'total_spent' greater than 1000.0. The view should contain the exact same columns as the customers table."
35
+
36
+ def grade(self, conn: sqlite3.Connection) -> float:
37
+ c = conn.cursor()
38
+ try:
39
+ # Check if view exists
40
+ c.execute("SELECT name FROM sqlite_master WHERE type='view' AND name='high_value_customers'")
41
+ if not c.fetchone():
42
+ return 0.0
43
+
44
+ # Check rows
45
+ c.execute("SELECT name, total_spent FROM high_value_customers ORDER BY name")
46
+ rows = c.fetchall()
47
+
48
+ if len(rows) != 3:
49
+ return 0.5 # partially correct, exists but wrong rows
50
+
51
+ expected = [("Bob", 1200.0), ("Diana", 3000.0), ("Eve", 1000.01)]
52
+ if rows == expected:
53
+ return 1.0
54
+ return 0.5
55
+ except Exception:
56
+ return 0.0
57
+
58
+
59
+ class MediumTask(Task):
60
+ def __init__(self):
61
+ super().__init__(2)
62
+
63
+ def setup_db(self, conn: sqlite3.Connection):
64
+ c = conn.cursor()
65
+ c.execute("CREATE TABLE products (id INTEGER PRIMARY KEY, name TEXT, category TEXT, price TEXT)")
66
+ c.executemany("INSERT INTO products (name, category, price) VALUES (?, ?, ?)", [
67
+ ("Laptop", "Electronics", "$999.99"),
68
+ ("Mouse", "electronics", "25.50 USD"),
69
+ ("Desk", "FURNITURE", "150.0"),
70
+ ("Chair", "furniture", "$85.00"),
71
+ ("Headphones", "ELEC", "€45.00") # We'll just ask them to remove letters/symbols
72
+ ])
73
+ conn.commit()
74
+
75
+ def get_goal(self) -> str:
76
+ return (
77
+ "The 'products' table is messy. "
78
+ "1. Standardize the 'category' column to be fully UPPERCASE. (Hint: treat 'ELEC' as 'ELECTRONICS'). "
79
+ "2. Create a new column 'price_usd' of type REAL. Extract the numeric value from the 'price' string and populate 'price_usd'. "
80
+ "Do not drop any original columns."
81
+ )
82
+
83
+ def grade(self, conn: sqlite3.Connection) -> float:
84
+ score = 0.0
85
+ c = conn.cursor()
86
+ try:
87
+ # Check column exists
88
+ c.execute("PRAGMA table_info(products)")
89
+ columns = [row[1] for row in c.fetchall()]
90
+ if 'price_usd' in columns:
91
+ score += 0.3
92
+
93
+ # Check data accuracy for price
94
+ c.execute("SELECT price_usd FROM products ORDER BY id")
95
+ prices = [row[0] for row in c.fetchall()]
96
+ expected_prices = [999.99, 25.50, 150.0, 85.0, 45.0]
97
+
98
+ # allow small float diffs
99
+ correct_prices = sum(1 for p, e in zip(prices, expected_prices) if p is not None and abs(p - e) < 0.01)
100
+ score += (correct_prices / 5.0) * 0.4 # up to 0.4 for correct prices
101
+
102
+ # Check category uppercase
103
+ c.execute("SELECT category FROM products ORDER BY id")
104
+ categories = [row[0] for row in c.fetchall()]
105
+ expected_cats = ["ELECTRONICS", "ELECTRONICS", "FURNITURE", "FURNITURE", "ELECTRONICS"]
106
+
107
+ correct_cats = sum(1 for c, e in zip(categories, expected_cats) if c == e)
108
+ score += (correct_cats / 5.0) * 0.3 # up to 0.3 for correct categories
109
+
110
+ return min(1.0, score)
111
+ except Exception:
112
+ return score
113
+
114
+
115
+ class HardTask(Task):
116
+ def __init__(self):
117
+ super().__init__(3)
118
+
119
+ def setup_db(self, conn: sqlite3.Connection):
120
+ c = conn.cursor()
121
+ c.execute("""
122
+ CREATE TABLE hospital_records (
123
+ patient_name TEXT,
124
+ patient_dob TEXT,
125
+ doctor_name TEXT,
126
+ doctor_specialty TEXT,
127
+ appointment_date TEXT,
128
+ diagnosis TEXT
129
+ )
130
+ """)
131
+ records = [
132
+ ("John Doe", "1980-01-01", "Dr. Smith", "Cardiology", "2023-10-01", "Hypertension"),
133
+ ("Jane Roe", "1992-05-15", "Dr. Jones", "Neurology", "2023-10-02", "Migraine"),
134
+ ("John Doe", "1980-01-01", "Dr. Smith", "Cardiology", "2023-11-01", "Follow-up"),
135
+ ("Bob Guy", "1975-11-20", "Dr. Smith", "Cardiology", "2023-10-05", "Checkup")
136
+ ]
137
+ c.executemany("INSERT INTO hospital_records VALUES (?, ?, ?, ?, ?, ?)", records)
138
+ conn.commit()
139
+
140
+ def get_goal(self) -> str:
141
+ return (
142
+ "Normalize the flat 'hospital_records' table into 3 tables: "
143
+ "'patients' (id INTEGER PRIMARY KEY, name TEXT, dob TEXT), "
144
+ "'doctors' (id INTEGER PRIMARY KEY, name TEXT, specialty TEXT), and "
145
+ "'appointments' (id INTEGER PRIMARY KEY, patient_id INTEGER, doctor_id INTEGER, date TEXT, diagnosis TEXT). "
146
+ "Migrate all data from 'hospital_records' correctly without duplication. "
147
+ "Ensure foreign keys are correctly pointing to the new IDs."
148
+ )
149
+
150
+ def grade(self, conn: sqlite3.Connection) -> float:
151
+ score = 0.0
152
+ c = conn.cursor()
153
+ try:
154
+ # Check tables exist
155
+ c.execute("SELECT name FROM sqlite_master WHERE type='table'")
156
+ tables = [row[0] for row in c.fetchall()]
157
+
158
+ if 'patients' in tables: score += 0.1
159
+ if 'doctors' in tables: score += 0.1
160
+ if 'appointments' in tables: score += 0.2
161
+
162
+ if score < 0.4:
163
+ return score
164
+
165
+ # Check data counts (3 unique patients, 2 unique doctors, 4 appointments)
166
+ c.execute("SELECT COUNT(*) FROM patients")
167
+ if c.fetchone()[0] == 3: score += 0.1
168
+
169
+ c.execute("SELECT COUNT(*) FROM doctors")
170
+ if c.fetchone()[0] == 2: score += 0.1
171
+
172
+ c.execute("SELECT COUNT(*) FROM appointments")
173
+ if c.fetchone()[0] == 4: score += 0.1
174
+
175
+ # Check referential integrity (can we reconstruct the original view?)
176
+ query = """
177
+ SELECT p.name, p.dob, d.name, d.specialty, a.date, a.diagnosis
178
+ FROM appointments a
179
+ JOIN patients p ON a.patient_id = p.id
180
+ JOIN doctors d ON a.doctor_id = d.id
181
+ ORDER BY p.name, a.date
182
+ """
183
+ c.execute(query)
184
+ reconstructed = c.fetchall()
185
+ if len(reconstructed) == 4:
186
+ score += 0.3
187
+
188
+ return min(1.0, score)
189
+ except Exception:
190
+ return score
191
+
192
+ TASKS = {
193
+ 1: EasyTask(),
194
+ 2: MediumTask(),
195
+ 3: HardTask()
196
+ }
tmp_env.db ADDED
Binary file (16.4 kB). View file
 
uv.lock ADDED
The diff for this file is too large to render. See raw diff