sql-debug-env / tests /test_env.py
md896's picture
Initial OpenEnv SQL debug environment
30cf758
import asyncio
import unittest
from server.env import SQLDebugEnv
from server.models import SQLDebugAction, ActionType
class TestEnv(unittest.TestCase):
def test_reset_and_inspect_schema(self):
async def run():
env = SQLDebugEnv(task_id="easy_syntax_fix")
obs, info = await env.reset()
self.assertFalse(obs.is_done)
action = SQLDebugAction(action_type=ActionType.INSPECT_SCHEMA)
obs2, reward, done, info2 = await env.step(action)
self.assertFalse(done)
self.assertIsNotNone(obs2.schema_info)
self.assertGreaterEqual(reward, 0.0)
asyncio.run(run())
def test_submit_broken_query_does_not_finish(self):
async def run():
env = SQLDebugEnv(task_id="easy_syntax_fix")
obs, _ = await env.reset()
action = SQLDebugAction(
action_type=ActionType.SUBMIT_QUERY,
query=env.task.broken_query,
)
obs2, reward, done, _ = await env.step(action)
self.assertFalse(done)
self.assertLessEqual(reward, 0.2)
self.assertGreaterEqual(reward, -1.0)
self.assertEqual(obs2.current_query, env.task.broken_query)
asyncio.run(run())
if __name__ == "__main__":
unittest.main()