DataDetective / server /environment.py
Viani's picture
Deploy DataDetective: 9-task business investigation environment
bcd8636 verified
"""Core environment logic for DataDetective."""
import random
import uuid
from typing import Any, Optional
from openenv.core.env_server import Environment
try:
from ..models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState
from .database import create_database, get_schema_info
from .tasks import TASKS, grade_answer
except (ImportError, ModuleNotFoundError):
from models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState
from server.database import create_database, get_schema_info
from server.tasks import TASKS, grade_answer
class DataDetectiveEnvironment(
Environment[DataDetectiveAction, DataDetectiveObservation, DataDetectiveState]
):
SUPPORTS_CONCURRENT_SESSIONS = True
MAX_STEPS = 30
def __init__(self):
super().__init__()
self._db = None
self._task_id: str = ""
self._step_count: int = 0
self._episode_id: str = ""
self._queries_executed: int = 0
self._state = DataDetectiveState()
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
task_id: Optional[str] = None,
**kwargs: Any,
) -> DataDetectiveObservation:
if seed is not None:
random.seed(seed)
self._episode_id = episode_id or str(uuid.uuid4())
self._task_id = task_id if task_id in TASKS else random.choice(list(TASKS))
self._step_count = 0
self._queries_executed = 0
if self._db is not None:
self._db.close()
self._db = create_database()
task = TASKS[self._task_id]
schema = get_schema_info(self._db)
self._state = DataDetectiveState(
episode_id=self._episode_id,
step_count=0,
task_id=self._task_id,
queries_executed=0,
max_steps=self.MAX_STEPS,
)
return DataDetectiveObservation(
done=False,
reward=None,
output="Environment ready. Run SQL queries to investigate the issue, then submit your answer.",
task_description=task["description"],
schema_info=schema,
step_number=0,
max_steps=self.MAX_STEPS,
message=f"Investigation: {task['title']} [{task['difficulty'].upper()}] -- {self.MAX_STEPS} steps available.",
)
def step(
self,
action: DataDetectiveAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> DataDetectiveObservation:
self._step_count += 1
self._state.step_count = self._step_count
remaining = self.MAX_STEPS - self._step_count
if self._step_count > self.MAX_STEPS:
return self._obs(
done=True, reward=0.0,
output="Maximum steps reached -- investigation ended with no answer submitted.",
message="Out of steps.",
)
atype = (action.action_type or "").strip().lower()
if atype == "query":
return self._handle_query(action.content, remaining)
elif atype == "answer":
return self._handle_answer(action.content)
else:
return self._obs(
done=False, reward=0.0,
output="",
message=f"Unknown action_type '{action.action_type}'. Use 'query' or 'answer'. ({remaining} steps left)",
)
@property
def state(self) -> DataDetectiveState:
return self._state
def close(self) -> None:
if self._db is not None:
self._db.close()
self._db = None
def _obs(self, *, done: bool, reward: float | None, output: str, message: str) -> DataDetectiveObservation:
return DataDetectiveObservation(
done=done,
reward=reward,
output=output,
task_description=TASKS[self._task_id]["description"],
schema_info="",
step_number=self._step_count,
max_steps=self.MAX_STEPS,
message=message,
)
def _handle_query(self, sql: str, remaining: int) -> DataDetectiveObservation:
self._queries_executed += 1
self._state.queries_executed = self._queries_executed
if not sql or not sql.strip():
return self._obs(
done=False, reward=0.0,
output="Empty query -- please provide a valid SQL statement.",
message=f"{remaining} steps left.",
)
try:
cur = self._db.cursor()
cur.execute(sql)
columns = [d[0] for d in cur.description] if cur.description else []
rows = cur.fetchall()
output = _format_table(columns, rows) if rows else "Query returned 0 rows."
except Exception as exc:
output = f"SQL Error: {exc}"
return self._obs(
done=False, reward=0.0,
output=output,
message=f"Query failed. Fix your SQL and retry. ({remaining} steps left)",
)
return self._obs(
done=False, reward=0.0,
output=output,
message=f"{len(rows)} row(s) returned. ({remaining} steps left)",
)
def _handle_answer(self, answer_text: str) -> DataDetectiveObservation:
reward = grade_answer(self._task_id, answer_text)
if reward >= 0.8:
verdict = "Excellent investigation!"
elif reward >= 0.5:
verdict = "Good findings, but some details missing."
else:
verdict = "Several key findings were missed."
return self._obs(
done=True,
reward=reward,
output=f"Score: {reward:.2f} / 1.00 -- {verdict}",
message=f"Investigation complete. Final score: {reward:.2f}",
)
def _format_table(columns: list[str], rows: list, max_rows: int = 100) -> str:
truncated = len(rows) > max_rows
display = rows[:max_rows]
widths = [len(str(c)) for c in columns]
for row in display:
for i, v in enumerate(row):
widths[i] = max(widths[i], min(len(str(v)), 60))
header = " | ".join(str(c).ljust(widths[i]) for i, c in enumerate(columns))
sep = "-+-".join("-" * w for w in widths)
lines = [header, sep]
for row in display:
lines.append(" | ".join(str(v).ljust(widths[i])[:60] for i, v in enumerate(row)))
if truncated:
lines.append(f"... (showing {max_rows} of {len(rows)} rows)")
return "\n".join(lines)