sql_env / tests /test_verifier_integration.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Integration tests for type-aware answer verification in SQLEnvironment."""
import json
import sqlite3
import pytest
from sql_env.models import QuestionRecord, SQLAction
from sql_env.server.sql_environment import SQLEnvironment
from sql_env.server.test_sql_env import MockTokenizer
@pytest.fixture
def env(tmp_path):
db_id = "integration_db"
db_root = tmp_path / "databases"
db_dir = db_root / db_id
db_dir.mkdir(parents=True)
db_path = db_dir / f"{db_id}.sqlite"
connection = sqlite3.connect(db_path)
cursor = connection.cursor()
cursor.execute(
"CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, dept TEXT, salary REAL)"
)
cursor.execute("CREATE TABLE departments (name TEXT)")
cursor.executemany(
"INSERT INTO employees (id, name, dept, salary) VALUES (?, ?, ?, ?)",
[
(1, "Alice", "Engineering", 99.5),
(2, "Bob", "Engineering", 100.0),
(3, "Cara", "Sales", 100.5),
],
)
cursor.executemany(
"INSERT INTO departments (name) VALUES (?)",
[("Alice",), ("Bob",)],
)
connection.commit()
connection.close()
questions_path = tmp_path / "questions.json"
questions_path.write_text(
json.dumps(
[
{
"question": "Placeholder",
"db_id": db_id,
"query": "SELECT 1",
}
]
),
encoding="utf-8",
)
return SQLEnvironment(
questions_path=str(questions_path),
db_dir=str(db_root),
tokenizer=MockTokenizer(),
)
def _set_single_question(env: SQLEnvironment, *, sql: str, answer_type: str | None) -> None:
env.questions = [
QuestionRecord(
question_id="q-0",
question_text="Integration check",
database_name="integration_db",
gold_sql=sql,
gold_answer="",
answer_type=answer_type if answer_type is not None else "string",
difficulty="easy",
tables_involved=[],
)
]
if answer_type is None:
env.questions[0].answer_type = None
def test_integer_answer_flow(env):
_set_single_question(
env,
sql="SELECT COUNT(*) FROM employees",
answer_type="integer",
)
env.reset(seed=1)
observation = env.step(SQLAction(action_type="ANSWER", argument="3.0"))
assert observation.done is True
assert observation.reward == 1.0
def test_float_answer_flow(env):
_set_single_question(
env,
sql="SELECT AVG(salary) FROM employees",
answer_type="float",
)
env.reset(seed=1)
observation = env.step(SQLAction(action_type="ANSWER", argument="100.0"))
assert observation.done is True
assert observation.reward == 1.0
def test_string_answer_flow(env):
_set_single_question(
env,
sql="SELECT dept FROM employees WHERE id = 1",
answer_type="string",
)
env.reset(seed=1)
observation = env.step(SQLAction(action_type="ANSWER", argument=" engineering "))
assert observation.done is True
assert observation.reward == 1.0
def test_list_answer_flow(env):
_set_single_question(
env,
sql="SELECT name FROM departments ORDER BY name",
answer_type="list",
)
env.reset(seed=1)
observation = env.step(SQLAction(action_type="ANSWER", argument="Bob, Alice"))
assert observation.done is True
assert observation.reward == 1.0
def test_fallback_when_answer_type_missing(env):
_set_single_question(
env,
sql="SELECT dept FROM employees WHERE id = 1",
answer_type=None,
)
env.reset(seed=1)
observation = env.step(SQLAction(action_type="ANSWER", argument="engineering"))
assert observation.done is True
assert observation.reward == 1.0
def test_type_coercion_failure_returns_zero_reward(env):
_set_single_question(
env,
sql="SELECT COUNT(*) FROM employees",
answer_type="integer",
)
env.reset(seed=1)
observation = env.step(SQLAction(action_type="ANSWER", argument="not-a-number"))
assert observation.done is True
assert observation.reward == 0.0