| """Tests for SQL Arena environment."""
|
|
|
| import pytest
|
| from src.sql_arena.environment import SQLArenaEnvironment
|
| from src.sql_arena.models import SQLArenaAction
|
| from src.sql_arena.tasks import list_tasks, TASK_BY_ID
|
|
|
|
|
| class TestEnvironmentBasics:
|
|
|
| def setup_method(self):
|
| self.env = SQLArenaEnvironment()
|
|
|
| def teardown_method(self):
|
| self.env.close()
|
|
|
| def test_reset_returns_observation(self):
|
| result = self.env.reset(difficulty="basic_select", task_id="easy_001")
|
| assert result.observation is not None
|
| assert result.reward == 0.0
|
| assert result.done is False
|
|
|
| def test_step_with_correct_query(self):
|
| self.env.reset(difficulty="basic_select", task_id="easy_001")
|
| task = self.env.current_task
|
| action = SQLArenaAction(sql_query=task.expected_sql)
|
| result = self.env.step(action)
|
| assert result.reward > 0.0
|
| assert result.info.get("score", 0) >= 0.8
|
|
|
| def test_step_with_invalid_query(self):
|
| self.env.reset(difficulty="basic_select", task_id="easy_001")
|
| action = SQLArenaAction(sql_query="INVALID SQL QUERY")
|
| result = self.env.step(action)
|
| assert result.reward == 0.01
|
| assert result.observation.error_message is not None
|
|
|
| def test_state_tracking(self):
|
| self.env.reset(difficulty="basic_select", task_id="easy_001")
|
| state = self.env.state()
|
| assert state.current_step == 0
|
|
|
| self.env.step(SQLArenaAction(sql_query="SELECT 1"))
|
| state = self.env.state()
|
| assert state.current_step == 1
|
|
|
| def test_episode_terminates(self):
|
| self.env.reset(difficulty="basic_select", task_id="easy_001")
|
| task = self.env.current_task
|
| for _ in range(task.max_steps + 1):
|
| if self.env.state().done:
|
| break
|
| self.env.step(SQLArenaAction(sql_query="SELECT 1"))
|
| assert self.env.state().done is True
|
|
|
|
|
| class TestAllDifficulties:
|
|
|
| def setup_method(self):
|
| self.env = SQLArenaEnvironment()
|
|
|
| def teardown_method(self):
|
| self.env.close()
|
|
|
| def test_easy(self):
|
| result = self.env.reset(difficulty="basic_select")
|
| assert result.observation.difficulty == "basic_select"
|
|
|
| def test_medium(self):
|
| result = self.env.reset(difficulty="join_aggregate")
|
| assert result.observation.difficulty == "join_aggregate"
|
|
|
| def test_hard(self):
|
| result = self.env.reset(difficulty="complex_analysis")
|
| assert result.observation.difficulty == "complex_analysis"
|
|
|
|
|
| class TestGrading:
|
|
|
| def setup_method(self):
|
| self.env = SQLArenaEnvironment()
|
|
|
| def teardown_method(self):
|
| self.env.close()
|
|
|
| def test_scores_in_range(self):
|
| for task_id, task in TASK_BY_ID.items():
|
| self.env.reset(difficulty=task.difficulty, task_id=task_id)
|
| action = SQLArenaAction(sql_query=task.expected_sql)
|
| result = self.env.step(action)
|
| assert 0.0 <= result.reward <= 1.0
|
| assert 0.0 <= result.info.get("score", 0) <= 1.0
|
| self.env.reset(difficulty=task.difficulty, task_id=task_id)
|
|
|
| def test_varying_scores(self):
|
| scores = set()
|
| queries = [
|
| "SELECT name, salary FROM employees WHERE is_active = 1 AND salary > 80000 ORDER BY salary DESC",
|
| "SELECT * FROM employees",
|
| "INVALID",
|
| "SELECT name FROM employees",
|
| ]
|
| for q in queries:
|
| self.env.reset(difficulty="basic_select", task_id="easy_001")
|
| result = self.env.step(SQLArenaAction(sql_query=q))
|
| scores.add(round(result.info.get("score", 0), 2))
|
| assert len(scores) > 1, "Grader always returns the same score!"
|
|
|
|
|
| class TestTaskRegistry:
|
|
|
| def test_list_tasks(self):
|
| tasks = list_tasks()
|
| assert "basic_select" in tasks
|
| assert "join_aggregate" in tasks
|
| assert "complex_analysis" in tasks
|
|
|
| def test_minimum_3_tasks(self):
|
| tasks = list_tasks()
|
| for difficulty, task_ids in tasks.items():
|
| assert len(task_ids) >= 3, f"{difficulty} has fewer than 3 tasks"
|
|
|
|
|
| if __name__ == "__main__":
|
| pytest.main([__file__, "-v"]) |