Eishaan commited on
Commit
6a32325
·
1 Parent(s): 71fa486

self push after fixing few errors

Browse files
README.md CHANGED
@@ -1,165 +1,164 @@
1
- ---
2
- title: SQL Migration Agent
3
- emoji: "\U0001F5C4\uFE0F"
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: docker
7
- pinned: false
8
- tags:
9
- - openenv
10
- ---
11
 
12
- # SQL Schema Migration Agent
13
 
14
- > **An OpenEnv environment for benchmarking autonomous database migration agents.**
15
- >
16
- > Built for the Meta x Hugging Face OpenEnv Hackathon.
17
 
18
- ---
 
 
 
 
19
 
20
- ## Why This Matters (Real-World Utility)
21
 
22
- Database schema migrations are among the most error-prone, high-stakes tasks in software engineering. Every production system faces them as application models evolve, yet they are extremely difficult to automate safely because data must be perfectly preserved.
23
-
24
- This environment trains AI agents to autonomously reconcile schema drift the exact way a real CI/CD pipeline would -- given a flawed current state and an ideal target state, the agent must compute and safely execute the transformation sequence using raw SQL.
25
-
26
- **Real-world analogues:** `Flyway`, `Liquibase`, Django `makemigrations`, `Terraform` state transitions. This environment models that exact problem, reduced to an agentic RL core.
27
-
28
- ---
29
-
30
- ## Evaluation Philosophy & Anti-Exploit Mechanics
31
-
32
- Unlike simplistic environments that merely string-match SQL schemas, this environment uses a **deep structural reconciliation grader** built specifically to prevent LLM gamification:
33
-
34
- 1. **Zero-Sum Exploit Protection:** Naive agents will often execute `DROP TABLE x; CREATE TABLE x (...)` to easily match the target schema, silently destroying all data. Our grader actively runs `SELECT COUNT(*)`, `SUM(id)`, and data-integrity fingerprinting. If a table's schema matches but the data is gone, the score is brutally clamped to `0.01`.
35
- 2. **PRAGMA Bypass Prevention:** The grader re-asserts `PRAGMA foreign_keys = ON` before every scoring pass, preventing agents from disabling FK constraints to cheat.
36
- 3. **Granular Partial Credit:** Multi-step migrations (like Task 7's 6-to-4 table consolidation) require 18+ steps. Binary pass/fail rewards provide zero learning signal. Our grader assigns fractional weights to individual FK constraints, data type coercions, and orphaned record audit logs, providing continuous RL reward gradients.
37
- 4. **Deterministic Adversarial Seeds:** Our injected data includes edge cases that break naive SQL: `O'Brien` (apostrophes), `$1,234.56` (comma+dollar coercion), orphaned foreign keys, NULL emails, and leading whitespace in emails.
38
-
39
- ---
 
 
 
 
 
 
 
 
 
40
 
41
  ## Tasks (2 Easy / 3 Medium / 2 Hard)
42
 
43
- | # | Name | Difficulty | Steps | Description |
44
  |---|------|-----------|-------|-------------|
45
- | 1 | `column-restructure` | Easy | 10 | Merge `first_name` + `last_name` into `full_name` without data loss. Adversarial: apostrophes (`O'Brien`), mid-caps (`McDonald`) |
46
- | 2 | `soft-delete-restoration` | Easy | 10 | Restore deleted products from `deletion_log`, add `is_deleted`/`deleted_at` columns. Adversarial: `stock=0` must not be confused with `is_deleted=1` |
47
- | 3 | `table-normalization` | Medium | 15 | Decompose flat `purchases` into `customers` + `orders` with FK. Adversarial: duplicate emails (x3), commas in item names |
48
- | 4 | `schema-version-merge` | Medium | 15 | Merge overlapping `products_v1` (TEXT prices) and `products_v2` (REAL prices) with conflict resolution and `source` tracking. Adversarial: `$XX.XX` coercion, NULL category, high ID=101 |
49
- | 5 | `multi-entity-extraction` | Medium | 15 | Decompose `sales_records` god-table into 3NF (5 tables) with 3 FKs and invalid data routing. Adversarial: leading whitespace email, empty email, comma in SKU |
50
- | 6 | `cascade-migration` | Hard | 20 | 4-table FK cascade: type coercion (`$90000` TEXT to `90000` INTEGER), orphan audit logging, NULL salary removal, full FK chain enforcement |
51
- | 7 | `dual-source-consolidation` | Hard | 20 | Merge 6 tables from two incompatible systems (Legacy CRM + Modern SaaS) into 4 unified tables with cross-system email dedup, currency coercion, orphan detection |
52
-
53
- ---
54
-
55
- ## Observation Space
56
-
57
- | Field | Type | Description |
58
- |-------|------|-------------|
59
- | `current_schema_sql` | `str` | Current database DDL extracted from `sqlite_master` |
60
- | `target_schema_sql` | `str` | Target DDL the agent must reach |
61
- | `last_execution_result` | `str` | Result of last SQL execution, or error message |
62
- | `step_number` | `int` | Current step count |
63
- | `migration_progress` | `float` | Current grader score [0.01-0.99] |
64
- | `task_name` | `str` | Name of the active task |
65
- | `done` | `bool` | Whether the episode has terminated |
66
- | `reward` | `float` | Step reward: score delta from previous step (can be negative) |
67
-
68
- ## Action Space
69
-
70
- | Field | Type | Description |
71
- |-------|------|-------------|
72
- | `sql_command` | `str` | Raw SQL statement to execute against the database |
73
- | `reasoning` | `str` | Chain-of-thought explanation (logged for review) |
74
- | `submit_final` | `bool` | Set `true` when migration is believed complete |
75
-
76
- ---
77
-
78
- ## Reward Function
79
-
80
- - **Step reward**: Delta between current and previous migration score. Strongly negative for destructive actions (e.g., wrong DROP TABLE leads to -0.4).
81
- - **Episode score**: Clamped to (0.01, 0.99). Final state wins -- regressions hurt.
82
- - **Exploit protection**: If schema matches target but tables are empty (agent deleted data), score is capped at 0.01.
83
- - **PRAGMA protection**: `PRAGMA foreign_keys = ON` is re-asserted before every grading pass.
84
- - **Auto-termination**: Episode ends immediately when score reaches 0.99, preventing post-success regression.
85
-
86
- ---
87
-
88
- ## Setup & Usage
89
-
90
  ```bash
91
- # Install dependencies
92
  pip install -r requirements.txt
 
93
 
94
- # Run baseline inference (requires HF_TOKEN)
95
- export HF_TOKEN=your_token_here
96
- export API_BASE_URL=https://router.huggingface.co/v1
 
97
  export MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
98
- python inference.py
 
 
 
 
 
 
99
 
100
- # Run validation tests
101
- python test_smoke.py
102
- python test_all_tasks.py
 
103
 
104
- # Start environment server locally
 
105
  uvicorn server.app:app --host 0.0.0.0 --port 7860
106
  ```
107
 
108
- ---
109
-
110
  ## API Endpoints
111
 
112
  | Endpoint | Method | Description |
113
  |----------|--------|-------------|
114
- | `/health` | GET | Health check |
115
- | `/reset` | POST | Reset environment, returns initial observation |
116
- | `/step` | POST | Execute action, returns observation + reward |
117
  | `/state` | GET | Current environment state |
118
- | `/tasks` | GET | List all 7 tasks with descriptions |
119
- | `/grader` | POST | Run grader on all tasks, return scores |
120
- | `/schema` | GET | OpenEnv schema (action/observation types) |
121
- | `/ws` | WS | WebSocket for real-time interaction |
 
 
 
 
 
 
 
 
 
122
 
123
- ---
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  ## Deployment
126
 
 
127
  ```bash
128
- # Docker (local test)
129
  docker build -t sql-migration-env .
130
- docker run -p 7860:7860 \
131
- -e HF_TOKEN=your_token \
132
- -e API_BASE_URL=https://router.huggingface.co/v1 \
133
- -e MODEL_NAME=Qwen/Qwen2.5-72B-Instruct \
134
- sql-migration-env
135
  ```
136
 
137
- **Hugging Face Spaces:** Push this repo to HF Spaces with your `HF_TOKEN`, `API_BASE_URL`, and `MODEL_NAME` set as Space secrets. The Dockerfile builds automatically.
138
-
139
- ---
140
-
141
- ## Baseline Scores
142
-
143
- | Task | Score | Steps | Model |
144
- |------|-------|-------|-------|
145
- | `column-restructure` | 0.99 | 4 | Qwen/Qwen2.5-72B-Instruct |
146
- | `soft-delete-restoration` | 0.99 | 5-7 | Qwen/Qwen2.5-72B-Instruct |
147
- | `table-normalization` | 0.99 | 5-8 | Qwen/Qwen2.5-72B-Instruct |
148
- | `schema-version-merge` | 0.60-0.85 | 8-12 | Qwen/Qwen2.5-72B-Instruct |
149
- | `multi-entity-extraction` | 0.40-0.70 | 12-15 | Qwen/Qwen2.5-72B-Instruct |
150
- | `cascade-migration` | 0.30-0.65 | 15-20 | Qwen/Qwen2.5-72B-Instruct |
151
- | `dual-source-consolidation` | 0.20-0.50 | 18-20 | Qwen/Qwen2.5-72B-Instruct |
152
-
153
- ---
154
-
155
- ## Pre-Submission Checklist
156
 
157
- - [x] `docker build` succeeds
158
- - [x] `curl /health` returns 200
159
- - [x] `curl /tasks` returns 7 tasks
160
- - [x] `curl -X POST /reset` returns valid observation
161
- - [x] `openenv validate` passes
162
- - [x] Baseline script completes all 7 tasks without crashing
163
- - [x] Grader scores in (0.01, 0.99) range
164
- - [x] Exploit protection: empty-table shortcuts penalized
165
- - [x] PRAGMA bypass protection enforced
 
1
+ # SQL Schema Migration Agent — OpenEnv Benchmark
 
 
 
 
 
 
 
 
 
2
 
3
+ An OpenEnv-compatible environment for evaluating AI agents on autonomous SQLite database migration tasks. The agent receives a broken/drifted schema and must write SQL to transform it to a target state without losing data.
4
 
5
+ ## Why This Benchmark?
 
 
6
 
7
+ Database schema migration is a **real-world task** that humans perform daily. Unlike toy benchmarks, it tests:
8
+ - **Reasoning under constraints** (SQLite's limited ALTER TABLE support)
9
+ - **Data preservation** (agents must never silently drop rows)
10
+ - **Multi-step planning** (complex migrations require 5-15 coordinated SQL commands)
11
+ - **Edge case handling** (apostrophes, NULL values, empty strings, type coercion)
12
 
13
+ ## Architecture
14
 
15
+ ```
16
+ ┌─────────────────────────────────┐
17
+ │ inference.py (Baseline Agent) │
18
+ │ - LLM API calls (OpenAI fmt) │
19
+ - JSON mode + fallback parser
20
+ │ - Task-specific prompts │
21
+ └─────────┬───────────────────────┘
22
+ │ MigrationAction
23
+ ┌─────────▼───────────────────────┐
24
+ │ environment.py (OpenEnv Env) │
25
+ - SQLite execution engine │
26
+ │ - SELECT result passthrough │
27
+ - SQL timeout (progress hdlr)
28
+ - Dangerous SQL blacklist │
29
+ - Transaction awareness │
30
+ │ - Trajectory logging │
31
+ └─────────┬───────────────────────┘
32
+ │ score()
33
+ ┌─────────▼───────────────────────┐
34
+ │ grader.py (Golden DB Engine) │
35
+ │ - Dynamic golden reference DB │
36
+ │ - Schema + data + FK scoring │
37
+ │ - Case-insensitive comparison │
38
+ │ - PRAGMA state preservation │
39
+ │ - Anti-exploit checks │
40
+ └─────────────────────────────────┘
41
+ ```
42
 
43
  ## Tasks (2 Easy / 3 Medium / 2 Hard)
44
 
45
+ | # | Task | Difficulty | Steps | Description |
46
  |---|------|-----------|-------|-------------|
47
+ | 1 | `column-restructure` | Easy | 10 | Merge first_name + last_name full_name |
48
+ | 2 | `soft-delete-restoration` | Easy | 10 | Restore deleted products from deletion_log |
49
+ | 3 | `table-normalization` | Medium | 15 | Normalize purchases customers + orders + FK |
50
+ | 4 | `schema-version-merge` | Medium | 15 | Merge v1/v2 product tables with price coercion |
51
+ | 5 | `multi-entity-extraction` | Medium | 15 | 3NF decomposition with invalid data routing |
52
+ | 6 | `cascade-migration` | Hard | 20 | 4-table FK cascade, type coercion, orphan audit |
53
+ | 7 | `dual-source-consolidation` | Hard | 20 | 64 table merge, cross-system email dedup |
54
+
55
+ ### Adversarial Edge Cases
56
+ - **O'Brien** (apostrophe in data — tests SQL escaping)
57
+ - **$90,000 salary** (TEXT→INTEGER coercion — tests string processing)
58
+ - **Empty string emails** (not NULL — tests data validation logic)
59
+ - **Leading whitespace** (` alice@company.com` tests TRIM awareness)
60
+ - **ID conflicts** (same ID in two source tables — tests merge logic)
61
+ - **Orphaned FKs** (references to deleted entities tests audit logging)
62
+ - **NULL currency** (must default to 'USD' tests COALESCE)
63
+
64
+ ## Dynamic Golden Database Grading
65
+
66
+ Unlike benchmarks with hardcoded expected values, our grader is **seed-independent**:
67
+
68
+ 1. At scoring time, a fresh DB is seeded and the correct migration is applied
69
+ 2. The agent's DB is compared table-by-table against this golden reference
70
+ 3. If seed data changes, the golden DB auto-updates
71
+
72
+ **Scoring breakdown (per task):**
73
+ - **Schema match (30%)**: Tables exist with correct columns
74
+ - **Data match (40%)**: Row content matches golden DB (order-independent)
75
+ - **FK & integrity (20%)**: Foreign keys enforced, PRAGMA integrity_check passes
76
+ - **Anti-exploit (10%)**: No empty tables, no schema pollution
77
+
78
+ ## Security & Robustness
79
+
80
+ - **SQL Timeout**: Progress-handler-based execution timeout prevents infinite CTEs
81
+ - **Dangerous SQL Blacklist**: ATTACH DATABASE, DETACH, LOAD_EXTENSION blocked
82
+ - **Transaction Awareness**: Respects BEGIN/COMMIT/ROLLBACK from agents
83
+ - **Case-Insensitive Grading**: Table/column names compared case-insensitively
84
+ - **PRAGMA Preservation**: Grader doesn't corrupt agent's FK state
85
+ - **Trajectory Logging**: Full SQL history attached to final observation
86
+
87
+ ## Setup
88
+
89
+ ### Requirements
 
 
90
  ```bash
 
91
  pip install -r requirements.txt
92
+ ```
93
 
94
+ ### Environment Variables
95
+ ```bash
96
+ export HF_TOKEN=your_huggingface_token
97
+ export API_BASE_URL=https://router.huggingface.co/v1 # or Groq, etc.
98
  export MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
99
+ ```
100
+
101
+ ### Run Tests
102
+ ```bash
103
+ python test_smoke.py # Quick validation
104
+ python test_all_tasks.py # All 7 tasks: golden migration + lifecycle
105
+ ```
106
 
107
+ ### Run Baseline Inference
108
+ ```bash
109
+ python inference.py # Runs all 7 tasks sequentially
110
+ ```
111
 
112
+ ### Start Server (HF Spaces)
113
+ ```bash
114
  uvicorn server.app:app --host 0.0.0.0 --port 7860
115
  ```
116
 
 
 
117
  ## API Endpoints
118
 
119
  | Endpoint | Method | Description |
120
  |----------|--------|-------------|
121
+ | `/reset` | POST | Start new migration episode |
122
+ | `/step` | POST | Execute a SQL action |
 
123
  | `/state` | GET | Current environment state |
124
+ | `/tasks` | GET | List all 7 tasks with metadata |
125
+ | `/grader` | POST | Run grader on specific/all tasks |
126
+ | `/health` | GET | Health check |
127
+ | `/docs` | GET | Interactive API documentation |
128
+
129
+ ## Action Schema
130
+ ```json
131
+ {
132
+ "sql_command": "ALTER TABLE users ADD COLUMN full_name TEXT",
133
+ "reasoning": "Add the target column before migrating data",
134
+ "submit_final": false
135
+ }
136
+ ```
137
 
138
+ ## Observation Schema
139
+ ```json
140
+ {
141
+ "current_schema_sql": "CREATE TABLE users (...);",
142
+ "target_schema_sql": "CREATE TABLE users (...);",
143
+ "last_execution_result": "Success: 5 rows affected",
144
+ "step_number": 3,
145
+ "migration_progress": 0.75,
146
+ "task_name": "column-restructure",
147
+ "done": false,
148
+ "reward": 0.15
149
+ }
150
+ ```
151
 
152
  ## Deployment
153
 
154
+ ### Docker
155
  ```bash
 
156
  docker build -t sql-migration-env .
157
+ docker run -p 7860:7860 -e HF_TOKEN=your_token sql-migration-env
 
 
 
 
158
  ```
159
 
160
+ ### Hugging Face Spaces
161
+ Push to a Space with the included Dockerfile. Set `HF_TOKEN`, `API_BASE_URL`, and `MODEL_NAME` as Space secrets.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ ## License
164
+ MIT
 
 
 
 
 
 
 
__pycache__/inference.cpython-312.pyc ADDED
Binary file (14.6 kB). View file
 
__pycache__/seeds.cpython-312.pyc CHANGED
Binary files a/__pycache__/seeds.cpython-312.pyc and b/__pycache__/seeds.cpython-312.pyc differ
 
inference.py CHANGED
@@ -2,9 +2,18 @@
2
  """
3
  Baseline Inference Script for SQL Migration Environment.
4
 
5
- Runs all 3 migration tasks sequentially using an LLM via OpenAI-compatible API.
6
  Outputs structured [START]/[STEP]/[END] format for automated evaluation.
7
 
 
 
 
 
 
 
 
 
 
8
  Usage:
9
  python inference.py
10
 
@@ -16,6 +25,7 @@ Environment Variables:
16
 
17
  import json
18
  import os
 
19
  import sys
20
  import time
21
  import traceback
@@ -23,29 +33,31 @@ import traceback
23
  # Server URL for the environment
24
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
25
 
26
- # LLM Configuration — defaults required for API_BASE_URL and MODEL_NAME only
27
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
28
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
29
- HF_TOKEN = os.getenv("HF_TOKEN") # No default — must be set by user
30
- # Also support OPENAI_API_KEY as primary (per spec) and API_KEY as alias
31
  API_KEY = os.getenv("OPENAI_API_KEY") or HF_TOKEN or os.getenv("API_KEY")
32
 
 
33
  SYSTEM_PROMPT_TEMPLATE = """You are an autonomous SQLite database migration engine. You receive the current schema and a target schema. Write SQL to transform the current state to the target state without losing row data.
34
 
35
- CRITICAL — SQLite-specific rules (violations cause immediate errors):
36
- 1. SQLite does NOT support ALTER TABLE ADD CONSTRAINT — never use it.
37
- 2. SQLite does NOT support ALTER TABLE ALTER COLUMN — never use it.
38
- 3. SQLite does NOT support ALTER TABLE ADD PRIMARY KEY — never use it.
39
- 4. SQLite does NOT support ADD COLUMN with non-constant DEFAULT add column as NULL then UPDATE.
40
- 5. To change column types, add NOT NULL, or add FKs: CREATE new table with correct schema, INSERT INTO new SELECT from old, DROP old, RENAME new to original name.
41
- 6. Apostrophes in data (e.g., O'Brien, O'Neill) are present — always use parameterized patterns or escape with ''.
42
- 7. For table normalization: create new tables first, INSERT INTO ... SELECT, then drop old tables.
43
- 8. For ORPHANED FK rows: before inserting into a FK-constrained table, DELETE or INSERT INTO audit_log any rows whose FK reference does not exist in the parent table. Example: DELETE FROM assets WHERE employee_id NOT IN (SELECT id FROM employees).
44
- 9. For TEXT salary columns like '$90000': use CAST(REPLACE(REPLACE(salary, '$', ''), ',', '') AS INTEGER) to convert.
45
- 10. Execute exactly ONE SQL statement per step.
46
- 11. When migration is complete (schemas match, data preserved), set submit_final to true IMMEDIATELY.
47
-
48
- TARGET SCHEMA (fixed achieve this exactly):
 
 
49
  {target_ddl}
50
 
51
  Respond ONLY with valid JSON — no markdown, no code blocks, no text outside the object:
@@ -60,15 +72,12 @@ ALL_TASKS = [
60
  "cascade-migration",
61
  "dual-source-consolidation",
62
  ]
63
- MAX_STEPS = 20 # Global fallback; per-task limits override this
64
- MAX_PARSE_ERRORS = 5 # Higher tolerance for thinking models (Qwen3, DeepSeek-R1)
65
-
66
- # Auto-submit threshold: if migration_progress >= this, force submit_final
67
  AUTO_SUBMIT_THRESHOLD = 0.95
68
 
69
 
70
  def call_llm(messages: list, timeout: int = 90) -> str:
71
- """Call the LLM API and return the response content."""
72
  from openai import OpenAI
73
 
74
  client = OpenAI(
@@ -77,11 +86,25 @@ def call_llm(messages: list, timeout: int = 90) -> str:
77
  timeout=timeout,
78
  )
79
 
 
80
  try:
81
  response = client.chat.completions.create(
82
  model=MODEL_NAME,
83
  messages=messages,
84
- temperature=0.0, # Deterministic output — eliminates variance
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  max_tokens=1024,
86
  )
87
  return response.choices[0].message.content.strip()
@@ -93,18 +116,13 @@ def parse_action(raw_text: str) -> dict:
93
  """
94
  Parse LLM output into an action dict.
95
 
96
- Handles: raw JSON, markdown-wrapped JSON (```json ... ```),
97
- <think>...</think> reasoning tokens (Qwen3, DeepSeek-R1),
98
- and common LLM mistakes like trailing commas or extra text.
99
  """
100
- import re
101
  text = raw_text.strip()
102
 
103
- # Strip <think>...</think> blocks emitted by reasoning models (Qwen3, R1)
104
- # Must do this BEFORE any other processing
105
  text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
106
-
107
- # Also strip partial/unclosed think blocks (truncated output)
108
  text = re.sub(r"<think>.*$", "", text, flags=re.DOTALL).strip()
109
 
110
  # Strip markdown code block fences
@@ -119,7 +137,7 @@ def parse_action(raw_text: str) -> dict:
119
  except json.JSONDecodeError:
120
  pass
121
 
122
- # Try to find JSON object in the text (handles preamble text or extra trailing content)
123
  start = text.find("{")
124
  end = text.rfind("}") + 1
125
  if start >= 0 and end > start:
@@ -128,11 +146,14 @@ def parse_action(raw_text: str) -> dict:
128
  except json.JSONDecodeError:
129
  pass
130
 
131
- # Last resort: try to extract just sql_command if JSON is truncated
132
- sql_match = re.search(r'"sql_command"\s*:\s*"([^"]+)"', text)
133
  if sql_match:
 
 
 
134
  return {
135
- "sql_command": sql_match.group(1),
136
  "reasoning": "auto-extracted from malformed response",
137
  "submit_final": False,
138
  }
@@ -143,49 +164,47 @@ def parse_action(raw_text: str) -> dict:
143
  def run_task_local(task_name: str) -> dict:
144
  """
145
  Run a single task using a local environment instance (no server needed).
146
-
147
- This is the primary mode — avoids HTTP overhead and works inside Docker.
148
  """
149
- # Import environment directly
150
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
151
  from server.environment import DbMigrationEnvironment
152
  from models import MigrationAction
153
  import seeds
154
 
155
  env = DbMigrationEnvironment(task_name=task_name)
156
-
157
- # Use task-specific step budget (defaults to global MAX_STEPS)
158
- task_max_steps = seeds.TASKS.get(task_name, {}).get("max_steps", MAX_STEPS)
159
 
160
  print(f"[START] task={task_name} env=sql-migration-agent model={MODEL_NAME}", flush=True)
161
 
162
  obs = env.reset()
163
 
164
- # Build task-specific system prompt with target DDL baked in (sent ONCE)
165
- task_system_prompt = SYSTEM_PROMPT_TEMPLATE.format(target_ddl=obs.target_schema_sql)
 
 
 
166
  history = [{"role": "system", "content": task_system_prompt}]
167
 
168
- # Initial observation — only current schema (target already in system prompt)
169
  initial_msg = (
170
  f"CURRENT DATABASE SCHEMA:\n{obs.current_schema_sql}\n\n"
171
  f"Status: {obs.last_execution_result}\n"
172
  f"Migration progress: {obs.migration_progress:.2f}\n\n"
173
- f"Write your first SQL command to begin the migration."
174
  )
175
  history.append({"role": "user", "content": initial_msg})
176
 
177
  rewards_list = []
178
- parse_errors = 0
179
  final_score = 0.0
180
  steps_taken = 0
181
  done = False
182
- peak_score = 0.0 # Track the highest score we've reached
183
 
184
  for step in range(task_max_steps):
185
  if done:
186
  break
187
 
188
- # Context truncation: system prompt + last 10 messages (5 pairs)
189
  messages = [history[0]] + history[-10:]
190
 
191
  try:
@@ -199,22 +218,21 @@ def run_task_local(task_name: str) -> dict:
199
  # Parse the action
200
  try:
201
  action_dict = parse_action(raw_response)
202
- except ValueError as e:
203
- parse_errors += 1
 
204
  print(f"[STEP] step={step+1} action=PARSE_ERROR reward=0.00 done=false error=parse_error", flush=True)
205
- if parse_errors >= MAX_PARSE_ERRORS:
206
- print(f"[STEP] step={step+1} action=MAX_PARSE_ERRORS reward=0.00 done=true error=too_many_parse_errors", flush=True)
207
  done = True
208
  break
209
  history.append({"role": "assistant", "content": raw_response})
210
  history.append({
211
  "role": "user",
212
- "content": "ERROR: Your response was not valid JSON. Respond ONLY with: {\"sql_command\": \"...\", \"reasoning\": \"...\", \"submit_final\": false}",
213
  })
214
  continue
215
 
216
- parse_errors = 0
217
-
218
  # Build the MigrationAction
219
  try:
220
  action = MigrationAction(
@@ -234,15 +252,9 @@ def run_task_local(task_name: str) -> dict:
234
  final_score = obs.migration_progress
235
  done = obs.done
236
 
237
- # Track peak score
238
- if final_score > peak_score:
239
- peak_score = final_score
240
-
241
- # AUTO-SUBMIT: If we just reached a near-perfect score, force submit
242
- # This prevents the LLM from continuing to send queries and regressing
243
  if final_score >= AUTO_SUBMIT_THRESHOLD and not done:
244
  done = True
245
- # Submit a final no-op to lock in the score
246
  submit_action = MigrationAction(
247
  sql_command="SELECT 1",
248
  reasoning="Migration complete — auto-submitting",
@@ -251,15 +263,13 @@ def run_task_local(task_name: str) -> dict:
251
  obs = env.step(submit_action)
252
  final_score = obs.migration_progress
253
 
254
- # Abbreviate SQL for logging
255
  sql_abbrev = action.sql_command[:50].replace("\n", " ")
256
  if len(action.sql_command) > 50:
257
  sql_abbrev += "..."
258
-
259
  error_str = obs.metadata.get("error", "null") if obs.metadata else "null"
260
  if error_str != "null":
261
  error_str = error_str[:80]
262
-
263
  print(
264
  f"[STEP] step={steps_taken} action={sql_abbrev} "
265
  f"reward={step_reward:.2f} done={'true' if done else 'false'} "
@@ -270,19 +280,17 @@ def run_task_local(task_name: str) -> dict:
270
  # Add to conversation history
271
  history.append({"role": "assistant", "content": json.dumps(action_dict)})
272
 
273
- # Lean feedback — target is already in the system prompt, no need to repeat
274
  feedback_msg = (
275
- f"EXECUTION RESULT: {obs.last_execution_result}\n\n"
276
- f"CURRENT SCHEMA:\n{obs.current_schema_sql}\n\n"
277
  f"Progress: {obs.migration_progress:.2f}"
278
  )
279
  if done:
280
  feedback_msg += "\n\nEpisode complete."
281
  elif obs.migration_progress >= 0.9:
282
  feedback_msg += (
283
- "\n\nMigration is nearly complete! Compare the current schema "
284
- "carefully to the target schema. If they match and data is "
285
- "preserved, set submit_final to true in your next response."
286
  )
287
  else:
288
  feedback_msg += "\n\nContinue the migration. Write your next SQL command."
@@ -309,7 +317,7 @@ def run_task_local(task_name: str) -> dict:
309
 
310
 
311
  def main():
312
- """Run all 3 tasks sequentially."""
313
  if not API_KEY:
314
  print("WARNING: No API key found. Set HF_TOKEN or API_KEY.", file=sys.stderr)
315
  sys.exit(1)
 
2
  """
3
  Baseline Inference Script for SQL Migration Environment.
4
 
5
+ Runs all 7 migration tasks sequentially using an LLM via OpenAI-compatible API.
6
  Outputs structured [START]/[STEP]/[END] format for automated evaluation.
7
 
8
+ Fixes Applied:
9
+ - D1: Task description injected into system prompt
10
+ - D2: Hardcoded system prompt traps removed (no more audit_log/INTEGER traps)
11
+ - D3: Data discovery rule added (agent runs SELECT before DDL)
12
+ - D4: Submit guard added (agent must verify before submitting)
13
+ - D5: Context window bloat fixed (schema not repeated every step)
14
+ - D6: Parse error counter tracks consecutive errors only
15
+ - D7: response_format JSON mode with fallback
16
+
17
  Usage:
18
  python inference.py
19
 
 
25
 
26
  import json
27
  import os
28
+ import re
29
  import sys
30
  import time
31
  import traceback
 
33
  # Server URL for the environment
34
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
35
 
36
+ # LLM Configuration
37
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
38
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
39
+ HF_TOKEN = os.getenv("HF_TOKEN")
 
40
  API_KEY = os.getenv("OPENAI_API_KEY") or HF_TOKEN or os.getenv("API_KEY")
41
 
42
+ # --- D2: Cleaned system prompt — no hardcoded table names or type traps ---
43
  SYSTEM_PROMPT_TEMPLATE = """You are an autonomous SQLite database migration engine. You receive the current schema and a target schema. Write SQL to transform the current state to the target state without losing row data.
44
 
45
+ TASK OBJECTIVE:
46
+ {task_description}
47
+
48
+ CRITICAL SQLite-specific rules (violations cause immediate errors):
49
+ 1. SQLite does NOT support ALTER TABLE ADD CONSTRAINT, ALTER COLUMN, or ADD PRIMARY KEY.
50
+ 2. To change column types, add NOT NULL, or add FKs: CREATE new table, INSERT INTO new SELECT FROM old, DROP old, RENAME new.
51
+ 3. Apostrophes in data (O'Brien, O'Neill) are present — escape with '' in string literals.
52
+ 4. Execute exactly ONE SQL statement per step.
53
+ 5. For table normalization: create new tables first, INSERT INTO ... SELECT, then drop old tables.
54
+ 6. For orphaned FK rows: check the TARGET SCHEMA for the correct anomaly/issues table name (it varies per task). Log invalid records there before dropping.
55
+ 7. For text currency columns like '$90,000' or '$1,234.56': strip '$' and ',' then cast to the type in the target schema (INTEGER for whole numbers, REAL for decimals).
56
+ 8. IMPORTANT: Before writing any DDL, execute SELECT * FROM tablename LIMIT 5 for each source table to inspect the actual data format and identify edge cases like empty strings, leading whitespace, NULL values, and special characters.
57
+ 9. Do NOT set submit_final to true until you have run SELECT COUNT(*) on your target tables and verified the counts and data match what the task requires.
58
+ 10. When migration is complete and verified, set submit_final to true.
59
+
60
+ TARGET SCHEMA (achieve this exactly):
61
  {target_ddl}
62
 
63
  Respond ONLY with valid JSON — no markdown, no code blocks, no text outside the object:
 
72
  "cascade-migration",
73
  "dual-source-consolidation",
74
  ]
75
+ MAX_PARSE_ERRORS = 5 # Consecutive parse errors before giving up
 
 
 
76
  AUTO_SUBMIT_THRESHOLD = 0.95
77
 
78
 
79
  def call_llm(messages: list, timeout: int = 90) -> str:
80
+ """Call the LLM API with JSON mode fallback."""
81
  from openai import OpenAI
82
 
83
  client = OpenAI(
 
86
  timeout=timeout,
87
  )
88
 
89
+ # --- D7: Try JSON mode first, fallback to plain ---
90
  try:
91
  response = client.chat.completions.create(
92
  model=MODEL_NAME,
93
  messages=messages,
94
+ temperature=0.0,
95
+ max_tokens=1024,
96
+ response_format={"type": "json_object"},
97
+ )
98
+ return response.choices[0].message.content.strip()
99
+ except Exception:
100
+ pass
101
+
102
+ # Fallback: plain text mode
103
+ try:
104
+ response = client.chat.completions.create(
105
+ model=MODEL_NAME,
106
+ messages=messages,
107
+ temperature=0.0,
108
  max_tokens=1024,
109
  )
110
  return response.choices[0].message.content.strip()
 
116
  """
117
  Parse LLM output into an action dict.
118
 
119
+ Handles: raw JSON, markdown-wrapped JSON, <think>...</think> blocks,
120
+ escaped quotes in SQL, and truncated output recovery.
 
121
  """
 
122
  text = raw_text.strip()
123
 
124
+ # Strip <think>...</think> blocks (Qwen3, DeepSeek-R1)
 
125
  text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
 
 
126
  text = re.sub(r"<think>.*$", "", text, flags=re.DOTALL).strip()
127
 
128
  # Strip markdown code block fences
 
137
  except json.JSONDecodeError:
138
  pass
139
 
140
+ # Try to find JSON object in the text
141
  start = text.find("{")
142
  end = text.rfind("}") + 1
143
  if start >= 0 and end > start:
 
146
  except json.JSONDecodeError:
147
  pass
148
 
149
+ # --- D6: Improved regex that handles escaped quotes ---
150
+ sql_match = re.search(r'"sql_command"\s*:\s*"((?:[^"\\]|\\.)*)"', text)
151
  if sql_match:
152
+ sql = sql_match.group(1)
153
+ # Unescape JSON string escapes
154
+ sql = sql.replace('\\"', '"').replace("\\n", "\n").replace("\\\\", "\\")
155
  return {
156
+ "sql_command": sql,
157
  "reasoning": "auto-extracted from malformed response",
158
  "submit_final": False,
159
  }
 
164
  def run_task_local(task_name: str) -> dict:
165
  """
166
  Run a single task using a local environment instance (no server needed).
 
 
167
  """
 
168
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
169
  from server.environment import DbMigrationEnvironment
170
  from models import MigrationAction
171
  import seeds
172
 
173
  env = DbMigrationEnvironment(task_name=task_name)
174
+ task_config = seeds.TASKS[task_name]
175
+ task_max_steps = task_config.get("max_steps", 20)
 
176
 
177
  print(f"[START] task={task_name} env=sql-migration-agent model={MODEL_NAME}", flush=True)
178
 
179
  obs = env.reset()
180
 
181
+ # --- D1: Inject task description into system prompt ---
182
+ task_system_prompt = SYSTEM_PROMPT_TEMPLATE.format(
183
+ task_description=task_config["description"],
184
+ target_ddl=obs.target_schema_sql,
185
+ )
186
  history = [{"role": "system", "content": task_system_prompt}]
187
 
188
+ # Initial observation
189
  initial_msg = (
190
  f"CURRENT DATABASE SCHEMA:\n{obs.current_schema_sql}\n\n"
191
  f"Status: {obs.last_execution_result}\n"
192
  f"Migration progress: {obs.migration_progress:.2f}\n\n"
193
+ f"Start by inspecting the source data with SELECT queries, then begin the migration."
194
  )
195
  history.append({"role": "user", "content": initial_msg})
196
 
197
  rewards_list = []
198
+ consecutive_parse_errors = 0 # D6: Track consecutive only
199
  final_score = 0.0
200
  steps_taken = 0
201
  done = False
 
202
 
203
  for step in range(task_max_steps):
204
  if done:
205
  break
206
 
207
+ # --- D5: Context window fix — only keep last 10 messages + system ---
208
  messages = [history[0]] + history[-10:]
209
 
210
  try:
 
218
  # Parse the action
219
  try:
220
  action_dict = parse_action(raw_response)
221
+ consecutive_parse_errors = 0 # D6: Reset on success
222
+ except ValueError:
223
+ consecutive_parse_errors += 1
224
  print(f"[STEP] step={step+1} action=PARSE_ERROR reward=0.00 done=false error=parse_error", flush=True)
225
+ if consecutive_parse_errors >= MAX_PARSE_ERRORS:
226
+ print(f"[STEP] step={step+1} action=MAX_PARSE_ERRORS reward=0.00 done=true error=too_many_consecutive_parse_errors", flush=True)
227
  done = True
228
  break
229
  history.append({"role": "assistant", "content": raw_response})
230
  history.append({
231
  "role": "user",
232
+ "content": 'ERROR: Your response was not valid JSON. Respond ONLY with: {"sql_command": "...", "reasoning": "...", "submit_final": false}',
233
  })
234
  continue
235
 
 
 
236
  # Build the MigrationAction
237
  try:
238
  action = MigrationAction(
 
252
  final_score = obs.migration_progress
253
  done = obs.done
254
 
255
+ # AUTO-SUBMIT: If we reached near-perfect score, force submit
 
 
 
 
 
256
  if final_score >= AUTO_SUBMIT_THRESHOLD and not done:
257
  done = True
 
258
  submit_action = MigrationAction(
259
  sql_command="SELECT 1",
260
  reasoning="Migration complete — auto-submitting",
 
263
  obs = env.step(submit_action)
264
  final_score = obs.migration_progress
265
 
266
+ # Log
267
  sql_abbrev = action.sql_command[:50].replace("\n", " ")
268
  if len(action.sql_command) > 50:
269
  sql_abbrev += "..."
 
270
  error_str = obs.metadata.get("error", "null") if obs.metadata else "null"
271
  if error_str != "null":
272
  error_str = error_str[:80]
 
273
  print(
274
  f"[STEP] step={steps_taken} action={sql_abbrev} "
275
  f"reward={step_reward:.2f} done={'true' if done else 'false'} "
 
280
  # Add to conversation history
281
  history.append({"role": "assistant", "content": json.dumps(action_dict)})
282
 
283
+ # --- D5: Lean feedback — NO schema repetition ---
284
  feedback_msg = (
285
+ f"EXECUTION RESULT: {obs.last_execution_result}\n"
 
286
  f"Progress: {obs.migration_progress:.2f}"
287
  )
288
  if done:
289
  feedback_msg += "\n\nEpisode complete."
290
  elif obs.migration_progress >= 0.9:
291
  feedback_msg += (
292
+ "\n\nMigration is nearly complete! Run SELECT COUNT(*) on each table "
293
+ "and compare to your expectations. If everything matches, set submit_final to true."
 
294
  )
295
  else:
296
  feedback_msg += "\n\nContinue the migration. Write your next SQL command."
 
317
 
318
 
319
  def main():
320
+ """Run all 7 tasks sequentially."""
321
  if not API_KEY:
322
  print("WARNING: No API key found. Set HF_TOKEN or API_KEY.", file=sys.stderr)
323
  sys.exit(1)
seeds.py CHANGED
@@ -678,6 +678,372 @@ def seed_task7(conn: sqlite3.Connection) -> None:
678
  conn.commit()
679
 
680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681
  # =============================================================================
682
  # Task Registry
683
  # =============================================================================
@@ -685,50 +1051,92 @@ def seed_task7(conn: sqlite3.Connection) -> None:
685
  TASKS = {
686
  "column-restructure": {
687
  "seed_fn": seed_task1,
 
688
  "target_ddl": TASK1_TARGET_DDL,
689
- "description": "Merge first_name and last_name into a single full_name column without data loss",
690
  "difficulty": "easy",
691
  "max_steps": 10,
692
  },
693
  "soft-delete-restoration": {
694
  "seed_fn": seed_task4,
 
695
  "target_ddl": TASK4_TARGET_DDL,
696
- "description": "Restore deleted products from deletion_log, add is_deleted/deleted_at columns",
 
 
 
 
 
 
697
  "difficulty": "easy",
698
  "max_steps": 10,
699
  },
700
  "table-normalization": {
701
  "seed_fn": seed_task2,
 
702
  "target_ddl": TASK2_TARGET_DDL,
703
- "description": "Decompose a flat purchases table into normalized customers and orders tables with FK",
 
 
 
 
704
  "difficulty": "medium",
705
  "max_steps": 15,
706
  },
707
  "schema-version-merge": {
708
  "seed_fn": seed_task5,
 
709
  "target_ddl": TASK5_TARGET_DDL,
710
- "description": "Merge overlapping v1/v2 product tables with price coercion and conflict resolution",
 
 
 
 
 
711
  "difficulty": "medium",
712
  "max_steps": 15,
713
  },
714
  "multi-entity-extraction": {
715
  "seed_fn": seed_task6,
 
716
  "target_ddl": TASK6_TARGET_DDL,
717
- "description": "Decompose a sales god-table into 3NF with 3 FKs and invalid data routing",
 
 
 
 
 
718
  "difficulty": "medium",
719
  "max_steps": 15,
720
  },
721
  "cascade-migration": {
722
  "seed_fn": seed_task3,
 
723
  "target_ddl": TASK3_TARGET_DDL,
724
- "description": "Multi-table FK cascade with type coercion, NULL handling, and orphan audit logging",
 
 
 
 
 
 
725
  "difficulty": "hard",
726
  "max_steps": 20,
727
  },
728
  "dual-source-consolidation": {
729
  "seed_fn": seed_task7,
 
730
  "target_ddl": TASK7_TARGET_DDL,
731
- "description": "Merge 6 tables from two incompatible systems into 4 unified tables with cross-system dedup",
 
 
 
 
 
 
 
 
 
732
  "difficulty": "hard",
733
  "max_steps": 20,
734
  },
 
678
  conn.commit()
679
 
680
 
681
+ # =============================================================================
682
+ # Golden Migration Functions
683
+ # =============================================================================
684
+ # These produce the CORRECT expected database state from any seed data.
685
+ # Used by the dynamic grader to compare against the agent's output.
686
+ # If seed data changes, the golden DB auto-updates — no hardcoded literals.
687
+
688
+
689
+ def golden_task1(conn: sqlite3.Connection) -> None:
690
+ """Golden migration for Task 1: Column Restructure."""
691
+ conn.execute("CREATE TABLE users_new (id INTEGER PRIMARY KEY, full_name TEXT NOT NULL)")
692
+ conn.execute(
693
+ "INSERT INTO users_new (id, full_name) "
694
+ "SELECT id, first_name || ' ' || last_name FROM users"
695
+ )
696
+ conn.execute("DROP TABLE users")
697
+ conn.execute("ALTER TABLE users_new RENAME TO users")
698
+ conn.commit()
699
+
700
+
701
+ def golden_task2(conn: sqlite3.Connection) -> None:
702
+ """Golden migration for Task 2: Table Normalization."""
703
+ conn.execute("PRAGMA foreign_keys = OFF")
704
+ conn.execute(
705
+ "CREATE TABLE customers ("
706
+ "id INTEGER PRIMARY KEY, name TEXT NOT NULL, email TEXT NOT NULL UNIQUE)"
707
+ )
708
+ conn.execute(
709
+ "INSERT INTO customers (name, email) "
710
+ "SELECT DISTINCT customer_name, customer_email FROM purchases"
711
+ )
712
+ conn.execute(
713
+ "CREATE TABLE orders ("
714
+ "id INTEGER PRIMARY KEY, customer_id INTEGER NOT NULL, "
715
+ "item_name TEXT NOT NULL, price INTEGER NOT NULL, "
716
+ "FOREIGN KEY (customer_id) REFERENCES customers(id))"
717
+ )
718
+ conn.execute(
719
+ "INSERT INTO orders (customer_id, item_name, price) "
720
+ "SELECT c.id, p.item_name, p.price "
721
+ "FROM purchases p JOIN customers c ON p.customer_email = c.email"
722
+ )
723
+ conn.execute("DROP TABLE purchases")
724
+ conn.execute("PRAGMA foreign_keys = ON")
725
+ conn.commit()
726
+
727
+
728
+ def golden_task3(conn: sqlite3.Connection) -> None:
729
+ """Golden migration for Task 3: Cascade Migration."""
730
+ conn.execute("PRAGMA foreign_keys = OFF")
731
+ # Create audit_log
732
+ conn.execute(
733
+ "CREATE TABLE audit_log (id INTEGER PRIMARY KEY, source_table TEXT NOT NULL, "
734
+ "original_row_json TEXT NOT NULL, reason TEXT NOT NULL)"
735
+ )
736
+ # Log orphaned assets
737
+ conn.execute(
738
+ "INSERT INTO audit_log (source_table, original_row_json, reason) "
739
+ "SELECT 'assets', '{\"id\":' || id || ',\"employee_id\":' || employee_id || '}', 'orphaned_record' "
740
+ "FROM assets WHERE employee_id NOT IN (SELECT id FROM employees)"
741
+ )
742
+ # Log NULL salary employees
743
+ conn.execute(
744
+ "INSERT INTO audit_log (source_table, original_row_json, reason) "
745
+ "SELECT 'employees', '{\"id\":' || id || ',\"name\":\"' || name || '\"}', 'null_salary' "
746
+ "FROM employees WHERE salary IS NULL"
747
+ )
748
+ # Rebuild companies
749
+ conn.execute("CREATE TABLE companies_new (id INTEGER PRIMARY KEY, name TEXT NOT NULL)")
750
+ conn.execute("INSERT INTO companies_new SELECT id, name FROM companies")
751
+ conn.execute("DROP TABLE companies")
752
+ conn.execute("ALTER TABLE companies_new RENAME TO companies")
753
+ # Rebuild departments
754
+ conn.execute(
755
+ "CREATE TABLE departments_new (id INTEGER PRIMARY KEY, company_id INTEGER NOT NULL, "
756
+ "name TEXT NOT NULL, FOREIGN KEY (company_id) REFERENCES companies(id))"
757
+ )
758
+ conn.execute("INSERT INTO departments_new SELECT id, company_id, name FROM departments")
759
+ conn.execute("DROP TABLE departments")
760
+ conn.execute("ALTER TABLE departments_new RENAME TO departments")
761
+ # Rebuild employees (remove NULL salary, coerce TEXT to INT)
762
+ conn.execute(
763
+ "CREATE TABLE employees_new (id INTEGER PRIMARY KEY, department_id INTEGER NOT NULL, "
764
+ "name TEXT NOT NULL, salary INTEGER NOT NULL, "
765
+ "FOREIGN KEY (department_id) REFERENCES departments(id))"
766
+ )
767
+ conn.execute(
768
+ "INSERT INTO employees_new (id, department_id, name, salary) "
769
+ "SELECT id, department_id, name, "
770
+ "CAST(REPLACE(REPLACE(salary, '$', ''), ',', '') AS INTEGER) "
771
+ "FROM employees WHERE salary IS NOT NULL"
772
+ )
773
+ conn.execute("DROP TABLE employees")
774
+ conn.execute("ALTER TABLE employees_new RENAME TO employees")
775
+ # Rebuild assets (remove orphans)
776
+ conn.execute(
777
+ "CREATE TABLE assets_new (id INTEGER PRIMARY KEY, employee_id INTEGER NOT NULL, "
778
+ "description TEXT NOT NULL, FOREIGN KEY (employee_id) REFERENCES employees(id))"
779
+ )
780
+ conn.execute(
781
+ "INSERT INTO assets_new SELECT id, employee_id, description FROM assets "
782
+ "WHERE employee_id IN (SELECT id FROM employees)"
783
+ )
784
+ conn.execute("DROP TABLE assets")
785
+ conn.execute("ALTER TABLE assets_new RENAME TO assets")
786
+ conn.execute("PRAGMA foreign_keys = ON")
787
+ conn.commit()
788
+
789
+
790
+ def golden_task4(conn: sqlite3.Connection) -> None:
791
+ """Golden migration for Task 4: Soft-Delete Restoration."""
792
+ conn.execute("PRAGMA foreign_keys = OFF")
793
+ # Create new table with extra columns
794
+ conn.execute(
795
+ "CREATE TABLE products_new (id INTEGER PRIMARY KEY, name TEXT NOT NULL, "
796
+ "price REAL NOT NULL, stock INTEGER NOT NULL, "
797
+ "is_deleted INTEGER NOT NULL DEFAULT 0, deleted_at TEXT)"
798
+ )
799
+ # Copy existing products as active
800
+ conn.execute(
801
+ "INSERT INTO products_new (id, name, price, stock, is_deleted, deleted_at) "
802
+ "SELECT id, name, price, stock, 0, NULL FROM products"
803
+ )
804
+ # Restore deleted products from log
805
+ conn.execute(
806
+ "INSERT INTO products_new (id, name, price, stock, is_deleted, deleted_at) "
807
+ "SELECT product_id, product_name, product_price, product_stock, 1, deleted_at "
808
+ "FROM deletion_log"
809
+ )
810
+ conn.execute("DROP TABLE products")
811
+ conn.execute("ALTER TABLE products_new RENAME TO products")
812
+ conn.execute("DROP TABLE deletion_log")
813
+ conn.execute("PRAGMA foreign_keys = ON")
814
+ conn.commit()
815
+
816
+
817
+ def golden_task5(conn: sqlite3.Connection) -> None:
818
+ """Golden migration for Task 5: Schema Version Merge."""
819
+ conn.execute("PRAGMA foreign_keys = OFF")
820
+ conn.execute(
821
+ "CREATE TABLE products (id INTEGER PRIMARY KEY, name TEXT NOT NULL, "
822
+ "price REAL NOT NULL, category TEXT, supplier TEXT, brand TEXT, "
823
+ "sku TEXT, source TEXT NOT NULL)"
824
+ )
825
+ # Insert v1-only rows
826
+ conn.execute(
827
+ "INSERT INTO products (id, name, price, category, supplier, brand, sku, source) "
828
+ "SELECT id, name, CAST(REPLACE(REPLACE(price, '$', ''), ',', '') AS REAL), "
829
+ "category, supplier, NULL, NULL, 'v1' "
830
+ "FROM products_v1 WHERE id NOT IN (SELECT id FROM products_v2)"
831
+ )
832
+ # Insert v2-only rows
833
+ conn.execute(
834
+ "INSERT INTO products (id, name, price, category, supplier, brand, sku, source) "
835
+ "SELECT id, name, unit_cost, category, NULL, brand, sku, 'v2' "
836
+ "FROM products_v2 WHERE id NOT IN (SELECT id FROM products_v1)"
837
+ )
838
+ # Insert conflict rows (v2 wins for name/price)
839
+ conn.execute(
840
+ "INSERT INTO products (id, name, price, category, supplier, brand, sku, source) "
841
+ "SELECT v2.id, v2.name, v2.unit_cost, v2.category, v1.supplier, v2.brand, v2.sku, 'both' "
842
+ "FROM products_v2 v2 JOIN products_v1 v1 ON v2.id = v1.id"
843
+ )
844
+ conn.execute("DROP TABLE products_v1")
845
+ conn.execute("DROP TABLE products_v2")
846
+ conn.execute("PRAGMA foreign_keys = ON")
847
+ conn.commit()
848
+
849
+
850
+ def golden_task6(conn: sqlite3.Connection) -> None:
851
+ """Golden migration for Task 6: Multi-Entity Extraction."""
852
+ conn.execute("PRAGMA foreign_keys = OFF")
853
+ # Create target tables
854
+ conn.execute(
855
+ "CREATE TABLE salespersons (id INTEGER PRIMARY KEY, name TEXT NOT NULL, "
856
+ "email TEXT NOT NULL UNIQUE, region TEXT NOT NULL)"
857
+ )
858
+ conn.execute(
859
+ "CREATE TABLE customers (id INTEGER PRIMARY KEY, name TEXT NOT NULL, "
860
+ "email TEXT NOT NULL UNIQUE, tier TEXT NOT NULL)"
861
+ )
862
+ conn.execute(
863
+ "CREATE TABLE products (id INTEGER PRIMARY KEY, name TEXT NOT NULL, "
864
+ "sku TEXT NOT NULL UNIQUE, category TEXT NOT NULL)"
865
+ )
866
+ conn.execute(
867
+ "CREATE TABLE sales (id INTEGER PRIMARY KEY, salesperson_id INTEGER NOT NULL, "
868
+ "customer_id INTEGER NOT NULL, product_id INTEGER NOT NULL, "
869
+ "quantity INTEGER NOT NULL, unit_price REAL NOT NULL, "
870
+ "discount_pct INTEGER NOT NULL DEFAULT 0, sale_date TEXT NOT NULL, "
871
+ "FOREIGN KEY (salesperson_id) REFERENCES salespersons(id), "
872
+ "FOREIGN KEY (customer_id) REFERENCES customers(id), "
873
+ "FOREIGN KEY (product_id) REFERENCES products(id))"
874
+ )
875
+ conn.execute(
876
+ "CREATE TABLE data_issues (id INTEGER PRIMARY KEY, source_table TEXT NOT NULL, "
877
+ "source_row_id INTEGER NOT NULL, issue_type TEXT NOT NULL, "
878
+ "issue_detail TEXT NOT NULL)"
879
+ )
880
+ # Populate salespersons (TRIM email)
881
+ conn.execute(
882
+ "INSERT INTO salespersons (name, email, region) "
883
+ "SELECT DISTINCT rep_name, TRIM(rep_email), rep_region FROM sales_records"
884
+ )
885
+ # Populate customers (exclude empty email rows)
886
+ conn.execute(
887
+ "INSERT INTO customers (name, email, tier) "
888
+ "SELECT DISTINCT customer_name, customer_email, customer_tier "
889
+ "FROM sales_records WHERE customer_email IS NOT NULL AND customer_email != ''"
890
+ )
891
+ # Populate products
892
+ conn.execute(
893
+ "INSERT INTO products (name, sku, category) "
894
+ "SELECT DISTINCT product_name, product_sku, product_category FROM sales_records"
895
+ )
896
+ # Populate sales (exclude rows with empty customer email)
897
+ conn.execute(
898
+ "INSERT INTO sales (salesperson_id, customer_id, product_id, quantity, "
899
+ "unit_price, discount_pct, sale_date) "
900
+ "SELECT sp.id, c.id, p.id, sr.quantity, sr.unit_price, sr.discount_pct, sr.sale_date "
901
+ "FROM sales_records sr "
902
+ "JOIN salespersons sp ON TRIM(sr.rep_email) = sp.email "
903
+ "JOIN customers c ON sr.customer_email = c.email "
904
+ "JOIN products p ON sr.product_sku = p.sku "
905
+ "WHERE sr.customer_email IS NOT NULL AND sr.customer_email != ''"
906
+ )
907
+ # Log data issues (empty email)
908
+ conn.execute(
909
+ "INSERT INTO data_issues (source_table, source_row_id, issue_type, issue_detail) "
910
+ "SELECT 'sales_records', id, 'empty_email', "
911
+ "'Customer email is empty for: ' || customer_name "
912
+ "FROM sales_records WHERE customer_email IS NULL OR customer_email = ''"
913
+ )
914
+ conn.execute("DROP TABLE sales_records")
915
+ conn.execute("PRAGMA foreign_keys = ON")
916
+ conn.commit()
917
+
918
+
919
+ def golden_task7(conn: sqlite3.Connection) -> None:
920
+ """Golden migration for Task 7: Dual-Source Consolidation."""
921
+ conn.execute("PRAGMA foreign_keys = OFF")
922
+
923
+ # Create unified_customers
924
+ conn.execute(
925
+ "CREATE TABLE unified_customers (id INTEGER PRIMARY KEY AUTOINCREMENT, "
926
+ "legacy_id INTEGER, modern_uuid TEXT, name TEXT, email TEXT, phone TEXT, "
927
+ "tier TEXT NOT NULL DEFAULT 'free', source TEXT NOT NULL, created_at TEXT)"
928
+ )
929
+ # Insert legacy-only customers (no email match in modern)
930
+ conn.execute(
931
+ "INSERT INTO unified_customers (legacy_id, modern_uuid, name, email, phone, tier, source, created_at) "
932
+ "SELECT lc.id, NULL, lc.full_name, lc.contact_email, lc.phone, lc.account_type, 'legacy', lc.join_date "
933
+ "FROM legacy_customers lc "
934
+ "WHERE lc.contact_email IS NULL OR lc.contact_email NOT IN (SELECT email_address FROM modern_users WHERE email_address IS NOT NULL)"
935
+ )
936
+ # Insert modern-only users (no email match in legacy)
937
+ conn.execute(
938
+ "INSERT INTO unified_customers (legacy_id, modern_uuid, name, email, phone, tier, source, created_at) "
939
+ "SELECT NULL, mu.uuid, mu.display_name, mu.email_address, NULL, "
940
+ "CASE mu.subscription_tier "
941
+ " WHEN 1 THEN 'free' WHEN 2 THEN 'basic' WHEN 3 THEN 'premium' WHEN 4 THEN 'enterprise' "
942
+ " ELSE 'free' END, "
943
+ "'modern', mu.created_at "
944
+ "FROM modern_users mu "
945
+ "WHERE mu.email_address NOT IN (SELECT contact_email FROM legacy_customers WHERE contact_email IS NOT NULL)"
946
+ )
947
+ # Insert matched (both) customers — legacy name + modern tier
948
+ conn.execute(
949
+ "INSERT INTO unified_customers (legacy_id, modern_uuid, name, email, phone, tier, source, created_at) "
950
+ "SELECT lc.id, mu.uuid, lc.full_name, lc.contact_email, lc.phone, "
951
+ "CASE mu.subscription_tier "
952
+ " WHEN 1 THEN 'free' WHEN 2 THEN 'basic' WHEN 3 THEN 'premium' WHEN 4 THEN 'enterprise' "
953
+ " ELSE 'free' END, "
954
+ "'both', lc.join_date "
955
+ "FROM legacy_customers lc "
956
+ "JOIN modern_users mu ON lc.contact_email = mu.email_address "
957
+ "WHERE lc.contact_email IS NOT NULL"
958
+ )
959
+
960
+ # Create unified_products
961
+ conn.execute(
962
+ "CREATE TABLE unified_products (id INTEGER PRIMARY KEY AUTOINCREMENT, "
963
+ "code TEXT NOT NULL UNIQUE, title TEXT NOT NULL, price REAL NOT NULL, "
964
+ "source TEXT NOT NULL)"
965
+ )
966
+ # Legacy products
967
+ conn.execute(
968
+ "INSERT INTO unified_products (code, title, price, source) "
969
+ "SELECT code, description, "
970
+ "CAST(REPLACE(REPLACE(unit_price, '$', ''), ',', '') AS REAL), 'legacy' "
971
+ "FROM legacy_products"
972
+ )
973
+ # Modern products (no code overlap expected)
974
+ conn.execute(
975
+ "INSERT INTO unified_products (code, title, price, source) "
976
+ "SELECT sku, title, base_price, 'modern' "
977
+ "FROM modern_catalog"
978
+ )
979
+
980
+ # Create migration_issues
981
+ conn.execute(
982
+ "CREATE TABLE migration_issues (id INTEGER PRIMARY KEY, "
983
+ "source_system TEXT NOT NULL, source_table TEXT NOT NULL, "
984
+ "source_id TEXT NOT NULL, issue_type TEXT NOT NULL, "
985
+ "resolution TEXT NOT NULL)"
986
+ )
987
+ # Log NULL email customer
988
+ conn.execute(
989
+ "INSERT INTO migration_issues (source_system, source_table, source_id, issue_type, resolution) "
990
+ "SELECT 'legacy', 'legacy_customers', CAST(id AS TEXT), 'null_email', "
991
+ "'Imported without email' "
992
+ "FROM legacy_customers WHERE contact_email IS NULL"
993
+ )
994
+ # Log orphaned transactions
995
+ conn.execute(
996
+ "INSERT INTO migration_issues (source_system, source_table, source_id, issue_type, resolution) "
997
+ "SELECT 'modern', 'modern_transactions', CAST(id AS TEXT), 'orphaned_record', "
998
+ "'User UUID not found: ' || user_uuid "
999
+ "FROM modern_transactions WHERE user_uuid NOT IN (SELECT uuid FROM modern_users)"
1000
+ )
1001
+
1002
+ # Create unified_orders
1003
+ conn.execute(
1004
+ "CREATE TABLE unified_orders (id INTEGER PRIMARY KEY AUTOINCREMENT, "
1005
+ "customer_id INTEGER NOT NULL, product_id INTEGER, amount REAL NOT NULL, "
1006
+ "currency TEXT NOT NULL DEFAULT 'USD', status TEXT NOT NULL, "
1007
+ "order_date TEXT, source TEXT NOT NULL, "
1008
+ "FOREIGN KEY (customer_id) REFERENCES unified_customers(id))"
1009
+ )
1010
+ # Legacy orders
1011
+ conn.execute(
1012
+ "INSERT INTO unified_orders (customer_id, product_id, amount, currency, status, order_date, source) "
1013
+ "SELECT uc.id, up.id, "
1014
+ "CAST(REPLACE(REPLACE(lo.total_amount, '$', ''), ',', '') AS REAL), "
1015
+ "'USD', lo.order_status, lo.order_date, 'legacy' "
1016
+ "FROM legacy_orders lo "
1017
+ "JOIN legacy_customers lc ON lo.customer_id = lc.id "
1018
+ "JOIN unified_customers uc ON (uc.legacy_id = lc.id) "
1019
+ "LEFT JOIN unified_products up ON lo.product_code = up.code"
1020
+ )
1021
+ # Modern transactions (exclude orphans)
1022
+ conn.execute(
1023
+ "INSERT INTO unified_orders (customer_id, product_id, amount, currency, status, order_date, source) "
1024
+ "SELECT uc.id, up.id, mt.amount, "
1025
+ "COALESCE(mt.currency, 'USD'), "
1026
+ "CASE mt.tx_status "
1027
+ " WHEN 1 THEN 'pending' WHEN 2 THEN 'processing' WHEN 3 THEN 'complete' "
1028
+ " WHEN 4 THEN 'failed' WHEN 5 THEN 'refunded' ELSE 'unknown' END, "
1029
+ "mt.created_at, 'modern' "
1030
+ "FROM modern_transactions mt "
1031
+ "JOIN modern_users mu ON mt.user_uuid = mu.uuid "
1032
+ "JOIN unified_customers uc ON (uc.modern_uuid = mu.uuid OR uc.email = mu.email_address) "
1033
+ "LEFT JOIN unified_products up ON mt.item_sku = up.code"
1034
+ )
1035
+
1036
+ # Clean up source tables
1037
+ conn.execute("DROP TABLE legacy_customers")
1038
+ conn.execute("DROP TABLE legacy_orders")
1039
+ conn.execute("DROP TABLE legacy_products")
1040
+ conn.execute("DROP TABLE modern_users")
1041
+ conn.execute("DROP TABLE modern_transactions")
1042
+ conn.execute("DROP TABLE modern_catalog")
1043
+ conn.execute("PRAGMA foreign_keys = ON")
1044
+ conn.commit()
1045
+
1046
+
1047
  # =============================================================================
1048
  # Task Registry
1049
  # =============================================================================
 
1051
  TASKS = {
1052
  "column-restructure": {
1053
  "seed_fn": seed_task1,
1054
+ "golden_fn": golden_task1,
1055
  "target_ddl": TASK1_TARGET_DDL,
1056
+ "description": "Merge first_name and last_name into a single full_name column (concatenated with a space) without data loss. Apostrophes in names (e.g., O'Brien) must be preserved.",
1057
  "difficulty": "easy",
1058
  "max_steps": 10,
1059
  },
1060
  "soft-delete-restoration": {
1061
  "seed_fn": seed_task4,
1062
+ "golden_fn": golden_task4,
1063
  "target_ddl": TASK4_TARGET_DDL,
1064
+ "description": (
1065
+ "Restore deleted products from the deletion_log table back into the products table. "
1066
+ "Use product_id from deletion_log (NOT the log's id column) as the product's primary key. "
1067
+ "Add is_deleted and deleted_at columns. Original products: is_deleted=0, deleted_at=NULL. "
1068
+ "Restored products: is_deleted=1, deleted_at copied from log. "
1069
+ "Note: stock=0 on a product does NOT mean it was deleted."
1070
+ ),
1071
  "difficulty": "easy",
1072
  "max_steps": 10,
1073
  },
1074
  "table-normalization": {
1075
  "seed_fn": seed_task2,
1076
+ "golden_fn": golden_task2,
1077
  "target_ddl": TASK2_TARGET_DDL,
1078
+ "description": (
1079
+ "Decompose the flat purchases table into normalized customers and orders tables with a FK. "
1080
+ "customers should have DISTINCT entries by email. "
1081
+ "All 7 original purchases must be preserved as individual orders linked to the correct customer."
1082
+ ),
1083
  "difficulty": "medium",
1084
  "max_steps": 15,
1085
  },
1086
  "schema-version-merge": {
1087
  "seed_fn": seed_task5,
1088
+ "golden_fn": golden_task5,
1089
  "target_ddl": TASK5_TARGET_DDL,
1090
+ "description": (
1091
+ "Merge products_v1 and products_v2 into a single products table. "
1092
+ "v1 prices are stored as TEXT ('$XX.XX') — coerce to REAL. v2 uses 'unit_cost' — rename to 'price'. "
1093
+ "For ID conflicts (same ID in both tables), v2 values WIN for name/price. "
1094
+ "Set source='v1' for v1-only, 'v2' for v2-only, 'both' for conflicts."
1095
+ ),
1096
  "difficulty": "medium",
1097
  "max_steps": 15,
1098
  },
1099
  "multi-entity-extraction": {
1100
  "seed_fn": seed_task6,
1101
+ "golden_fn": golden_task6,
1102
  "target_ddl": TASK6_TARGET_DDL,
1103
+ "description": (
1104
+ "Decompose the sales_records god-table into 3NF: salespersons, customers, products, sales, data_issues. "
1105
+ "Route records with empty string '' customer emails to data_issues (not just NULL). "
1106
+ "TRIM leading/trailing whitespace from all email addresses before inserting. "
1107
+ "Each sale must link to the correct salesperson, customer, and product via FKs."
1108
+ ),
1109
  "difficulty": "medium",
1110
  "max_steps": 15,
1111
  },
1112
  "cascade-migration": {
1113
  "seed_fn": seed_task3,
1114
+ "golden_fn": golden_task3,
1115
  "target_ddl": TASK3_TARGET_DDL,
1116
+ "description": (
1117
+ "Multi-table FK cascade with type coercion, NULL handling, and orphan audit logging. "
1118
+ "Convert salary from TEXT ('$90000') to INTEGER (90000) by stripping '$' and ','. "
1119
+ "Remove employees with NULL salary and log them to audit_log with reason='null_salary'. "
1120
+ "Remove orphaned assets (employee_id not in employees) and log them with reason='orphaned_record'. "
1121
+ "Enforce NOT NULL and FK constraints on all tables."
1122
+ ),
1123
  "difficulty": "hard",
1124
  "max_steps": 20,
1125
  },
1126
  "dual-source-consolidation": {
1127
  "seed_fn": seed_task7,
1128
+ "golden_fn": golden_task7,
1129
  "target_ddl": TASK7_TARGET_DDL,
1130
+ "description": (
1131
+ "Merge 6 tables from Legacy CRM + Modern SaaS into 4 unified tables. "
1132
+ "Cross-system customer dedup: match by email address. Set source='both' for matches, "
1133
+ "'legacy' or 'modern' for unmatched. "
1134
+ "Tier mapping (modern subscription_tier): 1=free, 2=basic, 3=premium, 4=enterprise. "
1135
+ "Status mapping (modern tx_status): 1=pending, 2=processing, 3=complete, 4=failed, 5=refunded. "
1136
+ "Legacy amounts are TEXT ('$1,234.56') — coerce to REAL. NULL currency defaults to 'USD'. "
1137
+ "Log orphaned transactions (user_uuid not found) to migration_issues with issue_type='orphaned_record'. "
1138
+ "Log customers with NULL email to migration_issues with issue_type='null_email'."
1139
+ ),
1140
  "difficulty": "hard",
1141
  "max_steps": 20,
1142
  },
server/__pycache__/environment.cpython-312.pyc CHANGED
Binary files a/server/__pycache__/environment.cpython-312.pyc and b/server/__pycache__/environment.cpython-312.pyc differ
 
server/__pycache__/grader.cpython-312.pyc CHANGED
Binary files a/server/__pycache__/grader.cpython-312.pyc and b/server/__pycache__/grader.cpython-312.pyc differ
 
server/environment.py CHANGED
@@ -4,11 +4,21 @@ SQL Migration Environment Server Implementation.
4
  This is the core environment that wraps SQLite and exposes it via the OpenEnv
5
  Environment interface. Each WebSocket session gets its own environment instance
6
  with an isolated in-memory database.
 
 
 
 
 
 
 
 
7
  """
8
 
 
9
  import sqlite3
 
10
  import uuid
11
- from typing import Any, Optional
12
 
13
  # Support both in-repo and standalone imports
14
  try:
@@ -27,6 +37,26 @@ except ImportError:
27
  import seeds
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class DbMigrationEnvironment(Environment):
31
  """
32
  SQL Schema Migration Environment.
@@ -44,7 +74,7 @@ class DbMigrationEnvironment(Environment):
44
  Initialize the migration environment.
45
 
46
  Args:
47
- task_name: One of "column-restructure", "table-normalization", "cascade-migration"
48
  """
49
  super().__init__()
50
 
@@ -59,14 +89,17 @@ class DbMigrationEnvironment(Environment):
59
  self._conn: Optional[sqlite3.Connection] = None
60
  self._reconciler: Optional[StateReconciler] = None
61
  self._step_count = 0
 
 
 
62
  self._state = MigrationState(
63
  task_name=task_name,
64
  migration_progress=0.0,
65
- max_steps=20,
66
  )
67
 
68
  def _get_current_schema(self) -> str:
69
- """Get current database schema as DDL string."""
70
  if self._conn is None:
71
  return ""
72
  try:
@@ -79,6 +112,75 @@ class DbMigrationEnvironment(Environment):
79
  except Exception:
80
  return ""
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def reset(
83
  self,
84
  seed: Optional[int] = None,
@@ -101,6 +203,7 @@ class DbMigrationEnvironment(Environment):
101
  if task_name != self.task_name and task_name in seeds.TASKS:
102
  self.task_name = task_name
103
  self._task_config = seeds.TASKS[task_name]
 
104
 
105
  # Clean up previous connection
106
  if self._conn is not None:
@@ -122,16 +225,21 @@ class DbMigrationEnvironment(Environment):
122
  # Initialize grader
123
  self._reconciler = StateReconciler(self.task_name)
124
 
125
- # Reset counters
126
  self._step_count = 0
 
 
127
  self._state = MigrationState(
128
  episode_id=episode_id or str(uuid.uuid4()),
129
  step_count=0,
130
  task_name=self.task_name,
131
  migration_progress=0.0,
132
- max_steps=20,
133
  )
134
 
 
 
 
135
  return MigrationObservation(
136
  done=False,
137
  reward=0.0,
@@ -139,7 +247,7 @@ class DbMigrationEnvironment(Environment):
139
  target_schema_sql=self._task_config["target_ddl"],
140
  last_execution_result="Environment initialized. Ready for migration.",
141
  step_number=0,
142
- migration_progress=0.0,
143
  task_name=self.task_name,
144
  metadata={"status": "ready"},
145
  )
@@ -155,7 +263,7 @@ class DbMigrationEnvironment(Environment):
155
 
156
  Args:
157
  action: MigrationAction with sql_command, reasoning, and submit_final
158
- timeout_s: Unused
159
  **kwargs: Additional parameters
160
 
161
  Returns:
@@ -178,42 +286,87 @@ class DbMigrationEnvironment(Environment):
178
  )
179
 
180
  self._step_count += 1
 
181
 
182
- # Execute the SQL command
183
- execution_result = ""
184
- action_error = None
185
- try:
186
- cursor = self._conn.execute(action.sql_command)
187
- self._conn.commit()
188
- rows_affected = cursor.rowcount
189
- execution_result = f"Success: {rows_affected} rows affected"
190
- except sqlite3.Warning as e:
191
- # Multi-statement attempt — agent tried to combine statements
192
  execution_result = (
193
- f"Error: SQLite requires one statement per step. "
194
- f"Split your commands into separate steps. Original error: {e}"
 
195
  )
196
- action_error = "multi_statement"
197
- try:
198
- self._conn.rollback()
199
- except Exception:
200
- pass
201
- except Exception as e:
202
- # Never crash — feed the error back to the agent
203
- execution_result = str(e)
204
- action_error = str(e)
205
- # Rollback failed transaction
206
- try:
207
- self._conn.rollback()
208
- except Exception:
209
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  # Compute scores
212
  current_score, step_reward = self._reconciler.compute_step_reward(self._conn)
213
 
214
  # Episode termination: submit_final, max steps, OR perfect score
215
- task_max = self._task_config.get("max_steps", 20)
216
- done = action.submit_final or self._step_count >= task_max or current_score >= 0.99
 
 
 
 
 
 
 
 
 
 
217
 
218
  # Update state
219
  self._state.step_count = self._step_count
@@ -227,6 +380,9 @@ class DbMigrationEnvironment(Environment):
227
  }
228
  if action_error:
229
  meta["error"] = action_error
 
 
 
230
 
231
  return MigrationObservation(
232
  done=done,
 
4
  This is the core environment that wraps SQLite and exposes it via the OpenEnv
5
  Environment interface. Each WebSocket session gets its own environment instance
6
  with an isolated in-memory database.
7
+
8
+ Architecture Fixes Applied:
9
+ - A1: SELECT queries return actual data rows (not just "rows affected")
10
+ - A2: SQL execution timeout via progress handler (prevents infinite CTEs)
11
+ - A3: Dangerous SQL blacklist (ATTACH, DETACH, LOAD_EXTENSION, writable_schema)
12
+ - A4: Transaction awareness (respects BEGIN/COMMIT/ROLLBACK from agent)
13
+ - A5: Trajectory logging (full SQL history in metadata on episode end)
14
+ - A6: Per-task max_steps from seeds registry
15
  """
16
 
17
+ import re
18
  import sqlite3
19
+ import threading
20
  import uuid
21
+ from typing import Any, Dict, List, Optional
22
 
23
  # Support both in-repo and standalone imports
24
  try:
 
37
  import seeds
38
 
39
 
40
+ # --- A3: Dangerous SQL Blacklist ---
41
+ _DANGEROUS_PATTERNS = re.compile(
42
+ r"\b(ATTACH\s+DATABASE|DETACH\s+DATABASE|LOAD_EXTENSION)\b"
43
+ r"|PRAGMA\s+writable_schema",
44
+ re.IGNORECASE,
45
+ )
46
+
47
+ # --- A4: Transaction control keywords ---
48
+ _TX_BEGIN = re.compile(r"^\s*(BEGIN|BEGIN\s+TRANSACTION|BEGIN\s+DEFERRED|BEGIN\s+IMMEDIATE|BEGIN\s+EXCLUSIVE)\s*;?\s*$", re.IGNORECASE)
49
+ _TX_END = re.compile(r"^\s*(COMMIT|END|END\s+TRANSACTION|ROLLBACK)\s*;?\s*$", re.IGNORECASE)
50
+
51
+ # --- A2: Maximum SQLite operations before timeout ---
52
+ _MAX_OPS = 500_000 # ~5 seconds on typical hardware
53
+
54
+
55
+ class _TimeoutError(Exception):
56
+ """Raised when SQL execution exceeds the operation budget."""
57
+ pass
58
+
59
+
60
  class DbMigrationEnvironment(Environment):
61
  """
62
  SQL Schema Migration Environment.
 
74
  Initialize the migration environment.
75
 
76
  Args:
77
+ task_name: One of the registered task names in seeds.TASKS
78
  """
79
  super().__init__()
80
 
 
89
  self._conn: Optional[sqlite3.Connection] = None
90
  self._reconciler: Optional[StateReconciler] = None
91
  self._step_count = 0
92
+ self._trajectory: List[Dict[str, Any]] = [] # A5
93
+ self._in_explicit_tx = False # A4
94
+ self._max_steps = self._task_config.get("max_steps", 20) # A6
95
  self._state = MigrationState(
96
  task_name=task_name,
97
  migration_progress=0.0,
98
+ max_steps=self._max_steps, # A6
99
  )
100
 
101
  def _get_current_schema(self) -> str:
102
+ """Get current database schema as DDL string, filtering internal tables."""
103
  if self._conn is None:
104
  return ""
105
  try:
 
112
  except Exception:
113
  return ""
114
 
115
+ def _is_read_query(self, sql: str) -> bool:
116
+ """Check if SQL is a read-only query (SELECT or certain PRAGMAs)."""
117
+ stripped = sql.strip().upper()
118
+ if stripped.startswith("SELECT"):
119
+ return True
120
+ # PRAGMA table_info, foreign_key_list, etc. are read-only
121
+ if stripped.startswith("PRAGMA") and "=" not in stripped:
122
+ return True
123
+ return False
124
+
125
+ def _execute_with_timeout(self, sql: str) -> tuple:
126
+ """
127
+ Execute SQL with a progress-handler-based timeout.
128
+
129
+ Returns: (cursor_or_None, error_string_or_None)
130
+ """
131
+ ops_count = [0]
132
+
133
+ def _progress_callback():
134
+ ops_count[0] += 1
135
+ if ops_count[0] > _MAX_OPS:
136
+ return 1 # Non-zero = abort
137
+ return 0
138
+
139
+ self._conn.set_progress_handler(_progress_callback, 1000)
140
+ try:
141
+ cursor = self._conn.execute(sql)
142
+ return cursor, None
143
+ except sqlite3.OperationalError as e:
144
+ if "interrupted" in str(e).lower() or ops_count[0] > _MAX_OPS:
145
+ return None, "Error: Query exceeded execution time limit (possible infinite loop). Simplify your query."
146
+ return None, str(e)
147
+ except sqlite3.Warning as e:
148
+ return None, (
149
+ f"Error: SQLite requires one statement per step. "
150
+ f"Split your commands into separate steps. Original error: {e}"
151
+ )
152
+ except Exception as e:
153
+ return None, str(e)
154
+ finally:
155
+ self._conn.set_progress_handler(None, 0)
156
+
157
+ def _format_query_results(self, cursor) -> str:
158
+ """Format SELECT query results as a readable table string."""
159
+ try:
160
+ rows = cursor.fetchall()
161
+ if not rows:
162
+ return "Query returned 0 rows."
163
+
164
+ # Get column names
165
+ col_names = [desc[0] for desc in cursor.description] if cursor.description else []
166
+
167
+ # Cap at 50 rows
168
+ truncated = len(rows) > 50
169
+ display_rows = rows[:50]
170
+
171
+ # Build output
172
+ header = " | ".join(col_names) if col_names else "Results"
173
+ lines = [header, "-" * len(header)]
174
+ for row in display_rows:
175
+ lines.append(" | ".join(str(v) for v in row))
176
+ if truncated:
177
+ lines.append(f"... ({len(rows) - 50} more rows truncated)")
178
+ lines.append(f"({len(rows)} rows total)")
179
+
180
+ return "\n".join(lines)
181
+ except Exception:
182
+ return "Query executed successfully."
183
+
184
  def reset(
185
  self,
186
  seed: Optional[int] = None,
 
203
  if task_name != self.task_name and task_name in seeds.TASKS:
204
  self.task_name = task_name
205
  self._task_config = seeds.TASKS[task_name]
206
+ self._max_steps = self._task_config.get("max_steps", 20)
207
 
208
  # Clean up previous connection
209
  if self._conn is not None:
 
225
  # Initialize grader
226
  self._reconciler = StateReconciler(self.task_name)
227
 
228
+ # Reset counters and trajectory
229
  self._step_count = 0
230
+ self._trajectory = [] # A5
231
+ self._in_explicit_tx = False # A4
232
  self._state = MigrationState(
233
  episode_id=episode_id or str(uuid.uuid4()),
234
  step_count=0,
235
  task_name=self.task_name,
236
  migration_progress=0.0,
237
+ max_steps=self._max_steps, # A6
238
  )
239
 
240
+ # Compute initial score
241
+ initial_score = self._reconciler.score(self._conn)
242
+
243
  return MigrationObservation(
244
  done=False,
245
  reward=0.0,
 
247
  target_schema_sql=self._task_config["target_ddl"],
248
  last_execution_result="Environment initialized. Ready for migration.",
249
  step_number=0,
250
+ migration_progress=initial_score,
251
  task_name=self.task_name,
252
  metadata={"status": "ready"},
253
  )
 
263
 
264
  Args:
265
  action: MigrationAction with sql_command, reasoning, and submit_final
266
+ timeout_s: Unused (we use progress handler instead)
267
  **kwargs: Additional parameters
268
 
269
  Returns:
 
286
  )
287
 
288
  self._step_count += 1
289
+ sql_command = action.sql_command.strip()
290
 
291
+ # --- A3: Dangerous SQL Blacklist ---
292
+ if _DANGEROUS_PATTERNS.search(sql_command):
 
 
 
 
 
 
 
 
293
  execution_result = (
294
+ "Error: This SQL command is not allowed for security reasons. "
295
+ "ATTACH DATABASE, DETACH DATABASE, LOAD_EXTENSION, and "
296
+ "PRAGMA writable_schema are blocked."
297
  )
298
+ action_error = "blocked_command"
299
+ else:
300
+ # --- A4: Transaction Awareness ---
301
+ execution_result = ""
302
+ action_error = None
303
+
304
+ if _TX_BEGIN.match(sql_command):
305
+ # Agent wants to start a transaction
306
+ try:
307
+ self._conn.execute("BEGIN")
308
+ self._in_explicit_tx = True
309
+ execution_result = "Success: Transaction started."
310
+ except Exception as e:
311
+ execution_result = str(e)
312
+ action_error = str(e)
313
+ elif _TX_END.match(sql_command):
314
+ # Agent wants to commit or rollback
315
+ try:
316
+ if sql_command.strip().upper().startswith("ROLLBACK"):
317
+ self._conn.rollback()
318
+ execution_result = "Success: Transaction rolled back."
319
+ else:
320
+ self._conn.commit()
321
+ execution_result = "Success: Transaction committed."
322
+ self._in_explicit_tx = False
323
+ except Exception as e:
324
+ execution_result = str(e)
325
+ action_error = str(e)
326
+ self._in_explicit_tx = False
327
+ else:
328
+ # --- Normal SQL execution with timeout (A1, A2) ---
329
+ cursor, error = self._execute_with_timeout(sql_command)
330
+
331
+ if error:
332
+ execution_result = error
333
+ action_error = error
334
+ # Rollback failed transaction
335
+ try:
336
+ if not self._in_explicit_tx:
337
+ self._conn.rollback()
338
+ except Exception:
339
+ pass
340
+ else:
341
+ # --- A1: SELECT result passthrough ---
342
+ if self._is_read_query(sql_command):
343
+ execution_result = self._format_query_results(cursor)
344
+ else:
345
+ rows_affected = cursor.rowcount
346
+ execution_result = f"Success: {rows_affected} rows affected"
347
+ # Only auto-commit if not in explicit transaction (A4)
348
+ if not self._in_explicit_tx:
349
+ try:
350
+ self._conn.commit()
351
+ except Exception:
352
+ pass
353
 
354
  # Compute scores
355
  current_score, step_reward = self._reconciler.compute_step_reward(self._conn)
356
 
357
  # Episode termination: submit_final, max steps, OR perfect score
358
+ done = action.submit_final or self._step_count >= self._max_steps or current_score >= 0.99
359
+
360
+ # --- A5: Trajectory logging ---
361
+ self._trajectory.append({
362
+ "step": self._step_count,
363
+ "sql": action.sql_command,
364
+ "reasoning": action.reasoning,
365
+ "result": execution_result[:200], # Truncate for storage
366
+ "score": current_score,
367
+ "reward": step_reward,
368
+ "error": action_error,
369
+ })
370
 
371
  # Update state
372
  self._state.step_count = self._step_count
 
380
  }
381
  if action_error:
382
  meta["error"] = action_error
383
+ # Include full trajectory on episode end
384
+ if done:
385
+ meta["trajectory"] = self._trajectory
386
 
387
  return MigrationObservation(
388
  done=done,
server/grader.py CHANGED
@@ -1,756 +1,366 @@
1
  """
2
- StateReconciler — The Deep Structural Grading Engine for SQL Agents.
3
-
4
- > **Hackathon Judges Note:**
5
- > Naive SQL agents often "solve" migration environments by executing `DROP TABLE x; CREATE TABLE x ...`
6
- > to forge exactly matching schemas while silently destroying all data.
7
- >
8
- > This `StateReconciler` implements robust **Anti-Exploit Protection**. It doesn't just diff schemas;
9
- > it recursively runs data-integrity hashing, cross-checks row counts, and verifies orphaned records.
10
- > If an agent drops data to match a schema, the score is brutally clamped to 0.01.
11
- > Furthermore, it utilizes heavily weighted fractional rewards to provide continuous learning
12
- > signals to the RL agent during complex, multi-step constraints (e.g., fractional points for each FK enforced).
13
-
14
- CRITICAL ARCHITECTURE RULES:
15
- - The grader NEVER modifies the database (SELECT and PRAGMA only)
16
- - The grader NEVER raises exceptions (catches everything, isolated sandbox)
17
- - Scores are strictly clamped to (0.0, 1.0) exclusive per validation constraints.
 
 
 
 
 
18
  """
19
 
20
-
21
  import sqlite3
22
- from typing import Dict, List, Optional, Set, Tuple
23
 
24
- from seeds import (
25
- TASK1_EXPECTED_ROWS,
26
- TASK2_EXPECTED_CUSTOMER_COUNT,
27
- TASK2_EXPECTED_ORDER_COUNT,
28
- TASK3_EXPECTED_AUDIT_COUNT,
29
- TASK3_EXPECTED_AUDIT_ENTRIES,
30
- TASK3_EXPECTED_EMPLOYEE_COUNT,
31
- TASK3_EXPECTED_SALARIES,
32
- TASK4_EXPECTED_ROW_COUNT,
33
- TASK4_EXPECTED_ID_SUM,
34
- TASK4_EXPECTED_DELETED_COUNT,
35
- TASK4_EXPECTED_ACTIVE_COUNT,
36
- TASK5_EXPECTED_ROW_COUNT,
37
- TASK5_EXPECTED_PRICE_SUM,
38
- TASK5_EXPECTED_BOTH_COUNT,
39
- TASK6_EXPECTED_SALESPERSON_COUNT,
40
- TASK6_EXPECTED_CUSTOMER_COUNT,
41
- TASK6_EXPECTED_PRODUCT_COUNT,
42
- TASK6_EXPECTED_SALES_COUNT,
43
- TASK6_EXPECTED_DATA_ISSUES_COUNT,
44
- TASK7_EXPECTED_UNIFIED_CUSTOMERS,
45
- TASK7_EXPECTED_BOTH_SOURCE_COUNT,
46
- TASK7_EXPECTED_UNIFIED_ORDERS,
47
- TASK7_EXPECTED_MIGRATION_ISSUES,
48
- )
49
 
50
 
51
  def _get_table_names(conn: sqlite3.Connection) -> Set[str]:
52
- """Get all table names in the database."""
53
  try:
54
  cursor = conn.execute(
55
  "SELECT name FROM sqlite_master WHERE type='table' "
56
  "AND name NOT LIKE 'sqlite_%' ORDER BY name"
57
  )
58
- return {row[0] for row in cursor.fetchall()}
59
  except Exception:
60
  return set()
61
 
62
 
63
- def _get_column_names(conn: sqlite3.Connection, table: str) -> Set[str]:
64
- """Get column names for a given table."""
65
  try:
66
  cursor = conn.execute(f"PRAGMA table_info({table})")
67
- return {row[1] for row in cursor.fetchall()}
 
 
 
68
  except Exception:
69
- return set()
 
 
 
 
 
70
 
71
 
72
  def _get_row_count(conn: sqlite3.Connection, table: str) -> int:
73
- """Get row count of a table. Returns 0 on any error."""
74
  try:
75
- cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
76
  return cursor.fetchone()[0]
77
  except Exception:
78
  return 0
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def _has_foreign_key(conn: sqlite3.Connection, table: str, ref_table: str) -> bool:
82
- """Check if table has a FK referencing ref_table."""
83
  try:
84
- cursor = conn.execute(f"PRAGMA foreign_key_list({table})")
85
  for row in cursor.fetchall():
86
- if row[2] == ref_table:
87
  return True
88
  return False
89
  except Exception:
90
  return False
91
 
92
 
93
- class StateReconciler:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  """
95
- Scores the current database state against the target for a specific task.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- Instantiated once per episode. Tracks previous score to compute step deltas.
 
 
 
 
 
 
98
  """
99
 
100
  def __init__(self, task_name: str):
101
  self.task_name = task_name
102
  self._last_score: float = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def score(self, conn: sqlite3.Connection) -> float:
105
  """
106
- Compute the current migration score [0.0, 1.0].
107
-
108
- Routes to the appropriate task-specific scorer.
109
- Never raises returns 0.0 on any unexpected error.
 
 
 
 
 
110
  """
111
  try:
112
- if self.task_name == "column-restructure":
113
- return self._score_task1(conn)
114
- elif self.task_name == "table-normalization":
115
- return self._score_task2(conn)
116
- elif self.task_name == "cascade-migration":
117
- return self._score_task3(conn)
118
- elif self.task_name == "soft-delete-restoration":
119
- return self._score_task4(conn)
120
- elif self.task_name == "schema-version-merge":
121
- return self._score_task5(conn)
122
- elif self.task_name == "multi-entity-extraction":
123
- return self._score_task6(conn)
124
- elif self.task_name == "dual-source-consolidation":
125
- return self._score_task7(conn)
126
- else:
127
- return 0.01
128
  except Exception:
129
  return 0.01
130
 
131
  def compute_step_reward(self, conn: sqlite3.Connection) -> Tuple[float, float]:
132
  """
133
- Compute both the current score and the step reward delta.
134
-
135
- Returns:
136
- (current_score, step_reward) where step_reward = current - previous
137
  """
 
 
 
 
 
 
138
  current_score = self.score(conn)
139
  step_reward = current_score - self._last_score
140
  self._last_score = current_score
141
- return current_score, step_reward
142
-
143
- # =========================================================================
144
- # Task 1: Column Restructure
145
- # =========================================================================
146
- # Weights: schema=0.4, row_count=0.2, data=0.4
147
-
148
- def _score_task1(self, conn: sqlite3.Connection) -> float:
149
- score = 0.0
150
- tables = _get_table_names(conn)
151
-
152
- if "users" not in tables:
153
- return 0.0
154
-
155
- columns = _get_column_names(conn, "users")
156
-
157
- # Schema check: full_name exists, old columns gone
158
- has_full_name = "full_name" in columns
159
- old_cols_gone = "first_name" not in columns and "last_name" not in columns
160
-
161
- if has_full_name and old_cols_gone:
162
- score += 0.4 # Full schema credit
163
- elif has_full_name:
164
- score += 0.2 # Partial: full_name exists but old cols remain
165
-
166
- # Row count check
167
- row_count = _get_row_count(conn, "users")
168
- if row_count == len(TASK1_EXPECTED_ROWS):
169
- score += 0.2
170
-
171
- # Data correctness check
172
- if has_full_name:
173
- try:
174
- cursor = conn.execute("SELECT id, full_name FROM users ORDER BY id")
175
- actual_rows = cursor.fetchall()
176
- if actual_rows == TASK1_EXPECTED_ROWS:
177
- score += 0.4
178
- elif len(actual_rows) > 0:
179
- # Partial credit: fraction of correct rows
180
- correct = sum(
181
- 1 for a, e in zip(actual_rows, TASK1_EXPECTED_ROWS)
182
- if a == e
183
- )
184
- score += 0.4 * (correct / len(TASK1_EXPECTED_ROWS))
185
- except Exception:
186
- pass
187
-
188
- # Exploit check: if schema matches but table is empty, cap score
189
- if has_full_name and old_cols_gone and row_count == 0:
190
- score = min(score, 0.1)
191
-
192
- return max(0.01, min(0.99, score))
193
-
194
- # =========================================================================
195
- # Task 2: Table Normalization
196
- # =========================================================================
197
- # Weights: tables_exist=0.1, fk=0.2, customer_count=0.2,
198
- # order_count=0.2, no_null_ids=0.1, integrity=0.2
199
-
200
- def _score_task2(self, conn: sqlite3.Connection) -> float:
201
- # Re-assert FK enforcement to prevent PRAGMA bypass exploit
202
  try:
203
- conn.execute("PRAGMA foreign_keys = ON")
204
  except Exception:
205
  pass
206
- score = 0.0
207
- tables = _get_table_names(conn)
208
-
209
- # Both tables exist
210
- has_customers = "customers" in tables
211
- has_orders = "orders" in tables
212
- if has_customers and has_orders:
213
- score += 0.1
214
-
215
- # FK constraint: orders -> customers
216
- if has_orders and _has_foreign_key(conn, "orders", "customers"):
217
- score += 0.2
218
-
219
- # Correct distinct customer count
220
- if has_customers:
221
- try:
222
- cursor = conn.execute("SELECT COUNT(DISTINCT email) FROM customers")
223
- distinct_count = cursor.fetchone()[0]
224
- if distinct_count == TASK2_EXPECTED_CUSTOMER_COUNT:
225
- score += 0.2
226
- except Exception:
227
- pass
228
-
229
- # Correct order count (all original purchases preserved)
230
- if has_orders:
231
- order_count = _get_row_count(conn, "orders")
232
- if order_count == TASK2_EXPECTED_ORDER_COUNT:
233
- score += 0.2
234
-
235
- # No NULL customer_ids in orders
236
- if has_orders:
237
- try:
238
- cursor = conn.execute(
239
- "SELECT COUNT(*) FROM orders WHERE customer_id IS NULL"
240
- )
241
- null_count = cursor.fetchone()[0]
242
- if null_count == 0:
243
- score += 0.1
244
- except Exception:
245
- pass
246
-
247
- # Integrity check
248
- try:
249
- cursor = conn.execute("PRAGMA integrity_check")
250
- result = cursor.fetchone()[0]
251
- if result == "ok":
252
- score += 0.2
253
- except Exception:
254
- pass
255
-
256
- # Exploit check: tables exist but are empty
257
- if has_customers and has_orders:
258
- c_count = _get_row_count(conn, "customers")
259
- o_count = _get_row_count(conn, "orders")
260
- if c_count == 0 and o_count == 0:
261
- score = min(score, 0.1)
262
-
263
- return max(0.01, min(0.99, score))
264
-
265
- # =========================================================================
266
- # Task 3: Cascade Migration
267
- # =========================================================================
268
- # Granular partial credit for each relationship in the FK chain.
269
- # Total weights: audit=0.30, fk_chain=0.20, emp_count=0.05,
270
- # salary_coercion=0.15, no_orphans=0.10, integrity=0.10
271
- # companies_not_null=0.05 (within fk_chain)
272
- # Total max = 0.90 for all grader checks + 0.10 integrity = 1.00
273
-
274
- def _score_task3(self, conn: sqlite3.Connection) -> float:
275
- # Re-assert FK enforcement to prevent PRAGMA bypass exploit
276
- try:
277
- conn.execute("PRAGMA foreign_keys = ON")
278
- except Exception:
279
- pass
280
- score = 0.0
281
- tables = _get_table_names(conn)
282
-
283
- # --- audit_log checks (0.30 total) ---
284
- has_audit = "audit_log" in tables
285
- if has_audit:
286
- score += 0.1 # table exists
287
-
288
- if has_audit:
289
- audit_count = _get_row_count(conn, "audit_log")
290
- if audit_count >= TASK3_EXPECTED_AUDIT_COUNT:
291
- score += 0.1 # has enough rows
292
-
293
- if has_audit:
294
- try:
295
- cursor = conn.execute(
296
- "SELECT source_table, reason FROM audit_log ORDER BY source_table, reason"
297
- )
298
- actual_entries = cursor.fetchall()
299
- expected_sorted = sorted(TASK3_EXPECTED_AUDIT_ENTRIES)
300
- if actual_entries == expected_sorted:
301
- score += 0.2
302
- elif len(actual_entries) > 0:
303
- correct = sum(1 for a in actual_entries if a in TASK3_EXPECTED_AUDIT_ENTRIES)
304
- score += 0.2 * (correct / TASK3_EXPECTED_AUDIT_COUNT)
305
- except Exception:
306
- pass
307
-
308
- # --- FK chain checks (0.20 total, 0.05 each) ---
309
- # departments -> companies
310
- if "departments" in tables and _has_foreign_key(conn, "departments", "companies"):
311
- score += 0.05
312
- # employees -> departments
313
- if "employees" in tables and _has_foreign_key(conn, "employees", "departments"):
314
- score += 0.05
315
- # assets -> employees
316
- if "assets" in tables and _has_foreign_key(conn, "assets", "employees"):
317
- score += 0.05
318
- # companies.name NOT NULL
319
- if "companies" in tables:
320
- try:
321
- cursor = conn.execute("PRAGMA table_info(companies)")
322
- for row in cursor.fetchall():
323
- if row[1] == "name" and row[3] == 1: # notnull flag
324
- score += 0.05
325
- break
326
- except Exception:
327
- pass
328
-
329
- # --- Employee count (Hal Patel removed) (0.05) ---
330
- if "employees" in tables:
331
- emp_count = _get_row_count(conn, "employees")
332
- if emp_count == TASK3_EXPECTED_EMPLOYEE_COUNT:
333
- score += 0.05
334
-
335
- # --- Salary coercion: TEXT $90000 -> INTEGER 90000 (0.15) ---
336
- if "employees" in tables:
337
- try:
338
- all_correct = True
339
- for emp_id, expected_salary in TASK3_EXPECTED_SALARIES.items():
340
- cursor = conn.execute(
341
- "SELECT salary FROM employees WHERE id = ?", (emp_id,)
342
- )
343
- row = cursor.fetchone()
344
- if row is None:
345
- all_correct = False
346
- break
347
- actual = row[0]
348
- if not isinstance(actual, int):
349
- try:
350
- actual = int(actual)
351
- except (ValueError, TypeError):
352
- all_correct = False
353
- break
354
- if actual != expected_salary:
355
- all_correct = False
356
- break
357
- if all_correct:
358
- score += 0.15
359
- except Exception:
360
- pass
361
-
362
- # --- No orphaned assets (0.10) ---
363
- if "assets" in tables and "employees" in tables:
364
- try:
365
- cursor = conn.execute(
366
- "SELECT COUNT(*) FROM assets WHERE employee_id NOT IN "
367
- "(SELECT id FROM employees)"
368
- )
369
- orphan_count = cursor.fetchone()[0]
370
- if orphan_count == 0:
371
- score += 0.10
372
- except Exception:
373
- pass
374
-
375
- # --- Integrity check (0.10) ---
376
- try:
377
- cursor = conn.execute("PRAGMA integrity_check")
378
- result = cursor.fetchone()[0]
379
- if result == "ok":
380
- score += 0.10
381
- except Exception:
382
- pass
383
-
384
- # Exploit check: if employees table is empty
385
- if "employees" in tables and _get_row_count(conn, "employees") == 0:
386
- score = min(score, 0.1)
387
-
388
- return max(0.01, min(0.99, score))
389
-
390
- # =========================================================================
391
- # Task 4: Soft-Delete Restoration (Easy)
392
- # =========================================================================
393
-
394
- def _score_task4(self, conn: sqlite3.Connection) -> float:
395
- score = 0.0
396
- tables = _get_table_names(conn)
397
-
398
- if "products" not in tables:
399
- return 0.01
400
-
401
- cols = _get_column_names(conn, "products")
402
-
403
- # is_deleted column exists (+0.15)
404
- if "is_deleted" in cols:
405
- score += 0.15
406
-
407
- # deleted_at column exists (+0.10)
408
- if "deleted_at" in cols:
409
- score += 0.10
410
-
411
- # Row count = 8 (+0.20)
412
- row_count = _get_row_count(conn, "products")
413
- if row_count == TASK4_EXPECTED_ROW_COUNT:
414
- score += 0.20
415
-
416
- # Active products: is_deleted=0, deleted_at IS NULL (+0.25)
417
- if "is_deleted" in cols:
418
- try:
419
- cursor = conn.execute(
420
- "SELECT COUNT(*) FROM products WHERE is_deleted = 0 AND deleted_at IS NULL"
421
- )
422
- active = cursor.fetchone()[0]
423
- if active == TASK4_EXPECTED_ACTIVE_COUNT:
424
- score += 0.25
425
- except Exception:
426
- pass
427
-
428
- # Restored products: is_deleted=1, deleted_at IS NOT NULL (+0.20)
429
- if "is_deleted" in cols:
430
- try:
431
- cursor = conn.execute(
432
- "SELECT COUNT(*) FROM products WHERE is_deleted = 1 AND deleted_at IS NOT NULL"
433
- )
434
- restored = cursor.fetchone()[0]
435
- if restored == TASK4_EXPECTED_DELETED_COUNT:
436
- score += 0.20
437
- except Exception:
438
- pass
439
-
440
- # SUM(id) fingerprint = 36 — no phantom rows (+0.10)
441
- try:
442
- cursor = conn.execute("SELECT SUM(id) FROM products")
443
- id_sum = cursor.fetchone()[0]
444
- if id_sum == TASK4_EXPECTED_ID_SUM:
445
- score += 0.10
446
- except Exception:
447
- pass
448
-
449
- # Exploit check
450
- if row_count == 0:
451
- score = min(score, 0.1)
452
-
453
- return max(0.01, min(0.99, score))
454
-
455
- # =========================================================================
456
- # Task 5: Schema Version Merge (Medium)
457
- # =========================================================================
458
-
459
- def _score_task5(self, conn: sqlite3.Connection) -> float:
460
- # Re-assert FK enforcement
461
- try:
462
- conn.execute("PRAGMA foreign_keys = ON")
463
- except Exception:
464
- pass
465
- score = 0.0
466
- tables = _get_table_names(conn)
467
 
468
- if "products" not in tables:
 
 
469
  return 0.01
470
-
471
- cols = _get_column_names(conn, "products")
472
-
473
- # Schema completeness: all 8 columns (+0.10)
474
- expected_cols = {"id", "name", "price", "category", "supplier", "brand", "sku", "source"}
475
- if expected_cols.issubset(cols):
476
- score += 0.10
477
-
478
- # Row count = 9 (+0.15)
479
- row_count = _get_row_count(conn, "products")
480
- if row_count == TASK5_EXPECTED_ROW_COUNT:
481
- score += 0.15
482
-
483
- # PRICE_SUM fingerprint (+0.20)
484
- try:
485
- cursor = conn.execute("SELECT ROUND(SUM(price), 2) FROM products")
486
- price_sum = cursor.fetchone()[0]
487
- if price_sum is not None and abs(price_sum - TASK5_EXPECTED_PRICE_SUM) < 0.02:
488
- score += 0.20
489
- except Exception:
490
- pass
491
-
492
- # source='both' for conflicted ids 1,2 (+0.15)
493
- if "source" in cols:
494
- try:
495
- cursor = conn.execute(
496
- "SELECT COUNT(*) FROM products WHERE source = 'both'"
497
- )
498
- both_count = cursor.fetchone()[0]
499
- if both_count == TASK5_EXPECTED_BOTH_COUNT:
500
- score += 0.15
501
- except Exception:
502
- pass
503
-
504
- # v2 name wins for conflicted rows (+0.15)
505
- try:
506
- cursor = conn.execute("SELECT name FROM products WHERE id = 2")
507
- row = cursor.fetchone()
508
- if row and "Updated" in row[0]:
509
- score += 0.15
510
- except Exception:
511
- pass
512
-
513
- # No NULL prices (+0.10)
514
- try:
515
- cursor = conn.execute("SELECT COUNT(*) FROM products WHERE price IS NULL")
516
- null_count = cursor.fetchone()[0]
517
- if null_count == 0:
518
- score += 0.10
519
- except Exception:
520
- pass
521
-
522
- # PRAGMA integrity_check (+0.15)
523
- try:
524
- cursor = conn.execute("PRAGMA integrity_check")
525
- result = cursor.fetchone()[0]
526
- if result == "ok":
527
- score += 0.15
528
- except Exception:
529
- pass
530
-
531
- # Exploit check
532
- if row_count == 0:
533
- score = min(score, 0.1)
534
-
535
- return max(0.01, min(0.99, score))
536
-
537
- # =========================================================================
538
- # Task 6: Multi-Entity Extraction (Medium — Hard End)
539
- # =========================================================================
540
-
541
- def _score_task6(self, conn: sqlite3.Connection) -> float:
542
- # Re-assert FK enforcement
543
- try:
544
- conn.execute("PRAGMA foreign_keys = ON")
545
- except Exception:
546
- pass
547
- score = 0.0
548
- tables = _get_table_names(conn)
549
-
550
- # All 5 tables exist (+0.10)
551
- required = {"salespersons", "customers", "products", "sales", "data_issues"}
552
- if required.issubset(tables):
553
- score += 0.10
554
-
555
- # salesperson count = 3 (+0.10)
556
- if "salespersons" in tables:
557
- count = _get_row_count(conn, "salespersons")
558
- if count == TASK6_EXPECTED_SALESPERSON_COUNT:
559
- score += 0.10
560
-
561
- # customer count = 3 (invalid excluded) (+0.12)
562
- if "customers" in tables:
563
- count = _get_row_count(conn, "customers")
564
- if count == TASK6_EXPECTED_CUSTOMER_COUNT:
565
- score += 0.12
566
-
567
- # product count = 5 (+0.10)
568
- if "products" in tables:
569
- count = _get_row_count(conn, "products")
570
- if count == TASK6_EXPECTED_PRODUCT_COUNT:
571
- score += 0.10
572
-
573
- # sales count = 11 (bad row excluded) (+0.12)
574
- if "sales" in tables:
575
- count = _get_row_count(conn, "sales")
576
- if count == TASK6_EXPECTED_SALES_COUNT:
577
- score += 0.12
578
-
579
- # All 3 FKs present in sales (+0.15)
580
- if "sales" in tables:
581
- fk_count = 0
582
- if _has_foreign_key(conn, "sales", "salespersons"): fk_count += 1
583
- if _has_foreign_key(conn, "sales", "customers"): fk_count += 1
584
- if _has_foreign_key(conn, "sales", "products"): fk_count += 1
585
- score += 0.05 * fk_count # 0.15 total for all 3
586
-
587
- # data_issues count = 1, for row 6 (+0.11)
588
- if "data_issues" in tables:
589
- count = _get_row_count(conn, "data_issues")
590
- if count == TASK6_EXPECTED_DATA_ISSUES_COUNT:
591
- score += 0.11
592
-
593
- # alice email is trimmed (+0.10)
594
- if "salespersons" in tables:
595
- try:
596
- cursor = conn.execute(
597
- "SELECT email FROM salespersons WHERE name LIKE '%Alice%'"
598
- )
599
- row = cursor.fetchone()
600
- if row and row[0] == "alice@company.com":
601
- score += 0.10
602
- except Exception:
603
- pass
604
-
605
- # PRAGMA integrity_check (+0.10)
606
- try:
607
- cursor = conn.execute("PRAGMA integrity_check")
608
- result = cursor.fetchone()[0]
609
- if result == "ok":
610
- score += 0.10
611
- except Exception:
612
- pass
613
-
614
- # Exploit check
615
- sales_count = _get_row_count(conn, "sales") if "sales" in tables else 0
616
- if sales_count == 0 and "sales" in tables:
617
- score = min(score, 0.1)
618
-
619
- return max(0.01, min(0.99, score))
620
-
621
- # =========================================================================
622
- # Task 7: Dual-Source Consolidation (Hard)
623
- # =========================================================================
624
-
625
- def _score_task7(self, conn: sqlite3.Connection) -> float:
626
- # Re-assert FK enforcement
627
  try:
 
628
  conn.execute("PRAGMA foreign_keys = ON")
629
- except Exception:
630
- pass
631
- score = 0.0
632
- tables = _get_table_names(conn)
633
-
634
- # All 4 tables exist (+0.05)
635
- required = {"unified_customers", "unified_products", "unified_orders", "migration_issues"}
636
- if required.issubset(tables):
637
- score += 0.05
638
-
639
- # unified_customers count = 7 (+0.08)
640
- if "unified_customers" in tables:
641
- count = _get_row_count(conn, "unified_customers")
642
- if count == TASK7_EXPECTED_UNIFIED_CUSTOMERS:
643
- score += 0.08
644
-
645
- # source='both' for email-matched records (+0.08)
646
- if "unified_customers" in tables:
647
- try:
648
- cursor = conn.execute(
649
- "SELECT COUNT(*) FROM unified_customers WHERE source = 'both'"
650
- )
651
- both = cursor.fetchone()[0]
652
- if both == TASK7_EXPECTED_BOTH_SOURCE_COUNT:
653
- score += 0.08
654
- except Exception:
655
- pass
656
-
657
- # Legacy amount coercion — check unified_orders has REAL amounts (+0.10)
658
- if "unified_orders" in tables:
659
- try:
660
- cursor = conn.execute(
661
- "SELECT COUNT(*) FROM unified_orders WHERE typeof(amount) = 'real' OR typeof(amount) = 'integer'"
662
- )
663
- real_count = cursor.fetchone()[0]
664
- order_count = _get_row_count(conn, "unified_orders")
665
- if real_count == order_count and order_count > 0:
666
- score += 0.10
667
- except Exception:
668
- pass
669
-
670
- # NULL currency → 'USD' fill (+0.07)
671
- if "unified_orders" in tables:
672
- try:
673
- cursor = conn.execute(
674
- "SELECT COUNT(*) FROM unified_orders WHERE currency IS NULL"
675
- )
676
- null_curr = cursor.fetchone()[0]
677
- if null_curr == 0:
678
- score += 0.07
679
- except Exception:
680
- pass
681
-
682
- # tx_status mapped to strings (+0.10)
683
- if "unified_orders" in tables:
684
- try:
685
- cursor = conn.execute(
686
- "SELECT COUNT(*) FROM unified_orders WHERE typeof(status) = 'text'"
687
- )
688
- text_count = cursor.fetchone()[0]
689
- order_count = _get_row_count(conn, "unified_orders")
690
- if text_count == order_count and order_count > 0:
691
- score += 0.10
692
- except Exception:
693
- pass
694
-
695
- # subscription_tier mapped to strings (+0.08)
696
- if "unified_customers" in tables:
697
- try:
698
- cursor = conn.execute(
699
- "SELECT COUNT(*) FROM unified_customers WHERE typeof(tier) = 'text'"
700
- )
701
- text_count = cursor.fetchone()[0]
702
- cust_count = _get_row_count(conn, "unified_customers")
703
- if text_count == cust_count and cust_count > 0:
704
- score += 0.08
705
- except Exception:
706
- pass
707
-
708
- # migration_issues count = 2 (+0.08)
709
- if "migration_issues" in tables:
710
- count = _get_row_count(conn, "migration_issues")
711
- if count == TASK7_EXPECTED_MIGRATION_ISSUES:
712
- score += 0.08
713
-
714
- # Orphaned transaction in issues (+0.07)
715
- if "migration_issues" in tables:
716
- try:
717
- cursor = conn.execute(
718
- "SELECT COUNT(*) FROM migration_issues WHERE issue_type = 'orphaned_record'"
719
- )
720
- orphan_issues = cursor.fetchone()[0]
721
- if orphan_issues >= 1:
722
- score += 0.07
723
- except Exception:
724
- pass
725
-
726
- # NULL email customer in issues (+0.07)
727
- if "migration_issues" in tables:
728
- try:
729
- cursor = conn.execute(
730
- "SELECT COUNT(*) FROM migration_issues WHERE issue_type = 'null_email'"
731
- )
732
- null_issues = cursor.fetchone()[0]
733
- if null_issues >= 1:
734
- score += 0.07
735
- except Exception:
736
- pass
737
-
738
- # FK integrity on unified_orders (+0.10)
739
- if "unified_orders" in tables:
740
- if _has_foreign_key(conn, "unified_orders", "unified_customers"):
741
- score += 0.10
742
-
743
- # PRAGMA integrity_check (+0.10)
744
- try:
745
  cursor = conn.execute("PRAGMA integrity_check")
746
  result = cursor.fetchone()[0]
747
- if result == "ok":
748
- score += 0.10
749
  except Exception:
750
  pass
751
-
752
- # Exploit check
753
- if "unified_orders" in tables and _get_row_count(conn, "unified_orders") == 0:
754
- score = min(score, 0.1)
755
-
756
- return max(0.01, min(0.99, score))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ StateReconciler — Dynamic Golden Database Grading Engine.
3
+
4
+ ARCHITECTURE:
5
+ - Instead of hardcoded expected values, we build a "golden" database by running
6
+ the correct migration on a fresh copy of the seed data.
7
+ - The agent's database is compared table-by-table against this golden reference.
8
+ - This makes the grader SEED-INDEPENDENT: if judges change the seed data,
9
+ the golden DB auto-updates and scoring remains accurate.
10
+
11
+ SCORING WEIGHTS (per-table, dynamic):
12
+ - Schema match (table exists, correct columns): 30%
13
+ - Data match (row count + content): 40%
14
+ - FK & constraint integrity: 20%
15
+ - Anti-exploit checks: 10%
16
+
17
+ ANTI-EXPLOIT PROTECTIONS:
18
+ - Case-insensitive table/column name comparison
19
+ - PRAGMA state preservation (grader doesn't corrupt agent's FK state)
20
+ - Phantom row detection (SUM fingerprinting)
21
+ - Empty table exploitation blocked
22
+ - Extra/leftover table penalty
23
  """
24
 
 
25
  import sqlite3
26
+ from typing import Any, Dict, List, Optional, Set, Tuple
27
 
28
+ # Import seeds for golden migration functions
29
+ try:
30
+ from .. import seeds
31
+ except ImportError:
32
+ import seeds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  def _get_table_names(conn: sqlite3.Connection) -> Set[str]:
36
+ """Get all user table names (case-normalized to lowercase)."""
37
  try:
38
  cursor = conn.execute(
39
  "SELECT name FROM sqlite_master WHERE type='table' "
40
  "AND name NOT LIKE 'sqlite_%' ORDER BY name"
41
  )
42
+ return {row[0].lower() for row in cursor.fetchall()}
43
  except Exception:
44
  return set()
45
 
46
 
47
+ def _get_column_info(conn: sqlite3.Connection, table: str) -> List[dict]:
48
+ """Get column info for a table. Returns list of {name, type, notnull, pk}."""
49
  try:
50
  cursor = conn.execute(f"PRAGMA table_info({table})")
51
+ return [
52
+ {"name": row[1].lower(), "type": row[2].upper(), "notnull": row[3], "pk": row[5]}
53
+ for row in cursor.fetchall()
54
+ ]
55
  except Exception:
56
+ return []
57
+
58
+
59
+ def _get_column_names(conn: sqlite3.Connection, table: str) -> Set[str]:
60
+ """Get column names (lowercase) for a table."""
61
+ return {col["name"] for col in _get_column_info(conn, table)}
62
 
63
 
64
  def _get_row_count(conn: sqlite3.Connection, table: str) -> int:
65
+ """Get row count. Returns 0 on error."""
66
  try:
67
+ cursor = conn.execute(f"SELECT COUNT(*) FROM [{table}]")
68
  return cursor.fetchone()[0]
69
  except Exception:
70
  return 0
71
 
72
 
73
+ def _get_all_rows(conn: sqlite3.Connection, table: str) -> List[Tuple]:
74
+ """Get all rows from a table, sorted for deterministic comparison."""
75
+ try:
76
+ cols = _get_column_names(conn, table)
77
+ if not cols:
78
+ return []
79
+ cursor = conn.execute(f"SELECT * FROM [{table}] ORDER BY 1")
80
+ return cursor.fetchall()
81
+ except Exception:
82
+ return []
83
+
84
+
85
  def _has_foreign_key(conn: sqlite3.Connection, table: str, ref_table: str) -> bool:
86
+ """Check if table has a FK referencing ref_table (case-insensitive)."""
87
  try:
88
+ cursor = conn.execute(f"PRAGMA foreign_key_list([{table}])")
89
  for row in cursor.fetchall():
90
+ if row[2].lower() == ref_table.lower():
91
  return True
92
  return False
93
  except Exception:
94
  return False
95
 
96
 
97
+ def _count_foreign_keys(conn: sqlite3.Connection, table: str) -> int:
98
+ """Count all FK relationships for a table."""
99
+ try:
100
+ cursor = conn.execute(f"PRAGMA foreign_key_list([{table}])")
101
+ refs = set()
102
+ for row in cursor.fetchall():
103
+ refs.add(row[2].lower())
104
+ return len(refs)
105
+ except Exception:
106
+ return 0
107
+
108
+
109
+ def _build_golden_db(task_name: str) -> sqlite3.Connection:
110
+ """
111
+ Build a golden reference database for a task.
112
+
113
+ Seeds a fresh in-memory DB with the task's seed data, then applies
114
+ the golden migration to produce the expected final state.
115
+ """
116
+ task_config = seeds.TASKS[task_name]
117
+ conn = sqlite3.connect(":memory:")
118
+ conn.execute("PRAGMA foreign_keys = ON")
119
+
120
+ # Seed with same data as agent
121
+ task_config["seed_fn"](conn)
122
+
123
+ # Apply perfect migration
124
+ task_config["golden_fn"](conn)
125
+
126
+ return conn
127
+
128
+
129
+ def _compare_row_data(
130
+ agent_rows: List[Tuple],
131
+ golden_rows: List[Tuple],
132
+ ) -> float:
133
+ """
134
+ Compare row data between agent and golden databases.
135
+
136
+ Returns a similarity score between 0.0 and 1.0.
137
+ Handles: different row counts, partial matches, type coercion differences.
138
  """
139
+ if not golden_rows:
140
+ return 1.0 if not agent_rows else 0.0
141
+ if not agent_rows:
142
+ return 0.0
143
+
144
+ # Exact match
145
+ if agent_rows == golden_rows:
146
+ return 1.0
147
+
148
+ # Row count match bonus
149
+ count_match = 1.0 if len(agent_rows) == len(golden_rows) else (
150
+ min(len(agent_rows), len(golden_rows)) / max(len(agent_rows), len(golden_rows))
151
+ )
152
+
153
+ # Per-row comparison (order-independent for flexibility)
154
+ golden_set = set()
155
+ for row in golden_rows:
156
+ # Normalize: convert all values to strings for loose comparison
157
+ golden_set.add(tuple(str(v).strip() if v is not None else "" for v in row))
158
+
159
+ matched = 0
160
+ for row in agent_rows:
161
+ normalized = tuple(str(v).strip() if v is not None else "" for v in row)
162
+ if normalized in golden_set:
163
+ matched += 1
164
+ golden_set.discard(normalized)
165
+
166
+ if len(golden_rows) == 0:
167
+ content_match = 0.0
168
+ else:
169
+ content_match = matched / len(golden_rows)
170
+
171
+ # Penalize extra rows (data bloat)
172
+ if len(agent_rows) > len(golden_rows):
173
+ bloat_penalty = max(0, 1.0 - (len(agent_rows) - len(golden_rows)) / len(golden_rows))
174
+ content_match *= bloat_penalty
175
+
176
+ return 0.4 * count_match + 0.6 * content_match
177
 
178
+
179
+ class StateReconciler:
180
+ """
181
+ Dynamic Golden Database grading engine.
182
+
183
+ Compares the agent's database state against a dynamically-generated
184
+ golden reference database. No hardcoded expected values.
185
  """
186
 
187
  def __init__(self, task_name: str):
188
  self.task_name = task_name
189
  self._last_score: float = 0.0
190
+ self._golden_conn: Optional[sqlite3.Connection] = None
191
+
192
+ # Build golden reference DB
193
+ try:
194
+ self._golden_conn = _build_golden_db(task_name)
195
+ self._golden_tables = _get_table_names(self._golden_conn)
196
+ self._golden_table_data: Dict[str, dict] = {}
197
+
198
+ for table in self._golden_tables:
199
+ self._golden_table_data[table] = {
200
+ "columns": _get_column_info(self._golden_conn, table),
201
+ "col_names": _get_column_names(self._golden_conn, table),
202
+ "rows": _get_all_rows(self._golden_conn, table),
203
+ "row_count": _get_row_count(self._golden_conn, table),
204
+ "fk_count": _count_foreign_keys(self._golden_conn, table),
205
+ }
206
+ except Exception:
207
+ self._golden_tables = set()
208
+ self._golden_table_data = {}
209
+
210
+ def __del__(self):
211
+ """Clean up golden DB connection."""
212
+ if self._golden_conn is not None:
213
+ try:
214
+ self._golden_conn.close()
215
+ except Exception:
216
+ pass
217
 
218
  def score(self, conn: sqlite3.Connection) -> float:
219
  """
220
+ Compute migration score by comparing agent DB against golden reference.
221
+
222
+ Scoring breakdown:
223
+ - Schema match: 0.30 (tables exist with correct columns)
224
+ - Data match: 0.40 (row content matches golden DB)
225
+ - FK/constraint integrity: 0.20 (FKs enforced, integrity OK)
226
+ - Anti-exploit bonus: 0.10 (no empty tables, no extra tables)
227
+
228
+ Returns: float in [0.01, 0.99]
229
  """
230
  try:
231
+ return self._score_dynamic(conn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  except Exception:
233
  return 0.01
234
 
235
  def compute_step_reward(self, conn: sqlite3.Connection) -> Tuple[float, float]:
236
  """
237
+ Compute current score and step reward delta.
238
+
239
+ CRITICAL: Preserves the agent's PRAGMA foreign_keys state.
240
+ The grader reads FK state, does its work, then restores it.
241
  """
242
+ # A8: Preserve PRAGMA state
243
+ try:
244
+ original_fk = conn.execute("PRAGMA foreign_keys").fetchone()[0]
245
+ except Exception:
246
+ original_fk = 1
247
+
248
  current_score = self.score(conn)
249
  step_reward = current_score - self._last_score
250
  self._last_score = current_score
251
+
252
+ # A8: Restore original PRAGMA state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  try:
254
+ conn.execute(f"PRAGMA foreign_keys = {'ON' if original_fk else 'OFF'}")
255
  except Exception:
256
  pass
257
+
258
+ return current_score, step_reward
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
+ def _score_dynamic(self, conn: sqlite3.Connection) -> float:
261
+ """Core dynamic scoring: compare agent DB against golden DB."""
262
+ if not self._golden_tables:
263
  return 0.01
264
+
265
+ agent_tables = _get_table_names(conn)
266
+
267
+ # ---- 1. Schema Match (0.30) ----
268
+ schema_score = 0.0
269
+ tables_found = 0
270
+ total_col_match = 0.0
271
+
272
+ for table in self._golden_tables:
273
+ golden_info = self._golden_table_data[table]
274
+
275
+ if table in agent_tables:
276
+ tables_found += 1
277
+ # Column name comparison
278
+ agent_cols = _get_column_names(conn, table)
279
+ golden_cols = golden_info["col_names"]
280
+ if golden_cols:
281
+ col_overlap = len(agent_cols & golden_cols) / len(golden_cols)
282
+ total_col_match += col_overlap
283
+ else:
284
+ total_col_match += 1.0
285
+
286
+ if self._golden_tables:
287
+ table_ratio = tables_found / len(self._golden_tables)
288
+ col_ratio = total_col_match / len(self._golden_tables) if self._golden_tables else 0
289
+ schema_score = 0.15 * table_ratio + 0.15 * col_ratio
290
+
291
+ # ---- 2. Data Match (0.40) ----
292
+ data_score = 0.0
293
+ data_checks = 0
294
+
295
+ for table in self._golden_tables:
296
+ golden_info = self._golden_table_data[table]
297
+ if table not in agent_tables:
298
+ data_checks += 1
299
+ continue
300
+
301
+ agent_rows = _get_all_rows(conn, table)
302
+ golden_rows = golden_info["rows"]
303
+
304
+ similarity = _compare_row_data(agent_rows, golden_rows)
305
+ data_score += similarity
306
+ data_checks += 1
307
+
308
+ if data_checks > 0:
309
+ data_score = 0.40 * (data_score / data_checks)
310
+
311
+ # ---- 3. FK & Constraint Integrity (0.20) ----
312
+ fk_score = 0.0
313
+ fk_checks = 0
314
+
315
+ for table in self._golden_tables:
316
+ golden_info = self._golden_table_data[table]
317
+ expected_fks = golden_info["fk_count"]
318
+
319
+ if expected_fks > 0 and table in agent_tables:
320
+ agent_fks = _count_foreign_keys(conn, table)
321
+ fk_ratio = min(agent_fks, expected_fks) / expected_fks
322
+ fk_score += fk_ratio
323
+ fk_checks += 1
324
+
325
+ # PRAGMA integrity check
326
+ integrity_ok = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  try:
328
+ # Temporarily enable FK for integrity check
329
  conn.execute("PRAGMA foreign_keys = ON")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  cursor = conn.execute("PRAGMA integrity_check")
331
  result = cursor.fetchone()[0]
332
+ integrity_ok = (result == "ok")
 
333
  except Exception:
334
  pass
335
+
336
+ if fk_checks > 0:
337
+ fk_score = 0.10 * (fk_score / fk_checks)
338
+ else:
339
+ # No FK constraints expected — award full FK portion
340
+ fk_score = 0.10
341
+ fk_score += 0.10 if integrity_ok else 0.0
342
+
343
+ # ---- 4. Anti-Exploit Checks (0.10) ----
344
+ exploit_score = 0.10 # Start with full credit, deduct for violations
345
+
346
+ # Check for empty tables where golden has data
347
+ for table in self._golden_tables:
348
+ golden_info = self._golden_table_data[table]
349
+ if golden_info["row_count"] > 0 and table in agent_tables:
350
+ agent_count = _get_row_count(conn, table)
351
+ if agent_count == 0:
352
+ # Agent emptied a table that should have data — heavy penalty
353
+ exploit_score = 0.0
354
+ # Also cap the data score for this exploit
355
+ data_score = min(data_score, 0.05)
356
+ break
357
+
358
+ # Penalize extra non-golden tables (schema pollution)
359
+ extra_tables = agent_tables - self._golden_tables
360
+ if extra_tables:
361
+ # Small penalty per extra table (some might be temp tables)
362
+ penalty = min(0.05, 0.01 * len(extra_tables))
363
+ exploit_score = max(0, exploit_score - penalty)
364
+
365
+ total = schema_score + data_score + fk_score + exploit_score
366
+ return max(0.01, min(0.99, total))
test_all_tasks.py CHANGED
@@ -1,49 +1,105 @@
1
- """Quick validation of all 7 tasks: seeds + graders."""
2
- import sqlite3
3
  import sys
4
  import os
 
5
 
6
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 
7
 
8
- from seeds import TASKS
9
  from server.grader import StateReconciler
 
 
10
 
11
- print(f"Tasks registered: {len(TASKS)}")
12
- assert len(TASKS) == 7, f"Expected 7 tasks, got {len(TASKS)}"
13
- print(f" Names: {list(TASKS.keys())}")
14
 
15
- for name, cfg in TASKS.items():
16
- # Seed
 
 
 
17
  conn = sqlite3.connect(":memory:")
18
  conn.execute("PRAGMA foreign_keys = ON")
19
- cfg["seed_fn"](conn)
20
 
21
- cursor = conn.execute(
22
- "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
23
- )
24
- tables = [r[0] for r in cursor.fetchall()]
25
- print(f"\n[{name}] ({cfg['difficulty']}, max_steps={cfg.get('max_steps', 20)})")
26
- print(f" Tables: {tables}")
27
 
28
- # Grade
29
- reconciler = StateReconciler(name)
30
- score = reconciler.score(conn)
31
- assert 0.01 <= score <= 0.99, f"Score {score} out of [0.01, 0.99]!"
32
- print(f" Initial score: {score:.2f} OK")
33
 
34
  conn.close()
 
 
 
 
 
 
35
 
36
- # Also test environment resets for each task
37
- from server.environment import DbMigrationEnvironment
38
 
39
- for name in TASKS:
40
- env = DbMigrationEnvironment(task_name=name)
 
41
  obs = env.reset()
42
- assert obs.done == False
43
- assert obs.step_number == 0
44
- print(f" [{name}] Environment reset OK")
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  env.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- print("\n" + "=" * 50)
48
- print("ALL 7 TASKS VALIDATED SUCCESSFULLY!")
49
- print("=" * 50)
 
1
+ """Test all 7 tasks: seed, golden migration, grade, reset, close."""
 
2
  import sys
3
  import os
4
+ import sqlite3
5
 
6
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
7
+ sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'OpenEnv', 'src'))
8
 
9
+ import seeds
10
  from server.grader import StateReconciler
11
+ from server.environment import DbMigrationEnvironment
12
+ from models import MigrationAction
13
 
 
 
 
14
 
15
+ def test_golden_migration(task_name: str) -> None:
16
+ """Test that golden migration produces a near-perfect grader score."""
17
+ config = seeds.TASKS[task_name]
18
+
19
+ # 1. Create DB and seed
20
  conn = sqlite3.connect(":memory:")
21
  conn.execute("PRAGMA foreign_keys = ON")
22
+ config["seed_fn"](conn)
23
 
24
+ # 2. Score before migration (should be low)
25
+ reconciler = StateReconciler(task_name)
26
+ score_before = reconciler.score(conn)
 
 
 
27
 
28
+ # 3. Run golden migration
29
+ config["golden_fn"](conn)
30
+
31
+ # 4. Score after migration (should be >0.90)
32
+ score_after = reconciler.score(conn)
33
 
34
  conn.close()
35
+
36
+ status = "PASS" if score_after >= 0.90 else "FAIL"
37
+ print(f" [{status}] {task_name}: before={score_before:.2f} after={score_after:.2f}")
38
+
39
+ if score_after < 0.90:
40
+ raise AssertionError(f"{task_name}: golden migration only scored {score_after:.2f}")
41
 
 
 
42
 
43
+ def test_environment_lifecycle(task_name: str) -> None:
44
+ """Test that environment can reset, step, and close without crashes."""
45
+ env = DbMigrationEnvironment(task_name=task_name)
46
  obs = env.reset()
47
+
48
+ assert not obs.done, f"{task_name}: obs.done should be False after reset"
49
+ assert obs.step_number == 0, f"{task_name}: step should be 0 after reset"
50
+ assert obs.current_schema_sql, f"{task_name}: should have current schema"
51
+ assert obs.target_schema_sql, f"{task_name}: should have target schema"
52
+
53
+ # Run a SELECT to verify data passthrough
54
+ action = MigrationAction(
55
+ sql_command="SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
56
+ reasoning="List tables",
57
+ submit_final=False,
58
+ )
59
+ obs = env.step(action)
60
+ assert "rows total" in obs.last_execution_result or "Query returned" in obs.last_execution_result, \
61
+ f"{task_name}: SELECT should return formatted data, got: {obs.last_execution_result[:100]}"
62
+
63
  env.close()
64
+ print(f" [PASS] {task_name}: environment lifecycle OK (SELECT data passthrough verified)")
65
+
66
+
67
+ def main():
68
+ print("=" * 60)
69
+ print("Testing Golden Migrations (all 7 tasks)")
70
+ print("=" * 60)
71
+
72
+ errors = []
73
+ for task_name in seeds.TASKS:
74
+ try:
75
+ test_golden_migration(task_name)
76
+ except Exception as e:
77
+ errors.append(f"Golden {task_name}: {e}")
78
+
79
+ print()
80
+ print("=" * 60)
81
+ print("Testing Environment Lifecycle (all 7 tasks)")
82
+ print("=" * 60)
83
+
84
+ for task_name in seeds.TASKS:
85
+ try:
86
+ test_environment_lifecycle(task_name)
87
+ except Exception as e:
88
+ errors.append(f"Lifecycle {task_name}: {e}")
89
+
90
+ print()
91
+ if errors:
92
+ print("=" * 60)
93
+ print(f"FAILURES ({len(errors)}):")
94
+ for e in errors:
95
+ print(f" ✗ {e}")
96
+ print("=" * 60)
97
+ sys.exit(1)
98
+ else:
99
+ print("=" * 60)
100
+ print("ALL 7 TASKS PASSED!")
101
+ print("=" * 60)
102
+
103
 
104
+ if __name__ == "__main__":
105
+ main()
 
test_smoke.py CHANGED
@@ -1,4 +1,4 @@
1
- """Smoke test for the SQL Migration Environment."""
2
  import sys
3
  import os
4
 
@@ -42,13 +42,29 @@ assert cursor.fetchone()[0] is None
42
  conn.close()
43
  print("PASS: Task 3 seeds - 5 employees, NULL salary")
44
 
45
- # Test 5: Grader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  from server.grader import StateReconciler
47
  conn = sqlite3.connect(":memory:")
 
48
  seed_task1(conn)
49
  reconciler = StateReconciler("column-restructure")
50
  score = reconciler.score(conn)
51
  print(f"PASS: Grader score for unmodified Task 1: {score:.2f}")
 
52
 
53
  # Simulate correct migration
54
  conn.execute("CREATE TABLE users_new (id INTEGER PRIMARY KEY, full_name TEXT NOT NULL)")
@@ -58,19 +74,42 @@ conn.execute("ALTER TABLE users_new RENAME TO users")
58
  conn.commit()
59
  score = reconciler.score(conn)
60
  print(f"PASS: Score after correct Task 1: {score:.2f}")
61
- assert score == 0.99, f"Expected 0.99, got {score}"
62
  conn.close()
63
 
64
- # Test 6: Full environment
65
  from server.environment import DbMigrationEnvironment
66
  env = DbMigrationEnvironment(task_name="column-restructure")
67
  obs = env.reset()
68
  assert obs.done == False
69
  assert obs.step_number == 0
70
- assert "users" in obs.current_schema_sql
71
  print(f"PASS: Environment reset. Step={obs.step_number}")
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # Run a complete correct migration
 
 
74
  steps = [
75
  "CREATE TABLE users_new (id INTEGER PRIMARY KEY, full_name TEXT NOT NULL)",
76
  "INSERT INTO users_new (id, full_name) SELECT id, first_name || ' ' || last_name FROM users",
@@ -79,20 +118,20 @@ steps = [
79
  ]
80
  for i, sql in enumerate(steps):
81
  is_final = (i == len(steps) - 1)
82
- action = MigrationAction(
83
- sql_command=sql,
84
- reasoning=f"Step {i+1}",
85
- submit_final=is_final,
86
- )
87
- obs = env.step(action)
88
- print(f" Step {i+1}: reward={obs.reward:.2f}, progress={obs.migration_progress:.2f}, done={obs.done}")
89
-
90
- assert obs.done == True
91
- assert obs.migration_progress == 0.99, f"Expected 0.99, got {obs.migration_progress}"
92
  env.close()
93
- print("PASS: Full migration episode completed with score 0.99")
94
 
95
- # Test 7: Task 2 grader
96
  conn = sqlite3.connect(":memory:")
97
  conn.execute("PRAGMA foreign_keys = ON")
98
  seed_task2(conn)
@@ -101,7 +140,7 @@ score_before = reconciler2.score(conn)
101
  print(f"PASS: Task 2 grader before migration: {score_before:.2f}")
102
  conn.close()
103
 
104
- # Test 8: Task 3 grader
105
  conn = sqlite3.connect(":memory:")
106
  conn.execute("PRAGMA foreign_keys = ON")
107
  seed_task3(conn)
@@ -110,6 +149,21 @@ score_before = reconciler3.score(conn)
110
  print(f"PASS: Task 3 grader before migration: {score_before:.2f}")
111
  conn.close()
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  print()
114
  print("=" * 50)
115
  print("ALL TESTS PASSED! Environment is fully working!")
 
1
+ """Smoke test for the SQL Migration Environment (updated for Golden DB grader)."""
2
  import sys
3
  import os
4
 
 
42
  conn.close()
43
  print("PASS: Task 3 seeds - 5 employees, NULL salary")
44
 
45
+ # Test 5: Golden migrations run without error
46
+ from seeds import golden_task1, golden_task2, golden_task3, golden_task4, golden_task5, golden_task6, golden_task7
47
+ for i, (seed_fn, golden_fn, name) in enumerate([
48
+ (seed_task1, golden_task1, "column-restructure"),
49
+ (seed_task2, golden_task2, "table-normalization"),
50
+ (seed_task3, golden_task3, "cascade-migration"),
51
+ ], 1):
52
+ conn = sqlite3.connect(":memory:")
53
+ conn.execute("PRAGMA foreign_keys = ON")
54
+ seed_fn(conn)
55
+ golden_fn(conn)
56
+ conn.close()
57
+ print(f"PASS: Golden migration {name} runs without error")
58
+
59
+ # Test 6: Grader with Golden DB
60
  from server.grader import StateReconciler
61
  conn = sqlite3.connect(":memory:")
62
+ conn.execute("PRAGMA foreign_keys = ON")
63
  seed_task1(conn)
64
  reconciler = StateReconciler("column-restructure")
65
  score = reconciler.score(conn)
66
  print(f"PASS: Grader score for unmodified Task 1: {score:.2f}")
67
+ assert score < 0.7, f"Expected moderate score before migration, got {score}"
68
 
69
  # Simulate correct migration
70
  conn.execute("CREATE TABLE users_new (id INTEGER PRIMARY KEY, full_name TEXT NOT NULL)")
 
74
  conn.commit()
75
  score = reconciler.score(conn)
76
  print(f"PASS: Score after correct Task 1: {score:.2f}")
77
+ assert score >= 0.89, f"Expected >= 0.89, got {score}"
78
  conn.close()
79
 
80
+ # Test 7: Full environment with SELECT passthrough
81
  from server.environment import DbMigrationEnvironment
82
  env = DbMigrationEnvironment(task_name="column-restructure")
83
  obs = env.reset()
84
  assert obs.done == False
85
  assert obs.step_number == 0
86
+ assert "users" in obs.current_schema_sql.lower()
87
  print(f"PASS: Environment reset. Step={obs.step_number}")
88
 
89
+ # Test SELECT returns actual data (A1 fix)
90
+ select_action = MigrationAction(
91
+ sql_command="SELECT * FROM users LIMIT 2",
92
+ reasoning="Inspecting data",
93
+ submit_final=False,
94
+ )
95
+ obs = env.step(select_action)
96
+ assert "O'Brien" in obs.last_execution_result, f"SELECT should return data, got: {obs.last_execution_result}"
97
+ print(f"PASS: SELECT returns actual data rows")
98
+
99
+ # Test dangerous SQL is blocked (A3 fix)
100
+ dangerous_action = MigrationAction(
101
+ sql_command="ATTACH DATABASE ':memory:' AS evil",
102
+ reasoning="Testing security",
103
+ submit_final=False,
104
+ )
105
+ obs = env.step(dangerous_action)
106
+ assert "not allowed" in obs.last_execution_result.lower() or "blocked" in obs.last_execution_result.lower(), \
107
+ f"ATTACH should be blocked, got: {obs.last_execution_result}"
108
+ print(f"PASS: Dangerous SQL is blocked")
109
+
110
  # Run a complete correct migration
111
+ env2 = DbMigrationEnvironment(task_name="column-restructure")
112
+ obs2 = env2.reset()
113
  steps = [
114
  "CREATE TABLE users_new (id INTEGER PRIMARY KEY, full_name TEXT NOT NULL)",
115
  "INSERT INTO users_new (id, full_name) SELECT id, first_name || ' ' || last_name FROM users",
 
118
  ]
119
  for i, sql in enumerate(steps):
120
  is_final = (i == len(steps) - 1)
121
+ action = MigrationAction(sql_command=sql, reasoning=f"Step {i+1}", submit_final=is_final)
122
+ obs2 = env2.step(action)
123
+ print(f" Step {i+1}: reward={obs2.reward:.2f}, progress={obs2.migration_progress:.2f}, done={obs2.done}")
124
+
125
+ assert obs2.done == True
126
+ assert obs2.migration_progress >= 0.89, f"Expected >= 0.89, got {obs2.migration_progress}"
127
+ # Check trajectory is included in final metadata
128
+ assert "trajectory" in obs2.metadata, "Trajectory should be in final metadata"
129
+ print(f"PASS: Full migration completed with score {obs2.migration_progress:.2f}")
130
+
131
  env.close()
132
+ env2.close()
133
 
134
+ # Test 8: Task 2 grader
135
  conn = sqlite3.connect(":memory:")
136
  conn.execute("PRAGMA foreign_keys = ON")
137
  seed_task2(conn)
 
140
  print(f"PASS: Task 2 grader before migration: {score_before:.2f}")
141
  conn.close()
142
 
143
+ # Test 9: Task 3 grader
144
  conn = sqlite3.connect(":memory:")
145
  conn.execute("PRAGMA foreign_keys = ON")
146
  seed_task3(conn)
 
149
  print(f"PASS: Task 3 grader before migration: {score_before:.2f}")
150
  conn.close()
151
 
152
+ # Test 10: Case insensitivity (A7)
153
+ conn = sqlite3.connect(":memory:")
154
+ conn.execute("PRAGMA foreign_keys = ON")
155
+ seed_task1(conn)
156
+ conn.execute("CREATE TABLE USERS_NEW (id INTEGER PRIMARY KEY, full_name TEXT NOT NULL)")
157
+ conn.execute("INSERT INTO USERS_NEW SELECT id, first_name || ' ' || last_name FROM users")
158
+ conn.execute("DROP TABLE users")
159
+ conn.execute("ALTER TABLE USERS_NEW RENAME TO USERS")
160
+ conn.commit()
161
+ reconciler_case = StateReconciler("column-restructure")
162
+ score_case = reconciler_case.score(conn)
163
+ print(f"PASS: Case-insensitive grading score: {score_case:.2f}")
164
+ assert score_case >= 0.79, f"Case-insensitive should score high, got {score_case}"
165
+ conn.close()
166
+
167
  print()
168
  print("=" * 50)
169
  print("ALL TESTS PASSED! Environment is fully working!")