Spaces:
Running
Running
| import asyncio | |
| import unittest | |
| from server.env import SQLDebugEnv | |
| from server.models import SQLDebugAction, ActionType | |
| class TestEnv(unittest.TestCase): | |
| def test_reset_and_inspect_schema(self): | |
| async def run(): | |
| env = SQLDebugEnv(task_id="easy_syntax_fix") | |
| obs, info = await env.reset() | |
| self.assertFalse(obs.is_done) | |
| action = SQLDebugAction(action_type=ActionType.INSPECT_SCHEMA) | |
| obs2, reward, done, info2 = await env.step(action) | |
| self.assertFalse(done) | |
| self.assertIsNotNone(obs2.schema_info) | |
| self.assertGreaterEqual(reward, 0.0) | |
| asyncio.run(run()) | |
| def test_submit_broken_query_does_not_finish(self): | |
| async def run(): | |
| env = SQLDebugEnv(task_id="easy_syntax_fix") | |
| obs, _ = await env.reset() | |
| action = SQLDebugAction( | |
| action_type=ActionType.SUBMIT_QUERY, | |
| query=env.task.broken_query, | |
| ) | |
| obs2, reward, done, _ = await env.step(action) | |
| self.assertFalse(done) | |
| self.assertLessEqual(reward, 0.2) | |
| self.assertGreaterEqual(reward, -1.0) | |
| self.assertEqual(obs2.current_query, env.task.broken_query) | |
| asyncio.run(run()) | |
| if __name__ == "__main__": | |
| unittest.main() | |