|
|
|
|
|
|
|
|
|
|
|
|
| """SQL Sandbox Environment Client."""
|
|
|
| from typing import Dict
|
|
|
| from openenv.core import EnvClient
|
| from openenv.core.client_types import StepResult
|
| from openenv.core.env_server.types import State
|
|
|
| from models import SqlSandboxAction, SqlSandboxObservation
|
|
|
|
|
| class SqlSandboxEnv(EnvClient[SqlSandboxAction, SqlSandboxObservation, State]):
|
| """
|
| Client for the SQL/Data Cleaning Sandbox.
|
|
|
| Example:
|
| >>> with SqlSandboxEnv(base_url="http://localhost:8000") as client:
|
| ... result = client.reset()
|
| ... print(result.observation.task_description)
|
| ... result = client.step(SqlSandboxAction(tool="sql", command="SELECT * FROM sales"))
|
| ... print(result.observation.output)
|
| """
|
|
|
| def _step_payload(self, action: SqlSandboxAction) -> Dict:
|
| return {"tool": action.tool, "command": action.command}
|
|
|
| def _parse_result(self, payload: Dict) -> StepResult[SqlSandboxObservation]:
|
| obs_data = payload.get("observation", {})
|
| observation = SqlSandboxObservation(
|
| output=obs_data.get("output", ""),
|
| error=obs_data.get("error"),
|
| current_step=obs_data.get("current_step", 0),
|
| max_steps=obs_data.get("max_steps", 20),
|
| task_description=obs_data.get("task_description", ""),
|
| done=payload.get("done", False),
|
| reward=payload.get("reward"),
|
| metadata=obs_data.get("metadata", {}),
|
| )
|
| return StepResult(
|
| observation=observation,
|
| reward=payload.get("reward"),
|
| done=payload.get("done", False),
|
| )
|
|
|
| def _parse_state(self, payload: Dict) -> State:
|
| return State(
|
| episode_id=payload.get("episode_id"),
|
| step_count=payload.get("step_count", 0),
|
| )
|
|
|