explainer-env / server /explainer_env_environment.py
kgdrathan's picture
Upload folder using huggingface_hub
5869d56 verified
"""
Research -> Interactive Explainer Environment (multi-step, async).
Episode flow:
1. reset() → agent gets a topic + tier
2. step(explore) × 0..MAX_EXPLORE → agent calls research tools
3. step(generate) × 1 → agent produces marimo/manim code
4. step(repair) × 0..MAX_REPAIR → agent fixes lint/build errors if needed
Each step returns a per-step reward. The final generate step also includes
a generation reward that accounts for how well the code uses the research.
The environment supports async via reset_async() / step_async() overrides.
OpenEnv's HTTP server detects these and calls them directly (no thread pool).
"""
import random
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS, clamp_action_reward
from ..models import ExplainerAction, ExplainerObservation
from ..research import AVAILABLE_TOOLS, run_research_tool
from ..rewards.exploration import compute_explore_reward
from ..rewards.generation import adjust_repair_reward, compute_generate_reward
from ..rewards.sandbox import validate_code
from ..task_bank import ALL_TASKS, EASY_TASKS, HARD_TASKS, MEDIUM_TASKS, Task
except ImportError:
from constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS, clamp_action_reward
from models import ExplainerAction, ExplainerObservation
from research import AVAILABLE_TOOLS, run_research_tool
from rewards.exploration import compute_explore_reward
from rewards.generation import adjust_repair_reward, compute_generate_reward
from rewards.sandbox import validate_code
from task_bank import ALL_TASKS, EASY_TASKS, HARD_TASKS, MEDIUM_TASKS, Task
MB002_REPAIR_HINT = (
"MB002 repair checklist: Marimo treats every non-underscore assignment as a "
"global notebook variable, including `for` loop variables. Audit the whole "
"file and rename cell-local names to private names everywhere: `arr` -> "
"`_arr`, `target` -> `_target`, `i` -> `_i`, `t` -> `_t`, `freqs` -> "
"`_freqs`, `fig` -> `_fig`, `ax` -> `_ax`. Public names should only be used "
"for values intentionally passed to later cells, and each public name may be "
"defined once globally."
)
def _render_errors_with_hints(errors: str, error_codes: list[str]) -> str:
if "MB002" not in error_codes:
return errors
return f"{errors}\n\n{MB002_REPAIR_HINT}"
class ExplainerEnvironment(Environment):
"""
Multi-step Research → Interactive Explainer environment.
Phase 1 (explore): agent issues search queries, receives papers/wiki sections.
Phase 2 (generate): agent produces marimo/manim code using the research.
Supports async via reset_async() / step_async() — OpenEnv's server detects
the overrides and awaits them directly instead of using a thread pool.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
super().__init__()
self._state = State(episode_id=str(uuid4()), step_count=0)
self._current_task: Task | None = None
self._difficulty_pool: list[Task] = EASY_TASKS
self._accumulated_context: list[str] = []
self._explore_actions: list[str] = []
self._used_tools: set[str] = set()
self._explore_steps: int = 0
self._repair_steps: int = 0
self._phase: str = "explore"
self._done: bool = False
self._last_code: str = ""
self._last_format: str = "marimo"
self._last_narration: str = ""
self._last_errors: str = ""
self._last_error_codes: list[str] = []
# ------------------------------------------------------------------
# Sync interface (fallback — OpenEnv prefers async when overridden)
# ------------------------------------------------------------------
def reset(self, seed=None, episode_id=None, **kwargs) -> ExplainerObservation:
"""Sample a task and return the initial observation (sync)."""
return self._do_reset(seed=seed, episode_id=episode_id, **kwargs)
def step(self, action: ExplainerAction, timeout_s=None, **kwargs) -> ExplainerObservation:
"""Route to explore or generate handler (sync — explore uses blocking fallback)."""
import asyncio
self._state.step_count += 1
task = self._current_task
if task is None:
return ExplainerObservation(
feedback="Error: no task set. Call reset() first.",
done=True,
reward=-1.0,
)
if self._done:
return self._make_obs(
task,
phase="done",
feedback="Episode is already done. Call reset() to start a new one.",
reward=0.0,
done=True,
)
try:
if action.action_type == "explore":
# Run async explore in a new event loop for sync callers
return asyncio.run(self._handle_explore(action, task))
elif action.action_type == "generate":
return self._handle_generate(action, task)
elif action.action_type == "repair":
return self._handle_repair(action, task)
else:
return self._make_obs(
task,
phase="explore",
feedback=f"Unknown action_type: {action.action_type}",
reward=0.0,
done=True,
)
except Exception as e:
return self._make_obs(
task,
phase="done",
feedback=f"Environment error: {e}",
reward=0.0,
done=True,
)
# ------------------------------------------------------------------
# Async interface (preferred — OpenEnv detects these overrides)
# ------------------------------------------------------------------
async def reset_async(self, seed=None, episode_id=None, **kwargs) -> ExplainerObservation:
"""Sample a task and return the initial observation (async)."""
return self._do_reset(seed=seed, episode_id=episode_id, **kwargs)
async def step_async(self, action: ExplainerAction, timeout_s=None, **kwargs) -> ExplainerObservation:
"""Route to explore or generate handler (async)."""
self._state.step_count += 1
task = self._current_task
if task is None:
return ExplainerObservation(
feedback="Error: no task set. Call reset() first.",
done=True,
reward=-1.0,
)
if self._done:
return self._make_obs(
task,
phase="done",
feedback="Episode is already done. Call reset() to start a new one.",
reward=0.0,
done=True,
)
try:
if action.action_type == "explore":
return await self._handle_explore(action, task)
elif action.action_type == "generate":
return self._handle_generate(action, task)
elif action.action_type == "repair":
return self._handle_repair(action, task)
else:
return self._make_obs(
task,
phase="explore",
feedback=f"Unknown action_type: {action.action_type}",
reward=0.0,
done=True,
)
except Exception as e:
return self._make_obs(
task,
phase="done",
feedback=f"Environment error: {e}",
reward=0.0,
done=True,
)
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
def _do_reset(self, seed=None, episode_id=None, **kwargs) -> ExplainerObservation:
"""Shared reset logic (no I/O, so sync is fine)."""
self._state = State(
episode_id=episode_id or str(uuid4()), step_count=0
)
self._accumulated_context = []
self._explore_actions = []
self._used_tools = set()
self._explore_steps = 0
self._repair_steps = 0
self._phase = "explore"
self._done = False
self._last_code = ""
self._last_format = "marimo"
self._last_narration = ""
self._last_errors = ""
self._last_error_codes = []
# Allow selecting a specific task by topic name
topic = kwargs.get("topic", None)
if topic:
match = next((t for t in ALL_TASKS if t.topic == topic), None)
if match:
self._current_task = match
else:
# Fallback to random if topic not found
rng = random.Random(seed) if seed is not None else random.Random()
self._current_task = rng.choice(ALL_TASKS)
else:
difficulty = kwargs.get("difficulty", None)
if difficulty == "medium":
pool = MEDIUM_TASKS
elif difficulty == "hard":
pool = HARD_TASKS
elif difficulty == "easy":
pool = EASY_TASKS
else:
pool = self._difficulty_pool
rng = random.Random(seed) if seed is not None else random.Random()
self._current_task = rng.choice(pool) if pool else rng.choice(ALL_TASKS)
t = self._current_task
return ExplainerObservation(
topic=t.topic,
content=t.content,
tier=t.tier,
keywords=t.keywords,
data_available=t.data_available,
difficulty=t.difficulty,
phase="explore",
feedback=(
"Research phase: choose a tool and query relevant to the topic. "
f"Available tools: {', '.join(AVAILABLE_TOOLS)}."
),
search_results="",
explored_context="",
explore_steps_left=MAX_EXPLORE_STEPS,
repair_attempts_left=MAX_REPAIR_STEPS,
available_tools=list(AVAILABLE_TOOLS),
done=False,
reward=0.0,
)
async def _handle_explore(self, action: ExplainerAction, task: Task) -> ExplainerObservation:
"""Process an explore action: call a research tool and score the result."""
if self._phase not in {"explore", "generate"}:
return self._make_obs(
task,
phase=self._phase,
feedback=f"Cannot explore during phase '{self._phase}'.",
reward=0.0,
)
if self._explore_steps >= MAX_EXPLORE_STEPS:
self._phase = "generate"
return self._make_obs(
task,
phase="generate",
feedback="Max explore steps reached. You must now generate.",
reward=0.0,
)
self._explore_steps += 1
query = action.query.strip()
intent = action.intent.strip()
tool = action.tool or "search_wikipedia"
if not query:
return self._make_obs(
task,
phase="explore",
feedback="Empty query. Provide a search query.",
reward=0.0,
)
previous_context = list(self._accumulated_context)
previous_actions = list(self._explore_actions)
used_tools = set(self._used_tools)
result = await run_research_tool(tool, query, intent)
results_text = result.render()
self._explore_actions.append(_explore_action_text(tool, query, intent))
if result.ok:
self._accumulated_context.append(result.text)
self._used_tools.add(tool)
# Compute per-step exploration reward
reward, components = compute_explore_reward(
query=query,
tool=tool,
intent=intent,
result=result,
topic=task.topic,
keywords_csv=task.keywords,
task_content=task.content,
difficulty=task.difficulty,
previous_context=previous_context,
accumulated_context=self._accumulated_context,
used_tools=used_tools,
previous_actions=previous_actions,
)
steps_left = MAX_EXPLORE_STEPS - self._explore_steps
if steps_left > 1:
phase = "explore"
hint = f"Research going well — {steps_left} more steps available. Keep searching or move to generation."
elif steps_left == 1:
phase = "explore"
hint = "Last research step available. Search for any missing context, or proceed to generate."
else:
phase = "generate"
hint = "Research phase complete. Time to generate your explanation."
self._phase = phase
top_chunks = _top_chunks_payload(result.chunks)
return self._make_obs(
task,
phase=phase,
feedback=f"{hint}\nTool: {tool}\nReward: {components}",
search_results=results_text,
top_chunks=top_chunks,
reward=reward,
metadata={
"step": self._state.step_count,
"phase": "explore",
"tool": tool,
"source_count": len(result.chunks),
"top_chunks": top_chunks,
"error": result.error,
**components,
},
)
def _handle_generate(self, action: ExplainerAction, task: Task) -> ExplainerObservation:
"""Process a generate action: run sandbox, maybe open repair phase."""
if self._phase not in {"explore", "generate"}:
return self._make_obs(
task,
phase=self._phase,
feedback=f"Cannot generate during phase '{self._phase}'.",
reward=0.0,
)
fmt = action.format or "marimo"
code = action.code
narration = action.narration
# Penalise generating without any exploration
if self._explore_steps == 0:
skip_penalty = -0.1
penalty_msg = "Warning: generating without any research. -0.1 penalty."
else:
skip_penalty = 0.0
penalty_msg = ""
sandbox = validate_code(fmt, code)
# Generation reward
reward, components = compute_generate_reward(
code=code,
fmt=fmt,
narration=narration,
task=task,
exec_success=sandbox.exec_success,
accumulated_context=self._accumulated_context,
static_check_passed=sandbox.check_passed,
error_codes=sandbox.error_codes,
)
reward = clamp_action_reward(reward + skip_penalty)
components["generate_total"] = round(reward, 4)
self._last_code = code
self._last_format = fmt
self._last_narration = narration
rendered_errors = _render_errors_with_hints(sandbox.render_errors(), sandbox.error_codes)
self._last_errors = rendered_errors
self._last_error_codes = sandbox.error_codes
# Feedback
parts = []
if penalty_msg:
parts.append(penalty_msg)
if not sandbox.parses:
parts.append("SYNTAX ERROR: code does not parse.")
elif not sandbox.exec_success:
parts.append(f"EXECUTION FAILED: {rendered_errors}")
else:
parts.append(f"EXECUTION OK: {sandbox.message}")
parts.append(
f"Reward: {', '.join(f'{k}={v}' for k, v in components.items())}"
)
done = sandbox.exec_success or self._repair_steps >= MAX_REPAIR_STEPS
phase = "done" if done else "repair"
self._phase = phase
self._done = done
if not done:
parts.append(
f"Repair phase: {MAX_REPAIR_STEPS} attempts available. "
"Submit a revised artifact using the error feedback."
)
return self._make_obs(
task,
phase=phase,
feedback="\n".join(parts),
reward=reward,
done=done,
last_errors="" if sandbox.exec_success else rendered_errors,
metadata={
"step": self._state.step_count,
"phase": "generate",
"explore_steps_used": self._explore_steps,
"sandbox_message": sandbox.message,
"error_codes": sandbox.error_codes,
**components,
},
)
def _handle_repair(self, action: ExplainerAction, task: Task) -> ExplainerObservation:
"""Process one repair attempt after a failed generate action."""
if self._phase != "repair":
return self._make_obs(
task,
phase=self._phase,
feedback="Repair is only available after a failed generate step.",
reward=0.0,
done=self._done,
)
if self._repair_steps >= MAX_REPAIR_STEPS:
self._phase = "done"
self._done = True
return self._make_obs(
task,
phase="done",
feedback="No repair attempts left.",
reward=0.0,
done=True,
)
self._repair_steps += 1
fmt = action.format or self._last_format or "marimo"
code = action.code
narration = action.narration or self._last_narration
previous_code = self._last_code
previous_errors = list(self._last_error_codes)
sandbox = validate_code(fmt, code)
base_reward, components = compute_generate_reward(
code=code,
fmt=fmt,
narration=narration,
task=task,
exec_success=sandbox.exec_success,
accumulated_context=self._accumulated_context,
static_check_passed=sandbox.check_passed,
error_codes=sandbox.error_codes,
)
repair_reward, repair_components = adjust_repair_reward(
base_reward,
repair_success=sandbox.exec_success,
previous_error_codes=previous_errors,
new_error_codes=sandbox.error_codes,
previous_code=previous_code,
repaired_code=code,
)
components.update(repair_components)
self._last_code = code
self._last_format = fmt
self._last_narration = narration
rendered_errors = _render_errors_with_hints(sandbox.render_errors(), sandbox.error_codes)
self._last_errors = rendered_errors
self._last_error_codes = sandbox.error_codes
attempts_left = MAX_REPAIR_STEPS - self._repair_steps
done = sandbox.exec_success or attempts_left <= 0
phase = "done" if done else "repair"
self._phase = phase
self._done = done
status = "REPAIR OK" if sandbox.exec_success else "REPAIR FAILED"
feedback_parts = [
f"{status}: {sandbox.message if sandbox.exec_success else rendered_errors}",
f"Reward: {', '.join(f'{k}={v}' for k, v in components.items())}",
]
if not done:
feedback_parts.append(
f"Repair phase continues: {attempts_left} repair attempts left. "
"Submit another corrected artifact using the latest error feedback."
)
feedback = "\n".join(feedback_parts)
return self._make_obs(
task,
phase=phase,
feedback=feedback,
reward=repair_reward,
done=done,
last_errors="" if sandbox.exec_success else rendered_errors,
metadata={
"step": self._state.step_count,
"phase": "repair",
"explore_steps_used": self._explore_steps,
"repair_steps_used": self._repair_steps,
"sandbox_message": sandbox.message,
"error_codes": sandbox.error_codes,
**components,
},
)
def _make_obs(
self,
task: Task,
*,
phase: str,
feedback: str,
reward: float = 0.0,
done: bool = False,
search_results: str = "",
top_chunks: list[dict] | None = None,
last_errors: str | None = None,
metadata: dict | None = None,
) -> ExplainerObservation:
"""Helper to build a consistent observation."""
return ExplainerObservation(
topic=task.topic,
content=task.content,
tier=task.tier,
keywords=task.keywords,
data_available=task.data_available,
difficulty=task.difficulty,
phase=phase,
feedback=feedback,
search_results=search_results,
top_chunks=top_chunks or [],
explored_context="\n---\n".join(self._accumulated_context),
explore_steps_left=MAX_EXPLORE_STEPS - self._explore_steps,
repair_attempts_left=MAX_REPAIR_STEPS - self._repair_steps,
last_errors=self._last_errors if last_errors is None else last_errors,
available_tools=list(AVAILABLE_TOOLS),
done=done,
reward=reward,
metadata=metadata or {},
)
@property
def state(self) -> State:
return self._state
def _explore_action_text(tool: str, query: str, intent: str) -> str:
return f"{tool} {query.strip()} {intent.strip()}".strip()
def _top_chunks_payload(chunks) -> list[dict]:
return [
{
"rank": chunk.rank,
"source": chunk.source,
"title": chunk.title,
"url": chunk.url,
"score": round(chunk.score, 4),
"snippet": chunk.text,
}
for chunk in chunks[:5]
]