Spaces:
Sleeping
Sleeping
| """Core environment logic for DataDetective.""" | |
| import random | |
| import uuid | |
| from typing import Any, Optional | |
| from openenv.core.env_server import Environment | |
| try: | |
| from ..models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState | |
| from .database import create_database, get_schema_info | |
| from .tasks import TASKS, grade_answer | |
| except (ImportError, ModuleNotFoundError): | |
| from models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState | |
| from server.database import create_database, get_schema_info | |
| from server.tasks import TASKS, grade_answer | |
| class DataDetectiveEnvironment( | |
| Environment[DataDetectiveAction, DataDetectiveObservation, DataDetectiveState] | |
| ): | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| MAX_STEPS = 30 | |
| def __init__(self): | |
| super().__init__() | |
| self._db = None | |
| self._task_id: str = "" | |
| self._step_count: int = 0 | |
| self._episode_id: str = "" | |
| self._queries_executed: int = 0 | |
| self._state = DataDetectiveState() | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| task_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> DataDetectiveObservation: | |
| if seed is not None: | |
| random.seed(seed) | |
| self._episode_id = episode_id or str(uuid.uuid4()) | |
| self._task_id = task_id if task_id in TASKS else random.choice(list(TASKS)) | |
| self._step_count = 0 | |
| self._queries_executed = 0 | |
| if self._db is not None: | |
| self._db.close() | |
| self._db = create_database() | |
| task = TASKS[self._task_id] | |
| schema = get_schema_info(self._db) | |
| self._state = DataDetectiveState( | |
| episode_id=self._episode_id, | |
| step_count=0, | |
| task_id=self._task_id, | |
| queries_executed=0, | |
| max_steps=self.MAX_STEPS, | |
| ) | |
| return DataDetectiveObservation( | |
| done=False, | |
| reward=None, | |
| output="Environment ready. Run SQL queries to investigate the issue, then submit your answer.", | |
| task_description=task["description"], | |
| schema_info=schema, | |
| step_number=0, | |
| max_steps=self.MAX_STEPS, | |
| message=f"Investigation: {task['title']} [{task['difficulty'].upper()}] -- {self.MAX_STEPS} steps available.", | |
| ) | |
| def step( | |
| self, | |
| action: DataDetectiveAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> DataDetectiveObservation: | |
| self._step_count += 1 | |
| self._state.step_count = self._step_count | |
| remaining = self.MAX_STEPS - self._step_count | |
| if self._step_count > self.MAX_STEPS: | |
| return self._obs( | |
| done=True, reward=0.0, | |
| output="Maximum steps reached -- investigation ended with no answer submitted.", | |
| message="Out of steps.", | |
| ) | |
| atype = (action.action_type or "").strip().lower() | |
| if atype == "query": | |
| return self._handle_query(action.content, remaining) | |
| elif atype == "answer": | |
| return self._handle_answer(action.content) | |
| else: | |
| return self._obs( | |
| done=False, reward=0.0, | |
| output="", | |
| message=f"Unknown action_type '{action.action_type}'. Use 'query' or 'answer'. ({remaining} steps left)", | |
| ) | |
| def state(self) -> DataDetectiveState: | |
| return self._state | |
| def close(self) -> None: | |
| if self._db is not None: | |
| self._db.close() | |
| self._db = None | |
| def _obs(self, *, done: bool, reward: float | None, output: str, message: str) -> DataDetectiveObservation: | |
| return DataDetectiveObservation( | |
| done=done, | |
| reward=reward, | |
| output=output, | |
| task_description=TASKS[self._task_id]["description"], | |
| schema_info="", | |
| step_number=self._step_count, | |
| max_steps=self.MAX_STEPS, | |
| message=message, | |
| ) | |
| def _handle_query(self, sql: str, remaining: int) -> DataDetectiveObservation: | |
| self._queries_executed += 1 | |
| self._state.queries_executed = self._queries_executed | |
| if not sql or not sql.strip(): | |
| return self._obs( | |
| done=False, reward=0.0, | |
| output="Empty query -- please provide a valid SQL statement.", | |
| message=f"{remaining} steps left.", | |
| ) | |
| try: | |
| cur = self._db.cursor() | |
| cur.execute(sql) | |
| columns = [d[0] for d in cur.description] if cur.description else [] | |
| rows = cur.fetchall() | |
| output = _format_table(columns, rows) if rows else "Query returned 0 rows." | |
| except Exception as exc: | |
| output = f"SQL Error: {exc}" | |
| return self._obs( | |
| done=False, reward=0.0, | |
| output=output, | |
| message=f"Query failed. Fix your SQL and retry. ({remaining} steps left)", | |
| ) | |
| return self._obs( | |
| done=False, reward=0.0, | |
| output=output, | |
| message=f"{len(rows)} row(s) returned. ({remaining} steps left)", | |
| ) | |
| def _handle_answer(self, answer_text: str) -> DataDetectiveObservation: | |
| reward = grade_answer(self._task_id, answer_text) | |
| if reward >= 0.8: | |
| verdict = "Excellent investigation!" | |
| elif reward >= 0.5: | |
| verdict = "Good findings, but some details missing." | |
| else: | |
| verdict = "Several key findings were missed." | |
| return self._obs( | |
| done=True, | |
| reward=reward, | |
| output=f"Score: {reward:.2f} / 1.00 -- {verdict}", | |
| message=f"Investigation complete. Final score: {reward:.2f}", | |
| ) | |
| def _format_table(columns: list[str], rows: list, max_rows: int = 100) -> str: | |
| truncated = len(rows) > max_rows | |
| display = rows[:max_rows] | |
| widths = [len(str(c)) for c in columns] | |
| for row in display: | |
| for i, v in enumerate(row): | |
| widths[i] = max(widths[i], min(len(str(v)), 60)) | |
| header = " | ".join(str(c).ljust(widths[i]) for i, c in enumerate(columns)) | |
| sep = "-+-".join("-" * w for w in widths) | |
| lines = [header, sep] | |
| for row in display: | |
| lines.append(" | ".join(str(v).ljust(widths[i])[:60] for i, v in enumerate(row))) | |
| if truncated: | |
| lines.append(f"... (showing {max_rows} of {len(rows)} rows)") | |
| return "\n".join(lines) | |